diff --git a/docker_commands.sh b/docker_commands.sh index 10dba786e412adfba86712057f6b1e2d1303f6ac..1a73485cb06ea04f0c4c67f8f28ebe653e1c7025 100644 --- a/docker_commands.sh +++ b/docker_commands.sh @@ -1,3 +1,7 @@ +cd spanfinder/ +python -m sociolome.lome_webserver & +cd .. + rm -rfv /.cache/sfdata/.git mv -v /.cache/sfdata/* /app/ du -h -d 2 diff --git a/requirements.txt b/requirements.txt index 79227c99bdcd08e37e92e09e2b7c42eb9e1d0d30..311bcb43b21dccbfa1883261bcf44f7c75c7afce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,12 +25,12 @@ concrete==4.15.1 jinja2==3.0.3 # added 2022-03-28 because of (new?) error itsdangerous==2.0.1 # idem -# # LOME partpip install pyopenssl -# allennlp==2.8.0 -# allennlp-models==2.8.0 -# transformers==4.12.5 -# numpy -# torch>=1.7.0,<1.8.0 -# tqdm -# overrides -# scipy +# LOME +allennlp==2.8.0 +allennlp-models==2.8.0 +transformers==4.12.5 +numpy +torch>=1.7.0,<1.8.0 +tqdm +overrides +scipy diff --git a/spanfinder/.vscode/launch.json b/spanfinder/.vscode/launch.json new file mode 100644 index 0000000000000000000000000000000000000000..92c0743da34fd40893474e88dd7c991312753d87 --- /dev/null +++ b/spanfinder/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Module", + "type": "python", + "request": "launch", + "module": "sociolome.lome_webserver", + "cwd": "${workspaceFolder}", + "justMyCode": true + } + ] +} \ No newline at end of file diff --git a/spanfinder/README.md b/spanfinder/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f52918abd9d0bf43e0a5d93b31e841bea40ca777 --- /dev/null +++ b/spanfinder/README.md @@ -0,0 +1,33 @@ +# Span Finder (v0.0.2) + +## Installation + +Environment: + - python >= 3.7 + - pip + +To install the dependencies, execute + +``` shell script +pip install -r requirements.txt +pip uninstall -y dataclasses +``` + +Then install SFTP (Span Finding - Transductive Parsing) package: + +``` shell script +python setup.py install +``` + +## Prediction + +If you use SpanFinder only for inference, please read [this example](scripts/predict_span.py). + +## Demo + +A demo (combined with Patrick's coref model) is [here](https://nlp.jhu.edu/demos/lome). + +## Pre-Trained Models + +Some parameters trained on FrameNet can be found at CLSP grid: `/home/gqin2/public/release/sftp/0.0.2`. +The file `model.tar.gz` is pointed to the best model so for, and file named after dates will not be updated anymore. diff --git a/spanfinder/config/ace/ace.jsonnet b/spanfinder/config/ace/ace.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..5fe909efd269144312e1c496a6b36737dd3abce2 --- /dev/null +++ b/spanfinder/config/ace/ace.jsonnet @@ -0,0 +1,131 @@ +local env = import "../env.jsonnet"; + +local dataset_path = env.str("DATA_PATH", "data/ace/events"); +local ontology_path = "data/ace/ontology.tsv"; + +local debug = false; + +# embedding +local label_dim = 64; +local pretrained_model = env.str("ENCODER", "roberta-large"); + +# module +local dropout = 0.2; +local bio_dim = 512; +local bio_layers = 2; +local span_typing_dims = [256, 256]; +local event_smoothing_factor = env.json("SMOOTHING", "0.0"); +local arg_smoothing_factor = env.json("SMOOTHING", "0.0"); +local layer_fix = 0; + +# training +local typing_loss_factor = 8.0; +local grad_acc = env.json("GRAD_ACC", "1"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; +local lr = env.json("LR", "1e-3"); +local cuda_devices = env.json("CUDA_DEVICES", "[0]"); + +{ + dataset_reader: { + type: "concrete", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: event_smoothing_factor, + arg_smoothing_factor: event_smoothing_factor, + }, + train_data_path: dataset_path + "/train.tar.gz", + validation_data_path: dataset_path + "/dev.tar.gz", + test_data_path: dataset_path + "/test.tar.gz", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'] + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + ontology_path: ontology_path, + typing_loss_factor: typing_loss_factor, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: null, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+arg-c_f", + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: lr, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true, +} diff --git a/spanfinder/config/ace/ft.jsonnet b/spanfinder/config/ace/ft.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..99b5d6881ac5cd7228b038d93316fca59bcd3596 --- /dev/null +++ b/spanfinder/config/ace/ft.jsonnet @@ -0,0 +1,51 @@ +local env = import "../env.jsonnet"; +local base = import "ace.jsonnet"; + +local pretrained_path = env.str("PRETRAINED_PATH", "cache/ace/best"); +local lr = env.json("FT_LR", 5e-5); + +# training +local cuda_devices = base.cuda_devices; + +{ + dataset_reader: base.dataset_reader, + train_data_path: base.train_data_path, + validation_data_path: base.validation_data_path, + test_data_path: base.test_data_path, + datasets_for_vocab_creation: ["train"], + data_loader: base.data_loader, + validation_data_loader: base.validation_data_loader, + + model: { + type: "from_archive", + archive_file: pretrained_path + }, + vocabulary: { + type: "from_files", + directory: pretrained_path + "/vocabulary" + }, + + trainer: { + num_epochs: base.trainer.num_epochs, + patience: base.trainer.patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+arg-c_f", + num_gradient_accumulation_steps: base.trainer.num_gradient_accumulation_steps, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: lr, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: base.trainer.optimizer.layer_fix, + } + }, + + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/ace/pt.jsonnet b/spanfinder/config/ace/pt.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..def0418d72c8b2f9758bf11f0eb0e367ce983e56 --- /dev/null +++ b/spanfinder/config/ace/pt.jsonnet @@ -0,0 +1,69 @@ +local env = import "../env.jsonnet"; +local base = import "ace.jsonnet"; + +local fn_path = "data/framenet/full/full.jsonl"; +local mapping_path = "data/ace/framenet2ace/"; + +local debug = false; + +# training +local lr = env.json("PT_LR", "5e-5"); +local cuda_devices = base.cuda_devices; + +# mapping +local min_weight = env.json("MIN_WEIGHT", '0.0'); +local max_weight = env.json("MAX_WEIGHT", '5.0'); + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: base.dataset_reader.pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: base.dataset_reader.event_smoothing_factor, + arg_smoothing_factor: base.dataset_reader.arg_smoothing_factor, + ontology_mapping_path: mapping_path + '/ontology_mapping.json', + min_weight: min_weight, + max_weight: max_weight, + }, + validation_dataset_reader: base.dataset_reader, + train_data_path: fn_path, + validation_data_path: base.validation_data_path, + test_data_path: base.test_data_path, + vocabulary: { + type: "extend", + directory: mapping_path + "/vocabulary" + }, + + datasets_for_vocab_creation: ["train"], + + data_loader: base.data_loader, + validation_data_loader: base.validation_data_loader, + + model: base.model, + + trainer: { + num_epochs: base.trainer.num_epochs, + patience: base.trainer.patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+arg-c_f", + num_gradient_accumulation_steps: base.trainer.num_gradient_accumulation_steps, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: lr, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: base.trainer.optimizer.layer_fix, + } + }, + + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/ace/rt.jsonnet b/spanfinder/config/ace/rt.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..363a69058db990e8b9d2e85643a555113ba645e6 --- /dev/null +++ b/spanfinder/config/ace/rt.jsonnet @@ -0,0 +1,89 @@ +local env = import "../env.jsonnet"; +local base = import "ace.jsonnet"; + +local dataset_path = env.str("DATA_PATH", "data/ace/events"); + +local debug = false; + +# re-train +local pretrained_path = env.str("PRETRAINED_PATH", "cache/fn/best"); +local rt_lr = env.json("RT_LR", 5e-5); + +# module +local cuda_devices = base.cuda_devices; + +{ + dataset_reader: base.dataset_reader, + train_data_path: base.train_data_path, + validation_data_path: base.validation_data_path, + test_data_path: base.test_data_path, + + datasets_for_vocab_creation: ["train"], + + data_loader: base.data_loader, + validation_data_loader: base.validation_data_loader, + + model: { + type: "span", + word_embedding: { + "_pretrained": { + "archive_file": pretrained_path, + "module_path": "word_embedding", + "freeze": false, + } + }, + span_extractor: { + "_pretrained": { + "archive_file": pretrained_path, + "module_path": "_span_extractor", + "freeze": false, + } + }, + span_finder: { + "_pretrained": { + "archive_file": pretrained_path, + "module_path": "_span_finder", + "freeze": false, + } + }, + span_typing: { + type: 'mlp', + hidden_dims: base.model.span_typing.hidden_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: base.model.typing_loss_factor, + label_dim: base.model.label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: base.trainer.num_epochs, + patience: base.trainer.patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+arg-c_f", + num_gradient_accumulation_steps: base.trainer.num_gradient_accumulation_steps, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: base.trainer.optimizer.base.lr, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: base.trainer.optimizer.layer_fix, + parameter_groups: [ + [['_span_finder.*'], {'lr': rt_lr}], + [['_span_extractor.*'], {'lr': rt_lr}], + ] + } + }, + + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/basic/basic.jsonnet b/spanfinder/config/basic/basic.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..fd0ac1af66dd1e032aabad3c3a23705d69a9c125 --- /dev/null +++ b/spanfinder/config/basic/basic.jsonnet @@ -0,0 +1,132 @@ +local env = import "../env.jsonnet"; + +local dataset_path = "data/better/basic/sent/"; +local ontology_path = "data/better/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = env.str("ENCODER", "xlm-roberta-large"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local patience = env.json("PATIENCE", "null"); + +{ + dataset_reader: { + type: "better", + eval_type: "basic", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + }, + train_data_path: dataset_path + "/basic.eng-provided-72.0pct.train-70.0pct.d.bp.json", + validation_data_path: dataset_path + "/basic.eng-provided-72.0pct.analysis-15.0pct.ref.d.bp.json", + test_data_path: dataset_path + "/basic.eng-provided-72.0pct.devtest-15.0pct.ref.d.bp.json", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'] + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: ontology_path, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/basic/ft.jsonnet b/spanfinder/config/basic/ft.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..cb7d70c502563f7e2fdfc3272d4f287f92ac571b --- /dev/null +++ b/spanfinder/config/basic/ft.jsonnet @@ -0,0 +1,51 @@ +local env = import "../env.jsonnet"; +local base = import "basic.jsonnet"; + +local pretrained_path = env.str("PRETRAINED_PATH", "cache/basic/best"); +local lr = env.json("FT_LR", 5e-5); + +# training +local cuda_devices = base.cuda_devices; + +{ + dataset_reader: base.dataset_reader, + train_data_path: base.train_data_path, + validation_data_path: base.validation_data_path, + test_data_path: base.test_data_path, + datasets_for_vocab_creation: ["train"], + data_loader: base.data_loader, + validation_data_loader: base.validation_data_loader, + + model: { + type: "from_archive", + archive_file: pretrained_path + }, + vocabulary: { + type: "from_files", + directory: pretrained_path + "/vocabulary" + }, + + trainer: { + num_epochs: base.trainer.num_epochs, + patience: base.trainer.patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+arg-c_f", + num_gradient_accumulation_steps: base.trainer.num_gradient_accumulation_steps, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: lr, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: base.trainer.optimizer.layer_fix, + } + }, + + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/basic/pt.jsonnet b/spanfinder/config/basic/pt.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..fefca17428c157eaec9fc34fe35777b0278ea3e0 --- /dev/null +++ b/spanfinder/config/basic/pt.jsonnet @@ -0,0 +1,67 @@ +local env = import "../env.jsonnet"; +local base = import "basic.jsonnet"; + +local fn_path = "data/framenet/full/full.jsonl"; +local mapping_path = "data/basic/framenet2better/"; + +local debug = false; + +# training +local lr = env.json("PT_LR", "5e-5"); +local cuda_devices = base.cuda_devices; + +# mapping +local min_weight = env.json("MIN_WEIGHT", '0.0'); +local max_weight = env.json("MAX_WEIGHT", '5.0'); + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: base.dataset_reader.pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + ontology_mapping_path: mapping_path + '/ontology_mapping.json', + min_weight: min_weight, + max_weight: max_weight, + }, + validation_dataset_reader: base.dataset_reader, + train_data_path: fn_path, + validation_data_path: base.validation_data_path, + test_data_path: base.test_data_path, + vocabulary: { + type: "extend", + directory: mapping_path + "/vocabulary" + }, + + datasets_for_vocab_creation: ["train"], + + data_loader: base.data_loader, + validation_data_loader: base.validation_data_loader, + + model: base.model, + + trainer: { + num_epochs: base.trainer.num_epochs, + patience: base.trainer.patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+arg-c_f", + num_gradient_accumulation_steps: base.trainer.num_gradient_accumulation_steps, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: lr, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: base.trainer.optimizer.layer_fix, + } + }, + + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/basic/rt.jsonnet b/spanfinder/config/basic/rt.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..768e475412a9b9085a6536887b4a96e51b5c0491 --- /dev/null +++ b/spanfinder/config/basic/rt.jsonnet @@ -0,0 +1,87 @@ +local env = import "../env.jsonnet"; +local base = import "basic.jsonnet"; + +local debug = false; + +# re-train +local pretrained_path = env.str("PRETRAINED_PATH", "cache/fn/best"); +local rt_lr = env.json("RT_LR", 5e-5); + +# module +local cuda_devices = base.cuda_devices; + +{ + dataset_reader: base.dataset_reader, + train_data_path: base.train_data_path, + validation_data_path: base.validation_data_path, + test_data_path: base.test_data_path, + + datasets_for_vocab_creation: ["train"], + + data_loader: base.data_loader, + validation_data_loader: base.validation_data_loader, + + model: { + type: "span", + word_embedding: { + "_pretrained": { + "archive_file": pretrained_path, + "module_path": "word_embedding", + "freeze": false, + } + }, + span_extractor: { + "_pretrained": { + "archive_file": pretrained_path, + "module_path": "_span_extractor", + "freeze": false, + } + }, + span_finder: { + "_pretrained": { + "archive_file": pretrained_path, + "module_path": "_span_finder", + "freeze": false, + } + }, + span_typing: { + type: 'mlp', + hidden_dims: base.model.span_typing.hidden_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: base.model.typing_loss_factor, + label_dim: base.model.label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: base.trainer.num_epochs, + patience: base.trainer.patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+arg-c_f", + num_gradient_accumulation_steps: base.trainer.num_gradient_accumulation_steps, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: base.trainer.optimizer.base.lr, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: base.trainer.optimizer.layer_fix, + parameter_groups: [ + [['_span_finder.*'], {'lr': rt_lr}], + [['_span_extractor.*'], {'lr': rt_lr}], + ] + } + }, + + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/env.jsonnet b/spanfinder/config/env.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..982dfc802eb666b8eecb0e0ceceac9ca2195dae3 --- /dev/null +++ b/spanfinder/config/env.jsonnet @@ -0,0 +1,4 @@ +{ + json: function(name, default) if std.extVar("LOGNAME")=="tuning" then std.parseJson(std.extVar(name)) else std.parseJson(default), + str: function(name, default) if std.extVar("LOGNAME")=="tuning" then std.extVar(name) else default +} \ No newline at end of file diff --git a/spanfinder/config/fn-evalita/evalita.framenet_xlmr.jsonnet b/spanfinder/config/fn-evalita/evalita.framenet_xlmr.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..c41948c3ac782b9549af7b7fdc7559c9b2975f1b --- /dev/null +++ b/spanfinder/config/fn-evalita/evalita.framenet_xlmr.jsonnet @@ -0,0 +1,141 @@ +local env = import "../env.jsonnet"; + +#local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local dataset_path = "/home/p289731/cloned/lome/preproc/evalita_jsonl"; +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = "/data/p289731/cloned/lome-models/models/xlm-roberta-framenet/"; +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +#local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local cuda_devices = [0]; +local patience = 32; + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + train_data_path: dataset_path + "/evalita_train.jsonl", + validation_data_path: dataset_path + "/evalita_dev.jsonl", + test_data_path: dataset_path + "/evalita_test.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': 1.0, + 'full text': 0.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: null, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/fn-evalita/evalita.it_mono.jsonnet b/spanfinder/config/fn-evalita/evalita.it_mono.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..362d2907a4eb8882acb32c8f3ac8957386cd562e --- /dev/null +++ b/spanfinder/config/fn-evalita/evalita.it_mono.jsonnet @@ -0,0 +1,141 @@ +local env = import "../env.jsonnet"; + +#local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local dataset_path = "/home/p289731/cloned/lome/preproc/evalita_jsonl"; +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = env.str("ENCODER", "Musixmatch/umberto-commoncrawl-cased-v1"); +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +#local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local cuda_devices = [0]; +local patience = 32; + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + train_data_path: dataset_path + "/evalita_train.jsonl", + validation_data_path: dataset_path + "/evalita_dev.jsonl", + test_data_path: dataset_path + "/evalita_test.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': 1.0, + 'full text': 0.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: null, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/fn-evalita/evalita.vanilla_xlmr.jsonnet b/spanfinder/config/fn-evalita/evalita.vanilla_xlmr.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..e89948dbca970299d2f644331d2e1716ef058d1b --- /dev/null +++ b/spanfinder/config/fn-evalita/evalita.vanilla_xlmr.jsonnet @@ -0,0 +1,141 @@ +local env = import "../env.jsonnet"; + +#local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local dataset_path = "/home/p289731/cloned/lome/preproc/evalita_jsonl"; +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = env.str("ENCODER", "xlm-roberta-large"); +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +#local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local cuda_devices = [0]; +local patience = 32; + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + train_data_path: dataset_path + "/evalita_train.jsonl", + validation_data_path: dataset_path + "/evalita_dev.jsonl", + test_data_path: dataset_path + "/evalita_test.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': 1.0, + 'full text': 0.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: null, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/fn-evalita/evalita_plus_fn.vanilla_xlmr.freeze.jsonnet b/spanfinder/config/fn-evalita/evalita_plus_fn.vanilla_xlmr.freeze.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..a0801d7c40f4ca5896a9c679aaa03bd4a46441ee --- /dev/null +++ b/spanfinder/config/fn-evalita/evalita_plus_fn.vanilla_xlmr.freeze.jsonnet @@ -0,0 +1,142 @@ +local env = import "../env.jsonnet"; + +#local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local dataset_path = "/home/p289731/cloned/lome/preproc/evalita_jsonl"; +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = env.str("ENCODER", "xlm-roberta-large"); +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +#local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local cuda_devices = [0]; +local patience = 32; + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + train_data_path: dataset_path + "/evalita_plus_fn_train.jsonl", + validation_data_path: dataset_path + "/evalita_dev.jsonl", + test_data_path: dataset_path + "/evalita_test.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': 1.0, + 'full text': 0.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + train_parameters: false + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: null, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/fn-evalita/evalita_plus_fn.vanilla_xlmr.jsonnet b/spanfinder/config/fn-evalita/evalita_plus_fn.vanilla_xlmr.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..741f1cf088e6a91e29198465c691bd503b042237 --- /dev/null +++ b/spanfinder/config/fn-evalita/evalita_plus_fn.vanilla_xlmr.jsonnet @@ -0,0 +1,141 @@ +local env = import "../env.jsonnet"; + +#local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local dataset_path = "/home/p289731/cloned/lome/preproc/evalita_jsonl"; +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = env.str("ENCODER", "xlm-roberta-large"); +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +#local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local cuda_devices = [0]; +local patience = 32; + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + train_data_path: dataset_path + "/evalita_plus_fn_train.jsonl", + validation_data_path: dataset_path + "/evalita_dev.jsonl", + test_data_path: dataset_path + "/evalita_test.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': 1.0, + 'full text': 0.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: null, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/fn-kicktionary/kicktionary.concat_clipped.vanilla_xlmr.jsonnet b/spanfinder/config/fn-kicktionary/kicktionary.concat_clipped.vanilla_xlmr.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..70aff4834e3f936669cfc23a81a4f01c9a5ba7a5 --- /dev/null +++ b/spanfinder/config/fn-kicktionary/kicktionary.concat_clipped.vanilla_xlmr.jsonnet @@ -0,0 +1,141 @@ +local env = import "../env.jsonnet"; + +#local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local dataset_path = "/home/p289731/cloned/lome/preproc/kicktionary_jsonl"; +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = env.str("ENCODER", "xlm-roberta-large"); +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +#local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local cuda_devices = [0]; +local patience = 32; + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + train_data_path: dataset_path + "/kicktionary_exemplars_train.concat_clipped.jsonl", + validation_data_path: dataset_path + "/kicktionary_exemplars_dev.jsonl", + test_data_path: dataset_path + "/kicktionary_exemplars_test.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': 1.0, + 'full text': 0.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: null, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/fn-kicktionary/kicktionary.football_xlmr.jsonnet b/spanfinder/config/fn-kicktionary/kicktionary.football_xlmr.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..24ab5f4625d3646f74cf4abfa8314e9d61ebacfe --- /dev/null +++ b/spanfinder/config/fn-kicktionary/kicktionary.football_xlmr.jsonnet @@ -0,0 +1,141 @@ +local env = import "../env.jsonnet"; + +#local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local dataset_path = "/home/p289731/cloned/lome/preproc/kicktionary_jsonl"; +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = env.str("ENCODER", "/data/p289731/cloned/lome-models/models/xlm-roberta-football/"); +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +#local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local cuda_devices = [0]; +local patience = 32; + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + train_data_path: dataset_path + "/kicktionary_exemplars_train.jsonl", + validation_data_path: dataset_path + "/kicktionary_exemplars_dev.jsonl", + test_data_path: dataset_path + "/kicktionary_exemplars_test.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': 1.0, + 'full text': 0.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: null, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/fn-kicktionary/kicktionary.framenet_xlmr.jsonnet b/spanfinder/config/fn-kicktionary/kicktionary.framenet_xlmr.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..d8ca62ed81841888ea68fbca183f64a1e8594ef2 --- /dev/null +++ b/spanfinder/config/fn-kicktionary/kicktionary.framenet_xlmr.jsonnet @@ -0,0 +1,141 @@ +local env = import "../env.jsonnet"; + +#local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local dataset_path = "/home/p289731/cloned/lome/preproc/kicktionary_jsonl"; +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = "/data/p289731/cloned/lome-models/models/xlm-roberta-framenet/"; +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +#local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local cuda_devices = [0]; +local patience = 32; + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + train_data_path: dataset_path + "/kicktionary_exemplars_train.jsonl", + validation_data_path: dataset_path + "/kicktionary_exemplars_dev.jsonl", + test_data_path: dataset_path + "/kicktionary_exemplars_test.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': 1.0, + 'full text': 0.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: null, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/fn-kicktionary/kicktionary.vanilla_xlmr.jsonnet b/spanfinder/config/fn-kicktionary/kicktionary.vanilla_xlmr.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..319863f0d9026411f796437f4c93b9ad20bc4049 --- /dev/null +++ b/spanfinder/config/fn-kicktionary/kicktionary.vanilla_xlmr.jsonnet @@ -0,0 +1,141 @@ +local env = import "../env.jsonnet"; + +#local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local dataset_path = "/home/p289731/cloned/lome/preproc/kicktionary_jsonl"; +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = env.str("ENCODER", "xlm-roberta-large"); +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +#local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local cuda_devices = [0]; +local patience = 32; + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + train_data_path: dataset_path + "/kicktionary_exemplars_train.jsonl", + validation_data_path: dataset_path + "/kicktionary_exemplars_dev.jsonl", + test_data_path: dataset_path + "/kicktionary_exemplars_test.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': 1.0, + 'full text': 0.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: null, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/fn-sonar/sonar-a1.framenet_xlmr.jsonnet b/spanfinder/config/fn-sonar/sonar-a1.framenet_xlmr.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..06779889275fe79b642c6d7d4c7b74c95b511955 --- /dev/null +++ b/spanfinder/config/fn-sonar/sonar-a1.framenet_xlmr.jsonnet @@ -0,0 +1,141 @@ +local env = import "../env.jsonnet"; + +#local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local dataset_path = "/home/p289731/cloned/lome/preproc/sonar_jsonl"; +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = "/data/p289731/cloned/lome-models/models/xlm-roberta-framenet/"; +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +#local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local cuda_devices = [0]; +local patience = 32; + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + train_data_path: dataset_path + "/dutch-sonar-train-A1.jsonl", + validation_data_path: dataset_path + "/dutch-sonar-dev-A1.jsonl", + test_data_path: dataset_path + "/dutch-sonar-test-A1.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': 1.0, + 'full text': 0.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: null, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/fn-sonar/sonar-a1.sonar_plus_fn.vanilla_xlmr.jsonnet b/spanfinder/config/fn-sonar/sonar-a1.sonar_plus_fn.vanilla_xlmr.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..a3fadf5a89d54aa03893b1b192ddf8a24fe89502 --- /dev/null +++ b/spanfinder/config/fn-sonar/sonar-a1.sonar_plus_fn.vanilla_xlmr.jsonnet @@ -0,0 +1,142 @@ +local env = import "../env.jsonnet"; + +#local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local dataset_path = "/home/p289731/cloned/lome/preproc/sonar_jsonl"; +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = env.str("ENCODER", "xlm-roberta-large"); +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +#local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local cuda_devices = [0]; +local patience = 32; + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + + train_data_path: dataset_path + "/dutch-sonar-train-A1.jsonl", + validation_data_path: dataset_path + "/dutch-sonar-dev-A1.jsonl", + test_data_path: dataset_path + "/dutch-sonar-test-A1.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': 1.0, + 'full text': 0.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: null, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/fn-sonar/sonar-a1.vanilla_xlmr.jsonnet b/spanfinder/config/fn-sonar/sonar-a1.vanilla_xlmr.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..18790880c3eceb4c2781f809d1b18ef6fde72642 --- /dev/null +++ b/spanfinder/config/fn-sonar/sonar-a1.vanilla_xlmr.jsonnet @@ -0,0 +1,141 @@ +local env = import "../env.jsonnet"; + +#local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local dataset_path = "/home/p289731/cloned/lome/preproc/sonar_jsonl"; +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = env.str("ENCODER", "xlm-roberta-large"); +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +#local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local cuda_devices = [0]; +local patience = 32; + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + train_data_path: dataset_path + "/dutch-sonar-train-A1.jsonl", + validation_data_path: dataset_path + "/dutch-sonar-dev-A1.jsonl", + test_data_path: dataset_path + "/dutch-sonar-test-A1.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': 1.0, + 'full text': 0.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: null, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/fn-sonar/sonar-a2.framenet_xlmr.jsonnet b/spanfinder/config/fn-sonar/sonar-a2.framenet_xlmr.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..9b2cc107a46d204fe6f12b2aa0eca6ccde0243c1 --- /dev/null +++ b/spanfinder/config/fn-sonar/sonar-a2.framenet_xlmr.jsonnet @@ -0,0 +1,141 @@ +local env = import "../env.jsonnet"; + +#local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local dataset_path = "/home/p289731/cloned/lome/preproc/sonar_jsonl"; +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = "/data/p289731/cloned/lome-models/models/xlm-roberta-framenet/"; +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +#local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local cuda_devices = [0]; +local patience = 32; + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + train_data_path: dataset_path + "/dutch-sonar-train-A2.jsonl", + validation_data_path: dataset_path + "/dutch-sonar-dev-A2.jsonl", + test_data_path: dataset_path + "/dutch-sonar-test-A2.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': 1.0, + 'full text': 0.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: null, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/fn-sonar/sonar-a2.sonar_plus_fn.vanilla_xlmr.jsonnet b/spanfinder/config/fn-sonar/sonar-a2.sonar_plus_fn.vanilla_xlmr.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..720583740e6bdfc1b1cd2b3a6e1ff845dacc7af1 --- /dev/null +++ b/spanfinder/config/fn-sonar/sonar-a2.sonar_plus_fn.vanilla_xlmr.jsonnet @@ -0,0 +1,141 @@ +local env = import "../env.jsonnet"; + +#local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local dataset_path = "/home/p289731/cloned/lome/preproc/sonar_jsonl"; +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = env.str("ENCODER", "xlm-roberta-large"); +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +#local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local cuda_devices = [0]; +local patience = 32; + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + train_data_path: dataset_path + "/dutch-sonar-train-A2.jsonl", + validation_data_path: dataset_path + "/dutch-sonar-dev-A2.jsonl", + test_data_path: dataset_path + "/dutch-sonar-test-A2.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': 1.0, + 'full text': 0.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: null, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/fn-sonar/sonar-a2.vanilla_xlmr.jsonnet b/spanfinder/config/fn-sonar/sonar-a2.vanilla_xlmr.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..720583740e6bdfc1b1cd2b3a6e1ff845dacc7af1 --- /dev/null +++ b/spanfinder/config/fn-sonar/sonar-a2.vanilla_xlmr.jsonnet @@ -0,0 +1,141 @@ +local env = import "../env.jsonnet"; + +#local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local dataset_path = "/home/p289731/cloned/lome/preproc/sonar_jsonl"; +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = env.str("ENCODER", "xlm-roberta-large"); +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +#local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local cuda_devices = [0]; +local patience = 32; + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + train_data_path: dataset_path + "/dutch-sonar-train-A2.jsonl", + validation_data_path: dataset_path + "/dutch-sonar-dev-A2.jsonl", + test_data_path: dataset_path + "/dutch-sonar-test-A2.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': 1.0, + 'full text': 0.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: null, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/fn/fn.orig.jsonnet b/spanfinder/config/fn/fn.orig.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..a1e8c3fa9c8c25c2069d36977887dbf7e76a7e0c --- /dev/null +++ b/spanfinder/config/fn/fn.orig.jsonnet @@ -0,0 +1,139 @@ +local env = import "../env.jsonnet"; + +local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = env.str("ENCODER", "xlm-roberta-large"); +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local patience = env.json("PATIENCE", "null"); + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + train_data_path: dataset_path + "/train.jsonl", + validation_data_path: dataset_path + "/dev.jsonl", + test_data_path: dataset_path + "/test.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': exemplar_ratio, + 'full text': 1.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: ontology_path, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/fn/fn.train-football.jsonnet b/spanfinder/config/fn/fn.train-football.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..0dd2bc3e66d227a1b931544831f70a1d3d5f5a3a --- /dev/null +++ b/spanfinder/config/fn/fn.train-football.jsonnet @@ -0,0 +1,142 @@ +local env = import "../env.jsonnet"; + +#local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local dataset_path = "/home/p289731/cloned/lome/preproc/framenet_jsonl/full"; +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +#local pretrained_model = env.str("ENCODER", "xlm-roberta-large"); +local pretrained_model = env.str("ENCODER", "/data/p289731/cloned/lome-models/models/xlm-roberta-football/"); +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +#local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local cuda_devices = [0]; +local patience = 32; + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: "xlm-roberta-large", + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + train_data_path: dataset_path + "/train.jsonl", + validation_data_path: dataset_path + "/dev.jsonl", + test_data_path: dataset_path + "/test.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': exemplar_ratio, + 'full text': 1.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: null, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/config/fn/fn.train3.jsonnet b/spanfinder/config/fn/fn.train3.jsonnet new file mode 100644 index 0000000000000000000000000000000000000000..9666b0d3f33bfb9e66a4e6229a2accf5891e4d33 --- /dev/null +++ b/spanfinder/config/fn/fn.train3.jsonnet @@ -0,0 +1,141 @@ +local env = import "../env.jsonnet"; + +#local dataset_path = env.str("DATA_PATH", "data/framenet/full"); +local dataset_path = "/home/p289731/cloned/lome/preproc/framenet_jsonl/full"; +local ontology_path = "data/framenet/ontology.tsv"; + +local debug = false; + +# reader +local pretrained_model = env.str("ENCODER", "xlm-roberta-large"); +local smoothing_factor = env.json("SMOOTHING", "0.1"); + +# model +local label_dim = env.json("LABEL_DIM", "64"); +local dropout = env.json("DROPOUT", "0.2"); +local bio_dim = env.json("BIO_DIM", "512"); +local bio_layers = env.json("BIO_LAYER", "2"); +local span_typing_dims = env.json("TYPING_DIMS", "[256, 256]"); +local typing_loss_factor = env.json("LOSS_FACTOR", "8.0"); + +# loader +local exemplar_ratio = env.json("EXEMPLAR_RATIO", "0.05"); +local max_training_tokens = 512; +local max_inference_tokens = 1024; + +# training +local layer_fix = env.json("LAYER_FIX", "0"); +local grad_acc = env.json("GRAD_ACC", "1"); +#local cuda_devices = env.json("CUDA_DEVICES", "[-1]"); +local cuda_devices = [0]; +local patience = 32; + +{ + dataset_reader: { + type: "semantic_role_labeling", + debug: debug, + pretrained_model: pretrained_model, + ignore_label: false, + [ if debug then "max_instances" ]: 128, + event_smoothing_factor: smoothing_factor, + arg_smoothing_factor: smoothing_factor, + }, + train_data_path: dataset_path + "/train.jsonl", + validation_data_path: dataset_path + "/dev.jsonl", + test_data_path: dataset_path + "/test.jsonl", + + datasets_for_vocab_creation: ["train"], + + data_loader: { + batch_sampler: { + type: "mix_sampler", + max_tokens: max_training_tokens, + sorting_keys: ['tokens'], + sampling_ratios: { + 'exemplar': exemplar_ratio, + 'full text': 1.0, + } + } + }, + + validation_data_loader: { + batch_sampler: { + type: "max_tokens_sampler", + max_tokens: max_inference_tokens, + sorting_keys: ['tokens'] + } + }, + + model: { + type: "span", + word_embedding: { + token_embedders: { + "pieces": { + type: "pretrained_transformer", + model_name: pretrained_model, + } + }, + }, + span_extractor: { + type: 'combo', + sub_extractors: [ + { + type: 'self_attentive', + }, + { + type: 'bidirectional_endpoint', + } + ] + }, + span_finder: { + type: "bio", + bio_encoder: { + type: "lstm", + hidden_size: bio_dim, + num_layers: bio_layers, + bidirectional: true, + dropout: dropout, + }, + no_label: false, + }, + span_typing: { + type: 'mlp', + hidden_dims: span_typing_dims, + }, + metrics: [{type: "srl"}], + + typing_loss_factor: typing_loss_factor, + ontology_path: null, + label_dim: label_dim, + max_decoding_spans: 128, + max_recursion_depth: 2, + debug: debug, + }, + + trainer: { + num_epochs: 128, + patience: patience, + [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], + validation_metric: "+em_f", + grad_norm: 10, + grad_clipping: 10, + num_gradient_accumulation_steps: grad_acc, + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 1e-3, + }, + embeddings_lr: 0.0, + encoder_lr: 1e-5, + pooler_lr: 1e-5, + layer_fix: layer_fix, + } + }, + + cuda_devices:: cuda_devices, + [if std.length(cuda_devices) > 1 then "distributed"]: { + "cuda_devices": cuda_devices + }, + [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true +} diff --git a/spanfinder/docs/data.md b/spanfinder/docs/data.md new file mode 100644 index 0000000000000000000000000000000000000000..a864e1b58d42c418e9a39b74eeb9fe5237e6f153 --- /dev/null +++ b/spanfinder/docs/data.md @@ -0,0 +1,68 @@ +# Data Format + +You can pass SpanFinder any formats of data, as long as you implement a dataset reader inherited from SpanReader. We also provide a Concrete dataset reader. Besides them, SpanFinder comes with its own JSON data format, which enables richer features for training and modeling. + +The minimal example of the JSON is + +```JSON +{ + "meta": { + "fully_annotated": true + }, + "tokens": ["Bob", "attacks", "the", "building", "."], + "annotations": [ + { + "span": [1, 1], + "label": "Attack", + "children": [ + { + "span": [0, 0], + "label": "Assailant", + "children": [] + }, + { + "span": [2, 3], + "label": "Victim", + "children": [] + } + ] + }, + { + "span": [3, 3], + "label": "Buildings", + "children": [ + { + "span": [3, 3], + "label": "Building", + "children": [] + } + ] + } + ] +} +``` + +You can have nested spans with unlimited depth. + +## Meta-info for Semantic Role Labeling (SRL) + +```JSON +{ + "ontology": { + "event": ["Violence-Attack"], + "argument": ["Agent", "Patient"], + "link": [[0, 0], [0, 1]] + }, + "ontology_mapping": { + "event": { + "Attack": ["Violence-Attack", 0.8] + }, + "argument": { + "Assault": ["Agent", 0.95], + "Victim": ["patient", 0.9] + } + } +} +``` + +TODO: Guanghui needs to doc this. diff --git a/spanfinder/docs/mapping.md b/spanfinder/docs/mapping.md new file mode 100644 index 0000000000000000000000000000000000000000..81ac1fe10057250d66ee65bc3e3d7735c8d1412a --- /dev/null +++ b/spanfinder/docs/mapping.md @@ -0,0 +1,17 @@ +## Mapping + +If a file is passed to the predictor, +the predicted spans will be converted into a new ontology. +The file format should be + +`\t\t` + +If the predicted span is labeled as ``, +and its parent is labeled as ``, +it will be re-labeled as ``. +If no rules match, the span and all of its descendents will be ignored. + +The `` is optional. +If the parent label is `@@VIRTUAL_ROOT@@`, then this rule matches the first layer of spans. +In semantic parsing, it matches events. +If the parent label is `*`, it means it can match anything. diff --git a/spanfinder/docs/training.md b/spanfinder/docs/training.md new file mode 100644 index 0000000000000000000000000000000000000000..7a5285ce15e4066ae66d8004fcb9316a7f5f2a56 --- /dev/null +++ b/spanfinder/docs/training.md @@ -0,0 +1,65 @@ +# Training Span Finder + +## Metrics explanation + +By default, the following metrics will be used + +- em: (includes emp, emr, emf) Exact matching metric. A span is exactly matched iff its parent, boundaries, and label are all correctly predicted. Note that if a parent is not correctly predicted, all its children will be treated as false negative. In another word, errors are propagated. +- sm: (includes smp, smr, smf) Span matching metric. Similar to EM but will not check the labels. If you observe high EM but low SM, then the typing system is not properly working. +- finder: (includes finder-p, finder-r, finder-f) A metric to measure how well the model can find spans. Different from SM, in this metric, gold parent will be provided, so the errors will not be propagated. +- typing_acc: Span typing accuracy with gold parent and gold span boundaries. + + +Optional metrics that might be useful for SRL-style tasks. Put the following line + +`metrics: [{type: "srl", check_type: true}],` + +to the span model in the config file to turn on this feature. You will see the following two metrics: + +- trigger: (include trigger-p, trigger-r, trigger-f) It measures how well the system can find the event triggers (or frames in FrameNet). If `check_type` is True, it also checks the event label. +- role: (include role-p, role-r, role-f) It measures how well the system can find roles. Note if the event/trigger is not found, all its children will be treated as false negative. If `check_type` is True, it also checks the role label. + +## Ontology Constraint + +In some cases, certain spans can also be attached to specific spans. +E.g., in SRL tasks, event can only be attached to the VirtualRoot, and arguments can only be attached to the events. +The constraints of FrameNet is harsher, where each frame have some specific frame elements. + +These constraints can be abstracted as a boolean square matrix whose columns and rows are span labels including VIRTUAL_ROOT. +Say it's `M`, label2 can be label1's child iff `M[label1, label2]` if True. + +You can specify ontology constraint for SpanFinder with the `ontology_path` argument in the SpanModel class. +The format of this file is simple. Each line is one row of the `M` matrix: + +```parent_label child_label_1 child_label_2``` + +which means child1 and child2 can be attached to the parent. +Both `parent_label` and `child_label` are strings, and the space between them should be `\t` not ` `. +If a parent_label is missing from the file, by default all children be attachable. +If this file is not provided, all labels can be attached to all labels. + +An example of this file can be found at CLSP grid: + +```/home/gqin2/data/framenet/ontology.tsv``` + +## Typing loss factor + +(This section might be updated soon -- Guanghui) + +The loss comes from two sources: SpanFinding and SpanTyping modules. +SpanFinder uses CRF and use probability as loss, but SpanTyping uses cross entropy. +They're of different scale so we have to re-scale them. +The formula is: + +`loss = finding_loss + typing_loss_factor * typing_loss` + +Empirically Guanghui finds the optimal `typing_loss_factor` for FrameNet system is 750. + +In theory, we should put the two losses to the same space. Guanghui is looking into this, and this might be solved in SpanFinder 0.0.2. + +## Optimizer + +A custom optimizer `transformer` is used for span finder. +It allows you to specify special learning rate for transformer encoder and fix the parameters of certain modules. +Empirically, fix embedding (so only fine-tune the encoder and pooler) and train with lr=1e-5 yields best results for FrameNet. +For usage and more details, see its class doc. diff --git a/spanfinder/evalita_scores.txt b/spanfinder/evalita_scores.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/spanfinder/model.kicktionary.mod.tar.gz b/spanfinder/model.kicktionary.mod.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..478c5b80dc26de0511cb2195cdf03a391d4b2b2c --- /dev/null +++ b/spanfinder/model.kicktionary.mod.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44922d9261fa39f226ffc36e0c7b29e7c17aed7b9f516ab32e69b7d2eeedfd11 +size 1785047888 diff --git a/spanfinder/model.mod.tar.gz b/spanfinder/model.mod.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..2245d8abfcc894d257ececd887827e2a5fead973 --- /dev/null +++ b/spanfinder/model.mod.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f5be5aeef50b2f4840317b8196c51186f9f138a853dc1eb2da980b1947ceb23 +size 1795605184 diff --git a/spanfinder/requirements.txt b/spanfinder/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..302e88c61a34c84789583808e78253055853c9a2 --- /dev/null +++ b/spanfinder/requirements.txt @@ -0,0 +1,15 @@ +allennlp>=2.0.0 +allennlp-models>=2.0.0 +transformers>=4.0.0 # Why is huggingface so unstable? +numpy +torch>=1.7.0,<1.8.0 +tqdm +nltk +overrides +concrete +flask +scipy +https://github.com/explosion/spacy-models/releases/download/it_core_news_md-3.0.0/it_core_news_md-3.0.0-py3-none-any.whl +https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.0.0/en_core_web_md-3.0.0-py3-none-any.whl +https://github.com/explosion/spacy-models/releases/download/nl_core_news_md-3.0.0/nl_core_news_md-3.0.0-py3-none-any.whl +https://github.com/explosion/spacy-models/releases/download/xx_sent_ud_sm-3.0.0/xx_sent_ud_sm-3.0.0-py3-none-any.whl diff --git a/spanfinder/scripts/__pycache__/predict_concrete.cpython-37.pyc b/spanfinder/scripts/__pycache__/predict_concrete.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..360223c84ed9179fc7fc934d777871f7fa9a206d Binary files /dev/null and b/spanfinder/scripts/__pycache__/predict_concrete.cpython-37.pyc differ diff --git a/spanfinder/scripts/__pycache__/predict_concrete.cpython-38.pyc b/spanfinder/scripts/__pycache__/predict_concrete.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02a4ae13622caf947ddbb2d0ea6daf8851483348 Binary files /dev/null and b/spanfinder/scripts/__pycache__/predict_concrete.cpython-38.pyc differ diff --git a/spanfinder/scripts/__pycache__/predict_concrete.cpython-39.pyc b/spanfinder/scripts/__pycache__/predict_concrete.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ad28e790ee3a7f97bc25a1e4ec9e91647c350fc Binary files /dev/null and b/spanfinder/scripts/__pycache__/predict_concrete.cpython-39.pyc differ diff --git a/spanfinder/scripts/__pycache__/predict_force.cpython-39.pyc b/spanfinder/scripts/__pycache__/predict_force.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08cad1ac9c507e0efcc95e457e6828b54ceb92bf Binary files /dev/null and b/spanfinder/scripts/__pycache__/predict_force.cpython-39.pyc differ diff --git a/spanfinder/scripts/__pycache__/repl.cpython-39.pyc b/spanfinder/scripts/__pycache__/repl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0ffcb6aebe56c0f4bd9c989c5f399351c302259 Binary files /dev/null and b/spanfinder/scripts/__pycache__/repl.cpython-39.pyc differ diff --git a/spanfinder/scripts/aida_experiment/predict_aida.py b/spanfinder/scripts/aida_experiment/predict_aida.py new file mode 100644 index 0000000000000000000000000000000000000000..71cb112360fa64a4f812a367302fbd26f18faf74 --- /dev/null +++ b/spanfinder/scripts/aida_experiment/predict_aida.py @@ -0,0 +1,42 @@ +import json +import os +import copy +from collections import defaultdict +from argparse import ArgumentParser +from tqdm import tqdm +import random +from tqdm import tqdm +from scripts.predict_concrete import read_kairos + +from sftp import SpanPredictor + + +parser = ArgumentParser() +parser.add_argument('aida', type=str) +parser.add_argument('model', type=str) +parser.add_argument('dst', type=str) +parser.add_argument('--topk', type=int, default=10) +parser.add_argument('--device', type=int, default=0) +args = parser.parse_args() + +k = args.topk +corpus = json.load(open(args.aida)) +predictor = SpanPredictor.from_path(args.model, cuda_device=args.device) +idx2fn = predictor._model.vocab.get_index_to_token_vocabulary('span_label') +random.seed(42) +random.shuffle(corpus) + + +output_fp = open(args.dst, 'a') +for line in tqdm(corpus): + tokens, ann = line['tokens'], line['annotation'] + start, end, kairos_label = ann['start_idx'], ann['end_idx'], ann['label'] + prob_dist = predictor.force_decode(tokens, [(start, end)])[0] + topk_indices = prob_dist.argsort(descending=True)[:k] + prob = prob_dist[topk_indices].tolist() + frames = [(idx2fn[int(idx)], p) for idx, p in zip(topk_indices, prob)] + output_fp.write(json.dumps({ + 'tokens': tokens, + 'frames': frames, + 'kairos': kairos_label + }) + '\n') diff --git a/spanfinder/scripts/aida_experiment/read_aida.py b/spanfinder/scripts/aida_experiment/read_aida.py new file mode 100644 index 0000000000000000000000000000000000000000..4d2eb12ceac8da8ba52ba6e93c80d03ea198d705 --- /dev/null +++ b/spanfinder/scripts/aida_experiment/read_aida.py @@ -0,0 +1,107 @@ +import json +import os +import copy +from collections import defaultdict +from argparse import ArgumentParser +from tqdm import tqdm + + +def extract_sentences(raw_doc): + sentence_tokens = list() # [(start, end), list_tokens, event_list] + for sent_boundary in raw_doc['_views']['_InitialView']['Sentence']: + start, end = sent_boundary.get('begin', 0), sent_boundary.get('end') + sentence_tokens.append([(start, end), list(), list()]) + begin2sentence, end2sentence = dict(), dict() + for token in raw_doc['_views']['_InitialView']['Token']: + start, end = token.get('begin', 0), token.get('end') + added = False + for sent_idx, (bound, tl, _) in enumerate(sentence_tokens): + if start in range(*bound) and (end - 1) in range(*bound): + assert not added + begin2sentence[start] = (sent_idx, len(tl)) + end2sentence[end] = (sent_idx, len(tl)) + tl.append((start, end)) + added = True + assert added + return sentence_tokens, begin2sentence, end2sentence + + +def read_aida2kairos(mapping_path): + mapping = dict() + for line in open(mapping_path).readlines(): + kairos, aida_list = line.replace('\n', '').replace(',', '').split('\t') + for aida in aida_list.split(): + if aida in 'x?': + continue + if aida in mapping: + print('warning:', aida, 'already in the mapping, repeated.') + mapping[aida] = kairos + return mapping + + +def read_aida(corpus_path, mapping_path): + print('reading aida data') + n_negative, n_span_mismatch, n_diff = 0, 0, 0 + outputs = list() + mapping = read_aida2kairos(mapping_path) + for event_fn in tqdm(os.listdir(corpus_path)): + event_name = event_fn.split('-')[0] + if event_name not in mapping: + print('warning:', event_name, 'not in the mapping.') + continue + event_name = mapping[event_name] + + for doc_name in os.listdir(os.path.join(corpus_path, event_fn)): + if not doc_name.endswith('json'): + continue + raw_doc = json.load(open(os.path.join(corpus_path, event_fn, doc_name))) + sentences, begin2sentence, end2sentence = extract_sentences(raw_doc) + for fss_no, fss in raw_doc['_referenced_fss'].items(): + if fss_no == '1': + continue + begin, end, is_negative = fss['begin'], fss['end'], fss['negative_example'] + if is_negative: + n_negative += 1 + continue + if begin not in begin2sentence or end not in end2sentence: + n_span_mismatch += 1 + continue + (b_idx_sent, b_idx_token), (e_idx_sent, e_idx_token) = begin2sentence[begin], end2sentence[end] + if b_idx_sent != e_idx_sent: + n_diff += 1 + continue + sentences[b_idx_sent][2].append([b_idx_token, e_idx_token]) + + text = raw_doc['_referenced_fss']['1']['sofaString'] + + for _, tokens, events in sentences: + tokens = [text[start:end] for start, end in tokens] + for (start, end) in events: + outputs.append({ + 'tokens': copy.deepcopy(tokens), + 'annotation': { + 'start_idx': start, + 'end_idx': end, + 'label': event_name, + } + }) + + print(f'Loaded {len(outputs)} annotations.') + print(f'{n_negative} negative annotations are ignored.') + print(f'{n_span_mismatch} mismatched annotations are ignored.') + print(f'{n_diff} annotations across sentences are ignored.') + + return outputs + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('aida', type=str) + parser.add_argument('aida2kairos', type=str) + parser.add_argument('dst', type=str) + args = parser.parse_args() + + aida = read_aida(args.aida, args.aida2kairos) + + json.dump(aida, open(args.dst, 'w')) + diff --git a/spanfinder/scripts/aida_experiment/test_mapping.py b/spanfinder/scripts/aida_experiment/test_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..3c590a458a230ada6af9330e84c6b54e378f8fc2 --- /dev/null +++ b/spanfinder/scripts/aida_experiment/test_mapping.py @@ -0,0 +1,59 @@ +import json +import os +import copy +from collections import defaultdict +from argparse import ArgumentParser +from tqdm import tqdm +import random +from tqdm import tqdm +from scripts.predict_concrete import read_kairos + +from sftp import SpanPredictor + + +parser = ArgumentParser() +parser.add_argument('aida', type=str) +parser.add_argument('model', type=str) +parser.add_argument('fn2kairos', type=str, default=None) +parser.add_argument('--device', type=int, default=3) +args = parser.parse_args() + +corpus = json.load(open(args.aida)) +mapping = read_kairos(args.fn2kairos) +predictor = SpanPredictor.from_path(args.model, cuda_device=args.device) +random.seed(42) +random.shuffle(corpus) +batch_size = 128 + + +def batchify(a_list): + cur = list() + for item in a_list: + cur.append(item) + if len(cur) == batch_size: + yield cur + cur = list() + if len(cur) > 0: + yield cur + + +batches = list(batchify(corpus)) + + +n_total = n_pos = n_span_match = 0 +for idx, lines in tqdm(enumerate(batches)): + n_total += batch_size + prediction_lines = predictor.predict_batch_sentences( + [line['tokens'] for line in lines], max_tokens=1024, ontology_mapping=mapping + ) + for preds, ann in zip(prediction_lines, lines): + ann = ann['annotation'] + preds = preds['prediction'] + for pred in preds: + if pred['start_idx'] == ann['start_idx'] and pred['end_idx'] == ann['end_idx']: + n_span_match += 1 + if pred['label'] == ann['label']: + n_pos += 1 + + print(f'exact match precision: {n_pos * 100 / n_total:.3f}') + print(f'span only precision: {n_span_match * 100 / n_total:.3f}') diff --git a/spanfinder/scripts/archive/eval_tie.py b/spanfinder/scripts/archive/eval_tie.py new file mode 100644 index 0000000000000000000000000000000000000000..69747434f01156995b756afc0dc59c556029461e --- /dev/null +++ b/spanfinder/scripts/archive/eval_tie.py @@ -0,0 +1,50 @@ +import os +import sys +import json +from pprint import pprint +from collections import defaultdict + +from sftp.metrics.exact_match import ExactMatch + + +def evaluate(): + em = ExactMatch(True) + sm = ExactMatch(False) + gold_file, pred_file = sys.argv[1:] + test_sentences = {json.loads(line)['meta']['sentence ID']: json.loads(line) for line in open(gold_file).readlines()} + pred_sentences = defaultdict(list) + for line in open(pred_file).readlines(): + one_pred = json.loads(line) + pred_sentences[one_pred['meta']['sentence ID']].append(one_pred) + for sent_id, gold_sent in test_sentences.items(): + pred_sent = pred_sentences.get(sent_id, []) + pred_frames, pred_fes = [], [] + for fr_idx, fr in enumerate(pred_sent): + pred_frames.append({key: fr[key] for key in ["start_idx", "end_idx", "label"]}) + pred_frames[-1]['parent'] = 0 + for fe in fr['children']: + pred_fes.append({key: fe[key] for key in ["start_idx", "end_idx", "label"]}) + pred_fes[-1]['parent'] = fr_idx+1 + pred_to_eval = pred_frames + pred_fes + + gold_frames, gold_fes = [], [] + for fr_idx, fr in enumerate(gold_sent['frame']): + gold_frames.append({ + 'start_idx': fr['target'][0], 'end_idx': fr['target'][-1], "label": fr['name'], 'parent': 0 + }) + for start_idx, end_idx, fe_name in fr['fe']: + gold_fes.append({ + "start_idx": start_idx, "end_idx": end_idx, "label": fe_name, "parent": fr_idx+1 + }) + gold_to_eval = gold_frames + gold_fes + em(pred_to_eval, gold_to_eval) + sm(pred_to_eval, gold_to_eval) + + print('EM') + pprint(em.get_metric(True)) + print('SM') + pprint(sm.get_metric(True)) + + +if __name__ == '__main__': + evaluate() diff --git a/spanfinder/scripts/archive/frame_similarity.py b/spanfinder/scripts/archive/frame_similarity.py new file mode 100644 index 0000000000000000000000000000000000000000..537707c9873c75357fe416ae93fc93ab2f64343d --- /dev/null +++ b/spanfinder/scripts/archive/frame_similarity.py @@ -0,0 +1,143 @@ +from argparse import ArgumentParser +from collections import defaultdict + +from torch import nn +from copy import deepcopy +import torch +import os +import json + +from sftp import SpanPredictor +import nltk + + +def shift_grid_cos_sim(mat: torch.Tensor): + mat1 = mat.unsqueeze(0).expand(mat.shape[0], -1, -1) + mat2 = mat.unsqueeze(1).expand(-1, mat.shape[0], -1) + cos = nn.CosineSimilarity(2) + sim = (cos(mat1, mat2) + 1) / 2 + return sim + + +def all_frames(): + nltk.download('framenet_v17') + fn = nltk.corpus.framenet + return fn.frames() + + +def extract_relations(fr): + ret = list() + added = {fr.name} + for rel in fr.frameRelations: + for key in ['subFrameName', 'superFrameName']: + rel_fr_name = rel[key] + if rel_fr_name in added: + continue + ret.append((rel_fr_name, key[:-4])) + return ret + + +def run(): + parser = ArgumentParser() + parser.add_argument('archive', metavar='ARCHIVE_PATH', type=str) + parser.add_argument('dst', metavar='DESTINATION', type=str) + parser.add_argument('kairos', metavar='KAIROS', type=str) + parser.add_argument('--topk', metavar='TOPK', type=int, default=10) + args = parser.parse_args() + + predictor = SpanPredictor.from_path(args.archive, cuda_device=-1) + kairos_gold_mapping = json.load(open(args.kairos)) + + label_emb = predictor._model._span_typing.label_emb.weight.clone().detach() + idx2label = predictor._model.vocab.get_index_to_token_vocabulary('span_label') + + emb_sim = shift_grid_cos_sim(label_emb) + fr2definition = {fr.name: (fr.URL, fr.definition) for fr in all_frames()} + + last_mlp = predictor._model._span_typing.MLPs[-1].weight.detach().clone() + mlp_sim = shift_grid_cos_sim(last_mlp) + + def rank_frame(sim): + rank = sim.argsort(1, True) + scores = sim.gather(1, rank) + mapping = { + fr.name: { + 'similarity': list(), + 'ontology': extract_relations(fr), + 'URL': fr.URL, + 'definition': fr.definition + } for fr in all_frames() + } + for left_idx, (right_indices, match_scores) in enumerate(zip(rank, scores)): + left_label = idx2label[left_idx] + if left_label not in mapping: + continue + for right_idx, s in zip(right_indices, match_scores): + right_label = idx2label[int(right_idx)] + if right_label not in mapping or right_idx == left_idx: + continue + mapping[left_label]['similarity'].append((right_label, float(s))) + return mapping + + emb_map = rank_frame(emb_sim) + mlp_map = rank_frame(mlp_sim) + + def dump(mapping, folder_path): + os.makedirs(folder_path, exist_ok=True) + json.dump(mapping, open(os.path.join(folder_path, 'raw.json'), 'w')) + sim_lines, onto_lines = list(), list() + + for fr, values in mapping.items(): + sim_line = [ + fr, + values['definition'], + values['URL'], + ] + onto_line = deepcopy(sim_line) + for rel_fr_name, rel_type in values['ontology']: + onto_line.append(f'{rel_fr_name} ({rel_type})') + onto_lines.append('\t'.join(onto_line)) + if len(values['similarity']) > 0: + for sim_fr_name, score in values['similarity'][:args.topk]: + sim_line.append(f'{sim_fr_name} ({score:.3f})') + sim_lines.append('\t'.join(sim_line)) + + with open(os.path.join(folder_path, 'similarity.tsv'), 'w') as fp: + fp.write('\n'.join(sim_lines)) + with open(os.path.join(folder_path, 'ontology.tsv'), 'w') as fp: + fp.write('\n'.join(onto_lines)) + + kairos_dump = list() + for kairos_event, kairos_content in kairos_gold_mapping.items(): + for gold_fr in kairos_content['framenet']: + gold_fr = gold_fr['label'] + if gold_fr not in fr2definition: + continue + kairos_dump.append([ + 'GOLD', + gold_fr, + kairos_event, + fr2definition[gold_fr][0], + fr2definition[gold_fr][1], + str(kairos_content['description']), + '1.00' + ]) + for ass_fr, sim_score in mapping[gold_fr]['similarity'][:args.topk]: + kairos_dump.append([ + '', + ass_fr, + kairos_event, + fr2definition[ass_fr][0], + fr2definition[ass_fr][1], + str(kairos_content['description']), + f'{sim_score:.2f}' + ]) + kairos_dump = list(map(lambda line: '\t'.join(line), kairos_dump)) + open(os.path.join(folder_path, 'kairos_sheet.tsv'), 'w').write('\n'.join(kairos_dump)) + + dump(mlp_map, os.path.join(args.dst, 'mlp')) + dump(emb_map, os.path.join(args.dst, 'emb')) + + +if __name__ == '__main__': + run() diff --git a/spanfinder/scripts/archive/kairos_mapping.py b/spanfinder/scripts/archive/kairos_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..6ef42de80fdb2d67aa7850d7d5e3988f628ca5c0 --- /dev/null +++ b/spanfinder/scripts/archive/kairos_mapping.py @@ -0,0 +1,43 @@ +import argparse +import os +import json + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('map', metavar='MappingFile', type=str, help="Mapping JSON file.") + parser.add_argument('src', metavar='SourceFile', type=str, help="Results of span finder.") + parser.add_argument('dst', metavar='Destination', type=str, help="Output path.") + args = parser.parse_args() + assert os.path.exists(args.map), "Mapping file doesn't exist." + assert os.path.exists(args.src), "Rouce file not found." + + k_raw = json.load(open(args.map)) + k_map = dict() + for kairos_event, content in k_raw.items(): + for fr in content['framenet']: + if fr['label'] in k_map: + print("Duplicate frame: " + fr['label']) + k_map[fr['label']] = kairos_event + inputs = list(map(json.loads, open(args.src).readlines())) + + n_total = n_mapped = 0 + + for line in inputs: + new_frames = list() + n_total += len(line['prediction']) + for fr in line['prediction']: + if fr['label'] in k_map: + fr['label'] = k_map[fr['label']] + new_frames.append(fr) + n_mapped += 1 + line['prediction'] = new_frames + + with open(args.dst, 'w') as fp: + fp.write('\n'.join(map(json.dumps, inputs))) + + print(f'Done. Among {n_total} frames, {n_mapped} are mapped to KAIROS ontology, others are omitted.') + + +if __name__ == '__main__': + main() diff --git a/spanfinder/scripts/archive/onto_test.py b/spanfinder/scripts/archive/onto_test.py new file mode 100644 index 0000000000000000000000000000000000000000..195b4e6aaac327f7af22561d10f2f128ab1dd4a4 --- /dev/null +++ b/spanfinder/scripts/archive/onto_test.py @@ -0,0 +1,34 @@ +import json +from tools.framenet.naive_identifier import FrameIdentifier + +test_file_path = '/home/gqin2/data/framenet/full/test.jsonl' +test_sentences = [ + json.loads(line) for line in open(test_file_path) +] +test_set = [] +for ann in test_sentences: + for fr in ann['frame']: + test_set.append((fr['name'], ann['text'][fr['target'][0]: fr['target'][-1]+1], fr['lu'])) + +fi = FrameIdentifier() + + +tp = fp = fn = 0 +fails = [] +for frame, target_words, lu in test_set: + pred = fi(target_words) + if frame in pred: + tp += 1 + fp += len(pred) - 1 + else: + fp += len(pred) + fn += 1 + fails.append((frame, target_words, pred, lu)) + +fails.sort(key=lambda x: x[0]) +for frame, target_words, pred, lu in fails: + print(frame, ' '.join(target_words), ' '.join(pred), lu, sep='\t') + +print(f'tp={tp}, fp={fp}, fn={fn}') +print(f'precision={tp/(tp+fp)}') +print(f'recall={tp/(tp+fn)}') diff --git a/spanfinder/scripts/archive/predict_better.py b/spanfinder/scripts/archive/predict_better.py new file mode 100644 index 0000000000000000000000000000000000000000..6061c301db87088bfa7a195d28831800f3c42450 --- /dev/null +++ b/spanfinder/scripts/archive/predict_better.py @@ -0,0 +1,47 @@ +from typing import * +import torch +import json +import argparse +import os +from tqdm import tqdm + +from sftp.predictor import SpanPredictor +from sftp.models import SpanModel +from sftp.data_reader import BetterDatasetReader + + +def predict_doc(predictor, json_path: str): + src = json.load(open(json_path)) + for doc_name, entry in tqdm(list(src['entries'].items())): + pred = predictor.predict_json(entry) + triggers = list() + for trigger in pred['prediction']: + children = list() + for child in trigger['children']: + children.append([child['start_idx'], child['end_idx']]) + triggers.append({ + "span": [trigger['start_idx'], trigger['end_idx']], + "argument": children + }) + entry['trigger span'] = triggers + return src + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-a', type=str, help='archive path') + parser.add_argument('-s', type=str, help='source path') + parser.add_argument('-d', type=str, help='destination path') + parser.add_argument('-c', type=int, default=0, help='cuda device') + args = parser.parse_args() + predictor_ = SpanPredictor.from_path(os.path.join(args.a, 'model.tar.gz'), 'span', cuda_device=args.c) + model_name = os.path.basename(args.a) + tgt_path = os.path.join(args.d, model_name) + os.makedirs(tgt_path, exist_ok=True) + for root, _, files in os.walk(args.s): + for fn in files: + if not fn.endswith('json') and not fn.endswith('valid'): + continue + processed_json = predict_doc(predictor_, os.path.join(root, fn)) + with open(os.path.join(tgt_path, fn), 'w') as fp: + json.dump(processed_json, fp) diff --git a/spanfinder/scripts/archive/predict_kairos.py b/spanfinder/scripts/archive/predict_kairos.py new file mode 100644 index 0000000000000000000000000000000000000000..d913131d52111894c73e71e54e2db9d6e12c8f8e --- /dev/null +++ b/spanfinder/scripts/archive/predict_kairos.py @@ -0,0 +1,98 @@ +import os +import argparse +from xml.etree import ElementTree +import copy +from operator import attrgetter +import json +import logging + +from sftp import SpanPredictor + + +def predict_kairos(model_archive, source_folder, onto_map): + xml_files = list() + for root, _, files in os.walk(source_folder): + for f in files: + if f.endswith('.xml'): + xml_files.append(os.path.join(root, f)) + logging.info(f'{len(xml_files)} files are found:') + for fn in xml_files: + logging.info(' - ' + fn) + + logging.info('Loading ontology from ' + onto_map) + k_map = dict() + for kairos_event, content in json.load(open(onto_map)).items(): + for fr in content['framenet']: + if fr['label'] in k_map: + logging.info("Duplicate frame: " + fr['label']) + k_map[fr['label']] = kairos_event + + logging.info('Loading model from ' + model_archive + ' ...') + predictor = SpanPredictor.from_path(model_archive) + + predictions = list() + + for fn in xml_files: + logging.info('Now processing ' + os.path.basename(fn)) + tree = ElementTree.parse(fn).getroot() + for doc in tree: + doc_meta = copy.deepcopy(doc.attrib) + text = list(doc)[0] + for seg in text: + seg_meta = copy.deepcopy(doc_meta) + seg_meta['seg'] = copy.deepcopy(seg.attrib) + tokens = [child for child in seg if child.tag == 'TOKEN'] + tokens.sort(key=lambda t: t.attrib['start_char']) + words = list(map(attrgetter('text'), tokens)) + one_pred = predictor.predict_sentence(words) + one_pred['meta'] = seg_meta + + new_frames = list() + for fr in one_pred['prediction']: + if fr['label'] in k_map: + fr['label'] = k_map[fr['label']] + new_frames.append(fr) + one_pred['prediction'] = new_frames + + predictions.append(one_pred) + + logging.info('Finished Prediction.') + + return predictions + + +def do_task(input_dir, model_archive, onto_map): + """ + This function is called by the KAIROS infrastructure code for each + TASK1 input. + """ + + return predict_kairos(model_archive=model_archive, + source_folder=input_dir, + onto_map=onto_map) + + +def run(): + parser = argparse.ArgumentParser(description='Span Finder for KAIROS Quizlet4\n') + parser.add_argument('model_archive', metavar='MODEL_ARCHIVE', type=str, help='Path to model archive file.') + parser.add_argument('source_folder', metavar='SOURCE_FOLDER', type=str, help='Path to the folder that contains the XMLs.') + parser.add_argument('onto_map', metavar='ONTO_MAP', type=str, help='Path to the ontology JSON.') + parser.add_argument('destination', metavar='DESTINATION', type=str, help='Output path. (jsonl file path)') + args = parser.parse_args() + + logging.basicConfig(level='INFO', format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s") + + predictions = predict_kairos(model_archive=args.model_archive, + source_folder=args.source_folder, + onto_map=args.onto_map) + + logging.info('Saving to ' + args.destination + ' ...') + os.makedirs(os.path.dirname(args.destination), exist_ok=True) + with open(args.destination, 'w') as fp: + fp.write('\n'.join(map(json.dumps, predictions))) + + logging.info('Done.') + + +if __name__ == '__main__': + run() diff --git a/spanfinder/scripts/fn_eval/frame_id.py b/spanfinder/scripts/fn_eval/frame_id.py new file mode 100644 index 0000000000000000000000000000000000000000..7455f987a65db00d8ceddc71061ef8a543b37812 --- /dev/null +++ b/spanfinder/scripts/fn_eval/frame_id.py @@ -0,0 +1,52 @@ +import json +from argparse import ArgumentParser +from collections import defaultdict +import numpy as np + +from tqdm import tqdm +from nltk.corpus import framenet as fn + +from sftp import SpanPredictor + + +def run(model_path, data_path, device, use_ontology=False): + data = list(map(json.loads, open(data_path).readlines())) + lu2frame = defaultdict(list) + for lu in fn.lus(): + lu2frame[lu.name].append(lu.frame.name) + predictor = SpanPredictor.from_path(model_path, cuda_device=device) + frame2idx = predictor._model.vocab.get_token_to_index_vocabulary('span_label') + all_frames = [fr.name for fr in fn.frames()] + n_positive = n_total = 0 + with tqdm(total=len(data)) as bar: + for sent in data: + bar.update() + for point in sent['annotations']: + model_output = predictor.force_decode( + sent['tokens'], child_spans=[(point['span'][0], point['span'][-1])] + ).distribution[0] + if use_ontology: + candidate_frames = lu2frame[point['lu']] + else: + candidate_frames = all_frames + candidate_prob = [-1.0 for _ in candidate_frames] + for idx_can, fr in enumerate(candidate_frames): + if fr in frame2idx: + candidate_prob[idx_can] = model_output[frame2idx[fr]] + if len(candidate_prob) > 0: + pred_frame = candidate_frames[int(np.argmax(candidate_prob))] + if pred_frame == point['label']: + n_positive += 1 + n_total += 1 + bar.set_description(f'acc={n_positive/n_total*100:.3f}') + print(f'acc={n_positive/n_total*100:.3f}') + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('model', metavar="MODEL") + parser.add_argument('data', metavar="DATA") + parser.add_argument('-d', default=-1, type=int, help='Device') + parser.add_argument('-o', action='store_true', help='Flag to use ontology.') + args = parser.parse_args() + run(args.model, args.data, args.d, args.o) diff --git a/spanfinder/scripts/gen_fn_constraints.py b/spanfinder/scripts/gen_fn_constraints.py new file mode 100644 index 0000000000000000000000000000000000000000..d9f4b13718f9d68b3897236a5529e07a516d9166 --- /dev/null +++ b/spanfinder/scripts/gen_fn_constraints.py @@ -0,0 +1,19 @@ +from nltk.corpus import framenet +from sftp.utils.common import VIRTUAL_ROOT +import os +import sys + + +output_path = sys.argv[1] +rules = [[VIRTUAL_ROOT]] + +for fr in framenet.frames(): + rules[0].append(fr.name) + new_rule = [fr.name] + for fe in fr.FE: + new_rule.append(fe) + rules.append(new_rule) + +os.makedirs(os.path.dirname(output_path), exist_ok=True) +with open(output_path, 'w') as fp: + fp.write('\n'.join(['\t'.join(r) for r in rules])) diff --git a/spanfinder/scripts/predict_concrete.py b/spanfinder/scripts/predict_concrete.py new file mode 100644 index 0000000000000000000000000000000000000000..87bdc528e4c260f22694160c17b944b458e820a1 --- /dev/null +++ b/spanfinder/scripts/predict_concrete.py @@ -0,0 +1,40 @@ +from argparse import ArgumentParser +from typing import * +import json +import logging + +from sftp import SpanPredictor + +logger = logging.getLogger('ConcretePredictor') + + +def read_kairos(ontology_mapping_path: Optional[str] = None): + # Legacy. For the old mapping file only. + if ontology_mapping_path is None: + return + raw = json.load(open(ontology_mapping_path)) + fn2kairos = dict() + for kairos_label in raw: + for fn in raw[kairos_label]['framenet']: + fn_label = fn['label'] + if fn_label in fn2kairos: + logger.warning(f'"{fn_label}" is repeated in the ontology file.') + fn2kairos[fn_label] = kairos_label + return fn2kairos + + +def run(src, dst, model_path, ontology_mapping_path, device): + mapping = SpanPredictor.read_ontology_mapping(ontology_mapping_path) + predictor = SpanPredictor.from_path(model_path, cuda_device=device) + predictor.predict_concrete(src, dst, ontology_mapping=mapping) + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('src', type=str) + parser.add_argument('dst', type=str) + parser.add_argument('model', type=str) + parser.add_argument('--map', type=str, default=None) + parser.add_argument('--device', type=int, default=-1) + args = parser.parse_args() + run(args.src, args.dst, args.model, args.map, args.device) diff --git a/spanfinder/scripts/predict_force.py b/spanfinder/scripts/predict_force.py new file mode 100644 index 0000000000000000000000000000000000000000..8c3222b7bb4e9ac1a2a4e2b79dfc29be96df6bf6 --- /dev/null +++ b/spanfinder/scripts/predict_force.py @@ -0,0 +1,30 @@ +from sftp import SpanPredictor + + +def print_children(sentence, boundary, labels, _): + print('Sentence:', ' '.join(sentence)) + for (start_idx, end_idx), lbl in zip(boundary, labels): + print(' '.join(sentence[start_idx:end_idx+1]), ':', lbl) + print('='*20) + + +def example(): + print("Loading predictor...") + predictor = SpanPredictor.from_path( + #'/home/gqin2/public/release/sftp/0.0.2/framenet', + "/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz", + cuda_device=-1 + ) + + print("Predicting for sentence..") + sentence = ['Tom', 'eats', 'an', 'apple', 'and', 'he', 'wakes', 'up', '.'] + p1 = predictor.force_decode(sentence) + print_children(sentence, *p1) + p2 = predictor.force_decode(sentence, parent_span=(1, 1), parent_label='Ingestion') + print_children(sentence, *p2) + p3 = predictor.force_decode(sentence, child_spans=[(0, 0), (2, 3)], parent_span=(1, 1), parent_label='Ingestion') + print_children(sentence, *p3) + + +if __name__ == '__main__': + example() diff --git a/spanfinder/scripts/predict_span.py b/spanfinder/scripts/predict_span.py new file mode 100644 index 0000000000000000000000000000000000000000..e0b857ee83b2e9791ed5c8034ab8bfb60ba12f47 --- /dev/null +++ b/spanfinder/scripts/predict_span.py @@ -0,0 +1,40 @@ +from sftp import SpanPredictor + +# Specify the path to the model and the device that the model resides. +# Here we use -1 device, which indicates CPU. +predictor = SpanPredictor.from_path( + '/home/gqin2/public/release/sftp/0.0.2/framenet/model.tar.gz', # MODIFY THIS + cuda_device=-1, +) + +# Input sentence could be a string. It will be tokenized by SpacyTokenizer, and the tokens will be returned +# along with the predictions. +input1 = "Bob saw Alice eating an apple." +print("Example 1 with input:", input1) +output1 = predictor.predict_sentence(input1) +output1.span.tree(output1.sentence) + +# Input sentence might already be tokenized. In this situation, we'll respect the tokenization. +# The output will be based on the given tokens. +input2 = ["Bob", "saw", "Alice", "eating", "an", "apple", "."] +print('-'*20+"\nExample 2 with input:", input2) +output2 = predictor.predict_sentence(input2) +output2.span.tree(output2.sentence) + +# To be efficient, you can input all the sentences as a whole. +# Note: The predictor will do batching itself. +# Instead of specifying the batch size, you should specify `max_tokens`, which +# indicates the maximum tokens that could be put into one batch. +# The predictor will dynamically batch the input sentences efficiently, +# and the outputs will be in the same order as the inputs. +output3 = predictor.predict_batch_sentences([input1, input2], max_tokens=512) +print('-'*20+"\nExample 3 with both inputs:") +for i in range(2): + output3[i].span.tree(output3[i].sentence) + +# For SRL, we can limit the decoding depth if we only need the events prediction. (save 13% time) +# And can possibly limit #spans to speedup. +predictor.economize(max_decoding_spans=20, max_recursion_depth=1) +output4 = predictor.predict_batch_sentences([input2], max_tokens=512) +print('-'*20+"\nExample 4 with input:", input2) +output4[0].span.tree(output4[0].sentence) diff --git a/spanfinder/scripts/repl.py b/spanfinder/scripts/repl.py new file mode 100644 index 0000000000000000000000000000000000000000..94aaf6b0c76cecdcc3ec3e3b14c70e5c05c0acb6 --- /dev/null +++ b/spanfinder/scripts/repl.py @@ -0,0 +1,18 @@ +from sftp import SpanPredictor + + +predictor = SpanPredictor.from_path( + # '/home/gqin2/public/release/sftp/0.0.1/framenet/model.tar.gz', # For CLSP grid + "../model.mod.tar.gz", # For BRTX + cuda_device=-1 # Change this to 0 for CUDA:0 +) + + +while True: + sentence = input(">>>") + result = predictor.predict_sentence(sentence) + for child in result.span: + print(child) + for sub_child in child: + print("\t", sub_child) + diff --git a/spanfinder/setup.py b/spanfinder/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..7f82ec739678fbb386d74483b2ed9eac1d255629 --- /dev/null +++ b/spanfinder/setup.py @@ -0,0 +1,9 @@ +from setuptools import setup, find_packages + + +setup( + name='sftp', + version='0.0.2', + author='Guanghui Qin', + packages=find_packages(), +) diff --git a/spanfinder/sftp/__init__.py b/spanfinder/sftp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4d8532ecf10ba04d9280c47b5810edf61e1c76f0 --- /dev/null +++ b/spanfinder/sftp/__init__.py @@ -0,0 +1,10 @@ +from .data_reader import ( + BetterDatasetReader, SRLDatasetReader +) +from .metrics import SRLMetric, BaseF, ExactMatch, FBetaMixMeasure +from .models import SpanModel +from .modules import ( + MLPSpanTyping, SpanTyping, SpanFinder, BIOSpanFinder +) +from .predictor import SpanPredictor +from .utils import Span diff --git a/spanfinder/sftp/__pycache__/__init__.cpython-37.pyc b/spanfinder/sftp/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5b05696e9e3522c234980772d67beb001943537 Binary files /dev/null and b/spanfinder/sftp/__pycache__/__init__.cpython-37.pyc differ diff --git a/spanfinder/sftp/__pycache__/__init__.cpython-38.pyc b/spanfinder/sftp/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98eeae379ce399606f806fe00d98dd275036bf84 Binary files /dev/null and b/spanfinder/sftp/__pycache__/__init__.cpython-38.pyc differ diff --git a/spanfinder/sftp/__pycache__/__init__.cpython-39.pyc b/spanfinder/sftp/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac8d5ca7fafce2c76bda6b759ecdf8af44eb8cfe Binary files /dev/null and b/spanfinder/sftp/__pycache__/__init__.cpython-39.pyc differ diff --git a/spanfinder/sftp/data_reader/__init__.py b/spanfinder/sftp/data_reader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8072050fbd6abe0c3e2117f8a95bb2b7af63c5b9 --- /dev/null +++ b/spanfinder/sftp/data_reader/__init__.py @@ -0,0 +1,6 @@ +from .batch_sampler import MixSampler +from .better_reader import BetterDatasetReader +from .span_reader import SpanReader +from .srl_reader import SRLDatasetReader +from .concrete_srl import concrete_doc, concrete_doc_tokenized, collect_concrete_srl +from .concrete_reader import ConcreteDatasetReader diff --git a/spanfinder/sftp/data_reader/__pycache__/__init__.cpython-37.pyc b/spanfinder/sftp/data_reader/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74cb625005c07a70c865db1eba82ad7827f5c981 Binary files /dev/null and b/spanfinder/sftp/data_reader/__pycache__/__init__.cpython-37.pyc differ diff --git a/spanfinder/sftp/data_reader/__pycache__/__init__.cpython-38.pyc b/spanfinder/sftp/data_reader/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f79bb7a3637ebc895ac1c8c8624558d3934758f Binary files /dev/null and b/spanfinder/sftp/data_reader/__pycache__/__init__.cpython-38.pyc differ diff --git a/spanfinder/sftp/data_reader/__pycache__/__init__.cpython-39.pyc b/spanfinder/sftp/data_reader/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f3a5a37abea90df28efff5cd4c8907d03eb3161 Binary files /dev/null and b/spanfinder/sftp/data_reader/__pycache__/__init__.cpython-39.pyc differ diff --git a/spanfinder/sftp/data_reader/__pycache__/better_reader.cpython-37.pyc b/spanfinder/sftp/data_reader/__pycache__/better_reader.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbad06fae11191301ebc418741e48345156a08dd Binary files /dev/null and b/spanfinder/sftp/data_reader/__pycache__/better_reader.cpython-37.pyc differ diff --git a/spanfinder/sftp/data_reader/__pycache__/better_reader.cpython-38.pyc b/spanfinder/sftp/data_reader/__pycache__/better_reader.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4ee983a10da2e6fb7d1f6b244585ad6855a8b69 Binary files /dev/null and b/spanfinder/sftp/data_reader/__pycache__/better_reader.cpython-38.pyc differ diff --git a/spanfinder/sftp/data_reader/__pycache__/better_reader.cpython-39.pyc b/spanfinder/sftp/data_reader/__pycache__/better_reader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e82cfb435ec1325b998b3badf7f02ae3621a367 Binary files /dev/null and b/spanfinder/sftp/data_reader/__pycache__/better_reader.cpython-39.pyc differ diff --git a/spanfinder/sftp/data_reader/__pycache__/concrete_reader.cpython-38.pyc b/spanfinder/sftp/data_reader/__pycache__/concrete_reader.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7d90c9ec9061c1a5ee0263e87f64d9c354af95d Binary files /dev/null and b/spanfinder/sftp/data_reader/__pycache__/concrete_reader.cpython-38.pyc differ diff --git a/spanfinder/sftp/data_reader/__pycache__/concrete_reader.cpython-39.pyc b/spanfinder/sftp/data_reader/__pycache__/concrete_reader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f83648c16f675f415a79b5820bf29f37ede7dfd7 Binary files /dev/null and b/spanfinder/sftp/data_reader/__pycache__/concrete_reader.cpython-39.pyc differ diff --git a/spanfinder/sftp/data_reader/__pycache__/concrete_srl.cpython-37.pyc b/spanfinder/sftp/data_reader/__pycache__/concrete_srl.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f77376261bfa8060b73bb0c49a8f7b0f6c2b2c36 Binary files /dev/null and b/spanfinder/sftp/data_reader/__pycache__/concrete_srl.cpython-37.pyc differ diff --git a/spanfinder/sftp/data_reader/__pycache__/concrete_srl.cpython-38.pyc b/spanfinder/sftp/data_reader/__pycache__/concrete_srl.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40b0da9195dc4a824b4560e5fc8ccb096b18e44f Binary files /dev/null and b/spanfinder/sftp/data_reader/__pycache__/concrete_srl.cpython-38.pyc differ diff --git a/spanfinder/sftp/data_reader/__pycache__/concrete_srl.cpython-39.pyc b/spanfinder/sftp/data_reader/__pycache__/concrete_srl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6dffa267f6f345e7493b6a10b10d9f142b74f26 Binary files /dev/null and b/spanfinder/sftp/data_reader/__pycache__/concrete_srl.cpython-39.pyc differ diff --git a/spanfinder/sftp/data_reader/__pycache__/span_reader.cpython-37.pyc b/spanfinder/sftp/data_reader/__pycache__/span_reader.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66827c27e692c23ae05c23e7362431a0664c0cbe Binary files /dev/null and b/spanfinder/sftp/data_reader/__pycache__/span_reader.cpython-37.pyc differ diff --git a/spanfinder/sftp/data_reader/__pycache__/span_reader.cpython-38.pyc b/spanfinder/sftp/data_reader/__pycache__/span_reader.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fab06ee057827283ddf49641af1bfc9266aa9134 Binary files /dev/null and b/spanfinder/sftp/data_reader/__pycache__/span_reader.cpython-38.pyc differ diff --git a/spanfinder/sftp/data_reader/__pycache__/span_reader.cpython-39.pyc b/spanfinder/sftp/data_reader/__pycache__/span_reader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6e930e2686b2144ea34c9c63a63b099a46a8eeb Binary files /dev/null and b/spanfinder/sftp/data_reader/__pycache__/span_reader.cpython-39.pyc differ diff --git a/spanfinder/sftp/data_reader/__pycache__/srl_reader.cpython-37.pyc b/spanfinder/sftp/data_reader/__pycache__/srl_reader.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e43e9bda2735d953d9ba3fc12d52245207f94836 Binary files /dev/null and b/spanfinder/sftp/data_reader/__pycache__/srl_reader.cpython-37.pyc differ diff --git a/spanfinder/sftp/data_reader/__pycache__/srl_reader.cpython-38.pyc b/spanfinder/sftp/data_reader/__pycache__/srl_reader.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ea8a259e92d0c0cd1f7f1948253fbfad1ef4325 Binary files /dev/null and b/spanfinder/sftp/data_reader/__pycache__/srl_reader.cpython-38.pyc differ diff --git a/spanfinder/sftp/data_reader/__pycache__/srl_reader.cpython-39.pyc b/spanfinder/sftp/data_reader/__pycache__/srl_reader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08256a7f4bdaaac3cebcd05fe10a3ddc874f3dbd Binary files /dev/null and b/spanfinder/sftp/data_reader/__pycache__/srl_reader.cpython-39.pyc differ diff --git a/spanfinder/sftp/data_reader/batch_sampler/__init__.py b/spanfinder/sftp/data_reader/batch_sampler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7f773dff5885a94aa3558ed2fda8940dbab0ef0 --- /dev/null +++ b/spanfinder/sftp/data_reader/batch_sampler/__init__.py @@ -0,0 +1 @@ +from .mix_sampler import MixSampler diff --git a/spanfinder/sftp/data_reader/batch_sampler/__pycache__/__init__.cpython-37.pyc b/spanfinder/sftp/data_reader/batch_sampler/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5941d99b51af8a2cc1bebaf3da8800a99b11dbe Binary files /dev/null and b/spanfinder/sftp/data_reader/batch_sampler/__pycache__/__init__.cpython-37.pyc differ diff --git a/spanfinder/sftp/data_reader/batch_sampler/__pycache__/__init__.cpython-38.pyc b/spanfinder/sftp/data_reader/batch_sampler/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1f8237f8373562934228591ccc7c063d88cbf3c Binary files /dev/null and b/spanfinder/sftp/data_reader/batch_sampler/__pycache__/__init__.cpython-38.pyc differ diff --git a/spanfinder/sftp/data_reader/batch_sampler/__pycache__/__init__.cpython-39.pyc b/spanfinder/sftp/data_reader/batch_sampler/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6175a24c94ce5fd844e904a51b05f29c098212de Binary files /dev/null and b/spanfinder/sftp/data_reader/batch_sampler/__pycache__/__init__.cpython-39.pyc differ diff --git a/spanfinder/sftp/data_reader/batch_sampler/__pycache__/mix_sampler.cpython-37.pyc b/spanfinder/sftp/data_reader/batch_sampler/__pycache__/mix_sampler.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..762e5b54ec4fd929283eb84251d4a94010b6a56a Binary files /dev/null and b/spanfinder/sftp/data_reader/batch_sampler/__pycache__/mix_sampler.cpython-37.pyc differ diff --git a/spanfinder/sftp/data_reader/batch_sampler/__pycache__/mix_sampler.cpython-38.pyc b/spanfinder/sftp/data_reader/batch_sampler/__pycache__/mix_sampler.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0474e3188c16f0c64551ef7b610197f0bbffdcc Binary files /dev/null and b/spanfinder/sftp/data_reader/batch_sampler/__pycache__/mix_sampler.cpython-38.pyc differ diff --git a/spanfinder/sftp/data_reader/batch_sampler/__pycache__/mix_sampler.cpython-39.pyc b/spanfinder/sftp/data_reader/batch_sampler/__pycache__/mix_sampler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53ce3fb2c7cec9cbc161db8ffacaff48763a122e Binary files /dev/null and b/spanfinder/sftp/data_reader/batch_sampler/__pycache__/mix_sampler.cpython-39.pyc differ diff --git a/spanfinder/sftp/data_reader/batch_sampler/mix_sampler.py b/spanfinder/sftp/data_reader/batch_sampler/mix_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..7b26bf77a5cf6131943c25cf4e03a2fbd74db739 --- /dev/null +++ b/spanfinder/sftp/data_reader/batch_sampler/mix_sampler.py @@ -0,0 +1,50 @@ +import logging +import random +from typing import * + +from allennlp.data.samplers.batch_sampler import BatchSampler +from allennlp.data.samplers.max_tokens_batch_sampler import MaxTokensBatchSampler +from torch.utils import data + +logger = logging.getLogger('mix_sampler') + + +@BatchSampler.register('mix_sampler') +class MixSampler(MaxTokensBatchSampler): + def __init__( + self, + max_tokens: int, + sorting_keys: List[str] = None, + padding_noise: float = 0.1, + sampling_ratios: Optional[Dict[str, float]] = None, + ): + super().__init__(max_tokens, sorting_keys, padding_noise) + + self.sampling_ratios = sampling_ratios or dict() + + def __iter__(self): + indices, lengths = self._argsort_by_padding(self.data_source) + + original_num = len(indices) + instance_types = [ + ins.fields['meta'].metadata.get('type', 'default') if 'meta' in ins.fields else 'default' + for ins in self.data_source + ] + instance_thresholds = [ + self.sampling_ratios[ins_type] if ins_type in self.sampling_ratios else 1.0 for ins_type in instance_types + ] + for idx, threshold in enumerate(instance_thresholds): + if random.random() > threshold: + # Reject + list_idx = indices.index(idx) + del indices[list_idx], lengths[list_idx] + if original_num != len(indices): + logger.info(f'#instances reduced from {original_num} to {len(indices)}.') + + max_lengths = [max(length) for length in lengths] + group_iterator = self._lazy_groups_of_max_size(indices, max_lengths) + + batches = [list(group) for group in group_iterator] + random.shuffle(batches) + for batch in batches: + yield batch diff --git a/spanfinder/sftp/data_reader/better_reader.py b/spanfinder/sftp/data_reader/better_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..578a564930d803a8ac1f09f0f73541f56b61372b --- /dev/null +++ b/spanfinder/sftp/data_reader/better_reader.py @@ -0,0 +1,286 @@ +import json +import logging +import os +from collections import defaultdict, namedtuple +from typing import * + +from allennlp.data.dataset_readers.dataset_reader import DatasetReader +from allennlp.data.instance import Instance + +from .span_reader import SpanReader +from ..utils import Span + +# logging.basicConfig(level=logging.DEBUG) + +# for v in logging.Logger.manager.loggerDict.values(): +# v.disabled = True + +logger = logging.getLogger(__name__) + +SpanTuple = namedtuple('Span', ['start', 'end']) + + +@DatasetReader.register('better') +class BetterDatasetReader(SpanReader): + def __init__( + self, + eval_type, + consolidation_strategy='first', + span_set_type='single', + max_argument_ss_size=1, + use_ref_events=False, + **extra + ): + super().__init__(**extra) + self.eval_type = eval_type + assert self.eval_type in ['abstract', 'basic'] + + self.consolidation_strategy = consolidation_strategy + self.unitary_spans = span_set_type == 'single' + # event anchors are always singleton spans + self.max_arg_spans = max_argument_ss_size + self.use_ref_events = use_ref_events + + self.n_overlap_arg = 0 + self.n_overlap_trigger = 0 + self.n_skip = 0 + self.n_too_long = 0 + + @staticmethod + def post_process_basic_span(predicted_span, basic_entry): + # Convert token offsets back to characters, also get the text spans as a sanity check + + # !!!!! + # SF outputs inclusive idxs + # char offsets are inc-exc + # token offsets are inc-inc + # !!!!! + + start_idx = predicted_span['start_idx'] # inc + end_idx = predicted_span['end_idx'] # inc + + char_start_idx = basic_entry['tok2char'][predicted_span['start_idx']][0] # inc + char_end_idx = basic_entry['tok2char'][predicted_span['end_idx']][-1] + 1 # exc + + span_text = basic_entry['segment-text'][char_start_idx:char_end_idx] # inc exc + span_text_tok = basic_entry['segment-text-tok'][start_idx:end_idx + 1] # inc exc + + span = {'string': span_text, + 'start': char_start_idx, + 'end': char_end_idx, + 'start-token': start_idx, + 'end-token': end_idx, + 'string-tok': span_text_tok, + 'label': predicted_span['label'], + 'predicted': True} + return span + + @staticmethod + def _get_shortest_span(spans): + # shortest_span_length = float('inf') + # shortest_span = None + # for span in spans: + # span_tokens = span['string-tok'] + # span_length = len(span_tokens) + # if span_length < shortest_span_length: + # shortest_span_length = span_length + # shortest_span = span + + # return shortest_span + return [s[-1] for s in sorted([(len(span['string']), ix, span) for ix, span in enumerate(spans)])] + + @staticmethod + def _get_first_span(spans): + spans = [(span['start'], -len(span['string']), ix, span) for ix, span in enumerate(spans)] + try: + return [s[-1] for s in sorted(spans)] + except: + breakpoint() + + @staticmethod + def _get_longest_span(spans): + return [s[-1] for s in sorted([(len(span['string']), ix, span) for ix, span in enumerate(spans)], reverse=True)] + + @staticmethod + def _subfinder(text, pattern): + # https://stackoverflow.com/a/12576755 + matches = [] + pattern_length = len(pattern) + for i, token in enumerate(text): + try: + if token == pattern[0] and text[i:i + pattern_length] == pattern: + matches.append(SpanTuple(start=i, end=i + pattern_length - 1)) # inclusive boundaries + except: + continue + return matches + + def consolidate_span_set(self, spans): + if self.consolidation_strategy == 'first': + spans = BetterDatasetReader._get_first_span(spans) + elif self.consolidation_strategy == 'shortest': + spans = BetterDatasetReader._get_shortest_span(spans) + elif self.consolidation_strategy == 'longest': + spans = BetterDatasetReader._get_longest_span(spans) + else: + raise NotImplementedError(f"{self.consolidation_strategy} does not exist") + + if self.unitary_spans: + spans = [spans[0]] + else: + spans = spans[:self.max_arg_spans] + + # TODO add some sanity checks here + + return spans + + def get_mention_spans(self, text: List[str], span_sets: Dict): + mention_spans = defaultdict(list) + for span_set_id in span_sets.keys(): + spans = span_sets[span_set_id]['spans'] + # span = BetterDatasetReader._get_shortest_span(spans) + # span = BetterDatasetReader._get_earliest_span(spans) + consolidated_spans = self.consolidate_span_set(spans) + # if len(spans) > 1: + # logging.info(f"Truncated a spanset from {len(spans)} spans to 1") + + if self.eval_type == 'abstract': + span = consolidated_spans[0] + span_tokens = span['string-tok'] + + span_indices = BetterDatasetReader._subfinder(text=text, pattern=span_tokens) + + if len(span_indices) > 1: + pass + + if len(span_indices) == 0: + continue + + mention_spans[span_set_id] = span_indices[0] + else: + # in basic, we already have token offsets in the right form + + # if not span['string-tok'] == text[span['start-token']:span['end-token'] + 1]: + # print(span, text[span['start-token']:span['end-token'] + 1]) + + # we should use these token offsets only! + for span in consolidated_spans: + mention_spans[span_set_id].append(SpanTuple(start=span['start-token'], end=span['end-token'])) + + return mention_spans + + def _read_single_file(self, file_path): + with open(file_path) as fp: + json_content = json.load(fp) + if 'entries' in json_content: + for doc_name, entry in json_content['entries'].items(): + instance = self.text_to_instance(entry, 'train' in file_path) + yield instance + else: # TODO why is this split in 2 cases? + for doc_name, entry in json_content.items(): + instance = self.text_to_instance(entry, True) + yield instance + + logger.warning(f'{self.n_overlap_arg} overlapped args detected!') + logger.warning(f'{self.n_overlap_trigger} overlapped triggers detected!') + logger.warning(f'{self.n_skip} skipped detected!') + logger.warning(f'{self.n_too_long} were skipped because they are too long!') + self.n_overlap_arg = self.n_skip = self.n_too_long = self.n_overlap_trigger = 0 + + def _read(self, file_path: str) -> Iterable[Instance]: + + if os.path.isdir(file_path): + for fn in os.listdir(file_path): + if not fn.endswith('.json'): + logger.info(f'Skipping {fn}') + continue + logger.info(f'Loading from {fn}') + yield from self._read_single_file(os.path.join(file_path, fn)) + else: + yield from self._read_single_file(file_path) + + def text_to_instance(self, entry, is_training=False): + word_tokens = entry['segment-text-tok'] + + # span sets have been trimmed to the earliest span mention + spans = self.get_mention_spans( + word_tokens, entry['annotation-sets'][f'{self.eval_type}-events']['span-sets'] + ) + + # idx of every token that is a part of an event trigger/anchor span + all_trigger_idxs = set() + + # actual inputs to the model + input_spans = [] + + self._local_child_overlap = 0 + self._local_child_total = 0 + + better_events = entry['annotation-sets'][f'{self.eval_type}-events']['events'] + + skipped_events = set() + # check for events that overlap other event's anchors, skip them later + for event_id, event in better_events.items(): + assert event['anchors'] in spans + + # take the first consolidated span for anchors + anchor_start, anchor_end = spans[event['anchors']][0] + + if any(ix in all_trigger_idxs for ix in range(anchor_start, anchor_end + 1)): + logger.warning( + f"Skipped {event_id} with anchor span {event['anchors']}, overlaps a previously found event trigger/anchor") + self.n_overlap_trigger += 1 + skipped_events.add(event_id) + continue + + all_trigger_idxs.update(range(anchor_start, anchor_end + 1)) # record the trigger + + for event_id, event in better_events.items(): + if event_id in skipped_events: + continue + + # arguments for just this event + local_arg_idxs = set() + # take the first consolidated span for anchors + anchor_start, anchor_end = spans[event['anchors']][0] + + event_span = Span(anchor_start, anchor_end, event['event-type'], True) + input_spans.append(event_span) + + def add_a_child(span_id, label): + # TODO this is a bad way to do this + assert span_id in spans + for child_span in spans[span_id]: + self._local_child_total += 1 + arg_start, arg_end = child_span + + if any(ix in local_arg_idxs for ix in range(arg_start, arg_end + 1)): + # logger.warn(f"Skipped argument {span_id}, overlaps a previously found argument") + # print(entry['annotation-sets'][f'{self.eval_type}-events']['span-sets'][span_id]) + self.n_overlap_arg += 1 + self._local_child_overlap += 1 + continue + + local_arg_idxs.update(range(arg_start, arg_end + 1)) + event_span.add_child(Span(arg_start, arg_end, label, False)) + + for agent in event['agents']: + add_a_child(agent, 'agent') + for patient in event['patients']: + add_a_child(patient, 'patient') + + if self.use_ref_events: + for ref_event in event['ref-events']: + if ref_event in skipped_events: + continue + ref_event_anchor_id = better_events[ref_event]['anchors'] + add_a_child(ref_event_anchor_id, 'ref-event') + + # if len(event['ref-events']) > 0: + # breakpoint() + + fields = self.prepare_inputs(word_tokens, spans=input_spans) + if self._local_child_overlap > 0: + logging.warning( + f"Skipped {self._local_child_overlap} / {self._local_child_total} argument spans due to overlaps") + return Instance(fields) + diff --git a/spanfinder/sftp/data_reader/concrete_reader.py b/spanfinder/sftp/data_reader/concrete_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..0033d283bbeca217ae32a99c64796c5e3fef8f72 --- /dev/null +++ b/spanfinder/sftp/data_reader/concrete_reader.py @@ -0,0 +1,44 @@ +import logging +from collections import defaultdict +from typing import * +import os + +from allennlp.data.dataset_readers.dataset_reader import DatasetReader +from allennlp.data.instance import Instance +from concrete import SituationMention +from concrete.util import CommunicationReader + +from .span_reader import SpanReader +from .srl_reader import SRLDatasetReader +from .concrete_srl import collect_concrete_srl +from ..utils import Span, BIOSmoothing + +logger = logging.getLogger(__name__) + + +@DatasetReader.register('concrete') +class ConcreteDatasetReader(SRLDatasetReader): + def __init__( + self, + event_only: bool = False, + event_smoothing_factor: float = 0., + arg_smoothing_factor: float = 0., + **extra + ): + super().__init__(**extra) + self.event_only = event_only + self.event_only = event_only + self.event_smooth_factor = event_smoothing_factor + self.arg_smooth_factor = arg_smoothing_factor + + def _read(self, file_path: str) -> Iterable[Instance]: + if os.path.isdir(file_path): + for fn in os.listdir(file_path): + yield from self._read(os.path.join(file_path, fn)) + all_files = CommunicationReader(file_path) + for comm, fn in all_files: + sentences = collect_concrete_srl(comm) + for tokens, vr in sentences: + yield self.text_to_instance(tokens, vr) + logger.warning(f'{self.n_span_removed} spans were removed') + self.n_span_removed = 0 diff --git a/spanfinder/sftp/data_reader/concrete_srl.py b/spanfinder/sftp/data_reader/concrete_srl.py new file mode 100644 index 0000000000000000000000000000000000000000..0d87d24b671d6ed6244f977d7ac261e929e780cb --- /dev/null +++ b/spanfinder/sftp/data_reader/concrete_srl.py @@ -0,0 +1,169 @@ +from time import time +from typing import * +from collections import defaultdict + +from concrete import ( + Token, TokenList, TextSpan, MentionArgument, SituationMentionSet, SituationMention, TokenRefSequence, + Communication, EntityMention, EntityMentionSet, Entity, EntitySet, AnnotationMetadata, Sentence +) +from concrete.util import create_comm, AnalyticUUIDGeneratorFactory +from concrete.validate import validate_communication + +from ..utils import Span + + +def _process_sentence(sent, comm_sent, aug, char_idx_offset: int): + token_list = list() + for tok_idx, (start_idx, end_idx) in enumerate(sent['tokenization']): + token_list.append(Token( + tokenIndex=tok_idx, + text=sent['sentence'][start_idx:end_idx + 1], + textSpan=TextSpan( + start=start_idx + char_idx_offset, + ending=end_idx + char_idx_offset + 1 + ), + )) + comm_sent.tokenization.tokenList = TokenList(tokenList=token_list) + + sm_list, em_dict, entity_list = list(), dict(), list() + + annotation = sent['annotations'] if isinstance(sent['annotations'], Span) else Span.from_json(sent['annotations']) + for event in annotation: + char_start_idx = sent['tokenization'][event.start_idx][0] + char_end_idx = sent['tokenization'][event.end_idx][1] + sm = SituationMention( + uuid=next(aug), + text=sent['sentence'][char_start_idx: char_end_idx + 1], + situationType='EVENT', + situationKind=event.label, + argumentList=list(), + tokens=TokenRefSequence( + tokenIndexList=list(range(event.start_idx, event.end_idx + 1)), + tokenizationId=comm_sent.tokenization.uuid + ), + ) + + for arg in event: + em = em_dict.get((arg.start_idx, arg.end_idx + 1)) + if em is None: + char_start_idx = sent['tokenization'][arg.start_idx][0] + char_end_idx = sent['tokenization'][arg.end_idx][1] + em = EntityMention(next(aug), TokenRefSequence( + tokenIndexList=list(range(arg.start_idx, arg.end_idx + 1)), + tokenizationId=comm_sent.tokenization.uuid, + ), text=sent['sentence'][char_start_idx: char_end_idx + 1]) + entity_list.append(Entity(next(aug), id=em.text, mentionIdList=[em.uuid])) + em_dict[(arg.start_idx, arg.end_idx + 1)] = em + sm.argumentList.append(MentionArgument( + role=arg.label, + entityMentionId=em.uuid, + )) + + sm_list.append(sm) + + return sm_list, list(em_dict.values()), entity_list + + +def concrete_doc( + sentences: List[Dict[str, Any]], + doc_name: str = 'document', +) -> Communication: + """ + Data format: A list of sentences. Each sentence should be a dict of the following format: + { + "sentence": String. + "tokenization": A list of Tuple[int, int] for start and end indices. Both inclusive. + "annotations": A list of event dict, or Span object. + } + If it is dict, its format should be: + + Each event should be a dict of the following format: + { + "span": [start_idx, end_idx]: Integer. Both inclusive. + "label": String. + "children": A list of arguments. + } + Each argument should be a dict of the following format: + { + "span": [start_idx, end_idx]: Integer. Both inclusive. + "label": String. + } + + Note the "indices" above all refer to the indices of tokens, instead of characters. + """ + comm = create_comm( + doc_name, + '\n'.join([sent['sentence'] for sent in sentences]), + ) + aug = AnalyticUUIDGeneratorFactory(comm).create() + situation_mention_set = SituationMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list()) + comm.situationMentionSetList = [situation_mention_set] + entity_mention_set = EntityMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list()) + comm.entityMentionSetList = [entity_mention_set] + entity_set = EntitySet( + next(aug), AnnotationMetadata('O(0) Coref Paser.', time()), list(), None, entity_mention_set.uuid + ) + comm.entitySetList = [entity_set] + assert len(sentences) == len(comm.sectionList[0].sentenceList) + + char_idx_offset = 0 + for sent, comm_sent in zip(sentences, comm.sectionList[0].sentenceList): + sm_list, em_list, entity_list = _process_sentence(sent, comm_sent, aug, char_idx_offset) + entity_set.entityList.extend(entity_list) + situation_mention_set.mentionList.extend(sm_list) + entity_mention_set.mentionList.extend(em_list) + char_idx_offset += len(sent['sentence']) + 1 + + validate_communication(comm) + return comm + + +def concrete_doc_tokenized( + sentences: List[List[str]], + spans: List[Span], + doc_name: str = "document", +): + """ + Similar to concrete_doc, but with tokenized words and spans. + """ + inputs = list() + for sent, vr in zip(sentences, spans): + cur_start = 0 + tokenization = list() + for token in sent: + tokenization.append((cur_start, cur_start + len(token) - 1)) + cur_start += len(token) + 1 + inputs.append({ + "sentence": " ".join(sent), + "tokenization": tokenization, + "annotations": vr + }) + return concrete_doc(inputs, doc_name) + + +def collect_concrete_srl(comm: Communication) -> List[Tuple[List[str], Span]]: + # Mapping from to [, ] + sentences = defaultdict(lambda: [None, list()]) + for sec in comm.sectionList: + for sen in sec.sentenceList: + sentences[sen.uuid.uuidString][0] = sen + # Assume there's only ONE situation mention set + assert len(comm.situationMentionSetList) == 1 + # Assign each situation mention to the corresponding sentence + for men in comm.situationMentionSetList[0].mentionList: + if men.tokens is None: continue # For ACE relations + sentences[men.tokens.tokenization.sentence.uuid.uuidString][1].append(men) + ret = list() + for sen, mention_list in sentences.values(): + tokens = [t.text for t in sen.tokenization.tokenList.tokenList] + spans = list() + for mention in mention_list: + mention_tokens = sorted(mention.tokens.tokenIndexList) + event = Span(mention_tokens[0], mention_tokens[-1], mention.situationKind, True) + for men_arg in mention.argumentList: + arg_tokens = sorted(men_arg.entityMention.tokens.tokenIndexList) + event.add_child(Span(arg_tokens[0], arg_tokens[-1], men_arg.role, False)) + spans.append(event) + vr = Span.virtual_root(spans) + ret.append((tokens, vr)) + return ret diff --git a/spanfinder/sftp/data_reader/span_reader.py b/spanfinder/sftp/data_reader/span_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..dc8cb73812351def3875f2efe8ef235f74b19c96 --- /dev/null +++ b/spanfinder/sftp/data_reader/span_reader.py @@ -0,0 +1,197 @@ +import logging +from abc import ABC +from typing import * + +import numpy as np +from allennlp.common.util import END_SYMBOL +from allennlp.data.dataset_readers.dataset_reader import DatasetReader +from allennlp.data.dataset_readers.dataset_utils.span_utils import bio_tags_to_spans +from allennlp.data.fields import * +from allennlp.data.token_indexers import PretrainedTransformerIndexer +from allennlp.data.tokenizers import PretrainedTransformerTokenizer, Token + +from ..utils import Span, BIOSmoothing, apply_bio_smoothing + +logger = logging.getLogger(__name__) + + +@DatasetReader.register('span') +class SpanReader(DatasetReader, ABC): + def __init__( + self, + pretrained_model: str, + max_length: int = 512, + ignore_label: bool = False, + debug: bool = False, + **extras + ) -> None: + """ + :param pretrained_model: The name of the pretrained model. E.g. xlm-roberta-large + :param max_length: Sequences longer than this limit will be truncated. + :param ignore_label: If True, label on spans will be anonymized. + :param debug: True to turn on debugging mode. + :param span_proposals: Needed for "enumeration" scheme, but not needed for "BIO". + If True, it will try to enumerate candidate spans in the sentence, which will then be fed into + a binary classifier (EnumSpanFinder). + Note: It might take time to propose spans. And better to use SpacyTokenizer if you want to call + constituency parser or dependency parser. + :param maximum_negative_spans: Necessary for EnumSpanFinder. + :param extras: Args to DatasetReader. + """ + super().__init__(**extras) + self.word_indexer = { + 'pieces': PretrainedTransformerIndexer(pretrained_model, namespace='pieces') + } + + self._pretrained_model_name = pretrained_model + self.debug = debug + self.ignore_label = ignore_label + + self._pretrained_tokenizer = PretrainedTransformerTokenizer(pretrained_model) + self.max_length = max_length + self.n_span_removed = 0 + + def retokenize( + self, sentence: List[str], truncate: bool = True + ) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]: + pieces, offsets = self._pretrained_tokenizer.intra_word_tokenize(sentence) + pieces = list(map(str, pieces)) + if truncate: + pieces = pieces[:self.max_length] + pieces[-1] = END_SYMBOL + return pieces, offsets + + def prepare_inputs( + self, + sentence: List[str], + spans: Optional[Union[List[Span], Span]] = None, + truncate: bool = True, + label_type: str = 'string', + ) -> Dict[str, Field]: + """ + Prepare inputs and auxiliary variables for span model. + :param sentence: A list of tokens. Do not pass in any special tokens, like BOS or EOS. + Necessary for both training and testing. + :param spans: Optional. For training, spans passed in will be considered as positive examples; the spans + that are automatically proposed and not in the positive set will be considered as negative examples. + Necessary for training. + :param truncate: If True, sequence will be truncated if it's longer than `self.max_training_length` + :param label_type: One of [string, list]. + + :return: Dict of AllenNLP fields. For detailed of explanation of every field, refer to the comments + below. For the shape of every field, check the module doc. + Fields list: + - words + - span_labels + - span_boundary + - parent_indices + - parent_mask + - bio_seqs + - raw_sentence + - raw_spans + - proposed_spans + """ + fields = dict() + + pieces, offsets = self.retokenize(sentence, truncate) + fields['tokens'] = TextField(list(map(Token, pieces)), self.word_indexer) + raw_inputs = {'sentence': sentence, "pieces": pieces, 'offsets': offsets} + fields['raw_inputs'] = MetadataField(raw_inputs) + + if spans is None: + return fields + + vr = spans if isinstance(spans, Span) else Span.virtual_root(spans) + self.n_span_removed = vr.remove_overlapping() + raw_inputs['spans'] = vr + + vr = vr.re_index(offsets) + if truncate: + vr.truncate(self.max_length) + if self.ignore_label: + vr.ignore_labels() + + # (start_idx, end_idx) pairs. Left and right inclusive. + # The first span is the Virtual Root node. Shape [span, 2] + span_boundary = list() + # label on span. Shape [span] + span_labels = list() + # parent idx (span indexing space). Shape [span] + span_parent_indices = list() + # True for parents. Shape [span] + parent_mask = [False] * vr.n_nodes + # Key: parent idx (span indexing space). Value: child span idx + flatten_spans = list(vr.bfs()) + for span_idx, span in enumerate(vr.bfs()): + if span.is_parent: + parent_mask[span_idx] = True + # 0 is the virtual root + parent_idx = flatten_spans.index(span.parent) if span.parent else 0 + span_parent_indices.append(parent_idx) + span_boundary.append(span.boundary) + span_labels.append(span.label) + + bio_tag_list: List[List[str]] = list() + bio_configs: List[List[BIOSmoothing]] = list() + # Shape: [#parent, #token, 3] + bio_seqs: List[np.ndarray] = list() + # Parent index for every BIO seq + for parent_idx, parent in filter(lambda node: node[1].is_parent, enumerate(flatten_spans)): + bio_tags = ['O'] * len(pieces) + bio_tag_list.append(bio_tags) + bio_smooth: List[BIOSmoothing] = [parent.child_smooth.clone() for _ in pieces] + bio_configs.append(bio_smooth) + for child in parent: + assert all(bio_tags[bio_idx] == 'O' for bio_idx in range(child.start_idx, child.end_idx + 1)) + if child.smooth_weight is not None: + for i in range(child.start_idx, child.end_idx+1): + bio_smooth[i].weight = child.smooth_weight + bio_tags[child.start_idx] = 'B' + for word_idx in range(child.start_idx + 1, child.end_idx + 1): + bio_tags[word_idx] = 'I' + bio_seqs.append(apply_bio_smoothing(bio_smooth, bio_tags)) + + fields['span_boundary'] = ArrayField( + np.array(span_boundary), padding_value=0, dtype=np.int + ) + fields['parent_indices'] = ArrayField(np.array(span_parent_indices), 0, np.int) + if label_type == 'string': + fields['span_labels'] = ListField([LabelField(label, 'span_label') for label in span_labels]) + elif label_type == 'list': + fields['span_labels'] = ArrayField(np.array(span_labels)) + else: + raise NotImplementedError + fields['parent_mask'] = ArrayField(np.array(parent_mask), False, np.bool) + fields['bio_seqs'] = ArrayField(np.stack(bio_seqs)) + + self._sanity_check( + flatten_spans, pieces, bio_tag_list, parent_mask, span_boundary, span_labels, span_parent_indices + ) + + return fields + + @staticmethod + def _sanity_check( + flatten_spans, words, bio_tag_list, parent_mask, span_boundary, span_labels, parent_indices, verbose=False + ): + # For debugging use. + assert len(parent_mask) == len(span_boundary) == len(span_labels) == len(parent_indices) + for (parent_idx, parent_span), bio_tags in zip( + filter(lambda x: x[1].is_parent, enumerate(flatten_spans)), bio_tag_list + ): + assert parent_mask[parent_idx] + parent_s, parent_e = span_boundary[parent_idx] + if verbose: + print('Parent: ', span_labels[parent_idx], 'Text: ', ' '.join(words[parent_s:parent_e+1])) + print(f'It contains {len(parent_span)} children.') + for child in parent_span: + child_idx = flatten_spans.index(child) + assert parent_indices[child_idx] == flatten_spans.index(parent_span) + if verbose: + child_s, child_e = span_boundary[child_idx] + print(' ', span_labels[child_idx], 'Text', words[child_s:child_e+1]) + + if verbose: + print(f'Child derived from BIO tags:') + for _, (start, end) in bio_tags_to_spans(bio_tags): + print(words[start:end+1]) diff --git a/spanfinder/sftp/data_reader/srl_reader.py b/spanfinder/sftp/data_reader/srl_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..3a2816c615c856b57ce29c07b69c51e5f1ce05cb --- /dev/null +++ b/spanfinder/sftp/data_reader/srl_reader.py @@ -0,0 +1,107 @@ +import json +import logging +import random +from typing import * + +import numpy as np +from allennlp.data.dataset_readers.dataset_reader import DatasetReader +from allennlp.data.fields import MetadataField +from allennlp.data.instance import Instance + +from .span_reader import SpanReader +from ..utils import Span, VIRTUAL_ROOT, BIOSmoothing + +logger = logging.getLogger(__name__) + + +@DatasetReader.register('semantic_role_labeling') +class SRLDatasetReader(SpanReader): + def __init__( + self, + min_negative: int = 5, + negative_ratio: float = 1., + event_only: bool = False, + event_smoothing_factor: float = 0., + arg_smoothing_factor: float = 0., + # For Ontology Mapping + ontology_mapping_path: Optional[str] = None, + min_weight: float = 1e-2, + max_weight: float = 1.0, + **extra + ): + super().__init__(**extra) + self.min_negative = min_negative + self.negative_ratio = negative_ratio + self.event_only = event_only + self.event_smooth_factor = event_smoothing_factor + self.arg_smooth_factor = arg_smoothing_factor + self.ontology_mapping = None + if ontology_mapping_path is not None: + self.ontology_mapping = json.load(open(ontology_mapping_path)) + for k1 in ['event', 'argument']: + for k2, weights in self.ontology_mapping['mapping'][k1].items(): + weights = np.array(weights) + weights[weights < min_weight] = 0.0 + weights[weights > max_weight] = max_weight + self.ontology_mapping['mapping'][k1][k2] = weights + self.ontology_mapping['mapping'][k1] = { + k2: weights for k2, weights in self.ontology_mapping['mapping'][k1].items() if weights.sum() > 1e-5 + } + vr_label = [0.] * len(self.ontology_mapping['target']['label']) + vr_label[self.ontology_mapping['target']['label'].index(VIRTUAL_ROOT)] = 1.0 + self.ontology_mapping['mapping']['event'][VIRTUAL_ROOT] = np.array(vr_label) + + def _read(self, file_path: str) -> Iterable[Instance]: + all_lines = list(map(json.loads, open(file_path).readlines())) + if self.debug: + random.seed(1); random.shuffle(all_lines) + for line in all_lines: + ins = self.text_to_instance(**line) + if ins is not None: + yield ins + if self.n_span_removed > 0: + logger.warning(f'{self.n_span_removed} spans are removed.') + self.n_span_removed = 0 + + def apply_ontology_mapping(self, vr): + new_events = list() + event_map, arg_map = self.ontology_mapping['mapping']['event'], self.ontology_mapping['mapping']['argument'] + for event in vr: + if event.label not in event_map: continue + event.child_smooth.weight = event.smooth_weight = event_map[event.label].sum() + event = event.map_ontology(event_map, False, False) + new_events.append(event) + new_children = list() + for child in event: + if child.label not in arg_map: continue + child.child_smooth.weight = child.smooth_weight = arg_map[child.label].sum() + child = child.map_ontology(arg_map, False, False) + new_children.append(child) + event.remove_child() + for child in new_children: event.add_child(child) + new_vr = Span.virtual_root(new_events) + # For Virtual Root itself. + new_vr.map_ontology(self.ontology_mapping['mapping']['event'], True, False) + return new_vr + + def text_to_instance(self, tokens, annotations=None, meta=None) -> Optional[Instance]: + meta = meta or {'fully_annotated': True} + meta['fully_annotated'] = meta.get('fully_annotated', True) + vr = None + if annotations is not None: + vr = annotations if isinstance(annotations, Span) else Span.from_json(annotations) + vr = self.apply_ontology_mapping(vr) if self.ontology_mapping is not None else vr + # if len(vr) == 0: return # Ignore sentence with empty annotation + if self.event_smooth_factor != 0.0: + vr.child_smooth = BIOSmoothing(o_smooth=self.event_smooth_factor if meta['fully_annotated'] else -1) + if self.arg_smooth_factor != 0.0: + for event in vr: + event.child_smooth = BIOSmoothing(o_smooth=self.arg_smooth_factor) + if self.event_only: + for event in vr: + event.remove_child() + event.is_parent = False + + fields = self.prepare_inputs(tokens, vr, True, 'string' if self.ontology_mapping is None else 'list') + fields['meta'] = MetadataField(meta) + return Instance(fields) diff --git a/spanfinder/sftp/metrics/__init__.py b/spanfinder/sftp/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92e5e3f749ec25523a7591dead6da95cc673172f --- /dev/null +++ b/spanfinder/sftp/metrics/__init__.py @@ -0,0 +1,4 @@ +from sftp.metrics.base_f import BaseF +from sftp.metrics.exact_match import ExactMatch +from sftp.metrics.fbeta_mix_measure import FBetaMixMeasure +from sftp.metrics.srl_metrics import SRLMetric diff --git a/spanfinder/sftp/metrics/__pycache__/__init__.cpython-38.pyc b/spanfinder/sftp/metrics/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e65288a2b042c3e700b00b04a937c4be6140cc4 Binary files /dev/null and b/spanfinder/sftp/metrics/__pycache__/__init__.cpython-38.pyc differ diff --git a/spanfinder/sftp/metrics/__pycache__/__init__.cpython-39.pyc b/spanfinder/sftp/metrics/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ee56b7660391c045e3d2dde99683f297bb1144e Binary files /dev/null and b/spanfinder/sftp/metrics/__pycache__/__init__.cpython-39.pyc differ diff --git a/spanfinder/sftp/metrics/__pycache__/base_f.cpython-38.pyc b/spanfinder/sftp/metrics/__pycache__/base_f.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1520495f9a353d63825e62639f93c6f7c2a6939d Binary files /dev/null and b/spanfinder/sftp/metrics/__pycache__/base_f.cpython-38.pyc differ diff --git a/spanfinder/sftp/metrics/__pycache__/base_f.cpython-39.pyc b/spanfinder/sftp/metrics/__pycache__/base_f.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6cc2be278368b0914971d2b570fa85fbe6d1563 Binary files /dev/null and b/spanfinder/sftp/metrics/__pycache__/base_f.cpython-39.pyc differ diff --git a/spanfinder/sftp/metrics/__pycache__/exact_match.cpython-38.pyc b/spanfinder/sftp/metrics/__pycache__/exact_match.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..627a5bc0f5a3a26da4fc39cec7b9167ff4c452ac Binary files /dev/null and b/spanfinder/sftp/metrics/__pycache__/exact_match.cpython-38.pyc differ diff --git a/spanfinder/sftp/metrics/__pycache__/exact_match.cpython-39.pyc b/spanfinder/sftp/metrics/__pycache__/exact_match.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d7537509b94d006d1495345a5c39cc6ce2dc5cc Binary files /dev/null and b/spanfinder/sftp/metrics/__pycache__/exact_match.cpython-39.pyc differ diff --git a/spanfinder/sftp/metrics/__pycache__/fbeta_mix_measure.cpython-38.pyc b/spanfinder/sftp/metrics/__pycache__/fbeta_mix_measure.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b57069056d1635c2aca92420c299e282545079c6 Binary files /dev/null and b/spanfinder/sftp/metrics/__pycache__/fbeta_mix_measure.cpython-38.pyc differ diff --git a/spanfinder/sftp/metrics/__pycache__/fbeta_mix_measure.cpython-39.pyc b/spanfinder/sftp/metrics/__pycache__/fbeta_mix_measure.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddd0144d98236c0b7eacff248341d8d5865f4864 Binary files /dev/null and b/spanfinder/sftp/metrics/__pycache__/fbeta_mix_measure.cpython-39.pyc differ diff --git a/spanfinder/sftp/metrics/__pycache__/srl_metrics.cpython-38.pyc b/spanfinder/sftp/metrics/__pycache__/srl_metrics.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3c8c0bc18c01cf9310f45495354f100e5219b19 Binary files /dev/null and b/spanfinder/sftp/metrics/__pycache__/srl_metrics.cpython-38.pyc differ diff --git a/spanfinder/sftp/metrics/__pycache__/srl_metrics.cpython-39.pyc b/spanfinder/sftp/metrics/__pycache__/srl_metrics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a256fb913ffb45cb43018be427b80b78a6fdc42 Binary files /dev/null and b/spanfinder/sftp/metrics/__pycache__/srl_metrics.cpython-39.pyc differ diff --git a/spanfinder/sftp/metrics/base_f.py b/spanfinder/sftp/metrics/base_f.py new file mode 100644 index 0000000000000000000000000000000000000000..aca78b5605fc7c47b0726e3aafd188eebeb9c1a7 --- /dev/null +++ b/spanfinder/sftp/metrics/base_f.py @@ -0,0 +1,27 @@ +from abc import ABC +from typing import * + +from allennlp.training.metrics import Metric + + +class BaseF(Metric, ABC): + def __init__(self, prefix: str): + self.tp = self.fp = self.fn = 0 + self.prefix = prefix + + def reset(self) -> None: + self.tp = self.fp = self.fn = 0 + + def get_metric( + self, reset: bool + ) -> Union[float, Tuple[float, ...], Dict[str, float], Dict[str, List[float]]]: + precision = self.tp * 100 / (self.tp + self.fp) if self.tp > 0 else 0. + recall = self.tp * 100 / (self.tp + self.fn) if self.tp > 0 else 0. + rst = { + f'{self.prefix}_p': precision, + f'{self.prefix}_r': recall, + f'{self.prefix}_f': 2 / (1 / precision + 1 / recall) if self.tp > 0 else 0. + } + if reset: + self.reset() + return rst diff --git a/spanfinder/sftp/metrics/exact_match.py b/spanfinder/sftp/metrics/exact_match.py new file mode 100644 index 0000000000000000000000000000000000000000..2d6596b29ee30969c981c52c1cf1fd3381ea53d1 --- /dev/null +++ b/spanfinder/sftp/metrics/exact_match.py @@ -0,0 +1,29 @@ +from allennlp.training.metrics import Metric +from overrides import overrides + +from .base_f import BaseF +from ..utils import Span + + +@Metric.register('exact_match') +class ExactMatch(BaseF): + def __init__(self, check_type: bool): + self.check_type = check_type + if check_type: + super(ExactMatch, self).__init__('em') + else: + super(ExactMatch, self).__init__('sm') + + @overrides + def __call__( + self, + prediction: Span, + gold: Span, + ): + tp = prediction.match(gold, self.check_type) - 1 + fp = prediction.n_nodes - tp - 1 + fn = gold.n_nodes - tp - 1 + assert tp >= 0 and fp >= 0 and fn >= 0 + self.tp += tp + self.fp += fp + self.fn += fn diff --git a/spanfinder/sftp/metrics/fbeta_mix_measure.py b/spanfinder/sftp/metrics/fbeta_mix_measure.py new file mode 100644 index 0000000000000000000000000000000000000000..9ef18b6f9db5b7698dc13646bf541e34a8291529 --- /dev/null +++ b/spanfinder/sftp/metrics/fbeta_mix_measure.py @@ -0,0 +1,34 @@ +from allennlp.training.metrics import FBetaMeasure, Metric + + +@Metric.register('fbeta_mix') +class FBetaMixMeasure(FBetaMeasure): + def __init__(self, null_idx, **kwargs): + super().__init__(**kwargs) + self.null_idx = null_idx + + def get_metric(self, reset: bool = False): + + tp = float(self._true_positive_sum.sum() - self._true_positive_sum[self.null_idx]) + total_pred = float(self._pred_sum.sum() - self._pred_sum[self.null_idx]) + total_gold = float(self._true_sum.sum() - self._true_sum[self.null_idx]) + + beta2 = self._beta ** 2 + p = 0. if total_pred == 0 else tp / total_pred + r = 0. if total_pred == 0 else tp / total_gold + f = 0. if p == 0. or r == 0. else ((1 + beta2) * p * r / (p * beta2 + r)) + + mix_f = { + 'p': p * 100, + 'r': r * 100, + 'f': f * 100 + } + + if reset: + self.reset() + + return mix_f + + def add_false_negative(self, labels): + for lab in labels: + self._true_sum[lab] += 1 diff --git a/spanfinder/sftp/metrics/srl_metrics.py b/spanfinder/sftp/metrics/srl_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..e769237931e3ba692537948e44461990c1271b83 --- /dev/null +++ b/spanfinder/sftp/metrics/srl_metrics.py @@ -0,0 +1,138 @@ +from typing import * + +from allennlp.training.metrics import Metric +from overrides import overrides +import numpy as np +import logging + +from .base_f import BaseF +from ..utils import Span, max_match + +logger = logging.getLogger('srl_metric') + + +@Metric.register('srl') +class SRLMetric(Metric): + def __init__(self, check_type: Optional[bool] = None): + self.tri_i = BaseF('tri-i') + self.tri_c = BaseF('tri-c') + self.arg_i = BaseF('arg-i') + self.arg_c = BaseF('arg-c') + if check_type is not None: + logger.warning('Check type argument is deprecated.') + + def reset(self) -> None: + for metric in [self.tri_i, self.tri_c, self.arg_i, self.arg_c]: + metric.reset() + + def get_metric(self, reset: bool) -> Dict[str, Any]: + ret = dict() + for metric in [self.tri_i, self.tri_c, self.arg_i, self.arg_c]: + ret.update(metric.get_metric(reset)) + return ret + + @overrides + def __call__(self, prediction: Span, gold: Span): + self.with_label_event(prediction, gold) + self.without_label_event(prediction, gold) + self.tuple_eval(prediction, gold) + # self.with_label_arg(prediction, gold) + # self.without_label_arg(prediction, gold) + + def tuple_eval(self, prediction: Span, gold: Span): + def extract_tuples(vr: Span, parent_boundary: bool): + labeled, unlabeled = list(), list() + for event in vr: + for arg in event: + if parent_boundary: + labeled.append((event.boundary, event.label, arg.boundary, arg.label)) + unlabeled.append((event.boundary, event.label, arg.boundary)) + else: + labeled.append((event.label, arg.boundary, arg.label)) + unlabeled.append((event.label, arg.boundary)) + return labeled, unlabeled + + def equal_matrix(l1, l2): return np.array([[e1 == e2 for e2 in l2] for e1 in l1], dtype=np.int) + + pred_label, pred_unlabel = extract_tuples(prediction, False) + gold_label, gold_unlabel = extract_tuples(gold, False) + + if len(pred_label) == 0 or len(gold_label) == 0: + arg_c_tp = arg_i_tp = 0 + else: + label_bipartite = equal_matrix(pred_label, gold_label) + unlabel_bipartite = equal_matrix(pred_unlabel, gold_unlabel) + arg_c_tp, arg_i_tp = max_match(label_bipartite), max_match(unlabel_bipartite) + + arg_c_fp = prediction.n_nodes - len(prediction) - 1 - arg_c_tp + arg_c_fn = gold.n_nodes - len(gold) - 1 - arg_c_tp + arg_i_fp = prediction.n_nodes - len(prediction) - 1 - arg_i_tp + arg_i_fn = gold.n_nodes - len(gold) - 1 - arg_i_tp + + assert arg_i_tp >= 0 and arg_i_fn >= 0 and arg_i_fp >= 0 + self.arg_i.tp += arg_i_tp + self.arg_i.fp += arg_i_fp + self.arg_i.fn += arg_i_fn + + assert arg_c_tp >= 0 and arg_c_fn >= 0 and arg_c_fp >= 0 + self.arg_c.tp += arg_c_tp + self.arg_c.fp += arg_c_fp + self.arg_c.fn += arg_c_fn + + def with_label_event(self, prediction: Span, gold: Span): + trigger_tp = prediction.match(gold, True, 2) - 1 + trigger_fp = len(prediction) - trigger_tp + trigger_fn = len(gold) - trigger_tp + assert trigger_fp >= 0 and trigger_fn >= 0 and trigger_tp >= 0 + self.tri_c.tp += trigger_tp + self.tri_c.fp += trigger_fp + self.tri_c.fn += trigger_fn + + def with_label_arg(self, prediction: Span, gold: Span): + trigger_tp = prediction.match(gold, True, 2) - 1 + role_tp = prediction.match(gold, True, ignore_parent_boundary=True) - 1 - trigger_tp + role_fp = (prediction.n_nodes - 1 - len(prediction)) - role_tp + role_fn = (gold.n_nodes - 1 - len(gold)) - role_tp + assert role_fp >= 0 and role_fn >= 0 and role_tp >= 0 + self.arg_c.tp += role_tp + self.arg_c.fp += role_fp + self.arg_c.fn += role_fn + + def without_label_event(self, prediction: Span, gold: Span): + tri_i_tp = prediction.match(gold, False, 2) - 1 + tri_i_fp = len(prediction) - tri_i_tp + tri_i_fn = len(gold) - tri_i_tp + assert tri_i_tp >= 0 and tri_i_fp >= 0 and tri_i_fn >= 0 + self.tri_i.tp += tri_i_tp + self.tri_i.fp += tri_i_fp + self.tri_i.fn += tri_i_fn + + def without_label_arg(self, prediction: Span, gold: Span): + arg_i_tp = 0 + matched_pairs: List[Tuple[Span, Span]] = list() + n_gold_arg, n_pred_arg = gold.n_nodes - len(gold) - 1, prediction.n_nodes - len(prediction) - 1 + prediction, gold = prediction.clone(), gold.clone() + for p in prediction: + for g in gold: + if p.match(g, True, 1) == 1: + arg_i_tp += (p.match(g, False) - 1) + matched_pairs.append((p, g)) + break + for p, g in matched_pairs: + prediction.remove_child(p) + gold.remove_child(g) + + sub_matches = np.zeros([len(prediction), len(gold)], np.int) + for p_idx, p in enumerate(prediction): + for g_idx, g in enumerate(gold): + if p.label == g.label: + sub_matches[p_idx, g_idx] = p.match(g, False, -1, True) + arg_i_tp += max_match(sub_matches) + + arg_i_fp = n_pred_arg - arg_i_tp + arg_i_fn = n_gold_arg - arg_i_tp + assert arg_i_tp >= 0 and arg_i_fn >= 0 and arg_i_fp >= 0 + + self.arg_i.tp += arg_i_tp + self.arg_i.fp += arg_i_fp + self.arg_i.fn += arg_i_fn diff --git a/spanfinder/sftp/models/__init__.py b/spanfinder/sftp/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ab4f4c06fb993fd918c8e95bcd73fc3c2658d9ab --- /dev/null +++ b/spanfinder/sftp/models/__init__.py @@ -0,0 +1 @@ +from sftp.models.span_model import SpanModel diff --git a/spanfinder/sftp/models/__pycache__/__init__.cpython-38.pyc b/spanfinder/sftp/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25b4132e96c8f41421827be64238ec4612cdb06e Binary files /dev/null and b/spanfinder/sftp/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/spanfinder/sftp/models/__pycache__/__init__.cpython-39.pyc b/spanfinder/sftp/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8131e55ef5b973f06a5bbbf35cf6490b8e7b3741 Binary files /dev/null and b/spanfinder/sftp/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/spanfinder/sftp/models/__pycache__/span_model.cpython-38.pyc b/spanfinder/sftp/models/__pycache__/span_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b9d05303831644051d6d249a62ef1a6f5c9cb53 Binary files /dev/null and b/spanfinder/sftp/models/__pycache__/span_model.cpython-38.pyc differ diff --git a/spanfinder/sftp/models/__pycache__/span_model.cpython-39.pyc b/spanfinder/sftp/models/__pycache__/span_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce7a1e9df385ea0d3722f60f314394c43a79f226 Binary files /dev/null and b/spanfinder/sftp/models/__pycache__/span_model.cpython-39.pyc differ diff --git a/spanfinder/sftp/models/span_model.py b/spanfinder/sftp/models/span_model.py new file mode 100644 index 0000000000000000000000000000000000000000..13c4376774e0f3cdaa9d3e964b0e849aa25200d1 --- /dev/null +++ b/spanfinder/sftp/models/span_model.py @@ -0,0 +1,362 @@ +import os +from typing import * + +import torch +from allennlp.common.from_params import Params, T, pop_and_construct_arg +from allennlp.data.vocabulary import Vocabulary, DEFAULT_PADDING_TOKEN, DEFAULT_OOV_TOKEN +from allennlp.models.model import Model +from allennlp.modules import TextFieldEmbedder +from allennlp.modules.seq2seq_encoders.pytorch_seq2seq_wrapper import Seq2SeqEncoder +from allennlp.modules.span_extractors import SpanExtractor +from allennlp.training.metrics import Metric + +from ..metrics import ExactMatch +from ..modules import SpanFinder, SpanTyping +from ..utils import num2mask, VIRTUAL_ROOT, Span, tensor2span + + +@Model.register("span") +class SpanModel(Model): + """ + Identify/Find spans; link them as a tree; label them. + """ + default_predictor = 'span' + + def __init__( + self, + vocab: Vocabulary, + + # Modules + word_embedding: TextFieldEmbedder, + span_extractor: SpanExtractor, + span_finder: SpanFinder, + span_typing: SpanTyping, + + # Config + typing_loss_factor: float = 1., + max_recursion_depth: int = -1, + max_decoding_spans: int = -1, + debug: bool = False, + + # Ontology Constraints + ontology_path: Optional[str] = None, + + # Metrics + metrics: Optional[List[Metric]] = None, + ) -> None: + """ + Note for jsonnet file: it doesn't strictly follow the init examples of every module for that we override + the from_params method. + You can either check the SpanModel.from_params or the example jsonnet file. + :param vocab: No need to specify. + ## Modules + :param word_embedding: Refer to the module doc. + :param span_extractor: Refer to the module doc. + :param span_finder: Refer to the module doc. + :param span_typing: Refer to the module doc. + ## Configs + :param typing_loss_factor: loss = span_finder_loss + span_typing_loss * typing_loss_factor + :param max_recursion_depth: Maximum tree depth for inference. E.g., 1 for shallow event typing, 2 for SRL, + -1 (unlimited) for dependency parsing. + :param max_decoding_spans: Maximum spans for inference. -1 for unlimited. + :param debug: Useless now. + """ + self._pad_idx = vocab.get_token_index(DEFAULT_PADDING_TOKEN, 'token') + self._null_idx = vocab.get_token_index(DEFAULT_OOV_TOKEN, 'span_label') + super().__init__(vocab) + + self.word_embedding = word_embedding + self._span_finder = span_finder + self._span_extractor = span_extractor + self._span_typing = span_typing + + self.metrics = [ExactMatch(True), ExactMatch(False)] + if metrics is not None: + self.metrics.extend(metrics) + + if ontology_path is not None and os.path.exists(ontology_path): + self._span_typing.load_ontology(ontology_path, self.vocab) + + self._max_decoding_spans = max_decoding_spans + self._typing_loss_factor = typing_loss_factor + self._max_recursion_depth = max_recursion_depth + self.debug = debug + + def forward( + self, + tokens: Dict[str, Dict[str, torch.Tensor]], + + span_boundary: Optional[torch.Tensor] = None, + span_labels: Optional[torch.Tensor] = None, + parent_indices: Optional[torch.Tensor] = None, + parent_mask: Optional[torch.Tensor] = None, + + bio_seqs: Optional[torch.Tensor] = None, + raw_inputs: Optional[dict] = None, + meta: Optional[dict] = None, + + **extra + ) -> Dict[str, torch.Tensor]: + """ + For training, provide all blow. + For inference, it's enough to only provide words. + + :param tokens: Indexed input sentence. Shape: [batch, token] + + :param span_boundary: Start and end indices for every span. Note this includes both parent and + non-parent spans. Shape: [batch, span, 2]. For the last dim, [0] is start idx and [1] is end idx. + :param span_labels: Indexed label for spans, including parent and non-parent ones. Shape: [batch, span] + :param parent_indices: The parent span idx of every span. Shape: [batch, span] + :param parent_mask: True if this span is a parent. Shape: [batch, span] + + :param bio_seqs: Shape [batch, parent, token, 3] + :param raw_inputs + + :param meta: Meta information. Will be copied to the outputs. + + :return: + - loss: training loss + - prediction: Predicted spans + - meta: Meta info copied from input + - inputs: Input sentences and spans (if exist) + """ + ret = {'inputs': raw_inputs, 'meta': meta or dict()} + + is_eval = span_labels is not None and not self.training # evaluation on dev set + is_test = span_labels is None # test on test set + # Shape [batch] + num_spans = (span_labels != -1).sum(1) if span_labels is not None else None + num_words = tokens['pieces']['mask'].sum(1) + # Shape [batch, word, token_dim] + token_vec = self.word_embedding(tokens) + + if span_labels is not None: + # Revise the padding value from -1 to 0 + span_labels[span_labels == -1] = 0 + + # Calculate Loss + if self.training or is_eval: + # Shape [batch, word, token_dim] + span_vec = self._span_extractor(token_vec, span_boundary) + finder_rst = self._span_finder( + token_vec, num2mask(num_words), span_vec, num2mask(num_spans), span_labels, parent_indices, + parent_mask, bio_seqs + ) + typing_rst = self._span_typing(span_vec, parent_indices, span_labels) + ret['loss'] = finder_rst['loss'] + typing_rst['loss'] * self._typing_loss_factor + + # Decoding + if is_eval or is_test: + pred_span_boundary, pred_span_labels, pred_parent_indices, pred_cursor, pred_label_confidence \ + = self.inference(num_words, token_vec, **extra) + prediction = self.post_process_pred( + pred_span_boundary, pred_span_labels, pred_parent_indices, pred_cursor, pred_label_confidence + ) + for pred, raw_in in zip(prediction, raw_inputs): + pred.re_index(raw_in['offsets'], True, True, True) + pred.remove_overlapping() + ret['prediction'] = prediction + if 'spans' in raw_inputs[0]: + for pred, raw_in in zip(prediction, raw_inputs): + gold = raw_in['spans'] + for metric in self.metrics: + metric(pred, gold) + + return ret + + def inference( + self, + num_words: torch.Tensor, + token_vec: torch.Tensor, + **auxiliaries + ): + n_batch = num_words.shape[0] + # The decoding results are preserved in the following tensors starting with `pred` + # During inference, we completely ignore the arguments defaulted None in the forward method. + # The span indexing space is shift to the decoding span space. (since we do not have gold span now) + # boundary indices of every predicted span + pred_span_boundary = num_words.new_zeros([n_batch, self._max_decoding_spans, 2]) + # labels (and corresponding confidence) for predicted spans + pred_span_labels = num_words.new_full( + [n_batch, self._max_decoding_spans], self.vocab.get_token_index(VIRTUAL_ROOT, 'span_label') + ) + pred_label_confidence = num_words.new_zeros([n_batch, self._max_decoding_spans]) + # label masked as True will be treated as parent in the next round + pred_parent_mask = num_words.new_zeros([n_batch, self._max_decoding_spans], dtype=torch.bool) + pred_parent_mask[:, 0] = True + # parent index (in the span indexing space) for every span + pred_parent_indices = num_words.new_zeros([n_batch, self._max_decoding_spans]) + # what index have we reached for every batch? + pred_cursor = num_words.new_ones([n_batch]) + + # Pass environment variables to handler. Extra variables will be ignored. + # So pass the union of variables that are needed by different modules. + span_find_handler = self._span_finder.inference_forward_handler( + token_vec, num2mask(num_words), self._span_extractor, **auxiliaries + ) + + # Every step here is one layer of the tree. It deals with all the parents for the last layer + # so there might be 0 to multiple parents for a batch for a single step. + for _ in range(self._max_recursion_depth): + cursor_before_find = pred_cursor.clone() + span_find_handler( + pred_span_boundary, pred_span_labels, pred_parent_mask, pred_parent_indices, pred_cursor + ) + # Labels of old spans are re-predicted. It doesn't matter since their results shouldn't change + # in theory. + span_typing_ret = self._span_typing( + self._span_extractor(token_vec, pred_span_boundary), pred_parent_indices, pred_span_labels, True + ) + pred_span_labels = span_typing_ret['prediction'] + pred_label_confidence = span_typing_ret['label_confidence'] + pred_span_labels[:, 0] = self.vocab.get_token_index(VIRTUAL_ROOT, 'span_label') + pred_parent_mask = ( + num2mask(cursor_before_find, self._max_decoding_spans) ^ num2mask(pred_cursor, + self._max_decoding_spans) + ) + + # Break the inference loop if 1) all batches reach max span limit OR 2) no parent is predicted + # at last step OR 3) max recursion limit is reached (for loop condition) + if (pred_cursor == self._max_decoding_spans).all() or pred_parent_mask.sum() == 0: + break + + return pred_span_boundary, pred_span_labels, pred_parent_indices, pred_cursor, pred_label_confidence + + def one_step_prediction( + self, + tokens: Dict[str, Dict[str, torch.Tensor]], + parent_boundary: torch.Tensor, + parent_labels: torch.Tensor, + ): + """ + Single step prediction. Given parent span boundary indices, return the corresponding children spans + and their labels. + Restriction: Each sentence contain exactly 1 parent. + For efficient multi-layer prediction, i.e. given a root, predict the whole tree, + refer to the `forward' method. + :param tokens: See forward. + :param parent_boundary: Pairs of (start_idx, end_idx) for parents. Shape [batch, 2] + :param parent_labels: Labels for parents. Shape [batch] + Note: If `no_label' is on in span_finder module, this will be ignored. + :return: + children_boundary: (start_idx, end_idx) for every child span. Padded with (0, 0). + Shape [batch, children, 2] + children_labels: Label for every child span. Padded with null_idx. Shape [batch, children] + num_children: The number of children predicted for parent/batch. Shape [batch] + Tips: You can use num2mask method to convert this to bool tensor mask. + """ + num_words = tokens['pieces']['mask'].sum(1) + # Shape [batch, word, token_dim] + token_vec = self.word_embedding(tokens) + n_batch = token_vec.shape[0] + + # The following variables assumes the parent is the 0-th span, and we let the model + # to extend the span list. + pred_span_boundary = num_words.new_zeros([n_batch, self._max_decoding_spans, 2]) + pred_span_boundary[:, 0] = parent_boundary + pred_span_labels = num_words.new_full([n_batch, self._max_decoding_spans], self._null_idx) + pred_span_labels[:, 0] = parent_labels + pred_parent_mask = num_words.new_zeros(pred_span_labels.shape, dtype=torch.bool) + pred_parent_mask[:, 0] = True + pred_parent_indices = num_words.new_zeros([n_batch, self._max_decoding_spans]) + # We start from idx 1 since 0 is the parents. + pred_cursor = num_words.new_ones([n_batch]) + + span_find_handler = self._span_finder.inference_forward_handler( + token_vec, num2mask(num_words), self._span_extractor + ) + span_find_handler( + pred_span_boundary, pred_span_labels, pred_parent_mask, pred_parent_indices, pred_cursor + ) + typing_out = self._span_typing( + self._span_extractor(token_vec, pred_span_boundary), pred_parent_indices, pred_span_labels, True + ) + pred_span_labels = typing_out['prediction'] + + # Now remove the parent + num_children = pred_cursor - 1 + max_children = int(num_children.max()) + children_boundary = pred_span_boundary[:, 1:max_children + 1] + children_labels = pred_span_labels[:, 1:max_children + 1] + children_distribution = typing_out['distribution'][:, 1:max_children + 1] + return children_boundary, children_labels, num_children, children_distribution + + def post_process_pred( + self, span_boundary, span_labels, parent_indices, num_spans, label_confidence + ) -> List[Span]: + pred_spans = tensor2span( + span_boundary, span_labels, parent_indices, num_spans, label_confidence, + self.vocab.get_index_to_token_vocabulary('span_label'), + label_ignore=[self._null_idx], + ) + return pred_spans + + def get_metrics(self, reset: bool = False) -> Dict[str, float]: + ret = dict() + if reset: + for metric in self.metrics: + ret.update(metric.get_metric(reset)) + ret.update(self._span_finder.get_metrics(reset)) + ret.update(self._span_typing.get_metric(reset)) + return ret + + @classmethod + def from_params( + cls: Type[T], + params: Params, + constructor_to_call: Callable[..., T] = None, + constructor_to_inspect: Callable[..., T] = None, + **extras, + ) -> T: + """ + Specify the dependency between modules. E.g. the input dim of a module might depend on the output dim + of another module. + """ + vocab = extras['vocab'] + word_embedding = pop_and_construct_arg('SpanModel', 'word_embedding', TextFieldEmbedder, None, params, **extras) + label_dim, token_emb_dim = params.pop('label_dim'), word_embedding.get_output_dim() + span_extractor = pop_and_construct_arg( + 'SpanModel', 'span_extractor', SpanExtractor, None, params, input_dim=token_emb_dim, **extras + ) + label_embedding = torch.nn.Embedding(vocab.get_vocab_size('span_label'), label_dim) + extras['label_emb'] = label_embedding + + if params.get('span_finder').get('type') == 'bio': + bio_encoder = Seq2SeqEncoder.from_params( + params['span_finder'].pop('bio_encoder'), + input_size=span_extractor.get_output_dim() + token_emb_dim + label_dim, + input_dim=span_extractor.get_output_dim() + token_emb_dim + label_dim, + **extras + ) + extras['span_finder'] = SpanFinder.from_params( + params.pop('span_finder'), bio_encoder=bio_encoder, **extras + ) + else: + extras['span_finder'] = pop_and_construct_arg( + 'SpanModel', 'span_finder', SpanFinder, None, params, **extras + ) + extras['span_finder'].label_emb = label_embedding + + if params.get('span_typing').get('type') == 'mlp': + extras['span_typing'] = SpanTyping.from_params( + params.pop('span_typing'), + input_dim=span_extractor.get_output_dim() * 2 + label_dim, + n_category=vocab.get_vocab_size('span_label'), + label_to_ignore=[ + vocab.get_token_index(lti, 'span_label') + for lti in [DEFAULT_OOV_TOKEN, DEFAULT_PADDING_TOKEN] + ], + **extras + ) + else: + extras['span_typing'] = pop_and_construct_arg( + 'SpanModel', 'span_typing', SpanTyping, None, params, **extras + ) + extras['span_typing'].label_emb = label_embedding + + return super().from_params( + params, + word_embedding=word_embedding, + span_extractor=span_extractor, + **extras + ) diff --git a/spanfinder/sftp/modules/__init__.py b/spanfinder/sftp/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d26d2ee7d196acc7fb671d8204b964ccb33efda0 --- /dev/null +++ b/spanfinder/sftp/modules/__init__.py @@ -0,0 +1,4 @@ +from .span_extractor import ComboSpanExtractor +from .span_finder import SpanFinder, BIOSpanFinder +from .span_typing import MLPSpanTyping, SpanTyping +from .smooth_crf import SmoothCRF diff --git a/spanfinder/sftp/modules/__pycache__/__init__.cpython-38.pyc b/spanfinder/sftp/modules/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa17f3eacf9df22967d83f5ffdc9c699e5d43cd7 Binary files /dev/null and b/spanfinder/sftp/modules/__pycache__/__init__.cpython-38.pyc differ diff --git a/spanfinder/sftp/modules/__pycache__/__init__.cpython-39.pyc b/spanfinder/sftp/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3d07f6a3a61ab0ed9db6ef87fca79bd77d2c338 Binary files /dev/null and b/spanfinder/sftp/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/spanfinder/sftp/modules/__pycache__/smooth_crf.cpython-38.pyc b/spanfinder/sftp/modules/__pycache__/smooth_crf.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc0fcc54ebe075fc3c7801fb9e25db6081b1ec6d Binary files /dev/null and b/spanfinder/sftp/modules/__pycache__/smooth_crf.cpython-38.pyc differ diff --git a/spanfinder/sftp/modules/__pycache__/smooth_crf.cpython-39.pyc b/spanfinder/sftp/modules/__pycache__/smooth_crf.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dab9c59785c0e4c9701510081666346fcfd842ad Binary files /dev/null and b/spanfinder/sftp/modules/__pycache__/smooth_crf.cpython-39.pyc differ diff --git a/spanfinder/sftp/modules/smooth_crf.py b/spanfinder/sftp/modules/smooth_crf.py new file mode 100644 index 0000000000000000000000000000000000000000..373392853c293e48566cc55bf9bf4f347deba81e --- /dev/null +++ b/spanfinder/sftp/modules/smooth_crf.py @@ -0,0 +1,77 @@ +import torch +from allennlp.modules.conditional_random_field import ConditionalRandomField +from allennlp.nn.util import logsumexp +from overrides import overrides + + +class SmoothCRF(ConditionalRandomField): + @overrides + def forward(self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.Tensor = None): + """ + + :param inputs: Shape [batch, token, tag] + :param tags: Shape [batch, token] or [batch, token, tag] + :param mask: Shape [batch, token] + :return: + """ + if mask is None: + mask = tags.new_ones(tags.shape, dtype=torch.bool) + mask = mask.to(dtype=torch.bool) + if tags.dim() == 2: + return super(SmoothCRF, self).forward(inputs, tags, mask) + + # smooth mode + log_denominator = self._input_likelihood(inputs, mask) + log_numerator = self._smooth_joint_likelihood(inputs, tags, mask) + + return torch.sum(log_numerator - log_denominator) + + def _smooth_joint_likelihood( + self, logits: torch.Tensor, soft_tags: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + batch_size, sequence_length, num_tags = logits.size() + + epsilon = 1e-30 + soft_tags = soft_tags.clone() + soft_tags[soft_tags < epsilon] = epsilon + + # Transpose batch size and sequence dimensions + mask = mask.transpose(0, 1).contiguous() + logits = logits.transpose(0, 1).contiguous() + soft_tags = soft_tags.transpose(0, 1).contiguous() + + # Initial alpha is the (batch_size, num_tags) tensor of likelihoods combining the + # transitions to the initial states and the logits for the first timestep. + if self.include_start_end_transitions: + alpha = self.start_transitions.view(1, num_tags) + logits[0] + soft_tags[0].log() + else: + alpha = logits[0] * soft_tags[0] + + # For each i we compute logits for the transitions from timestep i-1 to timestep i. + # We do so in a (batch_size, num_tags, num_tags) tensor where the axes are + # (instance, current_tag, next_tag) + for i in range(1, sequence_length): + # The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis. + emit_scores = logits[i].view(batch_size, 1, num_tags) + # Transition scores are (current_tag, next_tag) so we broadcast along the instance axis. + transition_scores = self.transitions.view(1, num_tags, num_tags) + # Alpha is for the current_tag, so we broadcast along the next_tag axis. + broadcast_alpha = alpha.view(batch_size, num_tags, 1) + + # Add all the scores together and logexp over the current_tag axis. + inner = broadcast_alpha + emit_scores + transition_scores + soft_tags[i].log().unsqueeze(1) + + # In valid positions (mask == True) we want to take the logsumexp over the current_tag dimension + # of `inner`. Otherwise (mask == False) we want to retain the previous alpha. + alpha = logsumexp(inner, 1) * mask[i].view(batch_size, 1) + alpha * ( + ~mask[i] + ).view(batch_size, 1) + + # Every sequence needs to end with a transition to the stop_tag. + if self.include_start_end_transitions: + stops = alpha + self.end_transitions.view(1, num_tags) + else: + stops = alpha + + # Finally we log_sum_exp along the num_tags dim, result is (batch_size,) + return logsumexp(stops) diff --git a/spanfinder/sftp/modules/span_extractor/__init__.py b/spanfinder/sftp/modules/span_extractor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..557848b1e5d61e54d428f0cb631d554f38e56e3a --- /dev/null +++ b/spanfinder/sftp/modules/span_extractor/__init__.py @@ -0,0 +1 @@ +from .combo import ComboSpanExtractor diff --git a/spanfinder/sftp/modules/span_extractor/__pycache__/__init__.cpython-38.pyc b/spanfinder/sftp/modules/span_extractor/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46e5b3775d800e0ffaff14afd8c9ca70081b2d5f Binary files /dev/null and b/spanfinder/sftp/modules/span_extractor/__pycache__/__init__.cpython-38.pyc differ diff --git a/spanfinder/sftp/modules/span_extractor/__pycache__/__init__.cpython-39.pyc b/spanfinder/sftp/modules/span_extractor/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..931e8f2e2f02ad020b45231047834a5b2b34c95a Binary files /dev/null and b/spanfinder/sftp/modules/span_extractor/__pycache__/__init__.cpython-39.pyc differ diff --git a/spanfinder/sftp/modules/span_extractor/__pycache__/combo.cpython-38.pyc b/spanfinder/sftp/modules/span_extractor/__pycache__/combo.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..743af3f2f824ae76f1f2e204b679d2a95afa95ee Binary files /dev/null and b/spanfinder/sftp/modules/span_extractor/__pycache__/combo.cpython-38.pyc differ diff --git a/spanfinder/sftp/modules/span_extractor/__pycache__/combo.cpython-39.pyc b/spanfinder/sftp/modules/span_extractor/__pycache__/combo.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31350731cd42dd299960f8a919e882660e1f082c Binary files /dev/null and b/spanfinder/sftp/modules/span_extractor/__pycache__/combo.cpython-39.pyc differ diff --git a/spanfinder/sftp/modules/span_extractor/combo.py b/spanfinder/sftp/modules/span_extractor/combo.py new file mode 100644 index 0000000000000000000000000000000000000000..8d1f08d608fad0549a94f1b60e5df40b0536eff6 --- /dev/null +++ b/spanfinder/sftp/modules/span_extractor/combo.py @@ -0,0 +1,36 @@ +from typing import * + +import torch +from allennlp.modules.span_extractors import SpanExtractor + + +@SpanExtractor.register('combo') +class ComboSpanExtractor(SpanExtractor): + def __init__(self, input_dim: int, sub_extractors: List[SpanExtractor]): + super().__init__() + self.sub_extractors = sub_extractors + for i, sub in enumerate(sub_extractors): + self.add_module(f'SpanExtractor-{i+1}', sub) + self.input_dim = input_dim + + def get_input_dim(self) -> int: + return self.input_dim + + def get_output_dim(self) -> int: + return sum([sub.get_output_dim() for sub in self.sub_extractors]) + + def forward( + self, + sequence_tensor: torch.FloatTensor, + span_indices: torch.LongTensor, + sequence_mask: torch.BoolTensor = None, + span_indices_mask: torch.BoolTensor = None, + ): + outputs = [ + sub( + sequence_tensor=sequence_tensor, + span_indices=span_indices, + span_indices_mask=span_indices_mask + ) for sub in self.sub_extractors + ] + return torch.cat(outputs, dim=2) diff --git a/spanfinder/sftp/modules/span_finder/__init__.py b/spanfinder/sftp/modules/span_finder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0988e4afe3439127a1b2df815f5088c3473193c --- /dev/null +++ b/spanfinder/sftp/modules/span_finder/__init__.py @@ -0,0 +1,2 @@ +from .bio_span_finder import BIOSpanFinder +from .span_finder import SpanFinder diff --git a/spanfinder/sftp/modules/span_finder/__pycache__/__init__.cpython-38.pyc b/spanfinder/sftp/modules/span_finder/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f206b1b9423a58aed87731a746cdad68d8340ed Binary files /dev/null and b/spanfinder/sftp/modules/span_finder/__pycache__/__init__.cpython-38.pyc differ diff --git a/spanfinder/sftp/modules/span_finder/__pycache__/__init__.cpython-39.pyc b/spanfinder/sftp/modules/span_finder/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecaf981b51f4cc8e8780aa2905f675304efdf2bc Binary files /dev/null and b/spanfinder/sftp/modules/span_finder/__pycache__/__init__.cpython-39.pyc differ diff --git a/spanfinder/sftp/modules/span_finder/__pycache__/bio_span_finder.cpython-38.pyc b/spanfinder/sftp/modules/span_finder/__pycache__/bio_span_finder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70d31fdbc48574bfd9add627893634a73f84710d Binary files /dev/null and b/spanfinder/sftp/modules/span_finder/__pycache__/bio_span_finder.cpython-38.pyc differ diff --git a/spanfinder/sftp/modules/span_finder/__pycache__/bio_span_finder.cpython-39.pyc b/spanfinder/sftp/modules/span_finder/__pycache__/bio_span_finder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cc479231079daefa573a6a668b425a5e33aec3a Binary files /dev/null and b/spanfinder/sftp/modules/span_finder/__pycache__/bio_span_finder.cpython-39.pyc differ diff --git a/spanfinder/sftp/modules/span_finder/__pycache__/span_finder.cpython-38.pyc b/spanfinder/sftp/modules/span_finder/__pycache__/span_finder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dae3e7a47368980e60837da013e7c4af7af96aee Binary files /dev/null and b/spanfinder/sftp/modules/span_finder/__pycache__/span_finder.cpython-38.pyc differ diff --git a/spanfinder/sftp/modules/span_finder/__pycache__/span_finder.cpython-39.pyc b/spanfinder/sftp/modules/span_finder/__pycache__/span_finder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..693c2a8cb92e40fb8c1f7ba420240661e78f2a8b Binary files /dev/null and b/spanfinder/sftp/modules/span_finder/__pycache__/span_finder.cpython-39.pyc differ diff --git a/spanfinder/sftp/modules/span_finder/bio_span_finder.py b/spanfinder/sftp/modules/span_finder/bio_span_finder.py new file mode 100644 index 0000000000000000000000000000000000000000..b124b6939a104d0f3a59b3aae3a3bbc900de5d7b --- /dev/null +++ b/spanfinder/sftp/modules/span_finder/bio_span_finder.py @@ -0,0 +1,216 @@ +from typing import * + +import torch +from allennlp.data.dataset_readers.dataset_utils.span_utils import bio_tags_to_spans +from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder +from allennlp.modules.span_extractors import SpanExtractor +from allennlp.training.metrics import FBetaMeasure + +from ..smooth_crf import SmoothCRF +from .span_finder import SpanFinder +from ...utils import num2mask, mask2idx, BIO + + +@SpanFinder.register("bio") +class BIOSpanFinder(SpanFinder): + """ + Train BIO representations for span finding. + """ + def __init__( + self, + bio_encoder: Seq2SeqEncoder, + label_emb: torch.nn.Embedding, + no_label: bool = True, + ): + super().__init__(no_label) + self.bio_encoder = bio_encoder + self.label_emb = label_emb + + self.classifier = torch.nn.Linear(bio_encoder.get_output_dim(), 3) + self.crf = SmoothCRF(3) + + self.fb_measure = FBetaMeasure(1., 'micro', [BIO.index('B'), BIO.index('I')]) + + def forward( + self, + token_vec: torch.Tensor, + token_mask: torch.Tensor, + span_vec: torch.Tensor, + span_mask: Optional[torch.Tensor] = None, # Do not need to provide + span_labels: Optional[torch.Tensor] = None, # Do not need to provide + parent_indices: Optional[torch.Tensor] = None, # Do not need to provide + parent_mask: Optional[torch.Tensor] = None, + bio_seqs: Optional[torch.Tensor] = None, + prediction: bool = False, + **extra + ) -> Dict[str, torch.Tensor]: + """ + See doc of SpanFinder. + Possible extra variables: + smoothing_factor + :return: + - loss + - prediction + """ + ret = dict() + is_soft = span_labels.dtype != torch.int64 + + distinct_parent_indices, num_parents = mask2idx(parent_mask) + n_batch, n_parent = distinct_parent_indices.shape + n_token = token_vec.shape[1] + # Shape [batch, parent, token_dim] + parent_span_features = span_vec.gather( + 1, distinct_parent_indices.unsqueeze(2).expand(-1, -1, span_vec.shape[2]) + ) + label_features = span_labels @ self.label_emb.weight if is_soft else self.label_emb(span_labels) + if self._no_label: + label_features = label_features.zero_() + # Shape [batch, span, label_dim] + parent_label_features = label_features.gather( + 1, distinct_parent_indices.unsqueeze(2).expand(-1, -1, label_features.shape[2]) + ) + # Shape [batch, parent, token, token_dim*2] + encoder_inputs = torch.cat([ + parent_span_features.unsqueeze(2).expand(-1, -1, n_token, -1), + token_vec.unsqueeze(1).expand(-1, n_parent, -1, -1), + parent_label_features.unsqueeze(2).expand(-1, -1, n_token, -1), + ], dim=3) + encoder_inputs = encoder_inputs.reshape(n_batch * n_parent, n_token, -1) + + # Shape [batch, parent]. Considers batches may have fewer seqs. + seq_mask = num2mask(num_parents) + # Shape [batch, parent, token]. Also considers batches may have fewer tokens. + token_mask = seq_mask.unsqueeze(2).expand(-1, -1, n_token) & token_mask.unsqueeze(1).expand(-1, n_parent, -1) + + class_in = self.bio_encoder(encoder_inputs, token_mask.flatten(0, 1)) + class_out = self.classifier(class_in).reshape(n_batch, n_parent, n_token, 3) + + if not prediction: + # For training + # We use `seq_mask` here because seq with length 0 is not acceptable. + ret['loss'] = -self.crf(class_out[seq_mask], bio_seqs[seq_mask], token_mask[seq_mask]) + self.fb_measure(class_out[seq_mask], bio_seqs[seq_mask].max(2).indices, token_mask[seq_mask]) + else: + # For prediction + features_for_decode = class_out.clone().detach() + decoded = self.crf.viterbi_tags(features_for_decode.flatten(0, 1), token_mask.flatten(0, 1)) + pred_tag = torch.tensor( + [path + [BIO.index('O')] * (n_token - len(path)) for path, _ in decoded] + ) + pred_tag = pred_tag.reshape(n_batch, n_parent, n_token) + ret['prediction'] = pred_tag + + return ret + + @staticmethod + def bio2boundary(seqs) -> Tuple[torch.Tensor, torch.Tensor]: + def recursive_construct_spans(seqs_): + """ + Helper function for bio2boundary + Recursively convert seqs of integers to boundary indices. + Return boundary indices and corresponding lens + """ + if isinstance(seqs_, torch.Tensor): + if seqs_.device.type == 'cuda': + seqs_ = seqs_.to(device='cpu') + seqs_ = seqs_.tolist() + if isinstance(seqs_[0], int): + seqs_ = [BIO[i] for i in seqs_] + span_boundary_list = bio_tags_to_spans(seqs_) + return torch.tensor([item[1] for item in span_boundary_list]), len(span_boundary_list) + span_boundary = list() + lens_ = list() + for seq in seqs_: + one_bou, one_len = recursive_construct_spans(seq) + span_boundary.append(one_bou) + lens_.append(one_len) + if isinstance(lens_[0], int): + lens_ = torch.tensor(lens_) + else: + lens_ = torch.stack(lens_) + return span_boundary, lens_ + + boundary_list, lens = recursive_construct_spans(seqs) + max_span = int(lens.max()) + boundary = torch.zeros((*lens.shape, max_span, 2), dtype=torch.long) + + def recursive_copy(list_var, tensor_var): + if len(list_var) == 0: + return + if isinstance(list_var, torch.Tensor): + tensor_var[:len(list_var)] = list_var + return + assert len(list_var) == len(tensor_var) + for list_var_, tensor_var_ in zip(list_var, tensor_var): + recursive_copy(list_var_, tensor_var_) + + recursive_copy(boundary_list, boundary) + + return boundary, lens + + def inference_forward_handler( + self, + token_vec: torch.Tensor, + token_mask: torch.Tensor, + span_extractor: SpanExtractor, + **auxiliaries, + ) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], None]: + """ + Refer to the doc of the SpanFinder for definition of this function. + """ + + def handler( + span_boundary: torch.Tensor, + span_labels: torch.Tensor, + parent_mask: torch.Tensor, + parent_indices: torch.Tensor, + cursor: torch.tensor, + ): + """ + Refer to the doc of the SpanFinder for definition of this function. + """ + max_decoding_span = span_boundary.shape[1] + # Shape [batch, span, token_dim] + span_vec = span_extractor(token_vec, span_boundary) + # Shape [batch, parent] + parent_indices_at_span, _ = mask2idx(parent_mask) + pred_bio = self( + token_vec, token_mask, span_vec, None, span_labels, None, parent_mask, prediction=True + )['prediction'] + # Shape [batch, parent, span, 2]; Shape [batch, parent] + pred_boundary, pred_num = self.bio2boundary(pred_bio) + if pred_boundary.device != span_boundary.device: + pred_boundary = pred_boundary.to(device=span_boundary.device) + pred_num = pred_num.to(device=span_boundary.device) + # Shape [batch, parent, span] + pred_mask = num2mask(pred_num) + + # Parent Loop + for pred_boundary_parent, pred_mask_parent, parent_indices_parent \ + in zip(pred_boundary.unbind(1), pred_mask.unbind(1), parent_indices_at_span.unbind(1)): + for pred_boundary_step, step_mask in zip(pred_boundary_parent.unbind(1), pred_mask_parent.unbind(1)): + step_mask &= cursor < max_decoding_span + parent_indices[step_mask] = parent_indices[step_mask].scatter( + 1, + cursor[step_mask].unsqueeze(1), + parent_indices_parent[step_mask].unsqueeze(1) + ) + span_boundary[step_mask] = span_boundary[step_mask].scatter( + 1, + cursor[step_mask].reshape(-1, 1, 1).expand(-1, -1, 2), + pred_boundary_step[step_mask].unsqueeze(1) + ) + cursor[step_mask] += 1 + + return handler + + def get_metrics(self, reset: bool = False) -> Dict[str, float]: + score = self.fb_measure.get_metric(reset) + if reset: + return { + 'finder_p': score['precision'] * 100, + 'finder_r': score['recall'] * 100, + 'finder_f': score['fscore'] * 100, + } + else: + return {'finder_f': score['fscore'] * 100} diff --git a/spanfinder/sftp/modules/span_finder/span_finder.py b/spanfinder/sftp/modules/span_finder/span_finder.py new file mode 100644 index 0000000000000000000000000000000000000000..276f723b8976a3c377138b0e2886653c040c43eb --- /dev/null +++ b/spanfinder/sftp/modules/span_finder/span_finder.py @@ -0,0 +1,87 @@ +from abc import ABC, abstractmethod +from typing import * + +import torch +from allennlp.common import Registrable +from allennlp.modules.span_extractors import SpanExtractor + + +class SpanFinder(Registrable, ABC, torch.nn.Module): + """ + Model the probability p(child_span | parent_span [, parent_label]) + It's optional to model parent_label, since in some cases we may want the parameters to be shared across + different tasks, where we may have similar span semantics but different label space. + """ + def __init__( + self, + no_label: bool = True, + ): + """ + :param no_label: If True, will not use input labels as features and use all 0 vector instead. + """ + super().__init__() + self._no_label = no_label + + @abstractmethod + def forward( + self, + token_vec: torch.Tensor, + token_mask: torch.Tensor, + span_vec: torch.Tensor, + span_mask: Optional[torch.Tensor] = None, # Do not need to provide + span_labels: Optional[torch.Tensor] = None, # Do not need to provide + parent_indices: Optional[torch.Tensor] = None, # Do not need to provide + parent_mask: Optional[torch.Tensor] = None, + bio_seqs: Optional[torch.Tensor] = None, + prediction: bool = False, + **extra + ) -> Dict[str, torch.Tensor]: + """ + Return training loss and predictions. + :param token_vec: Vector representation of tokens. Shape [batch, token ,token_dim] + :param token_mask: True for non-padding tokens. + :param span_vec: Vector representation of spans. Shape [batch, span, token_dim] + :param span_mask: True for non-padding spans. Shape [batch, span] + :param span_labels: The labels of spans. Shape [batch, span] + :param parent_indices: Parent indices of spans. Shape [batch, span] + :param parent_mask: True for parent spans. Shape [batch, span] + :param prediction: If True, no loss will be return & no metrics will be updated. + :param bio_seqs: BIO sequences. Shape [batch, parent, token, 3] + :return: + loss: Training loss + prediction: Shape [batch, span]. True for positive predictions. + """ + raise NotImplementedError + + @abstractmethod + def inference_forward_handler( + self, + token_vec: torch.Tensor, + token_mask: torch.Tensor, + span_extractor: SpanExtractor, + **auxiliaries, + ) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], None]: + """ + Pre-process some information and return a callable module for p(child_span | parent_span [,parent_label]) + :param token_vec: Vector representation of tokens. Shape [batch, token ,token_dim] + :param token_mask: True for non-padding tokens. + :param span_extractor: The same module in model. + :param auxiliaries: Environment variables. You can pass extra environment variables + since the extras will be ignored. + :return: + A callable function in a closure. + The arguments for the callable object are: + - span_boundary: Shape [batch, span, 2] + - span_labels: Shape [batch, span] + - parent_mask: Shape [batch, span] + - parent_indices: Shape [batch, span] + - cursor: Shape [batch] + No return values. Everything should be done inplace. + Note the span indexing space has different meaning from training process. We don't have gold span list, + so span here refers to the predicted spans. + """ + raise NotImplementedError + + @abstractmethod + def get_metrics(self, reset: bool = False) -> Dict[str, float]: + raise NotImplementedError diff --git a/spanfinder/sftp/modules/span_typing/__init__.py b/spanfinder/sftp/modules/span_typing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67601eb4c5375d9e742574d10715381b74594324 --- /dev/null +++ b/spanfinder/sftp/modules/span_typing/__init__.py @@ -0,0 +1,2 @@ +from .mlp_span_typing import MLPSpanTyping +from .span_typing import SpanTyping diff --git a/spanfinder/sftp/modules/span_typing/__pycache__/__init__.cpython-38.pyc b/spanfinder/sftp/modules/span_typing/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7f859f2adcdcd680734cb072c4c3a7b752a7725 Binary files /dev/null and b/spanfinder/sftp/modules/span_typing/__pycache__/__init__.cpython-38.pyc differ diff --git a/spanfinder/sftp/modules/span_typing/__pycache__/__init__.cpython-39.pyc b/spanfinder/sftp/modules/span_typing/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9be845b094d8562582ed2bad4128da47045e3cae Binary files /dev/null and b/spanfinder/sftp/modules/span_typing/__pycache__/__init__.cpython-39.pyc differ diff --git a/spanfinder/sftp/modules/span_typing/__pycache__/mlp_span_typing.cpython-38.pyc b/spanfinder/sftp/modules/span_typing/__pycache__/mlp_span_typing.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa579d504035e59e0c04e5184864c49540821699 Binary files /dev/null and b/spanfinder/sftp/modules/span_typing/__pycache__/mlp_span_typing.cpython-38.pyc differ diff --git a/spanfinder/sftp/modules/span_typing/__pycache__/mlp_span_typing.cpython-39.pyc b/spanfinder/sftp/modules/span_typing/__pycache__/mlp_span_typing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04a8a39455a459dfcc336b8516029ae40ba254ad Binary files /dev/null and b/spanfinder/sftp/modules/span_typing/__pycache__/mlp_span_typing.cpython-39.pyc differ diff --git a/spanfinder/sftp/modules/span_typing/__pycache__/span_typing.cpython-38.pyc b/spanfinder/sftp/modules/span_typing/__pycache__/span_typing.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcff5d34127c5fc33a8f3160c9b43971050f9c88 Binary files /dev/null and b/spanfinder/sftp/modules/span_typing/__pycache__/span_typing.cpython-38.pyc differ diff --git a/spanfinder/sftp/modules/span_typing/__pycache__/span_typing.cpython-39.pyc b/spanfinder/sftp/modules/span_typing/__pycache__/span_typing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19f18f7d840f9f52acbafa9c4edc108bbb3a9bc4 Binary files /dev/null and b/spanfinder/sftp/modules/span_typing/__pycache__/span_typing.cpython-39.pyc differ diff --git a/spanfinder/sftp/modules/span_typing/mlp_span_typing.py b/spanfinder/sftp/modules/span_typing/mlp_span_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..dd03e7f529c44436ed1b03da3b0342c1b07d13ef --- /dev/null +++ b/spanfinder/sftp/modules/span_typing/mlp_span_typing.py @@ -0,0 +1,99 @@ +from typing import * + +import torch +from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax + +from .span_typing import SpanTyping + + +@SpanTyping.register('mlp') +class MLPSpanTyping(SpanTyping): + """ + An MLP implementation for Span Typing. + """ + def __init__( + self, + input_dim: int, + hidden_dims: List[int], + label_emb: torch.nn.Embedding, + n_category: int, + label_to_ignore: Optional[List[int]] = None + ): + """ + :param input_dim: dim(parent_span) + dim(child_span) + dim(label_dim) + :param hidden_dims: The dim of hidden layers of MLP. + :param n_category: #labels + :param label_emb: Embeds labels to vectors. + """ + super().__init__(label_emb.num_embeddings, label_to_ignore, ) + self.MLPs: List[torch.nn.Linear] = list() + for i_mlp, output_dim in enumerate(hidden_dims + [n_category]): + mlp = torch.nn.Linear(input_dim, output_dim, bias=True) + self.MLPs.append(mlp) + self.add_module(f'MLP-{i_mlp}', mlp) + input_dim = output_dim + + # Embeds labels as features. + self.label_emb = label_emb + + def forward( + self, + span_vec: torch.Tensor, + parent_at_span: torch.Tensor, + span_labels: Optional[torch.Tensor], + prediction_only: bool = False, + ) -> Dict[str, torch.Tensor]: + """ + Inputs: All features for typing a child span. + Process: Update the metric. + Output: The loss of typing and predictions. + :return: + loss: Loss for label prediction. + prediction: Predicted labels. + """ + is_soft = span_labels.dtype != torch.int64 + # Shape [batch, span, label_dim] + label_vec = span_labels @ self.label_emb.weight if is_soft else self.label_emb(span_labels) + n_batch, n_span, _ = label_vec.shape + n_label, _ = self.ontology.shape + # Shape [batch, span, label_dim] + parent_label_features = label_vec.gather(1, parent_at_span.unsqueeze(2).expand_as(label_vec)) + # Shape [batch, span, token_dim] + parent_span_features = span_vec.gather(1, parent_at_span.unsqueeze(2).expand_as(span_vec)) + # Shape [batch, span, token_dim] + child_span_features = span_vec + + features = torch.cat([parent_label_features, parent_span_features, child_span_features], dim=2) + # Shape [batch, span, label] + for mlp in self.MLPs[:-1]: + features = torch.relu(mlp(features)) + logits = self.MLPs[-1](features) + + logits_for_prediction = logits.clone() + + if not is_soft: + # Shape [batch, span] + parent_labels = span_labels.gather(1, parent_at_span) + onto_mask = self.ontology.unsqueeze(0).expand(n_batch, -1, -1).gather( + 1, parent_labels.unsqueeze(2).expand(-1, -1, n_label) + ) + logits_for_prediction[~onto_mask] = float('-inf') + + label_dist = torch.softmax(logits_for_prediction, 2) + label_confidence, predictions = label_dist.max(2) + ret = {'prediction': predictions, 'label_confidence': label_confidence, 'distribution': label_dist} + if prediction_only: + return ret + + span_labels = span_labels.clone() + + if is_soft: + self.acc_metric(logits_for_prediction, span_labels.max(2)[1], ~span_labels.sum(2).isclose(torch.tensor(0.))) + ret['loss'] = KLDivLoss(reduction='sum')(LogSoftmax(dim=2)(logits), span_labels) + else: + for label_idx in self.label_to_ignore: + span_labels[span_labels == label_idx] = -100 + self.acc_metric(logits_for_prediction, span_labels, span_labels != -100) + ret['loss'] = CrossEntropyLoss(reduction='sum')(logits.flatten(0, 1), span_labels.flatten()) + + return ret diff --git a/spanfinder/sftp/modules/span_typing/span_typing.py b/spanfinder/sftp/modules/span_typing/span_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..7a9848410d6e354c2d1598b7ee652b293b0a9314 --- /dev/null +++ b/spanfinder/sftp/modules/span_typing/span_typing.py @@ -0,0 +1,64 @@ +from abc import ABC +from typing import * + +import torch +from allennlp.common import Registrable +from allennlp.data.vocabulary import DEFAULT_OOV_TOKEN, Vocabulary +from allennlp.training.metrics import CategoricalAccuracy + + +class SpanTyping(Registrable, torch.nn.Module, ABC): + """ + Models the probability p(child_label | child_span, parent_span, parent_label). + """ + def __init__( + self, + n_label: int, + label_to_ignore: Optional[List[int]] = None, + ): + """ + :param label_to_ignore: Label indexes in this list will be ignored. + Usually this should include NULL, PADDING and UNKNOWN. + """ + super().__init__() + self.label_to_ignore = label_to_ignore or list() + self.acc_metric = CategoricalAccuracy() + self.onto = torch.ones([n_label, n_label], dtype=torch.bool) + self.register_buffer('ontology', self.onto) + + def load_ontology(self, path: str, vocab: Vocabulary): + unk_id = vocab.get_token_index(DEFAULT_OOV_TOKEN, 'span_label') + for line in open(path).readlines(): + entities = [vocab.get_token_index(ent, 'span_label') for ent in line.replace('\n', '').split('\t')] + parent, children = entities[0], entities[1:] + if parent == unk_id: + continue + self.onto[parent, :] = False + children = list(filter(lambda x: x != unk_id, children)) + self.onto[parent, children] = True + self.register_buffer('ontology', self.onto) + + def forward( + self, + span_vec: torch.Tensor, + parent_at_span: torch.Tensor, + span_labels: Optional[torch.Tensor], + prediction_only: bool = False, + ) -> Dict[str, torch.Tensor]: + """ + Inputs: All features for typing a child span. + Output: The loss of typing and predictions. + :param span_vec: Shape [batch, span, token_dim] + :param parent_at_span: Shape [batch, span] + :param span_labels: Shape [batch, span] + :param prediction_only: If True, no loss returned & metric will not be updated + :return: + loss: Loss for label prediction. (absent of pred_only = True) + prediction: Predicted labels. + """ + raise NotImplementedError + + def get_metric(self, reset): + return{ + "typing_acc": self.acc_metric.get_metric(reset) * 100 + } diff --git a/spanfinder/sftp/predictor/__init__.py b/spanfinder/sftp/predictor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..591fe1601a28d616661017e9ae1af4ce5806f557 --- /dev/null +++ b/spanfinder/sftp/predictor/__init__.py @@ -0,0 +1 @@ +from .span_predictor import SpanPredictor diff --git a/spanfinder/sftp/predictor/__pycache__/__init__.cpython-38.pyc b/spanfinder/sftp/predictor/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e11283d015905a0ba897d91c2861c9034f7a8586 Binary files /dev/null and b/spanfinder/sftp/predictor/__pycache__/__init__.cpython-38.pyc differ diff --git a/spanfinder/sftp/predictor/__pycache__/__init__.cpython-39.pyc b/spanfinder/sftp/predictor/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8417da8616b655906a2ff7b2691811efc070a946 Binary files /dev/null and b/spanfinder/sftp/predictor/__pycache__/__init__.cpython-39.pyc differ diff --git a/spanfinder/sftp/predictor/__pycache__/span_predictor.cpython-38.pyc b/spanfinder/sftp/predictor/__pycache__/span_predictor.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c5c4ad0a8449ef965aa4b29e6a586214b04abb6 Binary files /dev/null and b/spanfinder/sftp/predictor/__pycache__/span_predictor.cpython-38.pyc differ diff --git a/spanfinder/sftp/predictor/__pycache__/span_predictor.cpython-39.pyc b/spanfinder/sftp/predictor/__pycache__/span_predictor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d9bfdb9cd905d571f587215076ebb51d54429ae Binary files /dev/null and b/spanfinder/sftp/predictor/__pycache__/span_predictor.cpython-39.pyc differ diff --git a/spanfinder/sftp/predictor/span_predictor.orig.py b/spanfinder/sftp/predictor/span_predictor.orig.py new file mode 100644 index 0000000000000000000000000000000000000000..e84fb18f87fd57d51e1f800351ad065cc32b87e0 --- /dev/null +++ b/spanfinder/sftp/predictor/span_predictor.orig.py @@ -0,0 +1,362 @@ +import os +from time import time +from typing import * +import json + +import numpy as np +import torch +from allennlp.common.util import JsonDict, sanitize +from allennlp.data import DatasetReader, Instance +from allennlp.data.data_loaders import SimpleDataLoader +from allennlp.data.samplers import MaxTokensBatchSampler +from allennlp.data.tokenizers import SpacyTokenizer +from allennlp.models import Model +from allennlp.nn import util as nn_util +from allennlp.predictors import Predictor +from concrete import ( + MentionArgument, SituationMentionSet, SituationMention, TokenRefSequence, + EntityMention, EntityMentionSet, Entity, EntitySet, AnnotationMetadata, Communication +) +from concrete.util import CommunicationReader, AnalyticUUIDGeneratorFactory, CommunicationWriterZip +from concrete.validate import validate_communication + +from ..data_reader import concrete_doc, concrete_doc_tokenized +from ..utils import Span, re_index_span, VIRTUAL_ROOT + + +class PredictionReturn(NamedTuple): + span: Union[Span, dict, Communication] + sentence: List[str] + meta: Dict[str, Any] + + +class ForceDecodingReturn(NamedTuple): + span: np.ndarray + label: List[str] + distribution: np.ndarray + + +@Predictor.register('span') +class SpanPredictor(Predictor): + @staticmethod + def format_convert( + sentence: Union[List[str], List[List[str]]], + prediction: Union[Span, List[Span]], + output_format: str + ): + if output_format == 'span': + return prediction + elif output_format == 'json': + if isinstance(prediction, list): + return [SpanPredictor.format_convert(sent, pred, 'json') for sent, pred in zip(sentence, prediction)] + return prediction.to_json() + elif output_format == 'concrete': + if isinstance(prediction, Span): + sentence, prediction = [sentence], [prediction] + return concrete_doc_tokenized(sentence, prediction) + + def predict_concrete( + self, + concrete_path: str, + output_path: Optional[str] = None, + max_tokens: int = 2048, + ontology_mapping: Optional[Dict[str, str]] = None, + ): + os.makedirs(os.path.dirname(output_path), exist_ok=True) + writer = CommunicationWriterZip(output_path) + + for comm, fn in CommunicationReader(concrete_path): + assert len(comm.sectionList) == 1 + concrete_sentences = comm.sectionList[0].sentenceList + json_sentences = list() + for con_sent in concrete_sentences: + json_sentences.append( + [t.text for t in con_sent.tokenization.tokenList.tokenList] + ) + predictions = self.predict_batch_sentences(json_sentences, max_tokens, ontology_mapping=ontology_mapping) + + # Merge predictions into concrete + aug = AnalyticUUIDGeneratorFactory(comm).create() + situation_mention_set = SituationMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list()) + comm.situationMentionSetList = [situation_mention_set] + situation_mention_set.mentionList = sm_list = list() + entity_mention_set = EntityMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list()) + comm.entityMentionSetList = [entity_mention_set] + entity_mention_set.mentionList = em_list = list() + entity_set = EntitySet( + next(aug), AnnotationMetadata('Span Finder', time()), list(), None, entity_mention_set.uuid + ) + comm.entitySetList = [entity_set] + + em_dict = dict() + for con_sent, pred in zip(concrete_sentences, predictions): + for event in pred.span: + def raw_text_span(start_idx, end_idx, **_): + si_char = con_sent.tokenization.tokenList.tokenList[start_idx].textSpan.start + ei_char = con_sent.tokenization.tokenList.tokenList[end_idx].textSpan.ending + return comm.text[si_char:ei_char] + sm = SituationMention( + next(aug), + text=raw_text_span(event.start_idx, event.end_idx), + situationKind=event.label, + situationType='EVENT', + confidence=event.confidence, + argumentList=list(), + tokens=TokenRefSequence( + tokenIndexList=list(range(event.start_idx, event.end_idx+1)), + tokenizationId=con_sent.tokenization.uuid + ) + ) + + for arg in event: + em = em_dict.get((arg.start_idx, arg.end_idx + 1)) + if em is None: + em = EntityMention( + next(aug), + tokens=TokenRefSequence( + tokenIndexList=list(range(arg.start_idx, arg.end_idx+1)), + tokenizationId=con_sent.tokenization.uuid, + ), + text=raw_text_span(arg.start_idx, arg.end_idx) + ) + em_list.append(em) + entity_set.entityList.append(Entity(next(aug), id=em.text, mentionIdList=[em.uuid])) + em_dict[(arg.start_idx, arg.end_idx+1)] = em + sm.argumentList.append(MentionArgument( + role=arg.label, + entityMentionId=em.uuid, + confidence=arg.confidence + )) + sm_list.append(sm) + validate_communication(comm) + writer.write(comm, fn) + writer.close() + + def predict_sentence( + self, + sentence: Union[str, List[str]], + ontology_mapping: Optional[Dict[str, str]] = None, + output_format: str = 'span', + ) -> PredictionReturn: + """ + Predict spans on a single sentence (no batch). If not tokenized, will tokenize it with SpacyTokenizer. + :param sentence: If tokenized, should be a list of tokens in string. If not, should be a string. + :param ontology_mapping: + :param output_format: span, json or concrete. + """ + prediction = self.predict_json(self._prepare_sentence(sentence)) + prediction['prediction'] = self.format_convert( + prediction['sentence'], + Span.from_json(prediction['prediction']).map_ontology(ontology_mapping), + output_format + ) + return PredictionReturn(prediction['prediction'], prediction['sentence'], prediction.get('meta', dict())) + + def predict_batch_sentences( + self, + sentences: List[Union[List[str], str]], + max_tokens: int = 512, + ontology_mapping: Optional[Dict[str, str]] = None, + output_format: str = 'span', + ) -> List[PredictionReturn]: + """ + Predict spans on a batch of sentences. If not tokenized, will tokenize it with SpacyTokenizer. + :param sentences: A list of sentences. Refer to `predict_sentence`. + :param max_tokens: Maximum tokens in a batch. + :param ontology_mapping: If not None, will try to map the output from one ontology to another. + If the predicted frame is not in the mapping, the prediction will be ignored. + :param output_format: span, json or concrete. + :return: A list of predictions. + """ + sentences = list(map(self._prepare_sentence, sentences)) + for i_sent, sent in enumerate(sentences): + sent['meta'] = {"idx": i_sent} + instances = list(map(self._json_to_instance, sentences)) + outputs = list() + for ins_indices in MaxTokensBatchSampler(max_tokens, ["tokens"], 0.0).get_batch_indices(instances): + batch_ins = list( + SimpleDataLoader([instances[ins_idx] for ins_idx in ins_indices], len(ins_indices), vocab=self.vocab) + )[0] + batch_inputs = nn_util.move_to_device(batch_ins, device=self.cuda_device) + batch_outputs = self._model(**batch_inputs) + for meta, prediction, inputs in zip( + batch_outputs['meta'], batch_outputs['prediction'], batch_outputs['inputs'] + ): + prediction.map_ontology(ontology_mapping) + prediction = self.format_convert(inputs['sentence'], prediction, output_format) + outputs.append(PredictionReturn(prediction, inputs['sentence'], {"input_idx": meta['idx']})) + + outputs.sort(key=lambda x: x.meta['input_idx']) + return outputs + + def predict_instance(self, instance: Instance) -> JsonDict: + outputs = self._model.forward_on_instance(instance) + outputs = sanitize(outputs) + return { + 'prediction': outputs['prediction'], + 'sentence': outputs['inputs']['sentence'], + 'meta': outputs.get('meta', {}) + } + + def __init__( + self, + model: Model, + dataset_reader: DatasetReader, + frozen: bool = True, + ): + super(SpanPredictor, self).__init__(model=model, dataset_reader=dataset_reader, frozen=frozen) + self.spacy_tokenizer = SpacyTokenizer(language='en_core_web_sm') + + def economize( + self, + max_decoding_spans: Optional[int] = None, + max_recursion_depth: Optional[int] = None, + ): + if max_decoding_spans: + self._model._max_decoding_spans = max_decoding_spans + if max_recursion_depth: + self._model._max_recursion_depth = max_recursion_depth + + def _json_to_instance(self, json_dict: JsonDict) -> Instance: + return self._dataset_reader.text_to_instance(**json_dict) + + @staticmethod + def to_nested(prediction: List[dict]): + first_layer, idx2children = list(), dict() + for idx, pred in enumerate(prediction): + children = list() + pred['children'] = idx2children[idx+1] = children + if pred['parent'] == 0: + first_layer.append(pred) + else: + idx2children[pred['parent']].append(pred) + del pred['parent'] + return first_layer + + def _prepare_sentence(self, sentence: Union[str, List[str]]) -> Dict[str, List[str]]: + if isinstance(sentence, str): + while ' ' in sentence: + sentence = sentence.replace(' ', ' ') + sentence = sentence.replace(chr(65533), '') + if sentence == '': + sentence = [""] + sentence = list(map(str, self.spacy_tokenizer.tokenize(sentence))) + return {"tokens": sentence} + + @staticmethod + def json_to_concrete( + predictions: List[dict], + ): + sentences = list() + for pred in predictions: + tokenization, event = list(), list() + sent = {'text': ' '.join(pred['inputs']), 'tokenization': tokenization, 'event': event} + sentences.append(sent) + start_idx = 0 + for token in pred['inputs']: + tokenization.append((start_idx, len(token)-1+start_idx)) + start_idx += len(token) + 1 + for pred_event in pred['prediction']: + arg_list = list() + one_event = {'argument': arg_list} + event.append(one_event) + for key in ['start_idx', 'end_idx', 'label']: + one_event[key] = pred_event[key] + for pred_arg in pred_event['children']: + arg_list.append({key: pred_arg[key] for key in ['start_idx', 'end_idx', 'label']}) + + concrete_comm = concrete_doc(sentences) + return concrete_comm + + def force_decode( + self, + sentence: List[str], + parent_span: Tuple[int, int] = (-1, -1), + parent_label: str = VIRTUAL_ROOT, + child_spans: Optional[List[Tuple[int, int]]] = None, + ) -> ForceDecodingReturn: + """ + Force decoding. There are 2 modes: + 1. Given parent span and its label, find all it children (direct children, not including other descendents) + and type them. + 2. Given parent span, parent label, and children spans, type all children. + :param sentence: Tokens. + :param parent_span: [start_idx, end_idx], both inclusive. + :param parent_label: Parent label in string. + :param child_spans: Optional. If provided, will turn to mode 2; else mode 1. + :return: + - span: children spans. + - label: most probable labels of children. + - distribution: distribution over children labels. + """ + instance = self._dataset_reader.text_to_instance(self._prepare_sentence(sentence)['tokens']) + model_input = nn_util.move_to_device( + list(SimpleDataLoader([instance], 1, vocab=self.vocab))[0], device=self.cuda_device + ) + offsets = instance.fields['raw_inputs'].metadata['offsets'] + + with torch.no_grad(): + tokens = model_input['tokens'] + parent_span = re_index_span(parent_span, offsets) + if parent_span[1] >= self._dataset_reader.max_length: + return ForceDecodingReturn( + np.zeros([0, 2], dtype=np.int), + [], + np.zeros([0, self.vocab.get_vocab_size('span_label')], dtype=np.float64) + ) + if child_spans is not None: + token_vec = self._model.word_embedding(tokens) + child_pieces = [re_index_span(bdr, offsets) for bdr in child_spans] + child_pieces = list(filter(lambda x: x[1] < self._dataset_reader.max_length-1, child_pieces)) + span_tensor = torch.tensor( + [parent_span] + child_pieces, dtype=torch.int64, device=self.device + ).unsqueeze(0) + parent_indices = span_tensor.new_zeros(span_tensor.shape[0:2]) + span_labels = parent_indices.new_full( + parent_indices.shape, self._model.vocab.get_token_index(parent_label, 'span_label') + ) + span_vec = self._model._span_extractor(token_vec, span_tensor) + typing_out = self._model._span_typing(span_vec, parent_indices, span_labels) + distribution = typing_out['distribution'][0, 1:].cpu().numpy() + boundary = np.array(child_spans) + else: + parent_label_tensor = torch.tensor( + [self._model.vocab.get_token_index(parent_label, 'span_label')], device=self.device + ) + parent_boundary_tensor = torch.tensor([parent_span], device=self.device) + boundary, _, num_children, distribution = self._model.one_step_prediction( + tokens, parent_boundary_tensor, parent_label_tensor + ) + boundary, distribution = boundary[0].cpu().tolist(), distribution[0].cpu().numpy() + boundary = np.array([re_index_span(bdr, offsets, True) for bdr in boundary]) + + labels = [ + self.vocab.get_token_from_index(label_idx, 'span_label') for label_idx in distribution.argmax(1) + ] + return ForceDecodingReturn(boundary, labels, distribution) + + @property + def vocab(self): + return self._model.vocab + + @property + def device(self): + return self.cuda_device if self.cuda_device > -1 else 'cpu' + + @staticmethod + def read_ontology_mapping(file_path: str): + """ + Read the ontology mapping file. The file format can be read in docs. + """ + if file_path is None: + return None + if file_path.endswith('.json'): + return json.load(open(file_path)) + mapping = dict() + for line in open(file_path).readlines(): + parent_label, original_label, new_label = line.replace('\n', '').split('\t') + if parent_label == '*': + mapping[original_label] = new_label + else: + mapping[(parent_label, original_label)] = new_label + return mapping diff --git a/spanfinder/sftp/predictor/span_predictor.py b/spanfinder/sftp/predictor/span_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..8bd3fd756c0777e9284c53d8ee797f59767a49be --- /dev/null +++ b/spanfinder/sftp/predictor/span_predictor.py @@ -0,0 +1,405 @@ +import os +from time import time +from typing import * +import json + +# # ---GFM add debugger +# import pdb +# # end--- + +import numpy as np +import torch +from allennlp.common.util import JsonDict, sanitize +from allennlp.data import DatasetReader, Instance +from allennlp.data.data_loaders import SimpleDataLoader +from allennlp.data.samplers import MaxTokensBatchSampler +from allennlp.data.tokenizers import SpacyTokenizer +from allennlp.models import Model +from allennlp.nn import util as nn_util +from allennlp.predictors import Predictor +from concrete import ( + MentionArgument, SituationMentionSet, SituationMention, TokenRefSequence, + EntityMention, EntityMentionSet, Entity, EntitySet, AnnotationMetadata, Communication +) +from concrete.util import CommunicationReader, AnalyticUUIDGeneratorFactory, CommunicationWriterZip +from concrete.validate import validate_communication + +from ..data_reader import concrete_doc, concrete_doc_tokenized +from ..utils import Span, re_index_span, VIRTUAL_ROOT + + +class PredictionReturn(NamedTuple): + span: Union[Span, dict, Communication] + sentence: List[str] + meta: Dict[str, Any] + + +class ForceDecodingReturn(NamedTuple): + span: np.ndarray + label: List[str] + distribution: np.ndarray + + +@Predictor.register('span') +class SpanPredictor(Predictor): + @staticmethod + def format_convert( + sentence: Union[List[str], List[List[str]]], + prediction: Union[Span, List[Span]], + output_format: str + ): + if output_format == 'span': + return prediction + elif output_format == 'json': + if isinstance(prediction, list): + return [SpanPredictor.format_convert(sent, pred, 'json') for sent, pred in zip(sentence, prediction)] + return prediction.to_json() + elif output_format == 'concrete': + if isinstance(prediction, Span): + sentence, prediction = [sentence], [prediction] + return concrete_doc_tokenized(sentence, prediction) + + def predict_concrete( + self, + concrete_path: str, + output_path: Optional[str] = None, + max_tokens: int = 2048, + ontology_mapping: Optional[Dict[str, str]] = None, + ): + os.makedirs(os.path.dirname(output_path), exist_ok=True) + writer = CommunicationWriterZip(output_path) + + print(concrete_path) + for comm, fn in CommunicationReader(concrete_path): + print(fn) + assert len(comm.sectionList) == 1 + concrete_sentences = comm.sectionList[0].sentenceList + json_sentences = list() + for con_sent in concrete_sentences: + json_sentences.append( + [t.text for t in con_sent.tokenization.tokenList.tokenList] + ) + predictions = self.predict_batch_sentences(json_sentences, max_tokens, ontology_mapping=ontology_mapping) + + # Merge predictions into concrete + aug = AnalyticUUIDGeneratorFactory(comm).create() + situation_mention_set = SituationMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list()) + comm.situationMentionSetList = [situation_mention_set] + situation_mention_set.mentionList = sm_list = list() + entity_mention_set = EntityMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list()) + comm.entityMentionSetList = [entity_mention_set] + entity_mention_set.mentionList = em_list = list() + entity_set = EntitySet( + next(aug), AnnotationMetadata('Span Finder', time()), list(), None, entity_mention_set.uuid + ) + comm.entitySetList = [entity_set] + + em_dict = dict() + for con_sent, pred in zip(concrete_sentences, predictions): + for event in pred.span: + def raw_text_span(start_idx, end_idx, **_): + si_char = con_sent.tokenization.tokenList.tokenList[start_idx].textSpan.start + ei_char = con_sent.tokenization.tokenList.tokenList[end_idx].textSpan.ending + return comm.text[si_char:ei_char] + + # ---GFM: added this to get around off-by-one errors (unclear why these arise) + event_start_idx = event.start_idx + event_end_idx = event.end_idx + if event_end_idx > len(con_sent.tokenization.tokenList.tokenList) - 1: + print("WARNING: invalid `event_end_idx` passed for sentence, adjusting to final token") + print("\tsentence:", con_sent.tokenization.tokenList) + print("event_end_idx:", event_end_idx) + print("length:", len(con_sent.tokenization.tokenList.tokenList)) + event_end_idx = len(con_sent.tokenization.tokenList.tokenList) - 1 + print("new event_end_idx:", event_end_idx) + print() + # end--- + + sm = SituationMention( + next(aug), + # ---GFM: added this to get around off-by-one errors (unclear why these arise) + text=raw_text_span(event_start_idx, event_end_idx), + # end--- + situationKind=event.label, + situationType='EVENT', + confidence=event.confidence, + argumentList=list(), + tokens=TokenRefSequence( + # ---GFM: added this to get around off-by-one errors (unclear why these arise) + tokenIndexList=list(range(event_start_idx, event_end_idx+1)), + # end--- + tokenizationId=con_sent.tokenization.uuid + ) + ) + + for arg in event: + # ---GFM: added this to get around off-by-one errors (unclear why these arise) + arg_start_idx = arg.start_idx + arg_end_idx = arg.end_idx + if arg_end_idx > len(con_sent.tokenization.tokenList.tokenList) - 1: + print("WARNING: invalid `arg_end_idx` passed for sentence, adjusting to final token") + print("\tsentence:", con_sent.tokenization.tokenList) + print("arg_end_idx:", arg_end_idx) + print("length:", len(con_sent.tokenization.tokenList.tokenList)) + arg_end_idx = len(con_sent.tokenization.tokenList.tokenList) - 1 + print("new arg_end_idx:", arg_end_idx) + print() + # end--- + + # ---GFM: replaced all arg.*_idx to arg_*_idx + em = em_dict.get((arg_start_idx, arg_end_idx + 1)) + if em is None: + em = EntityMention( + next(aug), + tokens=TokenRefSequence( + tokenIndexList=list(range(arg_start_idx, arg_end_idx+1)), + tokenizationId=con_sent.tokenization.uuid, + ), + text=raw_text_span(arg_start_idx, arg_end_idx) + ) + em_list.append(em) + entity_set.entityList.append(Entity(next(aug), id=em.text, mentionIdList=[em.uuid])) + em_dict[(arg_start_idx, arg_end_idx+1)] = em + sm.argumentList.append(MentionArgument( + role=arg.label, + entityMentionId=em.uuid, + confidence=arg.confidence + )) + # end--- + sm_list.append(sm) + validate_communication(comm) + writer.write(comm, fn) + writer.close() + + def predict_sentence( + self, + sentence: Union[str, List[str]], + ontology_mapping: Optional[Dict[str, str]] = None, + output_format: str = 'span', + ) -> PredictionReturn: + """ + Predict spans on a single sentence (no batch). If not tokenized, will tokenize it with SpacyTokenizer. + :param sentence: If tokenized, should be a list of tokens in string. If not, should be a string. + :param ontology_mapping: + :param output_format: span, json or concrete. + """ + prediction = self.predict_json(self._prepare_sentence(sentence)) + prediction['prediction'] = self.format_convert( + prediction['sentence'], + Span.from_json(prediction['prediction']).map_ontology(ontology_mapping), + output_format + ) + return PredictionReturn(prediction['prediction'], prediction['sentence'], prediction.get('meta', dict())) + + def predict_batch_sentences( + self, + sentences: List[Union[List[str], str]], + max_tokens: int = 512, + ontology_mapping: Optional[Dict[str, str]] = None, + output_format: str = 'span', + ) -> List[PredictionReturn]: + """ + Predict spans on a batch of sentences. If not tokenized, will tokenize it with SpacyTokenizer. + :param sentences: A list of sentences. Refer to `predict_sentence`. + :param max_tokens: Maximum tokens in a batch. + :param ontology_mapping: If not None, will try to map the output from one ontology to another. + If the predicted frame is not in the mapping, the prediction will be ignored. + :param output_format: span, json or concrete. + :return: A list of predictions. + """ + sentences = list(map(self._prepare_sentence, sentences)) + for i_sent, sent in enumerate(sentences): + sent['meta'] = {"idx": i_sent} + instances = list(map(self._json_to_instance, sentences)) + outputs = list() + for ins_indices in MaxTokensBatchSampler(max_tokens, ["tokens"], 0.0).get_batch_indices(instances): + batch_ins = list( + SimpleDataLoader([instances[ins_idx] for ins_idx in ins_indices], len(ins_indices), vocab=self.vocab) + )[0] + batch_inputs = nn_util.move_to_device(batch_ins, device=self.cuda_device) + batch_outputs = self._model(**batch_inputs) + for meta, prediction, inputs in zip( + batch_outputs['meta'], batch_outputs['prediction'], batch_outputs['inputs'] + ): + prediction.map_ontology(ontology_mapping) + prediction = self.format_convert(inputs['sentence'], prediction, output_format) + outputs.append(PredictionReturn(prediction, inputs['sentence'], {"input_idx": meta['idx']})) + + outputs.sort(key=lambda x: x.meta['input_idx']) + return outputs + + def predict_instance(self, instance: Instance) -> JsonDict: + outputs = self._model.forward_on_instance(instance) + outputs = sanitize(outputs) + return { + 'prediction': outputs['prediction'], + 'sentence': outputs['inputs']['sentence'], + 'meta': outputs.get('meta', {}) + } + + def __init__( + self, + model: Model, + dataset_reader: DatasetReader, + frozen: bool = True, + ): + super(SpanPredictor, self).__init__(model=model, dataset_reader=dataset_reader, frozen=frozen) + self.spacy_tokenizer = SpacyTokenizer(language='en_core_web_sm') + + def economize( + self, + max_decoding_spans: Optional[int] = None, + max_recursion_depth: Optional[int] = None, + ): + if max_decoding_spans: + self._model._max_decoding_spans = max_decoding_spans + if max_recursion_depth: + self._model._max_recursion_depth = max_recursion_depth + + def _json_to_instance(self, json_dict: JsonDict) -> Instance: + return self._dataset_reader.text_to_instance(**json_dict) + + @staticmethod + def to_nested(prediction: List[dict]): + first_layer, idx2children = list(), dict() + for idx, pred in enumerate(prediction): + children = list() + pred['children'] = idx2children[idx+1] = children + if pred['parent'] == 0: + first_layer.append(pred) + else: + idx2children[pred['parent']].append(pred) + del pred['parent'] + return first_layer + + def _prepare_sentence(self, sentence: Union[str, List[str]]) -> Dict[str, List[str]]: + if isinstance(sentence, str): + while ' ' in sentence: + sentence = sentence.replace(' ', ' ') + sentence = sentence.replace(chr(65533), '') + if sentence == '': + sentence = [""] + sentence = list(map(str, self.spacy_tokenizer.tokenize(sentence))) + return {"tokens": sentence} + + @staticmethod + def json_to_concrete( + predictions: List[dict], + ): + sentences = list() + for pred in predictions: + tokenization, event = list(), list() + sent = {'text': ' '.join(pred['inputs']), 'tokenization': tokenization, 'event': event} + sentences.append(sent) + start_idx = 0 + for token in pred['inputs']: + tokenization.append((start_idx, len(token)-1+start_idx)) + start_idx += len(token) + 1 + for pred_event in pred['prediction']: + arg_list = list() + one_event = {'argument': arg_list} + event.append(one_event) + for key in ['start_idx', 'end_idx', 'label']: + one_event[key] = pred_event[key] + for pred_arg in pred_event['children']: + arg_list.append({key: pred_arg[key] for key in ['start_idx', 'end_idx', 'label']}) + + concrete_comm = concrete_doc(sentences) + return concrete_comm + + def force_decode( + self, + sentence: List[str], + parent_span: Tuple[int, int] = (-1, -1), + parent_label: str = VIRTUAL_ROOT, + child_spans: Optional[List[Tuple[int, int]]] = None, + ) -> ForceDecodingReturn: + """ + Force decoding. There are 2 modes: + 1. Given parent span and its label, find all it children (direct children, not including other descendents) + and type them. + 2. Given parent span, parent label, and children spans, type all children. + :param sentence: Tokens. + :param parent_span: [start_idx, end_idx], both inclusive. + :param parent_label: Parent label in string. + :param child_spans: Optional. If provided, will turn to mode 2; else mode 1. + :return: + - span: children spans. + - label: most probable labels of children. + - distribution: distribution over children labels. + """ + instance = self._dataset_reader.text_to_instance(self._prepare_sentence(sentence)['tokens']) + model_input = nn_util.move_to_device( + list(SimpleDataLoader([instance], 1, vocab=self.vocab))[0], device=self.cuda_device + ) + offsets = instance.fields['raw_inputs'].metadata['offsets'] + + with torch.no_grad(): + tokens = model_input['tokens'] + # --- edit GM --- + print(parent_span) + print(offsets) + parent_span = re_index_span(parent_span, offsets) + if parent_span[1] is None or parent_span[1] >= self._dataset_reader.max_length: + return ForceDecodingReturn( + np.zeros([0, 2], dtype=np.int), + [], + np.zeros([0, self.vocab.get_vocab_size('span_label')], dtype=np.float64) + ) + # --- END --- + if child_spans is not None: + token_vec = self._model.word_embedding(tokens) + child_pieces = [re_index_span(bdr, offsets) for bdr in child_spans] + child_pieces = list(filter(lambda x: x[1] < self._dataset_reader.max_length-1, child_pieces)) + span_tensor = torch.tensor( + [parent_span] + child_pieces, dtype=torch.int64, device=self.device + ).unsqueeze(0) + parent_indices = span_tensor.new_zeros(span_tensor.shape[0:2]) + span_labels = parent_indices.new_full( + parent_indices.shape, self._model.vocab.get_token_index(parent_label, 'span_label') + ) + span_vec = self._model._span_extractor(token_vec, span_tensor) + typing_out = self._model._span_typing(span_vec, parent_indices, span_labels) + distribution = typing_out['distribution'][0, 1:].cpu().numpy() + boundary = np.array(child_spans) + else: + parent_label_tensor = torch.tensor( + [self._model.vocab.get_token_index(parent_label, 'span_label')], device=self.device + ) + parent_boundary_tensor = torch.tensor([parent_span], device=self.device) + boundary, _, num_children, distribution = self._model.one_step_prediction( + tokens, parent_boundary_tensor, parent_label_tensor + ) + boundary, distribution = boundary[0].cpu().tolist(), distribution[0].cpu().numpy() + boundary = np.array([re_index_span(bdr, offsets, True) for bdr in boundary]) + + labels = [ + self.vocab.get_token_from_index(label_idx, 'span_label') for label_idx in distribution.argmax(1) + ] + return ForceDecodingReturn(boundary, labels, distribution) + + @property + def vocab(self): + return self._model.vocab + + @property + def device(self): + return self.cuda_device if self.cuda_device > -1 else 'cpu' + + @staticmethod + def read_ontology_mapping(file_path: str): + """ + Read the ontology mapping file. The file format can be read in docs. + """ + if file_path is None: + return None + if file_path.endswith('.json'): + return json.load(open(file_path)) + mapping = dict() + for line in open(file_path).readlines(): + parent_label, original_label, new_label = line.replace('\n', '').split('\t') + if parent_label == '*': + mapping[original_label] = new_label + else: + mapping[(parent_label, original_label)] = new_label + return mapping diff --git a/spanfinder/sftp/training/__init__.py b/spanfinder/sftp/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/spanfinder/sftp/training/__pycache__/__init__.cpython-39.pyc b/spanfinder/sftp/training/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1e6eb5c5a08b1b63b5db3e02f3ff571988b7d71 Binary files /dev/null and b/spanfinder/sftp/training/__pycache__/__init__.cpython-39.pyc differ diff --git a/spanfinder/sftp/training/__pycache__/transformer_optimizer.cpython-39.pyc b/spanfinder/sftp/training/__pycache__/transformer_optimizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a066d87f5a7185f2f973cec522023ecb1e94d6c8 Binary files /dev/null and b/spanfinder/sftp/training/__pycache__/transformer_optimizer.cpython-39.pyc differ diff --git a/spanfinder/sftp/training/transformer_optimizer.py b/spanfinder/sftp/training/transformer_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc85620a372304f2bc14f9eea0ee311edaa030b --- /dev/null +++ b/spanfinder/sftp/training/transformer_optimizer.py @@ -0,0 +1,121 @@ +import logging +import re +from typing import * + +import torch +from allennlp.common.from_params import Params, T +from allennlp.training.optimizers import Optimizer + +logger = logging.getLogger('optim') + + +@Optimizer.register('transformer') +class TransformerOptimizer: + """ + Wrapper for AllenNLP optimizer. + This is used to fine-tune the pretrained transformer with some layers fixed and different learning rate. + When some layers are fixed, the wrapper will set the `require_grad` flag as False, which could save + training time and optimize memory usage. + Plz contact Guanghui Qin for bugs. + Params: + base: base optimizer. + embeddings_lr: learning rate for embedding layer. Set as 0.0 to fix it. + encoder_lr: learning rate for encoder layer. Set as 0.0 to fix it. + pooler_lr: learning rate for pooler layer. Set as 0.0 to fix it. + layer_fix: the number of encoder layers that should be fixed. + + Example json config: + + 1. No-op. Do nothing (why do you use me?) + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 0.001 + } + } + + 2. Fix everything in the transformer. + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 0.001 + }, + embeddings_lr: 0.0, + encoder_lr: 0.0, + pooler_lr: 0.0 + } + + Or equivalently (suppose we have 24 layers) + + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 0.001 + }, + embeddings_lr: 0.0, + layer_fix: 24, + pooler_lr: 0.0 + } + + 3. Fix embeddings and the lower 12 encoder layers, set a small learning rate + for the other parts of the transformer + + optimizer: { + type: "transformer", + base: { + type: "adam", + lr: 0.001 + }, + embeddings_lr: 0.0, + layer_fix: 12, + encoder_lr: 1e-5, + pooler_lr: 1e-5 + } + """ + @classmethod + def from_params( + cls: Type[T], + params: Params, + model_parameters: List[Tuple[str, torch.nn.Parameter]], + **_ + ): + param_groups = list() + + def remove_param(keyword_): + nonlocal model_parameters + logger.info(f'Fix param with name matching {keyword_}.') + for name, param in model_parameters: + if keyword_ in name: + logger.debug(f'Fix param {name}.') + param.requires_grad_(False) + model_parameters = list(filter(lambda x: keyword_ not in x[0], model_parameters)) + + for i_layer in range(params.pop('layer_fix')): + remove_param('transformer_model.encoder.layer.{}.'.format(i_layer)) + + for specific_lr, keyword in ( + (params.pop('embeddings_lr', None), 'transformer_model.embeddings'), + (params.pop('encoder_lr', None), 'transformer_model.encoder.layer'), + (params.pop('pooler_lr', None), 'transformer_model.pooler'), + ): + if specific_lr is not None: + if specific_lr > 0.: + pattern = '.*' + keyword.replace('.', r'\.') + '.*' + if len([name for name, _ in model_parameters if re.match(pattern, name)]) > 0: + param_groups.append([[pattern], {'lr': specific_lr}]) + else: + logger.warning(f'{pattern} is set to use lr {specific_lr} but no param matches.') + else: + remove_param(keyword) + + if 'parameter_groups' in params: + for pg in params.pop('parameter_groups'): + param_groups.append([pg[0], pg[1].as_dict()]) + + return Optimizer.by_name(params.get('base').pop('type'))( + model_parameters=model_parameters, parameter_groups=param_groups, + **params.pop('base').as_flat_dict() + ) diff --git a/spanfinder/sftp/utils/__init__.py b/spanfinder/sftp/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6fcf2f6ef8dc30b1519e49031506bc932073aed6 --- /dev/null +++ b/spanfinder/sftp/utils/__init__.py @@ -0,0 +1,7 @@ +import sftp.utils.label_smoothing +from sftp.utils.common import VIRTUAL_ROOT, DEFAULT_SPAN, BIO +from sftp.utils.db_storage import Cache +from sftp.utils.functions import num2mask, mask2idx, numpy2torch, one_hot, max_match +from sftp.utils.span import Span, re_index_span +from sftp.utils.span_utils import tensor2span +from sftp.utils.bio_smoothing import BIOSmoothing, apply_bio_smoothing diff --git a/spanfinder/sftp/utils/__pycache__/__init__.cpython-37.pyc b/spanfinder/sftp/utils/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e61d477a372603c25fefe3bed224a2ee44e6b13 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/__init__.cpython-37.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/__init__.cpython-38.pyc b/spanfinder/sftp/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f405254cd4ee4f9e4ac2ea2a8a3ca13ce6d91c7 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/__init__.cpython-39.pyc b/spanfinder/sftp/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df9d325b55641aed22a131c09dfbace75e9b865a Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/bio_smoothing.cpython-37.pyc b/spanfinder/sftp/utils/__pycache__/bio_smoothing.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f72e95a002ee1f775d16bf1c74abf9707136986 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/bio_smoothing.cpython-37.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/bio_smoothing.cpython-38.pyc b/spanfinder/sftp/utils/__pycache__/bio_smoothing.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2174e3905a1436053dbf4f47b2da2e669361e9d2 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/bio_smoothing.cpython-38.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/bio_smoothing.cpython-39.pyc b/spanfinder/sftp/utils/__pycache__/bio_smoothing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5bdaba4d0a3610ce95c84f752f18655904ab1e7 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/bio_smoothing.cpython-39.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/common.cpython-37.pyc b/spanfinder/sftp/utils/__pycache__/common.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49db4ded2763a71219da0d6a7e7a8efb74542474 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/common.cpython-37.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/common.cpython-38.pyc b/spanfinder/sftp/utils/__pycache__/common.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8c5a2ce6d088e8bfa218908d34422bf0ab7e223 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/common.cpython-38.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/common.cpython-39.pyc b/spanfinder/sftp/utils/__pycache__/common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4beadf49df597efe21bdd7d48ae7b6006326ee8 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/common.cpython-39.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/db_storage.cpython-37.pyc b/spanfinder/sftp/utils/__pycache__/db_storage.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97c69872adcef63ae39158af28c5e2ab958f81aa Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/db_storage.cpython-37.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/db_storage.cpython-38.pyc b/spanfinder/sftp/utils/__pycache__/db_storage.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcde185767146968d480d8d170b7623b83ae96d5 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/db_storage.cpython-38.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/db_storage.cpython-39.pyc b/spanfinder/sftp/utils/__pycache__/db_storage.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01628cae76d27ca5efedc6a59a6690e6fe16e9a5 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/db_storage.cpython-39.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/functions.cpython-37.pyc b/spanfinder/sftp/utils/__pycache__/functions.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c2910a2cb9b78a0a769a520e86836e79ad4babd Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/functions.cpython-37.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/functions.cpython-38.pyc b/spanfinder/sftp/utils/__pycache__/functions.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fda30da3da2bb7c5c4026391335359f7dbb34fae Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/functions.cpython-38.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/functions.cpython-39.pyc b/spanfinder/sftp/utils/__pycache__/functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3da60e29dee68d0996f5621fd60cbd465f27c21 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/functions.cpython-39.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/label_smoothing.cpython-37.pyc b/spanfinder/sftp/utils/__pycache__/label_smoothing.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acad414bca5bcb7b1462a435356b5799fb4b5b82 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/label_smoothing.cpython-37.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/label_smoothing.cpython-38.pyc b/spanfinder/sftp/utils/__pycache__/label_smoothing.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfa6b77f1ed706c7ae7ab32a2c068fe56f533280 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/label_smoothing.cpython-38.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/label_smoothing.cpython-39.pyc b/spanfinder/sftp/utils/__pycache__/label_smoothing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bc15407a3d0943d0fb8a3299fa04ab282d825b5 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/label_smoothing.cpython-39.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/span.cpython-37.pyc b/spanfinder/sftp/utils/__pycache__/span.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b745e1b19208d43b75c95ed899e2932f4ed64d41 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/span.cpython-37.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/span.cpython-38.pyc b/spanfinder/sftp/utils/__pycache__/span.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f466f07181278347b54b3b667cf18afcc82d804 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/span.cpython-38.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/span.cpython-39.pyc b/spanfinder/sftp/utils/__pycache__/span.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfe06ad85f55f016b725aa39a08e1c4d26f41832 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/span.cpython-39.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/span_utils.cpython-37.pyc b/spanfinder/sftp/utils/__pycache__/span_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a5fb085e1c0b005098e18298e6cbbbfb69bc0e4 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/span_utils.cpython-37.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/span_utils.cpython-38.pyc b/spanfinder/sftp/utils/__pycache__/span_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5fad0d978837cf2c7fda90f7be3457f28693791 Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/span_utils.cpython-38.pyc differ diff --git a/spanfinder/sftp/utils/__pycache__/span_utils.cpython-39.pyc b/spanfinder/sftp/utils/__pycache__/span_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..927abbc57f762f03add7ee2ed7c7d8e1b2b825dd Binary files /dev/null and b/spanfinder/sftp/utils/__pycache__/span_utils.cpython-39.pyc differ diff --git a/spanfinder/sftp/utils/bio_smoothing.py b/spanfinder/sftp/utils/bio_smoothing.py new file mode 100644 index 0000000000000000000000000000000000000000..785669a2cec478150d0aa5b973fa5631daecdde7 --- /dev/null +++ b/spanfinder/sftp/utils/bio_smoothing.py @@ -0,0 +1,62 @@ +from typing import * + +import numpy as np +from .common import BIO + + +class BIOSmoothing: + def __init__( + self, + b_smooth: float = 0.0, + i_smooth: float = 0.0, + o_smooth: float = 0.0, + weight: float = 1.0 + ): + self.smooth = [b_smooth, i_smooth, o_smooth] + self.weight = weight + + def apply_sequence(self, sequence: List[str]): + bio_tags = np.zeros([len(sequence), 3], np.float32) + for i, tag in enumerate(sequence): + bio_tags[i] = self.apply_tag(tag) + return bio_tags + + def apply_tag(self, tag: str): + j = BIO.index(tag) + ret = np.zeros([3], np.float32) + if self.smooth[j] >= 0.0: + # Smooth + ret[j] = 1.0 - self.smooth[j] + for j_ in set(range(3)) - {j}: + ret[j_] = self.smooth[j] / 2 + else: + # Marginalize + ret[:] = 1.0 + + return ret * self.weight + + def __repr__(self): + ret = f'' + + def clone(self): + return BIOSmoothing(*self.smooth, self.weight) + + +def apply_bio_smoothing( + config: Optional[Union[BIOSmoothing, List[BIOSmoothing]]], + bio_seq: List[str] +) -> np.ndarray: + if config is None: + config = BIOSmoothing() + if isinstance(config, BIOSmoothing): + return config.apply_sequence(bio_seq) + else: + assert len(bio_seq) == len(config) + return np.stack([cfg.apply_tag(tag) for cfg, tag in zip(config, bio_seq)]) diff --git a/spanfinder/sftp/utils/common.py b/spanfinder/sftp/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..ae0b000a6522f0c7b52606278c23e0813f8e9b40 --- /dev/null +++ b/spanfinder/sftp/utils/common.py @@ -0,0 +1,3 @@ +DEFAULT_SPAN = '@@SPAN@@' +VIRTUAL_ROOT = '@@VIRTUAL_ROOT@@' +BIO = 'BIO' diff --git a/spanfinder/sftp/utils/db_storage.py b/spanfinder/sftp/utils/db_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..dfdac7e63de39456a893446ae93f560a92b06c73 --- /dev/null +++ b/spanfinder/sftp/utils/db_storage.py @@ -0,0 +1,87 @@ +import pickle +import warnings + +import h5py +import numpy as np + + +class Cache: + def __init__(self, file: str, mode: str = 'a', overwrite=False): + self.db_file = h5py.File(file, mode=mode) + self.overwrite = overwrite + + @staticmethod + def _key(key): + if isinstance(key, str): + return key + elif isinstance(key, list): + ret = [] + for k in key: + ret.append(Cache._key(k)) + return ' '.join(ret) + else: + return str(key) + + @staticmethod + def _value(value: np.ndarray): + if isinstance(value, h5py.Dataset): + value: np.ndarray = value[()] + if value.dtype.name.startswith('bytes'): + value = pickle.loads(value) + return value + + def __getitem__(self, key): + key = self._key(key) + if key not in self: + raise KeyError + return self._value(self.db_file[key]) + + def __setitem__(self, key, value) -> None: + key = self._key(key) + if key in self: + del self.db_file[key] + if not isinstance(value, np.ndarray): + value = np.array(pickle.dumps(value)) + self.db_file[key] = value + + def __delitem__(self, key) -> None: + key = self._key(key) + if key in self: + del self.db_file[key] + + def __len__(self) -> int: + return len(self.db_file) + + def close(self) -> None: + self.db_file.close() + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() + + def __contains__(self, item): + item = self._key(item) + return item in self.db_file + + def __enter__(self): + return self + + def __call__(self, function): + """ + The object of the class could also be used as a decorator. Provide an additional + argument `cache_id' when calling the function, and the results will be cached. + """ + + def wrapper(*args, **kwargs): + if 'cache_id' in kwargs: + cache_id = kwargs['cache_id'] + del kwargs['cache_id'] + if cache_id in self and not self.overwrite: + return self[cache_id] + rst = function(*args, **kwargs) + self[cache_id] = rst + return rst + else: + warnings.warn("`cache_id' argument not found. Cache is disabled.") + return function(*args, **kwargs) + + return wrapper diff --git a/spanfinder/sftp/utils/functions.py b/spanfinder/sftp/utils/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2e15ebaf918847f78de1cb9fe6b11470552126e --- /dev/null +++ b/spanfinder/sftp/utils/functions.py @@ -0,0 +1,75 @@ +from typing import * + +import numpy as np +import torch +from scipy.optimize import linear_sum_assignment +from torch.nn.utils.rnn import pad_sequence + + +def num2mask( + nums: torch.Tensor, + max_length: Optional[int] = None +) -> torch.Tensor: + """ + E.g. input a tensor [2, 3, 4], return [[T T F F], [T T T F], [T T T T]] + :param nums: Shape [batch] + :param max_length: maximum length. if not provided, will choose the largest number from nums. + :return: 2D binary mask. + """ + shape_backup = nums.shape + nums = nums.flatten() + max_length = max_length or int(nums.max()) + batch_size = len(nums) + range_nums = torch.arange(0, max_length, device=nums.device).unsqueeze(0).expand([batch_size, max_length]) + ret = (range_nums.T < nums).T + return ret.reshape(*shape_backup, max_length) + + +def mask2idx( + mask: torch.Tensor, + max_length: Optional[int] = None, + padding_value: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + E.g. input a tensor [[T T F F], [T T T F], [F F F T]] with padding value -1, + return [[0, 1, -1], [0, 1, 2], [3, -1, -1]] + :param mask: Mask tensor. Boolean. Not necessarily to be 2D. + :param max_length: If provided, will truncate. + :param padding_value: Padding value. Default to 0. + :return: Index tensor. + """ + shape_prefix, mask_length = mask.shape[:-1], mask.shape[-1] + flat_mask = mask.flatten(0, -2) + index_list = [torch.arange(mask_length, device=mask.device)[one_mask] for one_mask in flat_mask.unbind(0)] + index_tensor = pad_sequence(index_list, batch_first=True, padding_value=padding_value) + if max_length is not None: + index_tensor = index_tensor[:, :max_length] + index_tensor = index_tensor.reshape(*shape_prefix, -1) + return index_tensor, mask.sum(-1) + + +def one_hot(tags: torch.Tensor, num_tags: Optional[int] = None) -> torch.Tensor: + num_tags = num_tags or int(tags.max()) + ret = tags.new_zeros(size=[*tags.shape, num_tags], dtype=torch.bool) + ret.scatter_(2, tags.unsqueeze(2), tags.new_ones([*tags.shape, 1], dtype=torch.bool)) + return ret + + +def numpy2torch( + dict_obj: dict +) -> dict: + """ + Convert list/np.ndarray data to torch.Tensor and add add a batch dim. + """ + ret = dict() + for k, v in dict_obj.items(): + if isinstance(v, list) or isinstance(v, np.ndarray): + ret[k] = torch.tensor(v).unsqueeze(0) + else: + ret[k] = v + return ret + + +def max_match(mat: np.ndarray): + row_idx, col_idx = linear_sum_assignment(mat, True) + return mat[row_idx, col_idx].sum() diff --git a/spanfinder/sftp/utils/label_smoothing.py b/spanfinder/sftp/utils/label_smoothing.py new file mode 100644 index 0000000000000000000000000000000000000000..1b9162a0997428b1519b4c825887cf48532ab079 --- /dev/null +++ b/spanfinder/sftp/utils/label_smoothing.py @@ -0,0 +1,48 @@ +import torch +from torch import nn +from torch.nn import KLDivLoss +from torch.nn import LogSoftmax + + +class LabelSmoothingLoss(nn.Module): + def __init__(self, label_smoothing=0.0, unreliable_label=None, ignore_index=-100): + """ + If label_smoothing == 0.0, it is equivalent to xentropy + """ + assert 0.0 <= label_smoothing <= 1.0 + super(LabelSmoothingLoss, self).__init__() + + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + + self.loss_fn = KLDivLoss(reduction='batchmean') + self.unreliable_label = unreliable_label + self.max_gap = 100. + self.log_softmax = LogSoftmax(1) + + def forward(self, output, target): + """ + output: logits + target: labels + """ + vocab_size = output.shape[1] + mask = (target != self.ignore_index) + output, target = output[mask], target[mask] + output = self.log_softmax(output) + + def get_smooth_prob(ls): + smoothing_value = ls / (vocab_size - 1) + prob = output.new_full((target.size(0), vocab_size), smoothing_value) + prob.scatter_(1, target.unsqueeze(1), 1 - ls) + return prob + + if self.unreliable_label is not None: + smoothed_prob = get_smooth_prob(self.label_smoothing) + hard_prob = get_smooth_prob(0.0) + unreliable_mask = (target == self.unreliable_label).to(torch.float) + model_prob = ((smoothed_prob.T * unreliable_mask) + (hard_prob.T * (1 - unreliable_mask))).T + else: + model_prob = get_smooth_prob(self.label_smoothing) + + loss = self.loss_fn(output, model_prob) + return loss diff --git a/spanfinder/sftp/utils/span.py b/spanfinder/sftp/utils/span.py new file mode 100644 index 0000000000000000000000000000000000000000..43a30844bf233395a0bfe475510acbf69118cacc --- /dev/null +++ b/spanfinder/sftp/utils/span.py @@ -0,0 +1,426 @@ +from typing import * + +import numpy as np + +from .common import VIRTUAL_ROOT, DEFAULT_SPAN +from .bio_smoothing import BIOSmoothing +from .functions import max_match + + +class Span: + """ + Span is a simple data structure for a span (not necessarily associated with text), along with its label, + children and possibly its parent and a confidence score. + + Basic usages (suppose span is a Span object): + 1. len(span) -- #children. + 2. span[i] -- i-th child. + 3. for s in span: ... -- iterate its children. + 4. for s in span.bfs: ... -- iterate its descendents. + 5. print(span) -- show its description. + 6. span.tree() -- print the whole tree. + + It provides some utilities: + 1. Re-indexing. BPE will change token indices, and the `re_index` method can convert normal tokens + BPE word piece indices, or vice versa. + 2. Span object and span dict (JSON format) are mutually convertible (by `to_json` and `from_json` methods). + 3. Recursively truncate spans up to a given length. (see `truncate` method) + 4. Recursively replace all labels with the default label. (see `ignore_labels` method) + 5. Recursively solve the span overlapping problem by removing children overlapped with others. + (see `remove_overlapping` method) + """ + def __init__( + self, + start_idx: int, + end_idx: int, + label: Union[str, int, list] = DEFAULT_SPAN, + is_parent: bool = False, + parent: Optional["Span"] = None, + confidence: Optional[float] = None, + ): + """ + Init function. Children should be added using the `add_children` method. + :param start_idx: Start index in a seq of tokens, inclusive. + :param end_idx: End index in a seq of tokens, inclusive. + :param label: Label. If not provided, will assign a default label. + Can be of various types: String, integer, or list of something. + :param is_parent: If True, will be treated as parent. This is important because in the training process of BIO + tagger, when a span has no children, we need to know if it's a parent with no children (so we should have + an training example with all O tags) or not (then the above example doesn't exist). + We follow a convention where if a span is not parent, then the key `children` shouldn't appear in its + JSON dict; if a span is parent but has no children, the key `children` in its JSON dict should appear + and be an empty list. + :param parent: A pointer to its parent. + :param confidence: Confidence value. + """ + self.start_idx, self.end_idx = start_idx, end_idx + self.label: Union[int, str, list] = label + self.is_parent = is_parent + self.parent = parent + self._children: List[Span] = list() + self.confidence = confidence + + # Following are for label smoothing. Leave default is you don't need smoothing. + # Logic: + # The label smoothing factors of (i.e. b_smooth, i_smooth, o_smooth) depend on the `child_span` of its parent. + # The re-weighting factor of a span also depends on the `child_span` of its parent, but can be overridden + # by its own `smoothing_weight` field if it's not None. + self.child_smooth: BIOSmoothing = BIOSmoothing() + self.smooth_weight: Optional[float] = None + + def add_child(self, span: "Span") -> "Span": + """ + Add a span to children list. Will link current span to child's parent pointer. + :param span: Child span. + """ + assert self.is_parent + span.parent = self + self._children.append(span) + return self + + def re_index( + self, + offsets: List[Optional[Tuple[int, int]]], + reverse: bool = False, + recursive: bool = True, + inplace: bool = False, + ) -> "Span": + """ + BPE will change token indices, and the `re_index` method can convert normal tokens BPE word piece indices, + or vice versa. + We assume Virtual Root has a boundary [-1, -1] before being mapped to the BPE space, and a boundary [0, 0] + after the re-indexing. We use [0, 0] because it's always the BOS token in BPE. + Mapping to BPE space is straight forward. The reverse mapping has special cases where the span might + contain BOS or EOS. Usually this is a parsing bug. We will map the BOS index to 0, and EOS index to -1. + :param offsets: Offsets. Defined by BPE tokenizer and resides in the SpanFinder outputs. + :param reverse: If True, map from the BPE space to original token space. + :param recursive: If True, will apply the re-indexing to its children. + :param inplace: Inplace? + :return: Re-indexed span. + """ + span = self if inplace else self.clone() + + span.start_idx, span.end_idx = re_index_span(span.boundary, offsets, reverse) + if recursive: + new_children = list() + for child in span._children: + new_children.append(child.re_index(offsets, reverse, recursive, True)) + span._children = new_children + return span + + def truncate(self, max_length: int) -> bool: + """ + Discard spans whose end_idx exceeds the max_length (inclusive). + This is done recursively. + This is useful for some encoder like XLMR that has a limit on input length. (512 for XLMR large) + :param max_length: Max length. + :return: You don't need to care return value. + """ + if self.end_idx >= max_length: + return False + else: + self._children = list(filter(lambda x: x.truncate(max_length), self._children)) + return True + + @classmethod + def virtual_root(cls: "Span", spans: Optional[List["Span"]] = None) -> "Span": + """ + An official method to create a tree: Generate the first layer of spans by yourself, and pass them into this + method. + E.g., for SRL style task, generate a list of events, assign arguments to them as children. Then pass the + events to this method to have a virtual root which serves as a parent of events. + :param spans: 1st layer spans. + :return: Virtual root. + """ + vr = Span(-1, -1, VIRTUAL_ROOT, True) + if spans is not None: + vr._children = spans + for child in vr._children: + child.parent = vr + return vr + + def ignore_labels(self) -> None: + """ + Remove all labels. Make them placeholders. Inplace. + """ + self.label = DEFAULT_SPAN + for child in self._children: + child.ignore_labels() + + def clone(self) -> "Span": + """ + Clone a tree. + :return: Cloned tree. + """ + span = Span(self.start_idx, self.end_idx, self.label, self.is_parent, self.parent, self.confidence) + span.child_smooth, span.smooth_weight = self.child_smooth, self.smooth_weight + for child in self._children: + span.add_child(child.clone()) + return span + + def bfs(self) -> Iterable["Span"]: + """ + Iterate over all descendents with BFS, including self. + :return: Spans. + """ + yield self + yield from self._bfs() + + def _bfs(self) -> List["Span"]: + """ + Helper function. + """ + for child in self._children: + yield child + for child in self._children: + yield from child._bfs() + + def remove_overlapping(self, recursive=True) -> int: + """ + Remove overlapped spans. If spans overlap, will pick the first one and discard the others, judged by start_idx. + :param recursive: Apply to all of the descendents? + :return: The number of spans that are removed. + """ + indices = set() + new_children = list() + removing = 0 + for child in self._children: + if len(set(range(child.start_idx, child.end_idx + 1)) & indices) > 0: + removing += 1 + continue + indices.update(set(range(child.start_idx, child.end_idx + 1))) + new_children.append(child) + if recursive: + removing += child.remove_overlapping(True) + self._children = new_children + return removing + + def describe(self, sentence: Optional[List[str]] = None) -> str: + """ + :param sentence: If provided, will replace the indices with real tokens for presentation. + :return: The description in a single line. + """ + if self.start_idx >= 0: + if sentence is None: + span = f'({self.start_idx}, {self.end_idx})' + else: + span = '(' + ' '.join(sentence[self.start_idx: self.end_idx + 1]) + ')' + if self.is_parent: + return f'' + else: + return f'[Span: {span}, {self.label}]' + else: + return f'' + + def __repr__(self) -> str: + return self.describe() + + @property + def n_nodes(self) -> int: + """ + :return: Number of descendents + self. + """ + return sum([child.n_nodes for child in self._children], 1) + + @property + def boundary(self): + """ + :return: (start_idx, end_idx), both inclusive. + """ + return self.start_idx, self.end_idx + + def __iter__(self) -> Iterable["Span"]: + """ + Iterate over children. + """ + yield from self._children + + def __len__(self): + """ + :return: #children. + """ + return len(self._children) + + def __getitem__(self, idx: int): + """ + :return: The indexed child. + """ + return self._children[idx] + + def tree(self, sentence: Optional[List[str]] = None, printing: bool = True) -> str: + """ + A tree description of all descendents. Human readable. + :param sentence: If provided, will replace the indices with real tokens for presentation. + :param printing: If True, will print out. + :return: The description. + """ + ret = list() + ret.append(self.describe(sentence)) + for child in self._children: + child_lines = child.tree(sentence, False).split('\n') + for line in child_lines: + ret.append(' ' + line) + desc = '\n'.join(ret) + if printing: print(desc) + else: return desc + + def match( + self, + other: "Span", + match_label: bool = True, + depth: int = -1, + ignore_parent_boundary: bool = False, + ) -> int: + """ + Used for evaluation. Count how many spans two trees share. Two spans are considered to be identical + if their boundary, label, and parent match. + :param other: The other tree to compare. + :param match_label: If False, will ignore label. + :param depth: If specified as non-negative, will only search thru certain depth. + :param ignore_parent_boundary: If True, two children can be matched ignoring parent boundaries. + :return: #spans two tree share. + """ + if depth == 0: + return 0 + if self.label != other.label and match_label: + return 0 + if self.boundary == other.boundary: + n_match = 1 + elif ignore_parent_boundary: + # Parents fail, Children might match! + n_match = 0 + else: + return 0 + + sub_matches = np.zeros([len(self), len(other)], dtype=np.int) + for self_idx, my_child in enumerate(self): + for other_idx, other_child in enumerate(other): + sub_matches[self_idx, other_idx] = my_child.match( + other_child, match_label, depth-1, ignore_parent_boundary + ) + if not ignore_parent_boundary: + for m in [sub_matches, sub_matches.T]: + for line in m: + assert (line > 0).sum() <= 1 + n_match += max_match(sub_matches) + return n_match + + def to_json(self) -> dict: + """ + To JSON dict format. See init. + """ + ret = { + "label": self.label, + "span": list(self.boundary), + } + if self.confidence is not None: + ret['confidence'] = self.confidence + if self.is_parent: + children = list() + for child in self._children: + children.append(child.to_json()) + ret['children'] = children + return ret + + @classmethod + def from_json(cls, span_json: Union[list, dict]) -> "Span": + """ + Load from JSON. See init. + """ + if isinstance(span_json, dict): + span = Span( + span_json['span'][0], span_json['span'][1], span_json.get('label', None), 'children' in span_json, + confidence=span_json.get('confidence', None) + ) + for child_dict in span_json.get('children', []): + span.add_child(Span.from_json(child_dict)) + else: + spans = [Span.from_json(child) for child in span_json] + span = Span.virtual_root(spans) + return span + + def map_ontology( + self, + ontology_mapping: Optional[dict] = None, + inplace: bool = True, + recursive: bool = True, + ) -> Optional["Span"]: + """ + Map labels to other things, like another ontology of soft labels. + :param ontology_mapping: Mapping dict. The key should be labels, and values can be anything. + Labels not in the dict will not be deleted. So be careful. + :param inplace: Inplace? + :param recursive: Apply to all descendents if True. + :return: The mapped tree. + """ + span = self if inplace else self.clone() + if ontology_mapping is None: + # Do nothing if mapping not provided. + return span + + if recursive: + new_children = list() + for child in span: + new_child = child.map_ontology(ontology_mapping, False, True) + if new_child is not None: + new_children.append(new_child) + span._children = new_children + + if span.label != VIRTUAL_ROOT: + if span.parent is not None and (span.parent.label, span.label) in ontology_mapping: + span.label = ontology_mapping[(span.parent.label, span.label)] + elif span.label in ontology_mapping: + span.label = ontology_mapping[span.label] + else: + return + + return span + + def isolate(self) -> "Span": + """ + Generate a span that is identical to self but has no children or parent. + """ + return Span(self.start_idx, self.end_idx, self.label, self.is_parent, None, self.confidence) + + def remove_child(self, span: Optional["Span"] = None): + """ + Remove a child. If pass None, will reset the children list. + """ + if span is None: + self._children = list() + else: + del self._children[self._children.index(span)] + + +def re_index_span( + boundary: Tuple[int, int], offsets: List[Tuple[int, int]], reverse: bool = False +) -> Tuple[int, int]: + """ + Helper function. + """ + if not reverse: + if boundary[0] == boundary[1] == -1: + # Virtual Root + start_idx = end_idx = 0 + else: + # --- edit GFM --- + bnd_start, bnd_end = boundary + if bnd_end >= len(offsets): + start_idx, end_idx = None, None + else: + start_idx = offsets[boundary[0]][0] + end_idx = offsets[boundary[1]][1] + # --- END --- + else: + if boundary[0] == boundary[1] == 0: + # Virtual Root + start_idx = end_idx = -1 + else: + start_within = [bo[0] <= boundary[0] <= bo[1] if bo is not None else False for bo in offsets] + end_within = [bo[0] <= boundary[1] <= bo[1] if bo is not None else False for bo in offsets] + assert sum(start_within) <= 1 and sum(end_within) <= 1 + start_idx = start_within.index(True) if sum(start_within) == 1 else 0 + end_idx = end_within.index(True) if sum(end_within) == 1 else len(offsets) + if start_idx > end_idx: + raise IndexError + return start_idx, end_idx diff --git a/spanfinder/sftp/utils/span_utils.py b/spanfinder/sftp/utils/span_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b9817ec2d956abdc2ff3566515208090a1e6eba3 --- /dev/null +++ b/spanfinder/sftp/utils/span_utils.py @@ -0,0 +1,57 @@ +from typing import * + +import torch + +from .span import Span + + +def _tensor2span_batch( + span_boundary: torch.Tensor, + span_labels: torch.Tensor, + parent_indices: torch.Tensor, + num_spans: torch.Tensor, + label_confidence: torch.Tensor, + idx2label: Dict[int, str], + label_ignore: List[int], +) -> Span: + spans = list() + for (start_idx, end_idx), parent_idx, label, label_conf in \ + list(zip(span_boundary, parent_indices, span_labels, label_confidence))[:int(num_spans)]: + if label not in label_ignore: + span = Span(int(start_idx), int(end_idx), idx2label[int(label)], True, confidence=float(label_conf)) + if int(parent_idx) < len(spans): + spans[int(parent_idx)].add_child(span) + spans.append(span) + return spans[0] + + +def tensor2span( + span_boundary: torch.Tensor, + span_labels: torch.Tensor, + parent_indices: torch.Tensor, + num_spans: torch.Tensor, + label_confidence: torch.Tensor, + idx2label: Dict[int, str], + label_ignore: Optional[List[int]] = None, +) -> List[Span]: + """ + Generate spans in dict from vectors. Refer to the model part for the meaning of these variables. + If idx_ignore is provided, some labels will be ignored. + :return: + """ + label_ignore = label_ignore or [] + if span_boundary.device.type != 'cpu': + span_boundary = span_boundary.to(device='cpu') + parent_indices = parent_indices.to(device='cpu') + span_labels = span_labels.to(device='cpu') + num_spans = num_spans.to(device='cpu') + label_confidence = label_confidence.to(device='cpu') + + ret = list() + for args in zip( + span_boundary.unbind(0), span_labels.unbind(0), parent_indices.unbind(0), num_spans.unbind(0), + label_confidence.unbind(0), + ): + ret.append(_tensor2span_batch(*args, label_ignore=label_ignore, idx2label=idx2label)) + + return ret diff --git a/spanfinder/sociolome/__pycache__/combine_models.cpython-39.pyc b/spanfinder/sociolome/__pycache__/combine_models.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f735da04886360826fc48db2dfc33865beff8af Binary files /dev/null and b/spanfinder/sociolome/__pycache__/combine_models.cpython-39.pyc differ diff --git a/spanfinder/sociolome/__pycache__/evalita_eval.cpython-37.pyc b/spanfinder/sociolome/__pycache__/evalita_eval.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eba5e130114634faa3ba3aa1dfff4a3c588afcdc Binary files /dev/null and b/spanfinder/sociolome/__pycache__/evalita_eval.cpython-37.pyc differ diff --git a/spanfinder/sociolome/__pycache__/evalita_eval.cpython-38.pyc b/spanfinder/sociolome/__pycache__/evalita_eval.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d71f26ccfe3da38ffa47b3fdb95e37f60fef828 Binary files /dev/null and b/spanfinder/sociolome/__pycache__/evalita_eval.cpython-38.pyc differ diff --git a/spanfinder/sociolome/__pycache__/evalita_eval.cpython-39.pyc b/spanfinder/sociolome/__pycache__/evalita_eval.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1514be87586b80c302965ff872d9eeb2f3efa7d Binary files /dev/null and b/spanfinder/sociolome/__pycache__/evalita_eval.cpython-39.pyc differ diff --git a/spanfinder/sociolome/__pycache__/evalita_eval.cpython-39.pyc.139628289135072 b/spanfinder/sociolome/__pycache__/evalita_eval.cpython-39.pyc.139628289135072 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/spanfinder/sociolome/__pycache__/evalita_force_predict.cpython-39.pyc b/spanfinder/sociolome/__pycache__/evalita_force_predict.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54491a16ad2f5e3b8d8e38fb30938aa4bd9ebffa Binary files /dev/null and b/spanfinder/sociolome/__pycache__/evalita_force_predict.cpython-39.pyc differ diff --git a/spanfinder/sociolome/__pycache__/lome_webserver.cpython-38.pyc b/spanfinder/sociolome/__pycache__/lome_webserver.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a7ea0764b03879df80f13fe24c702ad6a13e25c Binary files /dev/null and b/spanfinder/sociolome/__pycache__/lome_webserver.cpython-38.pyc differ diff --git a/spanfinder/sociolome/__pycache__/lome_webserver.cpython-39.pyc b/spanfinder/sociolome/__pycache__/lome_webserver.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b56b9a4fab84a75e144f204d27de20e5ac9e2aa1 Binary files /dev/null and b/spanfinder/sociolome/__pycache__/lome_webserver.cpython-39.pyc differ diff --git a/spanfinder/sociolome/combine_models.py b/spanfinder/sociolome/combine_models.py new file mode 100644 index 0000000000000000000000000000000000000000..da207170ca1628284c4855a6e1defc1c7f5b003b --- /dev/null +++ b/spanfinder/sociolome/combine_models.py @@ -0,0 +1,130 @@ +from typing import Any, Dict, List, Optional +import dataclasses +import glob +import os +import sys +import json + +import spacy +from spacy.language import Language + +from sftp import SpanPredictor + + +@dataclasses.dataclass +class FrameAnnotation: + tokens: List[str] = dataclasses.field(default_factory=list) + pos: List[str] = dataclasses.field(default_factory=list) + + +@dataclasses.dataclass +class MultiLabelAnnotation(FrameAnnotation): + frame_list: List[List[str]] = dataclasses.field(default_factory=list) + lu_list: List[Optional[str]] = dataclasses.field(default_factory=list) + + def to_txt(self): + for i, tok in enumerate(self.tokens): + yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}" + + +def convert_to_seq_labels(sentence: List[str], structures: Dict[int, Dict[str, Any]]) -> List[List[str]]: + labels = [[] for _ in sentence] + + for struct_id, struct in structures.items(): + tgt_span = struct["target"] + frame = struct["frame"] + + for i in range(tgt_span[0], tgt_span[1] + 1): + labels[i].append(f"T:{frame}@{struct_id:02}") + for role in struct["roles"]: + role_span = role["boundary"] + role_label = role["label"] + for i in range(role_span[0], role_span[1] + 1): + prefix = "B" if i == role_span[0] else "I" + labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}") + return labels + + +def predict_combined( + spacy_model: Language, + sentences: List[str], + tgt_predictor: SpanPredictor, + frm_predictor: SpanPredictor, + bnd_predictor: SpanPredictor, + arg_predictor: SpanPredictor, +) -> List[MultiLabelAnnotation]: + + annotations_out = [] + + for sent_idx, sent in enumerate(sentences): + + sent = sent.strip() + + print(f"Processing sent with idx={sent_idx}: {sent}") + + doc = spacy_model(sent) + sent_tokens = [t.text for t in doc] + + tgt_spans, _, _ = tgt_predictor.force_decode(sent_tokens) + + frame_structures = {} + + for i, span in enumerate(tgt_spans): + span = tuple(span) + _, fr_labels, _ = frm_predictor.force_decode(sent_tokens, child_spans=[span]) + frame = fr_labels[0] + if frame == "@@VIRTUAL_ROOT@@@": + continue + + boundaries, _, _ = bnd_predictor.force_decode(sent_tokens, parent_span=span, parent_label=frame) + _, arg_labels, _ = arg_predictor.force_decode(sent_tokens, parent_span=span, parent_label=frame, child_spans=boundaries) + + frame_structures[i] = { + "target": span, + "frame": frame, + "roles": [ + {"boundary": bnd, "label": label} + for bnd, label in zip(boundaries, arg_labels) + if label != "Target" + ] + } + annotations_out.append(MultiLabelAnnotation( + tokens=sent_tokens, + pos=[t.pos_ for t in doc], + frame_list=convert_to_seq_labels(sent_tokens, frame_structures), + lu_list=[None for _ in sent_tokens] + )) + return annotations_out + + +def main(input_folder): + + print("Loading spaCy model ...") + nlp = spacy.load("it_core_news_md") + + print("Loading predictors ...") + zs_predictor = SpanPredictor.from_path("/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz", cuda_device=0) + ev_predictor = SpanPredictor.from_path("/scratch/p289731/lome-training-files/train-evalita-plus-fn-vanilla/model.tar.gz", cuda_device=0) + + + print("Reading input files ...") + for file in glob.glob(os.path.join(input_folder, "*.txt")): + print(file) + with open(file, encoding="utf-8") as f: + sentences = list(f) + + annotations = predict_combined(nlp, sentences, zs_predictor, ev_predictor, ev_predictor, ev_predictor) + + out_name = os.path.splitext(os.path.basename(file))[0] + with open(f"../../data-out/{out_name}.combined_zs_ev.tc_bilstm.txt", "w", encoding="utf-8") as f_out: + for ann in annotations: + for line in ann.to_txt(): + f_out.write(line + os.linesep) + f_out.write(os.linesep) + + with open(f"../../data-out/{out_name}.combined_zs_ev.tc_bilstm.json", "w", encoding="utf-8") as f_out: + json.dump([dataclasses.asdict(ann) for ann in annotations], f_out) + + +if __name__ == "__main__": + main(sys.argv[1]) diff --git a/spanfinder/sociolome/evalita_eval.py b/spanfinder/sociolome/evalita_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..f4c426e3418986f9410df1f5fc19332a2a5fa15e --- /dev/null +++ b/spanfinder/sociolome/evalita_eval.py @@ -0,0 +1,319 @@ +import json +from typing import List, Tuple + +import pandas as pd + +from sftp import SpanPredictor + + +def main(): + # data_file = "/home/p289731/cloned/lome/preproc/evalita_jsonl/evalita_dev.jsonl" + # data_file = "/home/p289731/cloned/lome/preproc/svm_challenge.jsonl" + data_file = "/home/p289731/cloned/lome/preproc/evalita_jsonl/evalita_test.jsonl" + models = [ + ( + "lome-en", + "/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz", + ), + ( + "lome-it-best", + "/scratch/p289731/lome-training-files/train-evalita-plus-fn-vanilla/model.tar.gz", + ), + # ( + # "lome-it-freeze", + # "/data/p289731/cloned/lome/train-evalita-plus-fn-freeze/model.tar.gz", + # ), + # ( + # "lome-it-mono", + # "/data/p289731/cloned/lome/train-evalita-it_mono/model.tar.gz", + # ), + ] + + for (model_name, model_path) in models: + print("testing model: ", model_name) + predictor = SpanPredictor.from_path(model_path) + + print("=== FD (run 1) ===") + eval_frame_detection(data_file, predictor, model_name=model_name) + + for run in [1, 2]: + print(f"=== BD (run {run}) ===") + eval_boundary_detection(data_file, predictor, run=run) + + for run in [1, 2, 3]: + print(f"=== AC (run {run}) ===") + eval_argument_classification(data_file, predictor, run=run) + + +def predict_frame( + predictor: SpanPredictor, tokens: List[str], predicate_span: Tuple[int, int] +): + _, labels, _ = predictor.force_decode(tokens, child_spans=[predicate_span]) + return labels[0] + + +def eval_frame_detection(data_file, predictor, verbose=False, model_name="_"): + + true_pos = 0 + false_pos = 0 + + out = [] + + with open(data_file, encoding="utf-8") as f: + for sent_id, sent in enumerate(f): + sent_data = json.loads(sent) + + tokens = sent_data["tokens"] + annotation = sent_data["annotations"][0] + + predicate_span = tuple(annotation["span"]) + predicate = tokens[predicate_span[0] : predicate_span[1] + 1] + + frame_gold = annotation["label"] + frame_pred = predict_frame(predictor, tokens, predicate_span) + + if frame_pred == frame_gold: + true_pos += 1 + else: + false_pos += 1 + + out.append({ + "sentence": " ".join(tokens), + "predicate": predicate, + "frame_gold": frame_gold, + "frame_pred": frame_pred + }) + + if verbose: + print(f"Sentence #{sent_id:03}: {' '.join(tokens)}") + print(f"\tpredicate: {predicate}") + print(f"\t gold: {frame_gold}") + print(f"\tpredicted: {frame_pred}") + print() + + acc_score = true_pos / (true_pos + false_pos) + print("ACC =", acc_score) + + data_sect = "rai" if "svm_challenge" in data_file else "dev" if "dev" in data_file else "test" + + df_out = pd.DataFrame(out) + df_out.to_csv(f"frame_prediction_output_{model_name}_{data_sect}.csv") + + +def predict_boundaries(predictor: SpanPredictor, tokens, predicate_span, frame): + boundaries, labels, _ = predictor.force_decode( + tokens, parent_span=predicate_span, parent_label=frame + ) + out = [] + for bnd, lab in zip(boundaries, labels): + bnd = tuple(bnd) + if bnd == predicate_span and lab == "Target": + continue + out.append(bnd) + return out + + +def get_gold_boundaries(annotation, predicate_span): + return { + tuple(c["span"]) + for c in annotation["children"] + if not (tuple(c["span"]) == predicate_span and c["label"] == "Target") + } + + +def eval_boundary_detection(data_file, predictor, run=1, verbose=False): + + assert run in [1, 2] + + true_pos = 0 + false_pos = 0 + false_neg = 0 + + true_pos_tok = 0 + false_pos_tok = 0 + false_neg_tok = 0 + + with open(data_file, encoding="utf-8") as f: + for sent_id, sent in enumerate(f): + sent_data = json.loads(sent) + + tokens = sent_data["tokens"] + annotation = sent_data["annotations"][0] + + predicate_span = tuple(annotation["span"]) + predicate = tokens[predicate_span[0] : predicate_span[1] + 1] + + if run == 1: + frame = predict_frame(predictor, tokens, predicate_span) + else: + frame = annotation["label"] + + boundaries_gold = get_gold_boundaries(annotation, predicate_span) + boundaries_pred = set( + predict_boundaries(predictor, tokens, predicate_span, frame) + ) + + sent_true_pos = len(boundaries_gold & boundaries_pred) + sent_false_pos = len(boundaries_pred - boundaries_gold) + sent_false_neg = len(boundaries_gold - boundaries_pred) + true_pos += sent_true_pos + false_pos += sent_false_pos + false_neg += sent_false_neg + + boundary_toks_gold = { + tok_idx + for (start, stop) in boundaries_gold + for tok_idx in range(start, stop + 1) + } + boundary_toks_pred = { + tok_idx + for (start, stop) in boundaries_pred + for tok_idx in range(start, stop + 1) + } + sent_tok_true_pos = len(boundary_toks_gold & boundary_toks_pred) + sent_tok_false_pos = len(boundary_toks_pred - boundary_toks_gold) + sent_tok_false_neg = len(boundary_toks_gold - boundary_toks_pred) + true_pos_tok += sent_tok_true_pos + false_pos_tok += sent_tok_false_pos + false_neg_tok += sent_tok_false_neg + + if verbose: + print(f"Sentence #{sent_id:03}: {' '.join(tokens)}") + print(f"\tpredicate: {predicate}") + print(f"\t frame: {frame}") + print(f"\t gold: {boundaries_gold}") + print(f"\tpredicted: {boundaries_pred}") + print(f"\ttp={sent_true_pos}\tfp={sent_false_pos}\tfn={sent_false_neg}") + print( + f"\ttp_t={sent_tok_true_pos}\tfp_t={sent_tok_false_pos}\tfn_t={sent_tok_false_neg}" + ) + print() + + prec = true_pos / (true_pos + false_pos) + rec = true_pos / (true_pos + false_neg) + f1_score = 2 * ((prec * rec) / (prec + rec)) + + print(f"P/R/F=\n{prec}\t{rec}\t{f1_score}") + + tok_prec = true_pos_tok / (true_pos_tok + false_pos_tok) + tok_rec = true_pos_tok / (true_pos_tok + false_neg_tok) + tok_f1 = 2 * ((tok_prec * tok_rec) / (tok_prec + tok_rec)) + + print(f"Pt/Rt/Ft=\n{tok_prec}\t{tok_rec}\t{tok_f1}") + + +def predict_arguments( + predictor: SpanPredictor, tokens, predicate_span, frame, boundaries +): + boundaries = list(sorted(boundaries, key=lambda t: t[0])) + _, labels, _ = predictor.force_decode( + tokens, parent_span=predicate_span, parent_label=frame, child_spans=boundaries + ) + out = [] + for bnd, lab in zip(boundaries, labels): + if bnd == predicate_span and lab == "Target": + continue + out.append((bnd, lab)) + return out + + +def eval_argument_classification(data_file, predictor, run=1, verbose=False): + assert run in [1, 2, 3] + + true_pos = 0 + false_pos = 0 + false_neg = 0 + + true_pos_tok = 0 + false_pos_tok = 0 + false_neg_tok = 0 + + with open(data_file, encoding="utf-8") as f: + for sent_id, sent in enumerate(f): + sent_data = json.loads(sent) + + tokens = sent_data["tokens"] + annotation = sent_data["annotations"][0] + + predicate_span = tuple(annotation["span"]) + predicate = tokens[predicate_span[0] : predicate_span[1] + 1] + + # gold or predicted frames? + if run == 1: + frame = predict_frame(predictor, tokens, predicate_span) + else: + frame = annotation["label"] + + # gold or predicted argument boundaries? + if run in [1, 2]: + boundaries = set( + predict_boundaries(predictor, tokens, predicate_span, frame) + ) + else: + boundaries = get_gold_boundaries(annotation, predicate_span) + + pred_arguments = predict_arguments( + predictor, tokens, predicate_span, frame, boundaries + ) + gold_arguments = { + (tuple(c["span"]), c["label"]) + for c in annotation["children"] + if not (tuple(c["span"]) == predicate_span and c["label"] == "Target") + } + + if verbose: + print(f"Sentence #{sent_id:03}: {' '.join(tokens)}") + print(f"\tpredicate: {predicate}") + print(f"\t frame: {frame}") + print(f"\t gold: {gold_arguments}") + print(f"\tpredicted: {pred_arguments}") + print() + + # -- full spans version + for g_bnd, g_label in gold_arguments: + # true positive: found the span and labeled it correctly + if (g_bnd, g_label) in pred_arguments: + true_pos += 1 + # false negative: missed this argument + else: + false_neg += 1 + for p_bnd, p_label in pred_arguments: + # all predictions that are not true positives are false positives + if (p_bnd, p_label) not in gold_arguments: + false_pos += 1 + + # -- token based + tok_gold_labels = { + (token, label) + for ((bnd_start, bnd_end), label) in gold_arguments + for token in range(bnd_start, bnd_end + 1) + } + tok_pred_labels = { + (token, label) + for ((bnd_start, bnd_end), label) in pred_arguments + for token in range(bnd_start, bnd_end + 1) + } + for g_tok, g_tok_label in tok_gold_labels: + if (g_tok, g_tok_label) in tok_pred_labels: + true_pos_tok += 1 + else: + false_neg_tok += 1 + for p_tok, p_tok_label in tok_pred_labels: + if (p_tok, p_tok_label) not in tok_gold_labels: + false_pos_tok += 1 + + prec = true_pos / (true_pos + false_pos) + rec = true_pos / (true_pos + false_neg) + f1_score = 2 * ((prec * rec) / (prec + rec)) + + print(f"P/R/F=\n{prec}\t{rec}\t{f1_score}") + + tok_prec = true_pos_tok / (true_pos_tok + false_pos_tok) + tok_rec = true_pos_tok / (true_pos_tok + false_neg_tok) + tok_f1 = 2 * ((tok_prec * tok_rec) / (tok_prec + tok_rec)) + + print(f"Pt/Rt/Ft=\n{tok_prec}\t{tok_rec}\t{tok_f1}") + + +if __name__ == "__main__": + main() diff --git a/spanfinder/sociolome/lome_webserver.py b/spanfinder/sociolome/lome_webserver.py new file mode 100644 index 0000000000000000000000000000000000000000..adb82a0d1c61cb643da94297679166af647576b5 --- /dev/null +++ b/spanfinder/sociolome/lome_webserver.py @@ -0,0 +1,116 @@ +from sftp import SpanPredictor +import spacy + +from flask import Flask, request, render_template, jsonify, redirect, abort, session + +import sys +import dataclasses +from typing import List, Optional, Dict, Any + + +# --- NLP code --- + +@dataclasses.dataclass +class FrameAnnotation: + tokens: List[str] = dataclasses.field(default_factory=list) + pos: List[str] = dataclasses.field(default_factory=list) + + +@dataclasses.dataclass +class MultiLabelAnnotation(FrameAnnotation): + frame_list: List[List[str]] = dataclasses.field(default_factory=list) + lu_list: List[Optional[str]] = dataclasses.field(default_factory=list) + + def to_txt(self): + for i, tok in enumerate(self.tokens): + yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}" + + +# reused from "combine_predictions.py" (cloned/lome/src/spanfinder/sociolome) +def convert_to_seq_labels(sentence: List[str], structures: Dict[int, Dict[str, Any]]) -> List[List[str]]: + labels = [[] for _ in sentence] + + for struct_id, struct in structures.items(): + tgt_span = struct["target"] + frame = struct["frame"] + for i in range(tgt_span[0], tgt_span[1] + 1): + if i >= len(labels): + continue + labels[i].append(f"T:{frame}@{struct_id:02}") + for role in struct["roles"]: + role_span = role["boundary"] + role_label = role["label"] + for i in range(role_span[0], role_span[1] + 1): + if i >= len(labels): + continue + prefix = "B" if i == role_span[0] else "I" + labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}") + return labels + +def make_prediction(sentence, spacy_model, predictor): + spacy_doc = spacy_model(sentence) + tokens = [t.text for t in spacy_doc] + tgt_spans, fr_labels, fr_probas = predictor.force_decode(tokens) + + frame_structures = {} + + for i, (tgt, frm, fr_proba) in enumerate(sorted(zip(tgt_spans, fr_labels, fr_probas), key=lambda t: t[0][0])): + if frm.startswith("@@"): + continue + if frm.upper() == frm: + continue + if fr_proba.max() != 1.0: + continue + + arg_spans, arg_labels, label_probas = predictor.force_decode(tokens, parent_span=tgt, parent_label=frm) + + frame_structures[i] = { + "target": tgt, + "frame": frm, + "roles": [ + {"boundary": bnd, "label": label} + for bnd, label, probas in zip(arg_spans, arg_labels, label_probas) + if label != "Target" and max(probas) == 1.0 + ] + } + + return MultiLabelAnnotation( + tokens=tokens, + pos=[t.pos_ for t in spacy_doc], + frame_list=convert_to_seq_labels(tokens, frame_structures), + lu_list=[None for _ in tokens] + ) + + +#predictor = SpanPredictor.from_path("model.kicktionary.mod.tar.gz") +predictor = SpanPredictor.from_path("model.mod.tar.gz") +nlp = spacy.load("it_core_news_md") + + +# --- FLASK code --- + + +app = Flask(__name__) + + +@app.route("/analyze") +def analyze(): + text = request.args.get("text") + analyses = [] + for sentence in text.split("\n"): + analyses.append(make_prediction(sentence, nlp, predictor)) + + return jsonify({ + "result": "OK", + "analyses": [dataclasses.asdict(an) for an in analyses] + }) + + + +if __name__ == "__main__": + if len(sys.argv) > 1: + host = sys.argv[1] + else: + host = "127.0.0.1" + + app.run(host=host, debug=False, port=9090) diff --git a/spanfinder/tools/demo/flask_backend.py b/spanfinder/tools/demo/flask_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..c51604f1599683432b6ffa7f6b91404e5b9586e8 --- /dev/null +++ b/spanfinder/tools/demo/flask_backend.py @@ -0,0 +1,93 @@ +from argparse import ArgumentParser +from typing import * + +from flask import Flask +from flask import request + +from sftp import SpanPredictor, Span + +parser = ArgumentParser() +parser.add_argument('model', metavar='MODEL_PATH', type=str) +parser.add_argument('-p', metavar='PORT', type=int, default=7749) +parser.add_argument('-d', metavar='DEVICE', type=int, default=-1) +args = parser.parse_args() + +template = open('tools/demo/flask_template.html').read() +predictor = SpanPredictor.from_path(args.model, cuda_device=args.d) +app = Flask(__name__) +default_sentence = '因为 आरजू です vegan , هي купил soja .' + + +def visualized_prediction(inputs: List[str], prediction: Span, prefix=''): + spans = list() + span2event = [[] for _ in inputs] + for event_idx, event in enumerate(prediction): + for arg_idx, arg in enumerate(event): + for token_idx in range(arg.start_idx, arg.end_idx+1): + span2event[token_idx].append((event_idx, arg_idx)) + + for token_idx, token in enumerate(inputs): + class_labels = ' '.join( + ['token'] + [f'{prefix}-arg-{event_idx}-{arg_idx}' for event_idx, arg_idx in span2event[token_idx]] + ) + spans.append(f'{token} \n') + + for event_idx, event in enumerate(prediction): + spans[event.start_idx] = ( + f'' + '' + f'' + + spans[event.start_idx] + ) + spans[event.end_idx] += f'
{event.label}
' + arg_tips = [] + for arg_idx, arg in enumerate(event): + arg_tips.append(f'{arg.label}') + if len(arg_tips) > 0: + arg_tips = '
'.join(arg_tips) + spans[event.end_idx] += f'{arg_tips}\n' + spans[event.end_idx] += '\n
' + return( + '
\n' + + '\n'.join(spans) + '\n
' + ) + + +def structured_prediction(inputs, prediction): + ret = list() + for event in prediction: + event_text, event_label = ' '.join(inputs[event.start_idx: event.end_idx+1]), event.label + ret.append(f'
  • ' + f'{event_label}: {event_text}
  • ') + for arg in event: + arg_text = ' '.join(inputs[arg.start_idx: arg.end_idx+1]) + ret.append( + f'
  •     {arg.label}: {arg_text}
  • ' + ) + content = '\n'.join(ret) + return '\n
      \n' + content + '\n
    ' + + +@app.route('/') +def sftp(): + ret = template + tokens = request.args.get('sentence') + if tokens is not None: + ret = ret.replace('DEFAULT_SENTENCE', tokens) + sentences = tokens.split('\n') + model_outputs = predictor.predict_batch_sentences(sentences, max_tokens=512) + vis_pred, str_pred = list(), list() + for sent_idx, output in enumerate(model_outputs): + vis_pred.append(visualized_prediction(output.sentence, output.span, f'sent{sent_idx}')) + str_pred.append(structured_prediction(output.sentence, output.span)) + ret = ret.replace('VISUALIZED_PREDICTION', '
    '.join(vis_pred)) + ret = ret.replace('STRUCTURED_PREDICTION', '
    '.join(str_pred)) + else: + ret = ret.replace('DEFAULT_SENTENCE', default_sentence) + ret = ret.replace('VISUALIZED_PREDICTION', '') + ret = ret.replace('STRUCTURED_PREDICTION', '') + return ret + + +app.run(port=args.p) diff --git a/spanfinder/tools/demo/flask_template.html b/spanfinder/tools/demo/flask_template.html new file mode 100644 index 0000000000000000000000000000000000000000..081eb6c01a9ed5d1c252b9896f37e90a6cee42fb --- /dev/null +++ b/spanfinder/tools/demo/flask_template.html @@ -0,0 +1,688 @@ + + + + + + FrameNet Parser + + + + + + + + + + + +
    +
    +
    +
    +
    +
    +
    +
    + +
    +
    +
    + +
    +
    Visualized Output
    +
    + VISUALIZED_PREDICTION +
    +
    + +
    +
    Structured Output
    +
    + STRUCTURED_PREDICTION +
    +
    + +
    + + + diff --git a/spanfinder/tools/framenet/__init__.py b/spanfinder/tools/framenet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/spanfinder/tools/framenet/__pycache__/__init__.cpython-38.pyc b/spanfinder/tools/framenet/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cd7e8fcd4c4aff0b727f199d759c24a7cb0a40c Binary files /dev/null and b/spanfinder/tools/framenet/__pycache__/__init__.cpython-38.pyc differ diff --git a/spanfinder/tools/framenet/__pycache__/__init__.cpython-39.pyc b/spanfinder/tools/framenet/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..304fc8b5b52f1e273b183f979be36761a553333e Binary files /dev/null and b/spanfinder/tools/framenet/__pycache__/__init__.cpython-39.pyc differ diff --git a/spanfinder/tools/framenet/__pycache__/concrete_fn.cpython-39.pyc b/spanfinder/tools/framenet/__pycache__/concrete_fn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eacbfaa67e313afa7c56692ad3b4defa44261db1 Binary files /dev/null and b/spanfinder/tools/framenet/__pycache__/concrete_fn.cpython-39.pyc differ diff --git a/spanfinder/tools/framenet/__pycache__/fn_util.cpython-39.pyc b/spanfinder/tools/framenet/__pycache__/fn_util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c1dfbd8a467af4f97d3ee5a890415105755eb37 Binary files /dev/null and b/spanfinder/tools/framenet/__pycache__/fn_util.cpython-39.pyc differ diff --git a/spanfinder/tools/framenet/__pycache__/gen_fn_data.cpython-39.pyc b/spanfinder/tools/framenet/__pycache__/gen_fn_data.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9507fccb5a2ff0c1e19ff57a047b0328fea4af7 Binary files /dev/null and b/spanfinder/tools/framenet/__pycache__/gen_fn_data.cpython-39.pyc differ diff --git a/spanfinder/tools/framenet/__pycache__/nltk_framenet.cpython-39.pyc b/spanfinder/tools/framenet/__pycache__/nltk_framenet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4f2810e4cc968126d3571f1d9b6ab5de88d4f80 Binary files /dev/null and b/spanfinder/tools/framenet/__pycache__/nltk_framenet.cpython-39.pyc differ diff --git a/spanfinder/tools/framenet/__pycache__/retokenize_fn.cpython-39.pyc b/spanfinder/tools/framenet/__pycache__/retokenize_fn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e6518471e7cc50a137a44ef4c369eefdc94ede3 Binary files /dev/null and b/spanfinder/tools/framenet/__pycache__/retokenize_fn.cpython-39.pyc differ diff --git a/spanfinder/tools/framenet/concrete_fn.py b/spanfinder/tools/framenet/concrete_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..192bc77349028f39e8349004a132980843808fee --- /dev/null +++ b/spanfinder/tools/framenet/concrete_fn.py @@ -0,0 +1,117 @@ +import argparse +import os +from collections import defaultdict +from typing import Dict, Any + +from concrete.util import CommunicationWriterTGZ +from nltk.corpus import framenet, framenet15 +from tqdm import tqdm + +from sftp.data_reader.concrete_srl import concrete_doc +from tools.framenet.fn_util import framenet_split, Sentence as TokSentence + + +def process_sentence(sent) -> Dict[str, Any]: + ret = {'sentence': sent.text, 'tokenization': list(), 'annotations': list()} + tok_sent = TokSentence(sent.text) + for token in tok_sent.tokens: + ret['tokenization'].append((token.idx, token.idx_end-1)) + + def process_one_ann_set(ann_set): + ret['annotations'].append(event := {'label': ann_set.frame.name, 'children': (arg_list := list())}) + target_list = list() + for tar_start, tar_end in ann_set.Target: + target_list.extend( + list(range(tok_sent.span(tar_start, tar_end)[0], tok_sent.span(tar_start, tar_end)[1]+1)) + ) + target_list.sort() + event['span'] = (target_list[0], target_list[-1]) + + for fe_start, fe_end, fe_name in ann_set.FE[0]: + fe_start, fe_end = tok_sent.span(fe_start, fe_end) + arg_list.append({ + 'span': (fe_start, fe_end), + 'label': fe_name + }) + + if 'annotationSet' in sent: + for ann_item in sent.annotationSet: + if 'Target' not in ann_item: + continue + process_one_ann_set(ann_item) + if 'Target' in sent: + process_one_ann_set(sent) + + return ret + + +def process_doc(docs, dst_path: str): + writer = CommunicationWriterTGZ(dst_path) + for doc in tqdm(docs): + sentences = list() + for sent in doc.sentence: + sentences.append(process_sentence(sent)) + comm = concrete_doc(sentences, doc.filename) + writer.write(comm, comm.id + '.concrete') + writer.close() + + +def process_exemplar(dst_path, fn): + bar = tqdm() + raw_annotations = list() + print('Loading exemplars...') + try: + for ann_sent in fn.annotations(full_text=False): + if 'Target' not in ann_sent: + continue + bar.update() + raw_annotations.append(ann_sent) + except RuntimeError: + pass + finally: + bar.close() + + char_idx_offset = 0 + sentences = list() + for sent in raw_annotations: + sentences.append(process_sentence(sent)) + char_idx_offset += len(sent.text)+1 + + comm = concrete_doc(sentences, 'exemplar') + CommunicationWriterTGZ(dst_path).write(comm, 'exemplar.concrete') + + +def run(): + parser = argparse.ArgumentParser() + parser.add_argument( + 'dst', metavar='DESTINATION', type=str, + help='Destination folder path.' + ) + parser.add_argument( + '-v', metavar='VERSION', default='1.7', type=str, choices=['1.5', '1.7'], + help='Version of FrameNet. Either 1.5 or 1.7.' + ) + args = parser.parse_args() + fn = framenet if args.v == '1.7' else framenet15 + os.makedirs(args.dst, exist_ok=True) + + doc_group = defaultdict(list) + for doc in fn.docs(): + if doc.filename in framenet_split['dev']: + doc_group['dev'].append(doc) + elif doc.filename in framenet_split['test']: + doc_group['test'].append(doc) + else: + doc_group['train'].append(doc) + + for sp in framenet_split: + print(f'Loaded {len(doc_group[sp])} docs for {sp}.') + + for sp in framenet_split: + process_doc(doc_group[sp], dst_path=os.path.join(args.dst, f'{sp}.tar.gz')) + + process_exemplar(os.path.join(args.dst, 'exemplar.tar.gz'), fn) + + +if __name__ == '__main__': + run() diff --git a/spanfinder/tools/framenet/fn_util.py b/spanfinder/tools/framenet/fn_util.py new file mode 100644 index 0000000000000000000000000000000000000000..656330468aa9b0cde9e44eb1a85e429c3dc37f7c --- /dev/null +++ b/spanfinder/tools/framenet/fn_util.py @@ -0,0 +1,145 @@ +from allennlp.data.tokenizers.spacy_tokenizer import SpacyTokenizer + +framenet_split = { + "train": [ + "LUCorpus-v0.3__CNN_AARONBROWN_ENG_20051101_215800.partial-NEW.xml", + "NTI__Iran_Chemical.xml", + "NTI__Taiwan_Introduction.xml", + "LUCorpus-v0.3__20000416_xin_eng-NEW.xml", + "NTI__NorthKorea_ChemicalOverview.xml", + "NTI__workAdvances.xml", + "C-4__C-4Text.xml", + "ANC__IntroOfDublin.xml", + "LUCorpus-v0.3__20000420_xin_eng-NEW.xml", + "NTI__BWTutorial_chapter1.xml", + "ANC__110CYL068.xml", + "LUCorpus-v0.3__artb_004_A1_E1_NEW.xml", + "NTI__Iran_Missile.xml", + "LUCorpus-v0.3__20000424_nyt-NEW.xml", + "LUCorpus-v0.3__wsj_1640.mrg-NEW.xml", + "ANC__110CYL070.xml", + "NTI__Iran_Introduction.xml", + "KBEval__lcch.xml", + "ANC__HistoryOfLasVegas.xml", + "LUCorpus-v0.3__wsj_2465.xml", + "KBEval__LCC-M.xml", + "LUCorpus-v0.3__artb_004_A1_E2_NEW.xml", + "LUCorpus-v0.3__AFGP-2002-600002-Trans.xml", + "LUCorpus-v0.3__602CZL285-1.xml", + "PropBank__LomaPrieta.xml", + "NTI__Iran_Biological.xml", + "NTI__Kazakhstan.xml", + "LUCorpus-v0.3__AFGP-2002-600045-Trans.xml", + "NTI__Iran_Nuclear.xml", + "ANC__EntrepreneurAsMadonna.xml", + "SemAnno__Text1.xml", + "ANC__HistoryOfJerusalem.xml", + "NTI__ChinaOverview.xml", + "PropBank__ElectionVictory.xml", + "NTI__Russia_Introduction.xml", + "NTI__SouthAfrica_Introduction.xml", + "LUCorpus-v0.3__20000419_apw_eng-NEW.xml", + "NTI__LibyaCountry1.xml", + "ANC__IntroJamaica.xml", + "QA__IranRelatedQuestions.xml", + "ANC__HistoryOfGreece.xml", + "NTI__NorthKorea_NuclearCapabilities.xml", + "PropBank__BellRinging.xml", + "PropBank__PolemicProgressiveEducation.xml", + "NTI__WMDNews_042106.xml", + "ANC__110CYL200.xml", + "LUCorpus-v0.3__CNN_ENG_20030614_173123.4-NEW-1.xml" + ], + + "dev": [ + "NTI__WMDNews_062606.xml", + "LUCorpus-v0.3__ENRON-pearson-email-25jul02.xml", + "KBEval__MIT.xml", + "ANC__110CYL072.xml", + "LUCorpus-v0.3__20000415_apw_eng-NEW.xml", + "Miscellaneous__Hijack.xml", + "PropBank__TicketSplitting.xml", + "NTI__NorthKorea_NuclearOverview.xml" + ], + + "test": [ + "NTI__NorthKorea_Introduction.xml", + "LUCorpus-v0.3__enron-thread-159550.xml", + "ANC__WhereToHongKong.xml", + "KBEval__atm.xml", + "ANC__112C-L013.xml", + "LUCorpus-v0.3__IZ-060316-01-Trans-1.xml", + "LUCorpus-v0.3__AFGP-2002-602187-Trans.xml", + "ANC__StephanopoulosCrimes.xml", + "ANC__110CYL069.xml", + "ANC__110CYL067.xml", + "ANC__IntroHongKong.xml", + "LUCorpus-v0.3__20000410_nyt-NEW.xml", + "KBEval__Brandeis.xml", + "KBEval__Stanford.xml", + "LUCorpus-v0.3__SNO-525.xml", + "PropBank__AetnaLifeAndCasualty.xml", + "Miscellaneous__Hound-Ch14.xml", + "NTI__Syria_NuclearOverview.xml", + "KBEval__cycorp.xml", + "KBEval__utd-icsi.xml", + "LUCorpus-v0.3__sw2025-ms98-a-trans.ascii-1-NEW.xml", + "Miscellaneous__SadatAssassination.xml", + "KBEval__parc.xml" + ] +} + +_spacy_tokenizer = SpacyTokenizer(language='en_core_web_sm', pos_tags=True) + + +class Sentence: + def __init__(self, text): + """ + Re-tokenize sentence. Map character indices to token indices. + We assume the char and token span indices are left inclusive and right inclusive. + """ + self.tokens = _spacy_tokenizer.tokenize(text) + + @property + def pos(self): + return [t.pos_ for t in self.tokens] + + @property + def tag(self): + return [t.tag_ for t in self.tokens] + + @property + def starts(self): + return [t.idx for t in self.tokens] + + @property + def ends(self): + return [t.idx_end for t in self.tokens] + + def char2token(self, char_idx): + """ + If char_idx falls into the a token, return the index of this token. + Elif char_idx falls into the gap between 2 tokens, return the index of the previous token. + Elif char_idx is lower than the first token, return 0. + Elif return the index of the last token. + """ + if char_idx < self.starts[0]: + return 0 + if char_idx >= self.starts[-1]: + return len(self.tokens)-1 + for i_tok, start_idx in enumerate(self.starts): + if start_idx == char_idx: + return i_tok + if start_idx > char_idx: + return i_tok-1 + + def span(self, start, end): + # Left inclusive, right inclusive + assert end > start + start, end = self.char2token(start), self.char2token(end-1) + assert end >= start + return start, end + + def __repr__(self): + return self.tokens.__repr__() + diff --git a/spanfinder/tools/framenet/gen_fn_data.py b/spanfinder/tools/framenet/gen_fn_data.py new file mode 100644 index 0000000000000000000000000000000000000000..ad1c9534ca7e9f0596abf3ffe94da511a02f23d8 --- /dev/null +++ b/spanfinder/tools/framenet/gen_fn_data.py @@ -0,0 +1,37 @@ +import json +import os +from argparse import ArgumentParser + +from tools.framenet.retokenize_fn import load_nltk_exemplars, load_nltk_fully_annotated + + +def main(src_path, dst_path): + if src_path is not None: + full = json.load(open(os.path.join(src_path, 'full.17.json'))) + exe = json.load(open(os.path.join(src_path, 'exe.17.json'))) + else: + full = load_nltk_fully_annotated('1.7') + exe = load_nltk_exemplars('1.7') + train, dev, test = full['train'], full['dev'], full['test'] + + def dump(train_set, path): + os.makedirs(path, exist_ok=True) + for split, data_set in zip(['train', 'dev', 'test'], [train_set, dev, test]): + open(os.path.join(path, split+'.jsonl'), 'w').write('\n'.join(map(json.dumps, data_set))) + open(os.path.join(path, 'full.jsonl'), 'w').write('\n'.join(map(json.dumps, train_set+dev+test))) + + # Full text only + dump(train, os.path.join(dst_path, 'full')) + # Full test + exemplar + dump(train+exe, os.path.join(dst_path, 'full_exe')) + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('dst', metavar='destination') + parser.add_argument( + '-s', metavar='data', default=None, + help='Path to retokenized framenet. If not provided, will re-load.' + ) + cmd_args = parser.parse_args() + main(cmd_args.s, cmd_args.dst) diff --git a/spanfinder/tools/framenet/naive_identifier.py b/spanfinder/tools/framenet/naive_identifier.py new file mode 100644 index 0000000000000000000000000000000000000000..3520f95f1064f5ceab159888dd5e59441ff6804d --- /dev/null +++ b/spanfinder/tools/framenet/naive_identifier.py @@ -0,0 +1,108 @@ +from collections import defaultdict +from itertools import product +from typing import * + +import nltk +from nltk.corpus import framenet, framenet15 +from nltk.stem import WordNetLemmatizer + +lemmatizer = WordNetLemmatizer() + + +manual = { + '\'s': 'be', + '\'re': 'be', + '\'ve': 'have', + 've': 'have', + 'men': 'man', + 'saw': 'see', + 'could': 'can', + 'neighbour': 'neighbor', + 'felt': 'feel', + 'fell': 'fall', + 'little': 'a little', + 'have': 'have to', + 'raping': 'rape', + 'flavor': 'flavour', + 'ca': 'can', + 'bit': 'a bit', +} + + +def load_framenet_corpus(version): + if '1.5' in version: + nltk.download('framenet_v15') + return framenet15 + elif '1.7' in version: + nltk.download('framenet_v17') + return framenet + else: + raise NotImplementedError + + +def is_word(s: str): + return all([c.isalpha() or c in ' -\'' for c in s]) + + +def lu_to_frame(version: str): + fn = load_framenet_corpus(version) + fn._bad_statuses = [] + map_no_pos = defaultdict(set) + lexicon_set = set() + for frame in fn.frames(): + for lu in frame.lexUnit: + assert lu.count('.') == 1 + lexicon, pos = lu.split('.') + lexicon = lexicon.lower() + lexicon = ' '.join(filter(lambda x: is_word(x), lexicon.split())) + if lexicon == '': + continue + map_no_pos[lexicon].add(frame.name) + lexicon_set.add(lexicon) + fn._bad_statuses = [] + return map_no_pos + + +class FrameIdentifier: + def __init__(self): + lf_map = lu_to_frame('1.7') + lf_map['there have'].add('Existence') + lf_map['there must'].add('Existence') + lf_map['be there'].add('Existence') + self.lf_map = dict(lf_map) + + def __call__(self, tokens: List[str]): + if len(tokens) == 1 and tokens[0].isnumeric(): + return ['Cardinal_numbers'] + if len(tokens) == 1 and tokens[0].endswith('th') and tokens[0][:-2].isnumeric(): + return ['Ordinal_numbers'] + tokens = [t.lower() for t in tokens] + frames = list() + + if not all([is_word(t) for t in tokens]): + return [] + + for i, token in enumerate(tokens): + t2s = [token] + for _pos in 'asrnv': + t2s.append(lemmatizer.lemmatize(token, _pos)) + for t_ in t2s: + if t_ in manual: + t2s.append(manual[t_]) + t2s = list(set(t2s)) + tokens[i] = t2s + + for t2s in tokens: + for t in t2s: + key = t + if key in self.lf_map: + for f in self.lf_map[key]: + frames.append(f) + for t1, t2 in zip(tokens, tokens[1:]): + for ts in product(t1, t2): + t = ' '.join(ts) + if t in self.lf_map: + for f in self.lf_map[t]: + frames.append(f) + + return list(set(frames)) diff --git a/spanfinder/tools/framenet/nltk_framenet.py b/spanfinder/tools/framenet/nltk_framenet.py new file mode 100644 index 0000000000000000000000000000000000000000..4672eb1259fde8a24f524a5a6a6f4389f39561a3 --- /dev/null +++ b/spanfinder/tools/framenet/nltk_framenet.py @@ -0,0 +1,3478 @@ + +# Modified version of nltk framenet +# Natural Language Toolkit: Framenet Corpus Reader +# +# Copyright (C) 2001-2020 NLTK Project +# Authors: Chuck Wooters , +# Nathan Schneider +# URL: +# For license information, see LICENSE.TXT + + +""" +Corpus reader for the FrameNet 1.7 lexicon and corpus. +""" + +import os +import re +import textwrap +import itertools +import sys +import types +from collections import defaultdict, OrderedDict +from operator import itemgetter +from itertools import zip_longest + +from pprint import pprint + +from nltk.corpus.reader import XMLCorpusReader, XMLCorpusView +from nltk.corpus.util import LazyCorpusLoader +from nltk.util import LazyConcatenation, LazyMap, LazyIteratorList + +__docformat__ = "epytext en" + + +def mimic_wrap(lines, wrap_at=65, **kwargs): + """ + Wrap the first of 'lines' with textwrap and the remaining lines at exactly the same + positions as the first. + """ + l0 = textwrap.fill(lines[0], wrap_at, drop_whitespace=False).split("\n") + yield l0 + + def _(line): + il0 = 0 + while line and il0 < len(l0) - 1: + yield line[: len(l0[il0])] + line = line[len(l0[il0]) :] + il0 += 1 + if line: # Remaining stuff on this line past the end of the mimicked line. + # So just textwrap this line. + for ln in textwrap.fill(line, wrap_at, drop_whitespace=False).split("\n"): + yield ln + + for l in lines[1:]: + yield list(_(l)) + + +def _pretty_longstring(defstr, prefix="", wrap_at=65): + + """ + Helper function for pretty-printing a long string. + + :param defstr: The string to be printed. + :type defstr: str + :return: A nicely formated string representation of the long string. + :rtype: str + """ + + outstr = "" + for line in textwrap.fill(defstr, wrap_at).split("\n"): + outstr += prefix + line + "\n" + return outstr + + +def _pretty_any(obj): + + """ + Helper function for pretty-printing any AttrDict object. + + :param obj: The obj to be printed. + :type obj: AttrDict + :return: A nicely formated string representation of the AttrDict object. + :rtype: str + """ + + outstr = "" + for k in obj: + if isinstance(obj[k], str) and len(obj[k]) > 65: + outstr += "[{0}]\n".format(k) + outstr += "{0}".format(_pretty_longstring(obj[k], prefix=" ")) + outstr += "\n" + else: + outstr += "[{0}] {1}\n".format(k, obj[k]) + + return outstr + + +def _pretty_semtype(st): + + """ + Helper function for pretty-printing a semantic type. + + :param st: The semantic type to be printed. + :type st: AttrDict + :return: A nicely formated string representation of the semantic type. + :rtype: str + """ + + semkeys = st.keys() + if len(semkeys) == 1: + return "" + + outstr = "" + outstr += "semantic type ({0.ID}): {0.name}\n".format(st) + if "abbrev" in semkeys: + outstr += "[abbrev] {0}\n".format(st.abbrev) + if "definition" in semkeys: + outstr += "[definition]\n" + outstr += _pretty_longstring(st.definition, " ") + outstr += "[rootType] {0}({1})\n".format(st.rootType.name, st.rootType.ID) + if st.superType is None: + outstr += "[superType] \n" + else: + outstr += "[superType] {0}({1})\n".format(st.superType.name, st.superType.ID) + outstr += "[subTypes] {0} subtypes\n".format(len(st.subTypes)) + outstr += ( + " " + + ", ".join("{0}({1})".format(x.name, x.ID) for x in st.subTypes) + + "\n" * (len(st.subTypes) > 0) + ) + return outstr + + +def _pretty_frame_relation_type(freltyp): + + """ + Helper function for pretty-printing a frame relation type. + + :param freltyp: The frame relation type to be printed. + :type freltyp: AttrDict + :return: A nicely formated string representation of the frame relation type. + :rtype: str + """ + outstr = " {0.subFrameName}>".format( + freltyp + ) + return outstr + + +def _pretty_frame_relation(frel): + + """ + Helper function for pretty-printing a frame relation. + + :param frel: The frame relation to be printed. + :type frel: AttrDict + :return: A nicely formated string representation of the frame relation. + :rtype: str + """ + outstr = "<{0.type.superFrameName}={0.superFrameName} -- {0.type.name} -> {0.type.subFrameName}={0.subFrameName}>".format( + frel + ) + return outstr + + +def _pretty_fe_relation(ferel): + + """ + Helper function for pretty-printing an FE relation. + + :param ferel: The FE relation to be printed. + :type ferel: AttrDict + :return: A nicely formated string representation of the FE relation. + :rtype: str + """ + outstr = "<{0.type.superFrameName}={0.frameRelation.superFrameName}.{0.superFEName} -- {0.type.name} -> {0.type.subFrameName}={0.frameRelation.subFrameName}.{0.subFEName}>".format( + ferel + ) + return outstr + + +def _pretty_lu(lu): + + """ + Helper function for pretty-printing a lexical unit. + + :param lu: The lu to be printed. + :type lu: AttrDict + :return: A nicely formated string representation of the lexical unit. + :rtype: str + """ + + lukeys = lu.keys() + outstr = "" + outstr += "lexical unit ({0.ID}): {0.name}\n\n".format(lu) + if "definition" in lukeys: + outstr += "[definition]\n" + outstr += _pretty_longstring(lu.definition, " ") + if "frame" in lukeys: + outstr += "\n[frame] {0}({1})\n".format(lu.frame.name, lu.frame.ID) + if "incorporatedFE" in lukeys: + outstr += "\n[incorporatedFE] {0}\n".format(lu.incorporatedFE) + if "POS" in lukeys: + outstr += "\n[POS] {0}\n".format(lu.POS) + if "status" in lukeys: + outstr += "\n[status] {0}\n".format(lu.status) + if "totalAnnotated" in lukeys: + outstr += "\n[totalAnnotated] {0} annotated examples\n".format( + lu.totalAnnotated + ) + if "lexemes" in lukeys: + outstr += "\n[lexemes] {0}\n".format( + " ".join("{0}/{1}".format(lex.name, lex.POS) for lex in lu.lexemes) + ) + if "semTypes" in lukeys: + outstr += "\n[semTypes] {0} semantic types\n".format(len(lu.semTypes)) + outstr += ( + " " * (len(lu.semTypes) > 0) + + ", ".join("{0}({1})".format(x.name, x.ID) for x in lu.semTypes) + + "\n" * (len(lu.semTypes) > 0) + ) + if "URL" in lukeys: + outstr += "\n[URL] {0}\n".format(lu.URL) + if "subCorpus" in lukeys: + subc = [x.name for x in lu.subCorpus] + outstr += "\n[subCorpus] {0} subcorpora\n".format(len(lu.subCorpus)) + for line in textwrap.fill(", ".join(sorted(subc)), 60).split("\n"): + outstr += " {0}\n".format(line) + if "exemplars" in lukeys: + outstr += "\n[exemplars] {0} sentences across all subcorpora\n".format( + len(lu.exemplars) + ) + + return outstr + + +def _pretty_exemplars(exemplars, lu): + """ + Helper function for pretty-printing a list of exemplar sentences for a lexical unit. + + :param sent: The list of exemplar sentences to be printed. + :type sent: list(AttrDict) + :return: An index of the text of the exemplar sentences. + :rtype: str + """ + + outstr = "" + outstr += "exemplar sentences for {0.name} in {0.frame.name}:\n\n".format(lu) + for i, sent in enumerate(exemplars): + outstr += "[{0}] {1}\n".format(i, sent.text) + outstr += "\n" + return outstr + + +def _pretty_fulltext_sentences(sents): + """ + Helper function for pretty-printing a list of annotated sentences for a full-text document. + + :param sent: The list of sentences to be printed. + :type sent: list(AttrDict) + :return: An index of the text of the sentences. + :rtype: str + """ + + outstr = "" + outstr += "full-text document ({0.ID}) {0.name}:\n\n".format(sents) + outstr += "[corpid] {0.corpid}\n[corpname] {0.corpname}\n[description] {0.description}\n[URL] {0.URL}\n\n".format( + sents + ) + outstr += "[sentence]\n".format(sents) + for i, sent in enumerate(sents.sentence): + outstr += "[{0}] {1}\n".format(i, sent.text) + outstr += "\n" + return outstr + + +def _pretty_fulltext_sentence(sent): + """ + Helper function for pretty-printing an annotated sentence from a full-text document. + + :param sent: The sentence to be printed. + :type sent: list(AttrDict) + :return: The text of the sentence with annotation set indices on frame targets. + :rtype: str + """ + + outstr = "" + outstr += "full-text sentence ({0.ID}) in {1}:\n\n".format( + sent, sent.doc.get("name", sent.doc.description) + ) + outstr += "\n[POS] {0} tags\n".format(len(sent.POS)) + outstr += "\n[POS_tagset] {0}\n\n".format(sent.POS_tagset) + outstr += "[text] + [annotationSet]\n\n" + outstr += sent._ascii() # -> _annotation_ascii() + outstr += "\n" + return outstr + + +def _pretty_pos(aset): + """ + Helper function for pretty-printing a sentence with its POS tags. + + :param aset: The POS annotation set of the sentence to be printed. + :type sent: list(AttrDict) + :return: The text of the sentence and its POS tags. + :rtype: str + """ + + outstr = "" + outstr += "POS annotation set ({0.ID}) {0.POS_tagset} in sentence {0.sent.ID}:\n\n".format( + aset + ) + + # list the target spans and their associated aset index + overt = sorted(aset.POS) + + sent = aset.sent + s0 = sent.text + s1 = "" + s2 = "" + i = 0 + adjust = 0 + for j, k, lbl in overt: + assert j >= i, ("Overlapping targets?", (j, k, lbl)) + s1 += " " * (j - i) + "-" * (k - j) + if len(lbl) > (k - j): + # add space in the sentence to make room for the annotation index + amt = len(lbl) - (k - j) + s0 = ( + s0[: k + adjust] + "~" * amt + s0[k + adjust :] + ) # '~' to prevent line wrapping + s1 = s1[: k + adjust] + " " * amt + s1[k + adjust :] + adjust += amt + s2 += " " * (j - i) + lbl.ljust(k - j) + i = k + + long_lines = [s0, s1, s2] + + outstr += "\n\n".join( + map("\n".join, zip_longest(*mimic_wrap(long_lines), fillvalue=" ")) + ).replace("~", " ") + outstr += "\n" + return outstr + + +def _pretty_annotation(sent, aset_level=False): + """ + Helper function for pretty-printing an exemplar sentence for a lexical unit. + + :param sent: An annotation set or exemplar sentence to be printed. + :param aset_level: If True, 'sent' is actually an annotation set within a sentence. + :type sent: AttrDict + :return: A nicely formated string representation of the exemplar sentence + with its target, frame, and FE annotations. + :rtype: str + """ + + sentkeys = sent.keys() + outstr = "annotation set" if aset_level else "exemplar sentence" + outstr += " ({0.ID}):\n".format(sent) + if aset_level: # TODO: any UNANN exemplars? + outstr += "\n[status] {0}\n".format(sent.status) + for k in ("corpID", "docID", "paragNo", "sentNo", "aPos"): + if k in sentkeys: + outstr += "[{0}] {1}\n".format(k, sent[k]) + outstr += ( + "\n[LU] ({0.ID}) {0.name} in {0.frame.name}\n".format(sent.LU) + if sent.LU + else "\n[LU] Not found!" + ) + outstr += "\n[frame] ({0.ID}) {0.name}\n".format( + sent.frame + ) # redundant with above, but .frame is convenient + if not aset_level: + outstr += "\n[annotationSet] {0} annotation sets\n".format( + len(sent.annotationSet) + ) + outstr += "\n[POS] {0} tags\n".format(len(sent.POS)) + outstr += "\n[POS_tagset] {0}\n".format(sent.POS_tagset) + outstr += "\n[GF] {0} relation{1}\n".format( + len(sent.GF), "s" if len(sent.GF) != 1 else "" + ) + outstr += "\n[PT] {0} phrase{1}\n".format( + len(sent.PT), "s" if len(sent.PT) != 1 else "" + ) + """ + Special Layers + -------------- + + The 'NER' layer contains, for some of the data, named entity labels. + + The 'WSL' (word status layer) contains, for some of the data, + spans which should not in principle be considered targets (NT). + + The 'Other' layer records relative clause constructions (Rel=relativizer, Ant=antecedent), + pleonastic 'it' (Null), and existential 'there' (Exist). + On occasion they are duplicated by accident (e.g., annotationSet 1467275 in lu6700.xml). + + The 'Sent' layer appears to contain labels that the annotator has flagged the + sentence with for their convenience: values include + 'sense1', 'sense2', 'sense3', etc.; + 'Blend', 'Canonical', 'Idiom', 'Metaphor', 'Special-Sent', + 'keepS', 'deleteS', 'reexamine' + (sometimes they are duplicated for no apparent reason). + + The POS-specific layers may contain the following kinds of spans: + Asp (aspectual particle), Non-Asp (non-aspectual particle), + Cop (copula), Supp (support), Ctrlr (controller), + Gov (governor), X. Gov and X always cooccur. + + >>> from nltk.corpus import framenet as fn + >>> def f(luRE, lyr, ignore=set()): + ... for i,ex in enumerate(fn.exemplars(luRE)): + ... if lyr in ex and ex[lyr] and set(zip(*ex[lyr])[2]) - ignore: + ... print(i,ex[lyr]) + + - Verb: Asp, Non-Asp + - Noun: Cop, Supp, Ctrlr, Gov, X + - Adj: Cop, Supp, Ctrlr, Gov, X + - Prep: Cop, Supp, Ctrlr + - Adv: Ctrlr + - Scon: (none) + - Art: (none) + """ + for lyr in ("NER", "WSL", "Other", "Sent"): + if lyr in sent and sent[lyr]: + outstr += "\n[{0}] {1} entr{2}\n".format( + lyr, len(sent[lyr]), "ies" if len(sent[lyr]) != 1 else "y" + ) + outstr += "\n[text] + [Target] + [FE]" + # POS-specific layers: syntactically important words that are neither the target + # nor the FEs. Include these along with the first FE layer but with '^' underlining. + for lyr in ("Verb", "Noun", "Adj", "Adv", "Prep", "Scon", "Art"): + if lyr in sent and sent[lyr]: + outstr += " + [{0}]".format(lyr) + if "FE2" in sentkeys: + outstr += " + [FE2]" + if "FE3" in sentkeys: + outstr += " + [FE3]" + outstr += "\n\n" + outstr += sent._ascii() # -> _annotation_ascii() + outstr += "\n" + + return outstr + + +def _annotation_ascii(sent): + """ + Given a sentence or FE annotation set, construct the width-limited string showing + an ASCII visualization of the sentence's annotations, calling either + _annotation_ascii_frames() or _annotation_ascii_FEs() as appropriate. + This will be attached as a method to appropriate AttrDict instances + and called in the full pretty-printing of the instance. + """ + if sent._type == "fulltext_sentence" or ( + "annotationSet" in sent and len(sent.annotationSet) > 2 + ): + # a full-text sentence OR sentence with multiple targets. + # (multiple targets = >2 annotation sets, because the first annotation set is POS.) + return _annotation_ascii_frames(sent) + else: # an FE annotation set, or an LU sentence with 1 target + return _annotation_ascii_FEs(sent) + + +def _annotation_ascii_frames(sent): + """ + ASCII string rendering of the sentence along with its targets and frame names. + Called for all full-text sentences, as well as the few LU sentences with multiple + targets (e.g., fn.lu(6412).exemplars[82] has two want.v targets). + Line-wrapped to limit the display width. + """ + # list the target spans and their associated aset index + overt = [] + for a, aset in enumerate(sent.annotationSet[1:]): + for j, k in aset.Target: + indexS = "[{0}]".format(a + 1) + if aset.status == "UNANN" or aset.LU.status == "Problem": + indexS += " " + if aset.status == "UNANN": + indexS += ( + "!" + ) # warning indicator that there is a frame annotation but no FE annotation + if aset.LU.status == "Problem": + indexS += ( + "?" + ) # warning indicator that there is a missing LU definition (because the LU has Problem status) + overt.append((j, k, aset.LU.frame.name, indexS)) + overt = sorted(overt) + + duplicates = set() + for o, (j, k, fname, asetIndex) in enumerate(overt): + if o > 0 and j <= overt[o - 1][1]: + # multiple annotation sets on the same target + # (e.g. due to a coordination construction or multiple annotators) + if ( + overt[o - 1][:2] == (j, k) and overt[o - 1][2] == fname + ): # same target, same frame + # splice indices together + combinedIndex = ( + overt[o - 1][3] + asetIndex + ) # e.g., '[1][2]', '[1]! [2]' + combinedIndex = combinedIndex.replace(" !", "! ").replace(" ?", "? ") + overt[o - 1] = overt[o - 1][:3] + (combinedIndex,) + duplicates.add(o) + else: # different frames, same or overlapping targets + s = sent.text + for j, k, fname, asetIndex in overt: + s += "\n" + asetIndex + " " + sent.text[j:k] + " :: " + fname + s += "\n(Unable to display sentence with targets marked inline due to overlap)" + return s + for o in reversed(sorted(duplicates)): + del overt[o] + + s0 = sent.text + s1 = "" + s11 = "" + s2 = "" + i = 0 + adjust = 0 + fAbbrevs = OrderedDict() + for j, k, fname, asetIndex in overt: + if not j >= i: + assert j >= i, ( + "Overlapping targets?" + + ( + " UNANN" + if any(aset.status == "UNANN" for aset in sent.annotationSet[1:]) + else "" + ), + (j, k, asetIndex), + ) + s1 += " " * (j - i) + "*" * (k - j) + short = fname[: k - j] + if (k - j) < len(fname): + r = 0 + while short in fAbbrevs: + if fAbbrevs[short] == fname: + break + r += 1 + short = fname[: k - j - 1] + str(r) + else: # short not in fAbbrevs + fAbbrevs[short] = fname + s11 += " " * (j - i) + short.ljust(k - j) + if len(asetIndex) > (k - j): + # add space in the sentence to make room for the annotation index + amt = len(asetIndex) - (k - j) + s0 = ( + s0[: k + adjust] + "~" * amt + s0[k + adjust :] + ) # '~' to prevent line wrapping + s1 = s1[: k + adjust] + " " * amt + s1[k + adjust :] + s11 = s11[: k + adjust] + " " * amt + s11[k + adjust :] + adjust += amt + s2 += " " * (j - i) + asetIndex.ljust(k - j) + i = k + + long_lines = [s0, s1, s11, s2] + + outstr = "\n\n".join( + map("\n".join, zip_longest(*mimic_wrap(long_lines), fillvalue=" ")) + ).replace("~", " ") + outstr += "\n" + if fAbbrevs: + outstr += " (" + ", ".join("=".join(pair) for pair in fAbbrevs.items()) + ")" + assert len(fAbbrevs) == len(dict(fAbbrevs)), "Abbreviation clash" + + return outstr + + +def _annotation_ascii_FE_layer(overt, ni, feAbbrevs): + """Helper for _annotation_ascii_FEs().""" + s1 = "" + s2 = "" + i = 0 + for j, k, fename in overt: + s1 += " " * (j - i) + ("^" if fename.islower() else "-") * (k - j) + short = fename[: k - j] + if len(fename) > len(short): + r = 0 + while short in feAbbrevs: + if feAbbrevs[short] == fename: + break + r += 1 + short = fename[: k - j - 1] + str(r) + else: # short not in feAbbrevs + feAbbrevs[short] = fename + s2 += " " * (j - i) + short.ljust(k - j) + i = k + + sNI = "" + if ni: + sNI += " [" + ", ".join(":".join(x) for x in sorted(ni.items())) + "]" + return [s1, s2, sNI] + + +def _annotation_ascii_FEs(sent): + """ + ASCII string rendering of the sentence along with a single target and its FEs. + Secondary and tertiary FE layers are included if present. + 'sent' can be an FE annotation set or an LU sentence with a single target. + Line-wrapped to limit the display width. + """ + feAbbrevs = OrderedDict() + posspec = [] # POS-specific layer spans (e.g., Supp[ort], Cop[ula]) + posspec_separate = False + for lyr in ("Verb", "Noun", "Adj", "Adv", "Prep", "Scon", "Art"): + if lyr in sent and sent[lyr]: + for a, b, lbl in sent[lyr]: + if ( + lbl == "X" + ): # skip this, which covers an entire phrase typically containing the target and all its FEs + # (but do display the Gov) + continue + if any(1 for x, y, felbl in sent.FE[0] if x <= a < y or a <= x < b): + # overlap between one of the POS-specific layers and first FE layer + posspec_separate = ( + True + ) # show POS-specific layers on a separate line + posspec.append( + (a, b, lbl.lower().replace("-", "")) + ) # lowercase Cop=>cop, Non-Asp=>nonasp, etc. to distinguish from FE names + if posspec_separate: + POSSPEC = _annotation_ascii_FE_layer(posspec, {}, feAbbrevs) + FE1 = _annotation_ascii_FE_layer( + sorted(sent.FE[0] + (posspec if not posspec_separate else [])), + sent.FE[1], + feAbbrevs, + ) + FE2 = FE3 = None + if "FE2" in sent: + FE2 = _annotation_ascii_FE_layer(sent.FE2[0], sent.FE2[1], feAbbrevs) + if "FE3" in sent: + FE3 = _annotation_ascii_FE_layer(sent.FE3[0], sent.FE3[1], feAbbrevs) + + for i, j in sent.Target: + FE1span, FE1name, FE1exp = FE1 + if len(FE1span) < j: + FE1span += " " * (j - len(FE1span)) + if len(FE1name) < j: + FE1name += " " * (j - len(FE1name)) + FE1[1] = FE1name + FE1[0] = ( + FE1span[:i] + FE1span[i:j].replace(" ", "*").replace("-", "=") + FE1span[j:] + ) + long_lines = [sent.text] + if posspec_separate: + long_lines.extend(POSSPEC[:2]) + long_lines.extend([FE1[0], FE1[1] + FE1[2]]) # lines with no length limit + if FE2: + long_lines.extend([FE2[0], FE2[1] + FE2[2]]) + if FE3: + long_lines.extend([FE3[0], FE3[1] + FE3[2]]) + long_lines.append("") + outstr = "\n".join( + map("\n".join, zip_longest(*mimic_wrap(long_lines), fillvalue=" ")) + ) + if feAbbrevs: + outstr += "(" + ", ".join("=".join(pair) for pair in feAbbrevs.items()) + ")" + assert len(feAbbrevs) == len(dict(feAbbrevs)), "Abbreviation clash" + outstr += "\n" + + return outstr + + +def _pretty_fe(fe): + + """ + Helper function for pretty-printing a frame element. + + :param fe: The frame element to be printed. + :type fe: AttrDict + :return: A nicely formated string representation of the frame element. + :rtype: str + """ + fekeys = fe.keys() + outstr = "" + outstr += "frame element ({0.ID}): {0.name}\n of {1.name}({1.ID})\n".format( + fe, fe.frame + ) + if "definition" in fekeys: + outstr += "[definition]\n" + outstr += _pretty_longstring(fe.definition, " ") + if "abbrev" in fekeys: + outstr += "[abbrev] {0}\n".format(fe.abbrev) + if "coreType" in fekeys: + outstr += "[coreType] {0}\n".format(fe.coreType) + if "requiresFE" in fekeys: + outstr += "[requiresFE] " + if fe.requiresFE is None: + outstr += "\n" + else: + outstr += "{0}({1})\n".format(fe.requiresFE.name, fe.requiresFE.ID) + if "excludesFE" in fekeys: + outstr += "[excludesFE] " + if fe.excludesFE is None: + outstr += "\n" + else: + outstr += "{0}({1})\n".format(fe.excludesFE.name, fe.excludesFE.ID) + if "semType" in fekeys: + outstr += "[semType] " + if fe.semType is None: + outstr += "\n" + else: + outstr += "\n " + "{0}({1})".format(fe.semType.name, fe.semType.ID) + "\n" + + return outstr + + +def _pretty_frame(frame): + + """ + Helper function for pretty-printing a frame. + + :param frame: The frame to be printed. + :type frame: AttrDict + :return: A nicely formated string representation of the frame. + :rtype: str + """ + + outstr = "" + outstr += "frame ({0.ID}): {0.name}\n\n".format(frame) + outstr += "[URL] {0}\n\n".format(frame.URL) + outstr += "[definition]\n" + outstr += _pretty_longstring(frame.definition, " ") + "\n" + + outstr += "[semTypes] {0} semantic types\n".format(len(frame.semTypes)) + outstr += ( + " " * (len(frame.semTypes) > 0) + + ", ".join("{0}({1})".format(x.name, x.ID) for x in frame.semTypes) + + "\n" * (len(frame.semTypes) > 0) + ) + + outstr += "\n[frameRelations] {0} frame relations\n".format( + len(frame.frameRelations) + ) + outstr += " " + "\n ".join(repr(frel) for frel in frame.frameRelations) + "\n" + + outstr += "\n[lexUnit] {0} lexical units\n".format(len(frame.lexUnit)) + lustrs = [] + for luName, lu in sorted(frame.lexUnit.items()): + tmpstr = "{0} ({1})".format(luName, lu.ID) + lustrs.append(tmpstr) + outstr += "{0}\n".format(_pretty_longstring(", ".join(lustrs), prefix=" ")) + + outstr += "\n[FE] {0} frame elements\n".format(len(frame.FE)) + fes = {} + for feName, fe in sorted(frame.FE.items()): + try: + fes[fe.coreType].append("{0} ({1})".format(feName, fe.ID)) + except KeyError: + fes[fe.coreType] = [] + fes[fe.coreType].append("{0} ({1})".format(feName, fe.ID)) + for ct in sorted( + fes.keys(), + key=lambda ct2: [ + "Core", + "Core-Unexpressed", + "Peripheral", + "Extra-Thematic", + ].index(ct2), + ): + outstr += "{0:>16}: {1}\n".format(ct, ", ".join(sorted(fes[ct]))) + + outstr += "\n[FEcoreSets] {0} frame element core sets\n".format( + len(frame.FEcoreSets) + ) + outstr += ( + " " + + "\n ".join( + ", ".join([x.name for x in coreSet]) for coreSet in frame.FEcoreSets + ) + + "\n" + ) + + return outstr + + +class FramenetError(Exception): + + """An exception class for framenet-related errors.""" + + +class AttrDict(dict): + + """A class that wraps a dict and allows accessing the keys of the + dict as if they were attributes. Taken from here: + http://stackoverflow.com/a/14620633/8879 + + >>> foo = {'a':1, 'b':2, 'c':3} + >>> bar = AttrDict(foo) + >>> pprint(dict(bar)) + {'a': 1, 'b': 2, 'c': 3} + >>> bar.b + 2 + >>> bar.d = 4 + >>> pprint(dict(bar)) + {'a': 1, 'b': 2, 'c': 3, 'd': 4} + """ + + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + # self.__dict__ = self + + def __setattr__(self, name, value): + self[name] = value + + def __getattr__(self, name): + if name == "_short_repr": + return self._short_repr + return self[name] + + def __getitem__(self, name): + v = super(AttrDict, self).__getitem__(name) + if isinstance(v, Future): + return v._data() + return v + + def _short_repr(self): + if "_type" in self: + if self["_type"].endswith("relation"): + return self.__repr__() + try: + return "<{0} ID={1} name={2}>".format( + self["_type"], self["ID"], self["name"] + ) + except KeyError: + try: # no ID--e.g., for _type=lusubcorpus + return "<{0} name={1}>".format(self["_type"], self["name"]) + except KeyError: # no name--e.g., for _type=lusentence + return "<{0} ID={1}>".format(self["_type"], self["ID"]) + else: + return self.__repr__() + + def _str(self): + outstr = "" + + if "_type" not in self: + outstr = _pretty_any(self) + elif self["_type"] == "frame": + outstr = _pretty_frame(self) + elif self["_type"] == "fe": + outstr = _pretty_fe(self) + elif self["_type"] == "lu": + outstr = _pretty_lu(self) + elif self["_type"] == "luexemplars": # list of ALL exemplars for LU + outstr = _pretty_exemplars(self, self[0].LU) + elif ( + self["_type"] == "fulltext_annotation" + ): # list of all sentences for full-text doc + outstr = _pretty_fulltext_sentences(self) + elif self["_type"] == "lusentence": + outstr = _pretty_annotation(self) + elif self["_type"] == "fulltext_sentence": + outstr = _pretty_fulltext_sentence(self) + elif self["_type"] in ("luannotationset", "fulltext_annotationset"): + outstr = _pretty_annotation(self, aset_level=True) + elif self["_type"] == "posannotationset": + outstr = _pretty_pos(self) + elif self["_type"] == "semtype": + outstr = _pretty_semtype(self) + elif self["_type"] == "framerelationtype": + outstr = _pretty_frame_relation_type(self) + elif self["_type"] == "framerelation": + outstr = _pretty_frame_relation(self) + elif self["_type"] == "ferelation": + outstr = _pretty_fe_relation(self) + else: + outstr = _pretty_any(self) + + # ensure result is unicode string prior to applying the + # decorator (because non-ASCII characters + # could in principle occur in the data and would trigger an encoding error when + # passed as arguments to str.format()). + # assert isinstance(outstr, unicode) # not in Python 3.2 + return outstr + + def __str__(self): + return self._str() + + def __repr__(self): + return self.__str__() + + +class SpecialList(list): + """ + A list subclass which adds a '_type' attribute for special printing + (similar to an AttrDict, though this is NOT an AttrDict subclass). + """ + + def __init__(self, typ, *args, **kwargs): + super(SpecialList, self).__init__(*args, **kwargs) + self._type = typ + + def _str(self): + outstr = "" + + assert self._type + if len(self) == 0: + outstr = "[]" + elif self._type == "luexemplars": # list of ALL exemplars for LU + outstr = _pretty_exemplars(self, self[0].LU) + else: + assert False, self._type + return outstr + + def __str__(self): + return self._str() + + def __repr__(self): + return self.__str__() + + +class Future(object): + """ + Wraps and acts as a proxy for a value to be loaded lazily (on demand). + Adapted from https://gist.github.com/sergey-miryanov/2935416 + """ + + def __init__(self, loader, *args, **kwargs): + """ + :param loader: when called with no arguments, returns the value to be stored + :type loader: callable + """ + super(Future, self).__init__(*args, **kwargs) + self._loader = loader + self._d = None + + def _data(self): + if callable(self._loader): + self._d = self._loader() + self._loader = None # the data is now cached + return self._d + + def __nonzero__(self): + return bool(self._data()) + + def __len__(self): + return len(self._data()) + + def __setitem__(self, key, value): + return self._data().__setitem__(key, value) + + def __getitem__(self, key): + return self._data().__getitem__(key) + + def __getattr__(self, key): + return self._data().__getattr__(key) + + def __str__(self): + return self._data().__str__() + + def __repr__(self): + return self._data().__repr__() + + +class PrettyDict(AttrDict): + """ + Displays an abbreviated repr of values where possible. + Inherits from AttrDict, so a callable value will + be lazily converted to an actual value. + """ + + def __init__(self, *args, **kwargs): + _BREAK_LINES = kwargs.pop("breakLines", False) + super(PrettyDict, self).__init__(*args, **kwargs) + dict.__setattr__(self, "_BREAK_LINES", _BREAK_LINES) + + def __repr__(self): + parts = [] + for k, v in sorted(self.items()): + kv = repr(k) + ": " + try: + kv += v._short_repr() + except AttributeError: + kv += repr(v) + parts.append(kv) + return "{" + (",\n " if self._BREAK_LINES else ", ").join(parts) + "}" + + +class PrettyList(list): + """ + Displays an abbreviated repr of only the first several elements, not the whole list. + """ + + # from nltk.util + def __init__(self, *args, **kwargs): + self._MAX_REPR_SIZE = kwargs.pop("maxReprSize", 60) + self._BREAK_LINES = kwargs.pop("breakLines", False) + super(PrettyList, self).__init__(*args, **kwargs) + + def __repr__(self): + """ + Return a string representation for this corpus view that is + similar to a list's representation; but if it would be more + than 60 characters long, it is truncated. + """ + pieces = [] + length = 5 + + for elt in self: + pieces.append( + elt._short_repr() + ) # key difference from inherited version: call to _short_repr() + length += len(pieces[-1]) + 2 + if self._MAX_REPR_SIZE and length > self._MAX_REPR_SIZE and len(pieces) > 2: + return "[%s, ...]" % str( + ",\n " if self._BREAK_LINES else ", " + ).join(pieces[:-1]) + return "[%s]" % str(",\n " if self._BREAK_LINES else ", ").join(pieces) + + +class PrettyLazyMap(LazyMap): + """ + Displays an abbreviated repr of only the first several elements, not the whole list. + """ + + # from nltk.util + _MAX_REPR_SIZE = 60 + + def __repr__(self): + """ + Return a string representation for this corpus view that is + similar to a list's representation; but if it would be more + than 60 characters long, it is truncated. + """ + pieces = [] + length = 5 + for elt in self: + pieces.append( + elt._short_repr() + ) # key difference from inherited version: call to _short_repr() + length += len(pieces[-1]) + 2 + if length > self._MAX_REPR_SIZE and len(pieces) > 2: + return "[%s, ...]" % str(", ").join(pieces[:-1]) + return "[%s]" % str(", ").join(pieces) + + +class PrettyLazyIteratorList(LazyIteratorList): + """ + Displays an abbreviated repr of only the first several elements, not the whole list. + """ + + # from nltk.util + _MAX_REPR_SIZE = 60 + + def __repr__(self): + """ + Return a string representation for this corpus view that is + similar to a list's representation; but if it would be more + than 60 characters long, it is truncated. + """ + pieces = [] + length = 5 + for elt in self: + pieces.append( + elt._short_repr() + ) # key difference from inherited version: call to _short_repr() + length += len(pieces[-1]) + 2 + if length > self._MAX_REPR_SIZE and len(pieces) > 2: + return "[%s, ...]" % str(", ").join(pieces[:-1]) + return "[%s]" % str(", ").join(pieces) + + +class PrettyLazyConcatenation(LazyConcatenation): + """ + Displays an abbreviated repr of only the first several elements, not the whole list. + """ + + # from nltk.util + _MAX_REPR_SIZE = 60 + + def __repr__(self): + """ + Return a string representation for this corpus view that is + similar to a list's representation; but if it would be more + than 60 characters long, it is truncated. + """ + pieces = [] + length = 5 + for elt in self: + pieces.append( + elt._short_repr() + ) # key difference from inherited version: call to _short_repr() + length += len(pieces[-1]) + 2 + if length > self._MAX_REPR_SIZE and len(pieces) > 2: + return "[%s, ...]" % str(", ").join(pieces[:-1]) + return "[%s]" % str(", ").join(pieces) + + def __add__(self, other): + """Return a list concatenating self with other.""" + return PrettyLazyIteratorList(itertools.chain(self, other)) + + def __radd__(self, other): + """Return a list concatenating other with self.""" + return PrettyLazyIteratorList(itertools.chain(other, self)) + + +class FramenetCorpusReader(XMLCorpusReader): + """A corpus reader for the Framenet Corpus. + + >>> from nltk.corpus import framenet as fn + >>> fn.lu(3238).frame.lexUnit['glint.v'] is fn.lu(3238) + True + >>> fn.frame_by_name('Replacing') is fn.lus('replace.v')[0].frame + True + >>> fn.lus('prejudice.n')[0].frame.frameRelations == fn.frame_relations('Partiality') + True + """ + + _bad_statuses = ["Problem"] + """ + When loading LUs for a frame, those whose status is in this list will be ignored. + Due to caching, if user code modifies this, it should do so before loading any data. + 'Problem' should always be listed for FrameNet 1.5, as these LUs are not included + in the XML index. + """ + + _warnings = False + + def warnings(self, v): + """Enable or disable warnings of data integrity issues as they are encountered. + If v is truthy, warnings will be enabled. + + (This is a function rather than just an attribute/property to ensure that if + enabling warnings is the first action taken, the corpus reader is instantiated first.) + """ + self._warnings = v + + def __init__(self, root, fileids): + XMLCorpusReader.__init__(self, root, fileids) + + # framenet corpus sub dirs + # sub dir containing the xml files for frames + self._frame_dir = "frame" + # sub dir containing the xml files for lexical units + self._lu_dir = "lu" + # sub dir containing the xml files for fulltext annotation files + self._fulltext_dir = "fulltext" + + # location of latest development version of FrameNet + self._fnweb_url = "https://framenet2.icsi.berkeley.edu/fnReports/data" + + # Indexes used for faster look-ups + self._frame_idx = None + self._cached_frames = {} # name -> ID + self._lu_idx = None + self._fulltext_idx = None + self._semtypes = None + self._freltyp_idx = None # frame relation types (Inheritance, Using, etc.) + self._frel_idx = None # frame-to-frame relation instances + self._ferel_idx = None # FE-to-FE relation instances + self._frel_f_idx = None # frame-to-frame relations associated with each frame + + def help(self, attrname=None): + """Display help information summarizing the main methods.""" + + if attrname is not None: + return help(self.__getattribute__(attrname)) + + # No need to mention frame_by_name() or frame_by_id(), + # as it's easier to just call frame(). + # Also not mentioning lu_basic(). + + msg = """ +Citation: Nathan Schneider and Chuck Wooters (2017), +"The NLTK FrameNet API: Designing for Discoverability with a Rich Linguistic Resource". +Proceedings of EMNLP: System Demonstrations. https://arxiv.org/abs/1703.07438 + +Use the following methods to access data in FrameNet. +Provide a method name to `help()` for more information. + +FRAMES +====== + +frame() to look up a frame by its exact name or ID +frames() to get frames matching a name pattern +frames_by_lemma() to get frames containing an LU matching a name pattern +frame_ids_and_names() to get a mapping from frame IDs to names + +FRAME ELEMENTS +============== + +fes() to get frame elements (a.k.a. roles) matching a name pattern, optionally constrained + by a frame name pattern + +LEXICAL UNITS +============= + +lu() to look up an LU by its ID +lus() to get lexical units matching a name pattern, optionally constrained by frame +lu_ids_and_names() to get a mapping from LU IDs to names + +RELATIONS +========= + +frame_relation_types() to get the different kinds of frame-to-frame relations + (Inheritance, Subframe, Using, etc.). +frame_relations() to get the relation instances, optionally constrained by + frame(s) or relation type +fe_relations() to get the frame element pairs belonging to a frame-to-frame relation + +SEMANTIC TYPES +============== + +semtypes() to get the different kinds of semantic types that can be applied to + FEs, LUs, and entire frames +semtype() to look up a particular semtype by name, ID, or abbreviation +semtype_inherits() to check whether two semantic types have a subtype-supertype + relationship in the semtype hierarchy +propagate_semtypes() to apply inference rules that distribute semtypes over relations + between FEs + +ANNOTATIONS +=========== + +annotations() to get annotation sets, in which a token in a sentence is annotated + with a lexical unit in a frame, along with its frame elements and their syntactic properties; + can be constrained by LU name pattern and limited to lexicographic exemplars or full-text. + Sentences of full-text annotation can have multiple annotation sets. +sents() to get annotated sentences illustrating one or more lexical units +exemplars() to get sentences of lexicographic annotation, most of which have + just 1 annotation set; can be constrained by LU name pattern, frame, and overt FE(s) +doc() to look up a document of full-text annotation by its ID +docs() to get documents of full-text annotation that match a name pattern +docs_metadata() to get metadata about all full-text documents without loading them +ft_sents() to iterate over sentences of full-text annotation + +UTILITIES +========= + +buildindexes() loads metadata about all frames, LUs, etc. into memory to avoid + delay when one is accessed for the first time. It does not load annotations. +readme() gives the text of the FrameNet README file +warnings(True) to display corpus consistency warnings when loading data + """ + print(msg) + + def _buildframeindex(self): + # The total number of Frames in Framenet is fairly small (~1200) so + # this index should not be very large + if not self._frel_idx: + self._buildrelationindex() # always load frame relations before frames, + # otherwise weird ordering effects might result in incomplete information + self._frame_idx = {} + for f in XMLCorpusView( + self.abspath("frameIndex.xml"), "frameIndex/frame", self._handle_elt + ): + self._frame_idx[f["ID"]] = f + + def _buildcorpusindex(self): + # The total number of fulltext annotated documents in Framenet + # is fairly small (~90) so this index should not be very large + self._fulltext_idx = {} + for doclist in XMLCorpusView( + self.abspath("fulltextIndex.xml"), + "fulltextIndex/corpus", + self._handle_fulltextindex_elt, + ): + for doc in doclist: + self._fulltext_idx[doc.ID] = doc + + def _buildluindex(self): + # The number of LUs in Framenet is about 13,000 so this index + # should not be very large + self._lu_idx = {} + for lu in XMLCorpusView( + self.abspath("luIndex.xml"), "luIndex/lu", self._handle_elt + ): + self._lu_idx[ + lu["ID"] + ] = lu # populate with LU index entries. if any of these + # are looked up they will be replaced by full LU objects. + + def _buildrelationindex(self): + # print('building relation index...', file=sys.stderr) + freltypes = PrettyList( + x + for x in XMLCorpusView( + self.abspath("frRelation.xml"), + "frameRelations/frameRelationType", + self._handle_framerelationtype_elt, + ) + ) + self._freltyp_idx = {} + self._frel_idx = {} + self._frel_f_idx = defaultdict(set) + self._ferel_idx = {} + + for freltyp in freltypes: + self._freltyp_idx[freltyp.ID] = freltyp + for frel in freltyp.frameRelations: + supF = frel.superFrame = frel[freltyp.superFrameName] = Future( + (lambda fID: lambda: self.frame_by_id(fID))(frel.supID) + ) + subF = frel.subFrame = frel[freltyp.subFrameName] = Future( + (lambda fID: lambda: self.frame_by_id(fID))(frel.subID) + ) + self._frel_idx[frel.ID] = frel + self._frel_f_idx[frel.supID].add(frel.ID) + self._frel_f_idx[frel.subID].add(frel.ID) + for ferel in frel.feRelations: + ferel.superFrame = supF + ferel.subFrame = subF + ferel.superFE = Future( + (lambda fer: lambda: fer.superFrame.FE[fer.superFEName])(ferel) + ) + ferel.subFE = Future( + (lambda fer: lambda: fer.subFrame.FE[fer.subFEName])(ferel) + ) + self._ferel_idx[ferel.ID] = ferel + # print('...done building relation index', file=sys.stderr) + + def _warn(self, *message, **kwargs): + if self._warnings: + kwargs.setdefault("file", sys.stderr) + print(*message, **kwargs) + + def readme(self): + """ + Return the contents of the corpus README.txt (or README) file. + """ + try: + return self.open("README.txt").read() + except IOError: + return self.open("README").read() + + def buildindexes(self): + """ + Build the internal indexes to make look-ups faster. + """ + # Frames + self._buildframeindex() + # LUs + self._buildluindex() + # Fulltext annotation corpora index + self._buildcorpusindex() + # frame and FE relations + self._buildrelationindex() + + def doc(self, fn_docid): + """ + Returns the annotated document whose id number is + ``fn_docid``. This id number can be obtained by calling the + Documents() function. + + The dict that is returned from this function will contain the + following keys: + + - '_type' : 'fulltextannotation' + - 'sentence' : a list of sentences in the document + - Each item in the list is a dict containing the following keys: + - 'ID' : the ID number of the sentence + - '_type' : 'sentence' + - 'text' : the text of the sentence + - 'paragNo' : the paragraph number + - 'sentNo' : the sentence number + - 'docID' : the document ID number + - 'corpID' : the corpus ID number + - 'aPos' : the annotation position + - 'annotationSet' : a list of annotation layers for the sentence + - Each item in the list is a dict containing the following keys: + - 'ID' : the ID number of the annotation set + - '_type' : 'annotationset' + - 'status' : either 'MANUAL' or 'UNANN' + - 'luName' : (only if status is 'MANUAL') + - 'luID' : (only if status is 'MANUAL') + - 'frameID' : (only if status is 'MANUAL') + - 'frameName': (only if status is 'MANUAL') + - 'layer' : a list of labels for the layer + - Each item in the layer is a dict containing the + following keys: + - '_type': 'layer' + - 'rank' + - 'name' + - 'label' : a list of labels in the layer + - Each item is a dict containing the following keys: + - 'start' + - 'end' + - 'name' + - 'feID' (optional) + + :param fn_docid: The Framenet id number of the document + :type fn_docid: int + :return: Information about the annotated document + :rtype: dict + """ + try: + xmlfname = self._fulltext_idx[fn_docid].filename + except TypeError: # happens when self._fulltext_idx == None + # build the index + self._buildcorpusindex() + xmlfname = self._fulltext_idx[fn_docid].filename + except KeyError: # probably means that fn_docid was not in the index + raise FramenetError("Unknown document id: {0}".format(fn_docid)) + + # construct the path name for the xml file containing the document info + locpath = os.path.join("{0}".format(self._root), self._fulltext_dir, xmlfname) + + # Grab the top-level xml element containing the fulltext annotation + elt = XMLCorpusView(locpath, "fullTextAnnotation")[0] + info = self._handle_fulltextannotation_elt(elt) + # add metadata + for k, v in self._fulltext_idx[fn_docid].items(): + info[k] = v + return info + + def frame_by_id(self, fn_fid, ignorekeys=[]): + """ + Get the details for the specified Frame using the frame's id + number. + + Usage examples: + + >>> from nltk.corpus import framenet as fn + >>> f = fn.frame_by_id(256) + >>> f.ID + 256 + >>> f.name + 'Medical_specialties' + >>> f.definition + "This frame includes words that name ..." + + :param fn_fid: The Framenet id number of the frame + :type fn_fid: int + :param ignorekeys: The keys to ignore. These keys will not be + included in the output. (optional) + :type ignorekeys: list(str) + :return: Information about a frame + :rtype: dict + + Also see the ``frame()`` function for details about what is + contained in the dict that is returned. + """ + + # get the name of the frame with this id number + try: + fentry = self._frame_idx[fn_fid] + if "_type" in fentry: + return fentry # full frame object is cached + name = fentry["name"] + except TypeError: + self._buildframeindex() + name = self._frame_idx[fn_fid]["name"] + except KeyError: + raise FramenetError("Unknown frame id: {0}".format(fn_fid)) + + return self.frame_by_name(name, ignorekeys, check_cache=False) + + def frame_by_name(self, fn_fname, ignorekeys=[], check_cache=True): + """ + Get the details for the specified Frame using the frame's name. + + Usage examples: + + >>> from nltk.corpus import framenet as fn + >>> f = fn.frame_by_name('Medical_specialties') + >>> f.ID + 256 + >>> f.name + 'Medical_specialties' + >>> f.definition + "This frame includes words that name ..." + + :param fn_fname: The name of the frame + :type fn_fname: str + :param ignorekeys: The keys to ignore. These keys will not be + included in the output. (optional) + :type ignorekeys: list(str) + :return: Information about a frame + :rtype: dict + + Also see the ``frame()`` function for details about what is + contained in the dict that is returned. + """ + + if check_cache and fn_fname in self._cached_frames: + return self._frame_idx[self._cached_frames[fn_fname]] + elif not self._frame_idx: + self._buildframeindex() + + # construct the path name for the xml file containing the Frame info + locpath = os.path.join( + "{0}".format(self._root), self._frame_dir, fn_fname + ".xml" + ) + # print(locpath, file=sys.stderr) + # Grab the xml for the frame + try: + elt = XMLCorpusView(locpath, "frame")[0] + except IOError: + raise FramenetError("Unknown frame: {0}".format(fn_fname)) + + fentry = self._handle_frame_elt(elt, ignorekeys) + assert fentry + + fentry.URL = self._fnweb_url + "/" + self._frame_dir + "/" + fn_fname + ".xml" + + # INFERENCE RULE: propagate lexical semtypes from the frame to all its LUs + for st in fentry.semTypes: + if st.rootType.name == "Lexical_type": + for lu in fentry.lexUnit.values(): + if not any( + x is st for x in lu.semTypes + ): # identity containment check + lu.semTypes.append(st) + + self._frame_idx[fentry.ID] = fentry + self._cached_frames[fentry.name] = fentry.ID + """ + # now set up callables to resolve the LU pointers lazily. + # (could also do this here--caching avoids infinite recursion.) + for luName,luinfo in fentry.lexUnit.items(): + fentry.lexUnit[luName] = (lambda luID: Future(lambda: self.lu(luID)))(luinfo.ID) + """ + return fentry + + def frame(self, fn_fid_or_fname, ignorekeys=[]): + """ + Get the details for the specified Frame using the frame's name + or id number. + + Usage examples: + + >>> from nltk.corpus import framenet as fn + >>> f = fn.frame(256) + >>> f.name + 'Medical_specialties' + >>> f = fn.frame('Medical_specialties') + >>> f.ID + 256 + >>> # ensure non-ASCII character in definition doesn't trigger an encoding error: + >>> fn.frame('Imposing_obligation') + frame (1494): Imposing_obligation... + + The dict that is returned from this function will contain the + following information about the Frame: + + - 'name' : the name of the Frame (e.g. 'Birth', 'Apply_heat', etc.) + - 'definition' : textual definition of the Frame + - 'ID' : the internal ID number of the Frame + - 'semTypes' : a list of semantic types for this frame + - Each item in the list is a dict containing the following keys: + - 'name' : can be used with the semtype() function + - 'ID' : can be used with the semtype() function + + - 'lexUnit' : a dict containing all of the LUs for this frame. + The keys in this dict are the names of the LUs and + the value for each key is itself a dict containing + info about the LU (see the lu() function for more info.) + + - 'FE' : a dict containing the Frame Elements that are part of this frame + The keys in this dict are the names of the FEs (e.g. 'Body_system') + and the values are dicts containing the following keys + - 'definition' : The definition of the FE + - 'name' : The name of the FE e.g. 'Body_system' + - 'ID' : The id number + - '_type' : 'fe' + - 'abbrev' : Abbreviation e.g. 'bod' + - 'coreType' : one of "Core", "Peripheral", or "Extra-Thematic" + - 'semType' : if not None, a dict with the following two keys: + - 'name' : name of the semantic type. can be used with + the semtype() function + - 'ID' : id number of the semantic type. can be used with + the semtype() function + - 'requiresFE' : if not None, a dict with the following two keys: + - 'name' : the name of another FE in this frame + - 'ID' : the id of the other FE in this frame + - 'excludesFE' : if not None, a dict with the following two keys: + - 'name' : the name of another FE in this frame + - 'ID' : the id of the other FE in this frame + + - 'frameRelation' : a list of objects describing frame relations + - 'FEcoreSets' : a list of Frame Element core sets for this frame + - Each item in the list is a list of FE objects + + :param fn_fid_or_fname: The Framenet name or id number of the frame + :type fn_fid_or_fname: int or str + :param ignorekeys: The keys to ignore. These keys will not be + included in the output. (optional) + :type ignorekeys: list(str) + :return: Information about a frame + :rtype: dict + """ + + # get the frame info by name or id number + if isinstance(fn_fid_or_fname, str): + f = self.frame_by_name(fn_fid_or_fname, ignorekeys) + else: + f = self.frame_by_id(fn_fid_or_fname, ignorekeys) + + return f + + def frames_by_lemma(self, pat): + """ + Returns a list of all frames that contain LUs in which the + ``name`` attribute of the LU matchs the given regular expression + ``pat``. Note that LU names are composed of "lemma.POS", where + the "lemma" part can be made up of either a single lexeme + (e.g. 'run') or multiple lexemes (e.g. 'a little'). + + Note: if you are going to be doing a lot of this type of + searching, you'd want to build an index that maps from lemmas to + frames because each time frames_by_lemma() is called, it has to + search through ALL of the frame XML files in the db. + + >>> from nltk.corpus import framenet as fn + >>> from nltk.corpus.reader.framenet import PrettyList + >>> PrettyList(sorted(fn.frames_by_lemma(r'(?i)a little'), key=itemgetter('ID'))) # doctest: +ELLIPSIS + [, ] + + :return: A list of frame objects. + :rtype: list(AttrDict) + """ + return PrettyList( + f + for f in self.frames() + if any(re.search(pat, luName) for luName in f.lexUnit) + ) + + def lu_basic(self, fn_luid): + """ + Returns basic information about the LU whose id is + ``fn_luid``. This is basically just a wrapper around the + ``lu()`` function with "subCorpus" info excluded. + + >>> from nltk.corpus import framenet as fn + >>> lu = PrettyDict(fn.lu_basic(256), breakLines=True) + >>> # ellipses account for differences between FN 1.5 and 1.7 + >>> lu # doctest: +ELLIPSIS + {'ID': 256, + 'POS': 'V', + 'URL': 'https://framenet2.icsi.berkeley.edu/fnReports/data/lu/lu256.xml', + '_type': 'lu', + 'cBy': ..., + 'cDate': '02/08/2001 01:27:50 PST Thu', + 'definition': 'COD: be aware of beforehand; predict.', + 'definitionMarkup': 'COD: be aware of beforehand; predict.', + 'frame': , + 'lemmaID': 15082, + 'lexemes': [{'POS': 'V', 'breakBefore': 'false', 'headword': 'false', 'name': 'foresee', 'order': 1}], + 'name': 'foresee.v', + 'semTypes': [], + 'sentenceCount': {'annotated': ..., 'total': ...}, + 'status': 'FN1_Sent'} + + :param fn_luid: The id number of the desired LU + :type fn_luid: int + :return: Basic information about the lexical unit + :rtype: dict + """ + return self.lu(fn_luid, ignorekeys=["subCorpus", "exemplars"]) + + def lu(self, fn_luid, ignorekeys=[], luName=None, frameID=None, frameName=None): + """ + Access a lexical unit by its ID. luName, frameID, and frameName are used + only in the event that the LU does not have a file in the database + (which is the case for LUs with "Problem" status); in this case, + a placeholder LU is created which just contains its name, ID, and frame. + + + Usage examples: + + >>> from nltk.corpus import framenet as fn + >>> fn.lu(256).name + 'foresee.v' + >>> fn.lu(256).definition + 'COD: be aware of beforehand; predict.' + >>> fn.lu(256).frame.name + 'Expectation' + >>> pprint(list(map(PrettyDict, fn.lu(256).lexemes))) + [{'POS': 'V', 'breakBefore': 'false', 'headword': 'false', 'name': 'foresee', 'order': 1}] + + >>> fn.lu(227).exemplars[23] + exemplar sentence (352962): + [sentNo] 0 + [aPos] 59699508 + + [LU] (227) guess.v in Coming_to_believe + + [frame] (23) Coming_to_believe + + [annotationSet] 2 annotation sets + + [POS] 18 tags + + [POS_tagset] BNC + + [GF] 3 relations + + [PT] 3 phrases + + [Other] 1 entry + + [text] + [Target] + [FE] + + When he was inside the house , Culley noticed the characteristic + ------------------ + Content + + he would n't have guessed at . + -- ******* -- + Co C1 [Evidence:INI] + (Co=Cognizer, C1=Content) + + + + The dict that is returned from this function will contain most of the + following information about the LU. Note that some LUs do not contain + all of these pieces of information - particularly 'totalAnnotated' and + 'incorporatedFE' may be missing in some LUs: + + - 'name' : the name of the LU (e.g. 'merger.n') + - 'definition' : textual definition of the LU + - 'ID' : the internal ID number of the LU + - '_type' : 'lu' + - 'status' : e.g. 'Created' + - 'frame' : Frame that this LU belongs to + - 'POS' : the part of speech of this LU (e.g. 'N') + - 'totalAnnotated' : total number of examples annotated with this LU + - 'incorporatedFE' : FE that incorporates this LU (e.g. 'Ailment') + - 'sentenceCount' : a dict with the following two keys: + - 'annotated': number of sentences annotated with this LU + - 'total' : total number of sentences with this LU + + - 'lexemes' : a list of dicts describing the lemma of this LU. + Each dict in the list contains these keys: + - 'POS' : part of speech e.g. 'N' + - 'name' : either single-lexeme e.g. 'merger' or + multi-lexeme e.g. 'a little' + - 'order': the order of the lexeme in the lemma (starting from 1) + - 'headword': a boolean ('true' or 'false') + - 'breakBefore': Can this lexeme be separated from the previous lexeme? + Consider: "take over.v" as in: + Germany took over the Netherlands in 2 days. + Germany took the Netherlands over in 2 days. + In this case, 'breakBefore' would be "true" for the lexeme + "over". Contrast this with "take after.v" as in: + Mary takes after her grandmother. + *Mary takes her grandmother after. + In this case, 'breakBefore' would be "false" for the lexeme "after" + + - 'lemmaID' : Can be used to connect lemmas in different LUs + - 'semTypes' : a list of semantic type objects for this LU + - 'subCorpus' : a list of subcorpora + - Each item in the list is a dict containing the following keys: + - 'name' : + - 'sentence' : a list of sentences in the subcorpus + - each item in the list is a dict with the following keys: + - 'ID': + - 'sentNo': + - 'text': the text of the sentence + - 'aPos': + - 'annotationSet': a list of annotation sets + - each item in the list is a dict with the following keys: + - 'ID': + - 'status': + - 'layer': a list of layers + - each layer is a dict containing the following keys: + - 'name': layer name (e.g. 'BNC') + - 'rank': + - 'label': a list of labels for the layer + - each label is a dict containing the following keys: + - 'start': start pos of label in sentence 'text' (0-based) + - 'end': end pos of label in sentence 'text' (0-based) + - 'name': name of label (e.g. 'NN1') + + Under the hood, this implementation looks up the lexical unit information + in the *frame* definition file. That file does not contain + corpus annotations, so the LU files will be accessed on demand if those are + needed. In principle, valence patterns could be loaded here too, + though these are not currently supported. + + :param fn_luid: The id number of the lexical unit + :type fn_luid: int + :param ignorekeys: The keys to ignore. These keys will not be + included in the output. (optional) + :type ignorekeys: list(str) + :return: All information about the lexical unit + :rtype: dict + """ + # look for this LU in cache + if not self._lu_idx: + self._buildluindex() + OOV = object() + luinfo = self._lu_idx.get(fn_luid, OOV) + if luinfo is OOV: + # LU not in the index. We create a placeholder by falling back to + # luName, frameID, and frameName. However, this will not be listed + # among the LUs for its frame. + self._warn( + "LU ID not found: {0} ({1}) in {2} ({3})".format( + luName, fn_luid, frameName, frameID + ) + ) + luinfo = AttrDict( + { + "_type": "lu", + "ID": fn_luid, + "name": luName, + "frameID": frameID, + "status": "Problem", + } + ) + f = self.frame_by_id(luinfo.frameID) + assert f.name == frameName, (f.name, frameName) + luinfo["frame"] = f + self._lu_idx[fn_luid] = luinfo + elif "_type" not in luinfo: + # we only have an index entry for the LU. loading the frame will replace this. + f = self.frame_by_id(luinfo.frameID) + luinfo = self._lu_idx[fn_luid] + if ignorekeys: + return AttrDict( + dict((k, v) for k, v in luinfo.items() if k not in ignorekeys) + ) + + return luinfo + + def _lu_file(self, lu, ignorekeys=[]): + """ + Augment the LU information that was loaded from the frame file + with additional information from the LU file. + """ + fn_luid = lu.ID + + fname = "lu{0}.xml".format(fn_luid) + locpath = os.path.join("{0}".format(self._root), self._lu_dir, fname) + # print(locpath, file=sys.stderr) + if not self._lu_idx: + self._buildluindex() + + try: + elt = XMLCorpusView(locpath, "lexUnit")[0] + except IOError: + raise FramenetError("Unknown LU id: {0}".format(fn_luid)) + + lu2 = self._handle_lexunit_elt(elt, ignorekeys) + lu.URL = self._fnweb_url + "/" + self._lu_dir + "/" + fname + lu.subCorpus = lu2.subCorpus + lu.exemplars = SpecialList( + "luexemplars", [sent for subc in lu.subCorpus for sent in subc.sentence] + ) + for sent in lu.exemplars: + sent["LU"] = lu + sent["frame"] = lu.frame + for aset in sent.annotationSet: + aset["LU"] = lu + aset["frame"] = lu.frame + + return lu + + def _loadsemtypes(self): + """Create the semantic types index.""" + self._semtypes = AttrDict() + semtypeXML = [ + x + for x in XMLCorpusView( + self.abspath("semTypes.xml"), + "semTypes/semType", + self._handle_semtype_elt, + ) + ] + for st in semtypeXML: + n = st["name"] + a = st["abbrev"] + i = st["ID"] + # Both name and abbrev should be able to retrieve the + # ID. The ID will retrieve the semantic type dict itself. + self._semtypes[n] = i + self._semtypes[a] = i + self._semtypes[i] = st + # now that all individual semtype XML is loaded, we can link them together + roots = [] + for st in self.semtypes(): + if st.superType: + st.superType = self.semtype(st.superType.supID) + st.superType.subTypes.append(st) + else: + if st not in roots: + roots.append(st) + st.rootType = st + queue = list(roots) + assert queue + while queue: + st = queue.pop(0) + for child in st.subTypes: + child.rootType = st.rootType + queue.append(child) + # self.propagate_semtypes() # apply inferencing over FE relations + + def propagate_semtypes(self): + """ + Apply inference rules to distribute semtypes over relations between FEs. + For FrameNet 1.5, this results in 1011 semtypes being propagated. + (Not done by default because it requires loading all frame files, + which takes several seconds. If this needed to be fast, it could be rewritten + to traverse the neighboring relations on demand for each FE semtype.) + + >>> from nltk.corpus import framenet as fn + >>> x = sum(1 for f in fn.frames() for fe in f.FE.values() if fe.semType) + >>> fn.propagate_semtypes() + >>> y = sum(1 for f in fn.frames() for fe in f.FE.values() if fe.semType) + >>> y-x > 1000 + True + """ + if not self._semtypes: + self._loadsemtypes() + if not self._ferel_idx: + self._buildrelationindex() + changed = True + i = 0 + nPropagations = 0 + while changed: + # make a pass and see if anything needs to be propagated + i += 1 + changed = False + for ferel in self.fe_relations(): + superST = ferel.superFE.semType + subST = ferel.subFE.semType + try: + if superST and superST is not subST: + # propagate downward + assert subST is None or self.semtype_inherits(subST, superST), ( + superST.name, + ferel, + subST.name, + ) + if subST is None: + ferel.subFE.semType = subST = superST + changed = True + nPropagations += 1 + if ( + ferel.type.name in ["Perspective_on", "Subframe", "Precedes"] + and subST + and subST is not superST + ): + # propagate upward + assert superST is None, (superST.name, ferel, subST.name) + ferel.superFE.semType = superST = subST + changed = True + nPropagations += 1 + except AssertionError as ex: + # bug in the data! ignore + # print(ex, file=sys.stderr) + continue + # print(i, nPropagations, file=sys.stderr) + + def semtype(self, key): + """ + >>> from nltk.corpus import framenet as fn + >>> fn.semtype(233).name + 'Temperature' + >>> fn.semtype(233).abbrev + 'Temp' + >>> fn.semtype('Temperature').ID + 233 + + :param key: The name, abbreviation, or id number of the semantic type + :type key: string or int + :return: Information about a semantic type + :rtype: dict + """ + if isinstance(key, int): + stid = key + else: + try: + stid = self._semtypes[key] + except TypeError: + self._loadsemtypes() + stid = self._semtypes[key] + + try: + st = self._semtypes[stid] + except TypeError: + self._loadsemtypes() + st = self._semtypes[stid] + + return st + + def semtype_inherits(self, st, superST): + if not isinstance(st, dict): + st = self.semtype(st) + if not isinstance(superST, dict): + superST = self.semtype(superST) + par = st.superType + while par: + if par is superST: + return True + par = par.superType + return False + + def frames(self, name=None): + """ + Obtain details for a specific frame. + + >>> from nltk.corpus import framenet as fn + >>> len(fn.frames()) in (1019, 1221) # FN 1.5 and 1.7, resp. + True + >>> x = PrettyList(fn.frames(r'(?i)crim'), maxReprSize=0, breakLines=True) + >>> x.sort(key=itemgetter('ID')) + >>> x + [, + , + , + ] + + A brief intro to Frames (excerpted from "FrameNet II: Extended + Theory and Practice" by Ruppenhofer et. al., 2010): + + A Frame is a script-like conceptual structure that describes a + particular type of situation, object, or event along with the + participants and props that are needed for that Frame. For + example, the "Apply_heat" frame describes a common situation + involving a Cook, some Food, and a Heating_Instrument, and is + evoked by words such as bake, blanch, boil, broil, brown, + simmer, steam, etc. + + We call the roles of a Frame "frame elements" (FEs) and the + frame-evoking words are called "lexical units" (LUs). + + FrameNet includes relations between Frames. Several types of + relations are defined, of which the most important are: + + - Inheritance: An IS-A relation. The child frame is a subtype + of the parent frame, and each FE in the parent is bound to + a corresponding FE in the child. An example is the + "Revenge" frame which inherits from the + "Rewards_and_punishments" frame. + + - Using: The child frame presupposes the parent frame as + background, e.g the "Speed" frame "uses" (or presupposes) + the "Motion" frame; however, not all parent FEs need to be + bound to child FEs. + + - Subframe: The child frame is a subevent of a complex event + represented by the parent, e.g. the "Criminal_process" frame + has subframes of "Arrest", "Arraignment", "Trial", and + "Sentencing". + + - Perspective_on: The child frame provides a particular + perspective on an un-perspectivized parent frame. A pair of + examples consists of the "Hiring" and "Get_a_job" frames, + which perspectivize the "Employment_start" frame from the + Employer's and the Employee's point of view, respectively. + + :param name: A regular expression pattern used to match against + Frame names. If 'name' is None, then a list of all + Framenet Frames will be returned. + :type name: str + :return: A list of matching Frames (or all Frames). + :rtype: list(AttrDict) + """ + try: + fIDs = list(self._frame_idx.keys()) + except AttributeError: + self._buildframeindex() + fIDs = list(self._frame_idx.keys()) + + if name is not None: + return PrettyList( + self.frame(fID) for fID, finfo in self.frame_ids_and_names(name).items() + ) + else: + return PrettyLazyMap(self.frame, fIDs) + + def frame_ids_and_names(self, name=None): + """ + Uses the frame index, which is much faster than looking up each frame definition + if only the names and IDs are needed. + """ + if not self._frame_idx: + self._buildframeindex() + return dict( + (fID, finfo.name) + for fID, finfo in self._frame_idx.items() + if name is None or re.search(name, finfo.name) is not None + ) + + def fes(self, name=None, frame=None): + """ + Lists frame element objects. If 'name' is provided, this is treated as + a case-insensitive regular expression to filter by frame name. + (Case-insensitivity is because casing of frame element names is not always + consistent across frames.) Specify 'frame' to filter by a frame name pattern, + ID, or object. + + >>> from nltk.corpus import framenet as fn + >>> fn.fes('Noise_maker') + [] + >>> sorted([(fe.frame.name,fe.name) for fe in fn.fes('sound')]) + [('Cause_to_make_noise', 'Sound_maker'), ('Make_noise', 'Sound'), + ('Make_noise', 'Sound_source'), ('Sound_movement', 'Location_of_sound_source'), + ('Sound_movement', 'Sound'), ('Sound_movement', 'Sound_source'), + ('Sounds', 'Component_sound'), ('Sounds', 'Location_of_sound_source'), + ('Sounds', 'Sound_source'), ('Vocalizations', 'Location_of_sound_source'), + ('Vocalizations', 'Sound_source')] + >>> sorted([(fe.frame.name,fe.name) for fe in fn.fes('sound',r'(?i)make_noise')]) + [('Cause_to_make_noise', 'Sound_maker'), + ('Make_noise', 'Sound'), + ('Make_noise', 'Sound_source')] + >>> sorted(set(fe.name for fe in fn.fes('^sound'))) + ['Sound', 'Sound_maker', 'Sound_source'] + >>> len(fn.fes('^sound$')) + 2 + + :param name: A regular expression pattern used to match against + frame element names. If 'name' is None, then a list of all + frame elements will be returned. + :type name: str + :return: A list of matching frame elements + :rtype: list(AttrDict) + """ + # what frames are we searching in? + if frame is not None: + if isinstance(frame, int): + frames = [self.frame(frame)] + elif isinstance(frame, str): + frames = self.frames(frame) + else: + frames = [frame] + else: + frames = self.frames() + + return PrettyList( + fe + for f in frames + for fename, fe in f.FE.items() + if name is None or re.search(name, fename, re.I) + ) + + def lus(self, name=None, frame=None): + """ + Obtain details for lexical units. + Optionally restrict by lexical unit name pattern, and/or to a certain frame + or frames whose name matches a pattern. + + >>> from nltk.corpus import framenet as fn + >>> len(fn.lus()) in (11829, 13572) # FN 1.5 and 1.7, resp. + True + >>> PrettyList(sorted(fn.lus(r'(?i)a little'), key=itemgetter('ID')), maxReprSize=0, breakLines=True) + [, + , + ] + >>> PrettyList(sorted(fn.lus(r'interest', r'(?i)stimulus'), key=itemgetter('ID'))) + [, ] + + A brief intro to Lexical Units (excerpted from "FrameNet II: + Extended Theory and Practice" by Ruppenhofer et. al., 2010): + + A lexical unit (LU) is a pairing of a word with a meaning. For + example, the "Apply_heat" Frame describes a common situation + involving a Cook, some Food, and a Heating Instrument, and is + _evoked_ by words such as bake, blanch, boil, broil, brown, + simmer, steam, etc. These frame-evoking words are the LUs in the + Apply_heat frame. Each sense of a polysemous word is a different + LU. + + We have used the word "word" in talking about LUs. The reality + is actually rather complex. When we say that the word "bake" is + polysemous, we mean that the lemma "bake.v" (which has the + word-forms "bake", "bakes", "baked", and "baking") is linked to + three different frames: + + - Apply_heat: "Michelle baked the potatoes for 45 minutes." + + - Cooking_creation: "Michelle baked her mother a cake for her birthday." + + - Absorb_heat: "The potatoes have to bake for more than 30 minutes." + + These constitute three different LUs, with different + definitions. + + Multiword expressions such as "given name" and hyphenated words + like "shut-eye" can also be LUs. Idiomatic phrases such as + "middle of nowhere" and "give the slip (to)" are also defined as + LUs in the appropriate frames ("Isolated_places" and "Evading", + respectively), and their internal structure is not analyzed. + + Framenet provides multiple annotated examples of each sense of a + word (i.e. each LU). Moreover, the set of examples + (approximately 20 per LU) illustrates all of the combinatorial + possibilities of the lexical unit. + + Each LU is linked to a Frame, and hence to the other words which + evoke that Frame. This makes the FrameNet database similar to a + thesaurus, grouping together semantically similar words. + + In the simplest case, frame-evoking words are verbs such as + "fried" in: + + "Matilde fried the catfish in a heavy iron skillet." + + Sometimes event nouns may evoke a Frame. For example, + "reduction" evokes "Cause_change_of_scalar_position" in: + + "...the reduction of debt levels to $665 million from $2.6 billion." + + Adjectives may also evoke a Frame. For example, "asleep" may + evoke the "Sleep" frame as in: + + "They were asleep for hours." + + Many common nouns, such as artifacts like "hat" or "tower", + typically serve as dependents rather than clearly evoking their + own frames. + + :param name: A regular expression pattern used to search the LU + names. Note that LU names take the form of a dotted + string (e.g. "run.v" or "a little.adv") in which a + lemma preceeds the "." and a POS follows the + dot. The lemma may be composed of a single lexeme + (e.g. "run") or of multiple lexemes (e.g. "a + little"). If 'name' is not given, then all LUs will + be returned. + + The valid POSes are: + + v - verb + n - noun + a - adjective + adv - adverb + prep - preposition + num - numbers + intj - interjection + art - article + c - conjunction + scon - subordinating conjunction + + :type name: str + :type frame: str or int or frame + :return: A list of selected (or all) lexical units + :rtype: list of LU objects (dicts). See the lu() function for info + about the specifics of LU objects. + + """ + if not self._lu_idx: + self._buildluindex() + + if name is not None: # match LUs, then restrict by frame + result = PrettyList( + self.lu(luID) for luID, luName in self.lu_ids_and_names(name).items() + ) + if frame is not None: + if isinstance(frame, int): + frameIDs = {frame} + elif isinstance(frame, str): + frameIDs = {f.ID for f in self.frames(frame)} + else: + frameIDs = {frame.ID} + result = PrettyList(lu for lu in result if lu.frame.ID in frameIDs) + elif frame is not None: # all LUs in matching frames + if isinstance(frame, int): + frames = [self.frame(frame)] + elif isinstance(frame, str): + frames = self.frames(frame) + else: + frames = [frame] + result = PrettyLazyIteratorList( + iter(LazyConcatenation(list(f.lexUnit.values()) for f in frames)) + ) + else: # all LUs + luIDs = [ + luID + for luID, lu in self._lu_idx.items() + if lu.status not in self._bad_statuses + ] + result = PrettyLazyMap(self.lu, luIDs) + return result + + def lu_ids_and_names(self, name=None): + """ + Uses the LU index, which is much faster than looking up each LU definition + if only the names and IDs are needed. + """ + if not self._lu_idx: + self._buildluindex() + return { + luID: luinfo.name + for luID, luinfo in self._lu_idx.items() + if luinfo.status not in self._bad_statuses + and (name is None or re.search(name, luinfo.name) is not None) + } + + def docs_metadata(self, name=None): + """ + Return an index of the annotated documents in Framenet. + + Details for a specific annotated document can be obtained using this + class's doc() function and pass it the value of the 'ID' field. + + >>> from nltk.corpus import framenet as fn + >>> len(fn.docs()) in (78, 107) # FN 1.5 and 1.7, resp. + True + >>> set([x.corpname for x in fn.docs_metadata()])>=set(['ANC', 'KBEval', \ + 'LUCorpus-v0.3', 'Miscellaneous', 'NTI', 'PropBank']) + True + + :param name: A regular expression pattern used to search the + file name of each annotated document. The document's + file name contains the name of the corpus that the + document is from, followed by two underscores "__" + followed by the document name. So, for example, the + file name "LUCorpus-v0.3__20000410_nyt-NEW.xml" is + from the corpus named "LUCorpus-v0.3" and the + document name is "20000410_nyt-NEW.xml". + :type name: str + :return: A list of selected (or all) annotated documents + :rtype: list of dicts, where each dict object contains the following + keys: + + - 'name' + - 'ID' + - 'corpid' + - 'corpname' + - 'description' + - 'filename' + """ + try: + ftlist = PrettyList(self._fulltext_idx.values()) + except AttributeError: + self._buildcorpusindex() + ftlist = PrettyList(self._fulltext_idx.values()) + + if name is None: + return ftlist + else: + return PrettyList( + x for x in ftlist if re.search(name, x["filename"]) is not None + ) + + def docs(self, name=None): + """ + Return a list of the annotated full-text documents in FrameNet, + optionally filtered by a regex to be matched against the document name. + """ + return PrettyLazyMap((lambda x: self.doc(x.ID)), self.docs_metadata(name)) + + def sents(self, exemplars=True, full_text=True): + """ + Annotated sentences matching the specified criteria. + """ + if exemplars: + if full_text: + return self.exemplars() + self.ft_sents() + else: + return self.exemplars() + elif full_text: + return self.ft_sents() + + def annotations(self, luNamePattern=None, exemplars=True, full_text=True): + """ + Frame annotation sets matching the specified criteria. + """ + + if exemplars: + epart = PrettyLazyIteratorList( + sent.frameAnnotation for sent in self.exemplars(luNamePattern) + ) + else: + epart = [] + + if full_text: + if luNamePattern is not None: + matchedLUIDs = set(self.lu_ids_and_names(luNamePattern).keys()) + ftpart = PrettyLazyIteratorList( + aset + for sent in self.ft_sents() + for aset in sent.annotationSet[1:] + if luNamePattern is None or aset.get("luID", "CXN_ASET") in matchedLUIDs + ) + else: + ftpart = [] + + if exemplars: + if full_text: + return epart + ftpart + else: + return epart + elif full_text: + return ftpart + + def exemplars(self, luNamePattern=None, frame=None, fe=None, fe2=None): + """ + Lexicographic exemplar sentences, optionally filtered by LU name and/or 1-2 FEs that + are realized overtly. 'frame' may be a name pattern, frame ID, or frame instance. + 'fe' may be a name pattern or FE instance; if specified, 'fe2' may also + be specified to retrieve sentences with both overt FEs (in either order). + """ + if fe is None and fe2 is not None: + raise FramenetError("exemplars(..., fe=None, fe2=) is not allowed") + elif fe is not None and fe2 is not None: + if not isinstance(fe2, str): + if isinstance(fe, str): + # fe2 is specific to a particular frame. swap fe and fe2 so fe is always used to determine the frame. + fe, fe2 = fe2, fe + elif fe.frame is not fe2.frame: # ensure frames match + raise FramenetError( + "exemplars() call with inconsistent `fe` and `fe2` specification (frames must match)" + ) + if frame is None and fe is not None and not isinstance(fe, str): + frame = fe.frame + + # narrow down to frames matching criteria + + lusByFrame = defaultdict( + list + ) # frame name -> matching LUs, if luNamePattern is specified + if frame is not None or luNamePattern is not None: + if frame is None or isinstance(frame, str): + if luNamePattern is not None: + frames = set() + for lu in self.lus(luNamePattern, frame=frame): + frames.add(lu.frame.ID) + lusByFrame[lu.frame.name].append(lu) + frames = LazyMap(self.frame, list(frames)) + else: + frames = self.frames(frame) + else: + if isinstance(frame, int): + frames = [self.frame(frame)] + else: # frame object + frames = [frame] + + if luNamePattern is not None: + lusByFrame = {frame.name: self.lus(luNamePattern, frame=frame)} + + if fe is not None: # narrow to frames that define this FE + if isinstance(fe, str): + frames = PrettyLazyIteratorList( + f + for f in frames + if fe in f.FE + or any(re.search(fe, ffe, re.I) for ffe in f.FE.keys()) + ) + else: + if fe.frame not in frames: + raise FramenetError( + "exemplars() call with inconsistent `frame` and `fe` specification" + ) + frames = [fe.frame] + + if fe2 is not None: # narrow to frames that ALSO define this FE + if isinstance(fe2, str): + frames = PrettyLazyIteratorList( + f + for f in frames + if fe2 in f.FE + or any(re.search(fe2, ffe, re.I) for ffe in f.FE.keys()) + ) + # else we already narrowed it to a single frame + else: # frame, luNamePattern are None. fe, fe2 are None or strings + if fe is not None: + frames = {ffe.frame.ID for ffe in self.fes(fe)} + if fe2 is not None: + frames2 = {ffe.frame.ID for ffe in self.fes(fe2)} + frames = frames & frames2 + frames = LazyMap(self.frame, list(frames)) + else: + frames = self.frames() + + # we've narrowed down 'frames' + # now get exemplars for relevant LUs in those frames + + def _matching_exs(): + for f in frames: + fes = fes2 = None # FEs of interest + if fe is not None: + fes = ( + {ffe for ffe in f.FE.keys() if re.search(fe, ffe, re.I)} + if isinstance(fe, str) + else {fe.name} + ) + if fe2 is not None: + fes2 = ( + {ffe for ffe in f.FE.keys() if re.search(fe2, ffe, re.I)} + if isinstance(fe2, str) + else {fe2.name} + ) + + for lu in ( + lusByFrame[f.name] + if luNamePattern is not None + else f.lexUnit.values() + ): + for ex in lu.exemplars: + if (fes is None or self._exemplar_of_fes(ex, fes)) and ( + fes2 is None or self._exemplar_of_fes(ex, fes2) + ): + yield ex + + return PrettyLazyIteratorList(_matching_exs()) + + def _exemplar_of_fes(self, ex, fes=None): + """ + Given an exemplar sentence and a set of FE names, return the subset of FE names + that are realized overtly in the sentence on the FE, FE2, or FE3 layer. + + If 'fes' is None, returns all overt FE names. + """ + overtNames = set(list(zip(*ex.FE[0]))[2]) if ex.FE[0] else set() + if "FE2" in ex: + overtNames |= set(list(zip(*ex.FE2[0]))[2]) if ex.FE2[0] else set() + if "FE3" in ex: + overtNames |= set(list(zip(*ex.FE3[0]))[2]) if ex.FE3[0] else set() + return overtNames & fes if fes is not None else overtNames + + def ft_sents(self, docNamePattern=None): + """ + Full-text annotation sentences, optionally filtered by document name. + """ + return PrettyLazyIteratorList( + sent for d in self.docs(docNamePattern) for sent in d.sentence + ) + + def frame_relation_types(self): + """ + Obtain a list of frame relation types. + + >>> from nltk.corpus import framenet as fn + >>> frts = sorted(fn.frame_relation_types(), key=itemgetter('ID')) + >>> isinstance(frts, list) + True + >>> len(frts) in (9, 10) # FN 1.5 and 1.7, resp. + True + >>> PrettyDict(frts[0], breakLines=True) + {'ID': 1, + '_type': 'framerelationtype', + 'frameRelations': [ Child=Change_of_consistency>, Child=Rotting>, ...], + 'name': 'Inheritance', + 'subFrameName': 'Child', + 'superFrameName': 'Parent'} + + :return: A list of all of the frame relation types in framenet + :rtype: list(dict) + """ + if not self._freltyp_idx: + self._buildrelationindex() + return self._freltyp_idx.values() + + def frame_relations(self, frame=None, frame2=None, type=None): + """ + :param frame: (optional) frame object, name, or ID; only relations involving + this frame will be returned + :param frame2: (optional; 'frame' must be a different frame) only show relations + between the two specified frames, in either direction + :param type: (optional) frame relation type (name or object); show only relations + of this type + :type frame: int or str or AttrDict + :return: A list of all of the frame relations in framenet + :rtype: list(dict) + + >>> from nltk.corpus import framenet as fn + >>> frels = fn.frame_relations() + >>> isinstance(frels, list) + True + >>> len(frels) in (1676, 2070) # FN 1.5 and 1.7, resp. + True + >>> PrettyList(fn.frame_relations('Cooking_creation'), maxReprSize=0, breakLines=True) + [ Child=Cooking_creation>, + Child=Cooking_creation>, + ReferringEntry=Cooking_creation>] + >>> PrettyList(fn.frame_relations(274), breakLines=True) + [ Child=Dodging>, + Child=Evading>, ...] + >>> PrettyList(fn.frame_relations(fn.frame('Cooking_creation')), breakLines=True) + [ Child=Cooking_creation>, + Child=Cooking_creation>, ...] + >>> PrettyList(fn.frame_relations('Cooking_creation', type='Inheritance')) + [ Child=Cooking_creation>] + >>> PrettyList(fn.frame_relations('Cooking_creation', 'Apply_heat'), breakLines=True) + [ Child=Cooking_creation>, + ReferringEntry=Cooking_creation>] + """ + relation_type = type + + if not self._frel_idx: + self._buildrelationindex() + + rels = None + + if relation_type is not None: + if not isinstance(relation_type, dict): + type = [rt for rt in self.frame_relation_types() if rt.name == type][0] + assert isinstance(type, dict) + + # lookup by 'frame' + if frame is not None: + if isinstance(frame, dict) and "frameRelations" in frame: + rels = PrettyList(frame.frameRelations) + else: + if not isinstance(frame, int): + if isinstance(frame, dict): + frame = frame.ID + else: + frame = self.frame_by_name(frame).ID + rels = [self._frel_idx[frelID] for frelID in self._frel_f_idx[frame]] + + # filter by 'type' + if type is not None: + rels = [rel for rel in rels if rel.type is type] + elif type is not None: + # lookup by 'type' + rels = type.frameRelations + else: + rels = self._frel_idx.values() + + # filter by 'frame2' + if frame2 is not None: + if frame is None: + raise FramenetError( + "frame_relations(frame=None, frame2=) is not allowed" + ) + if not isinstance(frame2, int): + if isinstance(frame2, dict): + frame2 = frame2.ID + else: + frame2 = self.frame_by_name(frame2).ID + if frame == frame2: + raise FramenetError( + "The two frame arguments to frame_relations() must be different frames" + ) + rels = [ + rel + for rel in rels + if rel.superFrame.ID == frame2 or rel.subFrame.ID == frame2 + ] + + return PrettyList( + sorted( + rels, + key=lambda frel: (frel.type.ID, frel.superFrameName, frel.subFrameName), + ) + ) + + def fe_relations(self): + """ + Obtain a list of frame element relations. + + >>> from nltk.corpus import framenet as fn + >>> ferels = fn.fe_relations() + >>> isinstance(ferels, list) + True + >>> len(ferels) in (10020, 12393) # FN 1.5 and 1.7, resp. + True + >>> PrettyDict(ferels[0], breakLines=True) + {'ID': 14642, + '_type': 'ferelation', + 'frameRelation': Child=Lively_place>, + 'subFE': , + 'subFEName': 'Degree', + 'subFrame': , + 'subID': 11370, + 'supID': 2271, + 'superFE': , + 'superFEName': 'Degree', + 'superFrame': , + 'type': } + + :return: A list of all of the frame element relations in framenet + :rtype: list(dict) + """ + if not self._ferel_idx: + self._buildrelationindex() + return PrettyList( + sorted( + self._ferel_idx.values(), + key=lambda ferel: ( + ferel.type.ID, + ferel.frameRelation.superFrameName, + ferel.superFEName, + ferel.frameRelation.subFrameName, + ferel.subFEName, + ), + ) + ) + + def semtypes(self): + """ + Obtain a list of semantic types. + + >>> from nltk.corpus import framenet as fn + >>> stypes = fn.semtypes() + >>> len(stypes) in (73, 109) # FN 1.5 and 1.7, resp. + True + >>> sorted(stypes[0].keys()) + ['ID', '_type', 'abbrev', 'definition', 'definitionMarkup', 'name', 'rootType', 'subTypes', 'superType'] + + :return: A list of all of the semantic types in framenet + :rtype: list(dict) + """ + if not self._semtypes: + self._loadsemtypes() + return PrettyList( + self._semtypes[i] for i in self._semtypes if isinstance(i, int) + ) + + def _load_xml_attributes(self, d, elt): + """ + Extracts a subset of the attributes from the given element and + returns them in a dictionary. + + :param d: A dictionary in which to store the attributes. + :type d: dict + :param elt: An ElementTree Element + :type elt: Element + :return: Returns the input dict ``d`` possibly including attributes from ``elt`` + :rtype: dict + """ + + d = type(d)(d) + + try: + attr_dict = elt.attrib + except AttributeError: + return d + + if attr_dict is None: + return d + + # Ignore these attributes when loading attributes from an xml node + ignore_attrs = [ #'cBy', 'cDate', 'mDate', # <-- annotation metadata that could be of interest + "xsi", + "schemaLocation", + "xmlns", + "bgColor", + "fgColor", + ] + + for attr in attr_dict: + + if any(attr.endswith(x) for x in ignore_attrs): + continue + + val = attr_dict[attr] + if val.isdigit(): + d[attr] = int(val) + else: + d[attr] = val + + return d + + def _strip_tags(self, data): + """ + Gets rid of all tags and newline characters from the given input + + :return: A cleaned-up version of the input string + :rtype: str + """ + + try: + """ + # Look for boundary issues in markup. (Sometimes FEs are pluralized in definitions.) + m = re.search(r'\w[<][^/]|[<][/][^>]+[>](s\w|[a-rt-z0-9])', data) + if m: + print('Markup boundary:', data[max(0,m.start(0)-10):m.end(0)+10].replace('\n',' '), file=sys.stderr) + """ + + data = data.replace("", "") + data = data.replace("", "") + data = re.sub('', "", data) + data = data.replace("", "") + data = data.replace("", "") + data = data.replace("", "") + data = data.replace("", "") + data = data.replace("", "") + data = data.replace("", "") + data = data.replace("", "") + data = data.replace("", "'") + data = data.replace("", "'") + data = data.replace("", "") + data = data.replace("", "") + data = data.replace("", "") + data = data.replace("", "") + + # Get rid of and tags + data = data.replace("", "") + data = data.replace("", "") + + data = data.replace("\n", " ") + except AttributeError: + pass + + return data + + def _handle_elt(self, elt, tagspec=None): + """Extracts and returns the attributes of the given element""" + return self._load_xml_attributes(AttrDict(), elt) + + def _handle_fulltextindex_elt(self, elt, tagspec=None): + """ + Extracts corpus/document info from the fulltextIndex.xml file. + + Note that this function "flattens" the information contained + in each of the "corpus" elements, so that each "document" + element will contain attributes for the corpus and + corpusid. Also, each of the "document" items will contain a + new attribute called "filename" that is the base file name of + the xml file for the document in the "fulltext" subdir of the + Framenet corpus. + """ + ftinfo = self._load_xml_attributes(AttrDict(), elt) + corpname = ftinfo.name + corpid = ftinfo.ID + retlist = [] + for sub in elt: + if sub.tag.endswith("document"): + doc = self._load_xml_attributes(AttrDict(), sub) + if "name" in doc: + docname = doc.name + else: + docname = doc.description + doc.filename = "{0}__{1}.xml".format(corpname, docname) + doc.URL = ( + self._fnweb_url + "/" + self._fulltext_dir + "/" + doc.filename + ) + doc.corpname = corpname + doc.corpid = corpid + retlist.append(doc) + + return retlist + + def _handle_frame_elt(self, elt, ignorekeys=[]): + """Load the info for a Frame from a frame xml file""" + frinfo = self._load_xml_attributes(AttrDict(), elt) + + frinfo["_type"] = "frame" + frinfo["definition"] = "" + frinfo["definitionMarkup"] = "" + frinfo["FE"] = PrettyDict() + frinfo["FEcoreSets"] = [] + frinfo["lexUnit"] = PrettyDict() + frinfo["semTypes"] = [] + for k in ignorekeys: + if k in frinfo: + del frinfo[k] + + for sub in elt: + if sub.tag.endswith("definition") and "definition" not in ignorekeys: + frinfo["definitionMarkup"] = sub.text + frinfo["definition"] = self._strip_tags(sub.text) + elif sub.tag.endswith("FE") and "FE" not in ignorekeys: + feinfo = self._handle_fe_elt(sub) + frinfo["FE"][feinfo.name] = feinfo + feinfo["frame"] = frinfo # backpointer + elif sub.tag.endswith("FEcoreSet") and "FEcoreSet" not in ignorekeys: + coreset = self._handle_fecoreset_elt(sub) + # assumes all FEs have been loaded before coresets + frinfo["FEcoreSets"].append( + PrettyList(frinfo["FE"][fe.name] for fe in coreset) + ) + elif sub.tag.endswith("lexUnit") and "lexUnit" not in ignorekeys: + luentry = self._handle_framelexunit_elt(sub) + if luentry["status"] in self._bad_statuses: + # problematic LU entry; ignore it + continue + luentry["frame"] = frinfo + luentry["URL"] = ( + self._fnweb_url + + "/" + + self._lu_dir + + "/" + + "lu{0}.xml".format(luentry["ID"]) + ) + luentry["subCorpus"] = Future( + (lambda lu: lambda: self._lu_file(lu).subCorpus)(luentry) + ) + luentry["exemplars"] = Future( + (lambda lu: lambda: self._lu_file(lu).exemplars)(luentry) + ) + frinfo["lexUnit"][luentry.name] = luentry + if not self._lu_idx: + self._buildluindex() + self._lu_idx[luentry.ID] = luentry + elif sub.tag.endswith("semType") and "semTypes" not in ignorekeys: + semtypeinfo = self._load_xml_attributes(AttrDict(), sub) + frinfo["semTypes"].append(self.semtype(semtypeinfo.ID)) + + frinfo["frameRelations"] = self.frame_relations(frame=frinfo) + + # resolve 'requires' and 'excludes' links between FEs of this frame + for fe in frinfo.FE.values(): + if fe.requiresFE: + name, ID = fe.requiresFE.name, fe.requiresFE.ID + fe.requiresFE = frinfo.FE[name] + assert fe.requiresFE.ID == ID + if fe.excludesFE: + name, ID = fe.excludesFE.name, fe.excludesFE.ID + fe.excludesFE = frinfo.FE[name] + assert fe.excludesFE.ID == ID + + return frinfo + + def _handle_fecoreset_elt(self, elt): + """Load fe coreset info from xml.""" + info = self._load_xml_attributes(AttrDict(), elt) + tmp = [] + for sub in elt: + tmp.append(self._load_xml_attributes(AttrDict(), sub)) + + return tmp + + def _handle_framerelationtype_elt(self, elt, *args): + """Load frame-relation element and its child fe-relation elements from frRelation.xml.""" + info = self._load_xml_attributes(AttrDict(), elt) + info["_type"] = "framerelationtype" + info["frameRelations"] = PrettyList() + + for sub in elt: + if sub.tag.endswith("frameRelation"): + frel = self._handle_framerelation_elt(sub) + frel["type"] = info # backpointer + for ferel in frel.feRelations: + ferel["type"] = info + info["frameRelations"].append(frel) + + return info + + def _handle_framerelation_elt(self, elt): + """Load frame-relation element and its child fe-relation elements from frRelation.xml.""" + info = self._load_xml_attributes(AttrDict(), elt) + assert info["superFrameName"] != info["subFrameName"], (elt, info) + info["_type"] = "framerelation" + info["feRelations"] = PrettyList() + + for sub in elt: + if sub.tag.endswith("FERelation"): + ferel = self._handle_elt(sub) + ferel["_type"] = "ferelation" + ferel["frameRelation"] = info # backpointer + info["feRelations"].append(ferel) + + return info + + def _handle_fulltextannotation_elt(self, elt): + """Load full annotation info for a document from its xml + file. The main element (fullTextAnnotation) contains a 'header' + element (which we ignore here) and a bunch of 'sentence' + elements.""" + info = AttrDict() + info["_type"] = "fulltext_annotation" + info["sentence"] = [] + + for sub in elt: + if sub.tag.endswith("header"): + continue # not used + elif sub.tag.endswith("sentence"): + try: + s = self._handle_fulltext_sentence_elt(sub) + s.doc = info + info["sentence"].append(s) + except FramenetError as e: + print(e) + + return info + + def _handle_fulltext_sentence_elt(self, elt): + """Load information from the given 'sentence' element. Each + 'sentence' element contains a "text" and "annotationSet" sub + elements.""" + info = self._load_xml_attributes(AttrDict(), elt) + info["_type"] = "fulltext_sentence" + info["annotationSet"] = [] + info["targets"] = [] + target_spans = set() + info["_ascii"] = types.MethodType( + _annotation_ascii, info + ) # attach a method for this instance + info["text"] = "" + + for sub in elt: + if sub.tag.endswith("text"): + info["text"] = self._strip_tags(sub.text) + elif sub.tag.endswith("annotationSet"): + a = self._handle_fulltextannotationset_elt( + sub, is_pos=(len(info["annotationSet"]) == 0) + ) + if "cxnID" in a: # ignoring construction annotations for now + continue + a.sent = info + a.text = info.text + info["annotationSet"].append(a) + if "Target" in a: + for tspan in a.Target: + if tspan in target_spans: + self._warn( + 'Duplicate target span "{0}"'.format( + info.text[slice(*tspan)] + ), + tspan, + "in sentence", + info["ID"], + info.text, + ) + # this can happen in cases like "chemical and biological weapons" + # being annotated as "chemical weapons" and "biological weapons" + else: + target_spans.add(tspan) + info["targets"].append((a.Target, a.luName, a.frameName)) + + assert info["annotationSet"][0].status == "UNANN" + info["POS"] = info["annotationSet"][0].POS + info["POS_tagset"] = info["annotationSet"][0].POS_tagset + return info + + def _handle_fulltextannotationset_elt(self, elt, is_pos=False): + """Load information from the given 'annotationSet' element. Each + 'annotationSet' contains several "layer" elements.""" + + info = self._handle_luannotationset_elt(elt, is_pos=is_pos) + if not is_pos: + info["_type"] = "fulltext_annotationset" + if "cxnID" not in info: # ignoring construction annotations for now + info["LU"] = self.lu( + info.luID, + luName=info.luName, + frameID=info.frameID, + frameName=info.frameName, + ) + info["frame"] = info.LU.frame + return info + + def _handle_fulltextlayer_elt(self, elt): + """Load information from the given 'layer' element. Each + 'layer' contains several "label" elements.""" + info = self._load_xml_attributes(AttrDict(), elt) + info["_type"] = "layer" + info["label"] = [] + + for sub in elt: + if sub.tag.endswith("label"): + l = self._load_xml_attributes(AttrDict(), sub) + info["label"].append(l) + + return info + + def _handle_framelexunit_elt(self, elt): + """Load the lexical unit info from an xml element in a frame's xml file.""" + luinfo = AttrDict() + luinfo["_type"] = "lu" + luinfo = self._load_xml_attributes(luinfo, elt) + luinfo["definition"] = "" + luinfo["definitionMarkup"] = "" + luinfo["sentenceCount"] = PrettyDict() + luinfo["lexemes"] = PrettyList() # multiword LUs have multiple lexemes + luinfo["semTypes"] = PrettyList() # an LU can have multiple semtypes + + for sub in elt: + if sub.tag.endswith("definition"): + luinfo["definitionMarkup"] = sub.text + luinfo["definition"] = self._strip_tags(sub.text) + elif sub.tag.endswith("sentenceCount"): + luinfo["sentenceCount"] = self._load_xml_attributes(PrettyDict(), sub) + elif sub.tag.endswith("lexeme"): + lexemeinfo = self._load_xml_attributes(PrettyDict(), sub) + if not isinstance(lexemeinfo.name, str): + # some lexeme names are ints by default: e.g., + # thousand.num has lexeme with name="1000" + lexemeinfo.name = str(lexemeinfo.name) + luinfo["lexemes"].append(lexemeinfo) + elif sub.tag.endswith("semType"): + semtypeinfo = self._load_xml_attributes(PrettyDict(), sub) + luinfo["semTypes"].append(self.semtype(semtypeinfo.ID)) + + # sort lexemes by 'order' attribute + # otherwise, e.g., 'write down.v' may have lexemes in wrong order + luinfo["lexemes"].sort(key=lambda x: x.order) + + return luinfo + + def _handle_lexunit_elt(self, elt, ignorekeys): + """ + Load full info for a lexical unit from its xml file. + This should only be called when accessing corpus annotations + (which are not included in frame files). + """ + luinfo = self._load_xml_attributes(AttrDict(), elt) + luinfo["_type"] = "lu" + luinfo["definition"] = "" + luinfo["definitionMarkup"] = "" + luinfo["subCorpus"] = PrettyList() + luinfo["lexemes"] = PrettyList() # multiword LUs have multiple lexemes + luinfo["semTypes"] = PrettyList() # an LU can have multiple semtypes + for k in ignorekeys: + if k in luinfo: + del luinfo[k] + + for sub in elt: + if sub.tag.endswith("header"): + continue # not used + elif sub.tag.endswith("valences"): + continue # not used + elif sub.tag.endswith("definition") and "definition" not in ignorekeys: + luinfo["definitionMarkup"] = sub.text + luinfo["definition"] = self._strip_tags(sub.text) + elif sub.tag.endswith("subCorpus") and "subCorpus" not in ignorekeys: + sc = self._handle_lusubcorpus_elt(sub) + if sc is not None: + luinfo["subCorpus"].append(sc) + elif sub.tag.endswith("lexeme") and "lexeme" not in ignorekeys: + luinfo["lexemes"].append(self._load_xml_attributes(PrettyDict(), sub)) + elif sub.tag.endswith("semType") and "semType" not in ignorekeys: + semtypeinfo = self._load_xml_attributes(AttrDict(), sub) + luinfo["semTypes"].append(self.semtype(semtypeinfo.ID)) + + return luinfo + + def _handle_lusubcorpus_elt(self, elt): + """Load a subcorpus of a lexical unit from the given xml.""" + sc = AttrDict() + try: + sc["name"] = elt.get("name") + except AttributeError: + return None + sc["_type"] = "lusubcorpus" + sc["sentence"] = [] + + for sub in elt: + if sub.tag.endswith("sentence"): + s = self._handle_lusentence_elt(sub) + if s is not None: + sc["sentence"].append(s) + + return sc + + def _handle_lusentence_elt(self, elt): + """Load a sentence from a subcorpus of an LU from xml.""" + info = self._load_xml_attributes(AttrDict(), elt) + info["_type"] = "lusentence" + info["annotationSet"] = [] + info["_ascii"] = types.MethodType( + _annotation_ascii, info + ) # attach a method for this instance + for sub in elt: + if sub.tag.endswith("text"): + info["text"] = self._strip_tags(sub.text) + elif sub.tag.endswith("annotationSet"): + annset = self._handle_luannotationset_elt( + sub, is_pos=(len(info["annotationSet"]) == 0) + ) + if annset is not None: + assert annset.status == "UNANN" or "FE" in annset, annset + if annset.status != "UNANN": + info["frameAnnotation"] = annset + # copy layer info up to current level + for k in ( + "Target", + "FE", + "FE2", + "FE3", + "GF", + "PT", + "POS", + "POS_tagset", + "Other", + "Sent", + "Verb", + "Noun", + "Adj", + "Adv", + "Prep", + "Scon", + "Art", + ): + if k in annset: + info[k] = annset[k] + info["annotationSet"].append(annset) + annset["sent"] = info + annset["text"] = info.text + return info + + def _handle_luannotationset_elt(self, elt, is_pos=False): + """Load an annotation set from a sentence in an subcorpus of an LU""" + info = self._load_xml_attributes(AttrDict(), elt) + info["_type"] = "posannotationset" if is_pos else "luannotationset" + info["layer"] = [] + info["_ascii"] = types.MethodType( + _annotation_ascii, info + ) # attach a method for this instance + + if "cxnID" in info: # ignoring construction annotations for now. + return info + + for sub in elt: + if sub.tag.endswith("layer"): + l = self._handle_lulayer_elt(sub) + if l is not None: + overt = [] + ni = {} # null instantiations + + info["layer"].append(l) + for lbl in l.label: + if "start" in lbl: + thespan = (lbl.start, lbl.end + 1, lbl.name) + if l.name not in ( + "Sent", + "Other", + ): # 'Sent' and 'Other' layers sometimes contain accidental duplicate spans + assert thespan not in overt, (info.ID, l.name, thespan) + overt.append(thespan) + else: # null instantiation + if lbl.name in ni: + self._warn( + "FE with multiple NI entries:", + lbl.name, + ni[lbl.name], + lbl.itype, + ) + else: + ni[lbl.name] = lbl.itype + overt = sorted(overt) + + if l.name == "Target": + if not overt: + self._warn( + "Skipping empty Target layer in annotation set ID={0}".format( + info.ID + ) + ) + continue + assert all(lblname == "Target" for i, j, lblname in overt) + if "Target" in info: + self._warn( + "Annotation set {0} has multiple Target layers".format( + info.ID + ) + ) + else: + info["Target"] = [(i, j) for (i, j, _) in overt] + elif l.name == "FE": + if l.rank == 1: + assert "FE" not in info + info["FE"] = (overt, ni) + # assert False,info + else: + # sometimes there are 3 FE layers! e.g. Change_position_on_a_scale.fall.v + assert 2 <= l.rank <= 3, l.rank + k = "FE" + str(l.rank) + assert k not in info + info[k] = (overt, ni) + elif l.name in ("GF", "PT"): + assert l.rank == 1 + info[l.name] = overt + elif l.name in ("BNC", "PENN"): + assert l.rank == 1 + info["POS"] = overt + info["POS_tagset"] = l.name + else: + if is_pos: + if l.name not in ("NER", "WSL"): + self._warn( + "Unexpected layer in sentence annotationset:", + l.name, + ) + else: + if l.name not in ( + "Sent", + "Verb", + "Noun", + "Adj", + "Adv", + "Prep", + "Scon", + "Art", + "Other", + ): + self._warn( + "Unexpected layer in frame annotationset:", l.name + ) + info[l.name] = overt + if not is_pos and "cxnID" not in info: + if "Target" not in info: + self._warn("Missing target in annotation set ID={0}".format(info.ID)) + assert "FE" in info + if "FE3" in info: + assert "FE2" in info + + return info + + def _handle_lulayer_elt(self, elt): + """Load a layer from an annotation set""" + layer = self._load_xml_attributes(AttrDict(), elt) + layer["_type"] = "lulayer" + layer["label"] = [] + + for sub in elt: + if sub.tag.endswith("label"): + l = self._load_xml_attributes(AttrDict(), sub) + if l is not None: + layer["label"].append(l) + return layer + + def _handle_fe_elt(self, elt): + feinfo = self._load_xml_attributes(AttrDict(), elt) + feinfo["_type"] = "fe" + feinfo["definition"] = "" + feinfo["definitionMarkup"] = "" + feinfo["semType"] = None + feinfo["requiresFE"] = None + feinfo["excludesFE"] = None + for sub in elt: + if sub.tag.endswith("definition"): + feinfo["definitionMarkup"] = sub.text + feinfo["definition"] = self._strip_tags(sub.text) + elif sub.tag.endswith("semType"): + stinfo = self._load_xml_attributes(AttrDict(), sub) + feinfo["semType"] = self.semtype(stinfo.ID) + elif sub.tag.endswith("requiresFE"): + feinfo["requiresFE"] = self._load_xml_attributes(AttrDict(), sub) + elif sub.tag.endswith("excludesFE"): + feinfo["excludesFE"] = self._load_xml_attributes(AttrDict(), sub) + + return feinfo + + def _handle_semtype_elt(self, elt, tagspec=None): + semt = self._load_xml_attributes(AttrDict(), elt) + semt["_type"] = "semtype" + semt["superType"] = None + semt["subTypes"] = PrettyList() + for sub in elt: + if sub.text is not None: + semt["definitionMarkup"] = sub.text + semt["definition"] = self._strip_tags(sub.text) + else: + supertypeinfo = self._load_xml_attributes(AttrDict(), sub) + semt["superType"] = supertypeinfo + # the supertype may not have been loaded yet + + return semt + + +# +# Demo +# +def demo(): + from nltk.corpus import framenet as fn + + # + # It is not necessary to explicitly build the indexes by calling + # buildindexes(). We do this here just for demo purposes. If the + # indexes are not built explicitely, they will be built as needed. + # + print("Building the indexes...") + fn.buildindexes() + + # + # Get some statistics about the corpus + # + print("Number of Frames:", len(fn.frames())) + print("Number of Lexical Units:", len(fn.lus())) + print("Number of annotated documents:", len(fn.docs())) + print() + + # + # Frames + # + print( + 'getting frames whose name matches the (case insensitive) regex: "(?i)medical"' + ) + medframes = fn.frames(r"(?i)medical") + print('Found {0} Frames whose name matches "(?i)medical":'.format(len(medframes))) + print([(f.name, f.ID) for f in medframes]) + + # + # store the first frame in the list of frames + # + tmp_id = medframes[0].ID + m_frame = fn.frame(tmp_id) # reads all info for the frame + + # + # get the frame relations + # + print( + '\nNumber of frame relations for the "{0}" ({1}) frame:'.format( + m_frame.name, m_frame.ID + ), + len(m_frame.frameRelations), + ) + for fr in m_frame.frameRelations: + print(" ", fr) + + # + # get the names of the Frame Elements + # + print( + '\nNumber of Frame Elements in the "{0}" frame:'.format(m_frame.name), + len(m_frame.FE), + ) + print(" ", [x for x in m_frame.FE]) + + # + # get the names of the "Core" Frame Elements + # + print('\nThe "core" Frame Elements in the "{0}" frame:'.format(m_frame.name)) + print(" ", [x.name for x in m_frame.FE.values() if x.coreType == "Core"]) + + # + # get all of the Lexical Units that are incorporated in the + # 'Ailment' FE of the 'Medical_conditions' frame (id=239) + # + print('\nAll Lexical Units that are incorporated in the "Ailment" FE:') + m_frame = fn.frame(239) + ailment_lus = [ + x + for x in m_frame.lexUnit.values() + if "incorporatedFE" in x and x.incorporatedFE == "Ailment" + ] + print(" ", [x.name for x in ailment_lus]) + + # + # get all of the Lexical Units for the frame + # + print( + '\nNumber of Lexical Units in the "{0}" frame:'.format(m_frame.name), + len(m_frame.lexUnit), + ) + print(" ", [x.name for x in m_frame.lexUnit.values()][:5], "...") + + # + # get basic info on the second LU in the frame + # + tmp_id = m_frame.lexUnit["ailment.n"].ID # grab the id of the specified LU + luinfo = fn.lu_basic(tmp_id) # get basic info on the LU + print("\nInformation on the LU: {0}".format(luinfo.name)) + pprint(luinfo) + + # + # Get a list of all of the corpora used for fulltext annotation + # + print("\nNames of all of the corpora used for fulltext annotation:") + allcorpora = set(x.corpname for x in fn.docs_metadata()) + pprint(list(allcorpora)) + + # + # Get the names of the annotated documents in the first corpus + # + firstcorp = list(allcorpora)[0] + firstcorp_docs = fn.docs(firstcorp) + print('\nNames of the annotated documents in the "{0}" corpus:'.format(firstcorp)) + pprint([x.filename for x in firstcorp_docs]) + + # + # Search for frames containing LUs whose name attribute matches a + # regexp pattern. + # + # Note: if you were going to be doing a lot of this type of + # searching, you'd want to build an index that maps from + # lemmas to frames because each time frames_by_lemma() is + # called, it has to search through ALL of the frame XML files + # in the db. + print( + '\nSearching for all Frames that have a lemma that matches the regexp: "^run.v$":' + ) + pprint(fn.frames_by_lemma(r"^run.v$")) + + +if __name__ == "__main__": + demo() + +framenet15 = LazyCorpusLoader( + "framenet_v15", + FramenetCorpusReader, + [ + "frRelation.xml", + "frameIndex.xml", + "fulltextIndex.xml", + "luIndex.xml", + "semTypes.xml", + ], +) +framenet = LazyCorpusLoader( + "framenet_v17", + FramenetCorpusReader, + [ + "frRelation.xml", + "frameIndex.xml", + "fulltextIndex.xml", + "luIndex.xml", + "semTypes.xml", + ], +) diff --git a/spanfinder/tools/framenet/retokenize_fn.py b/spanfinder/tools/framenet/retokenize_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..02a6085b4f03cb1a852fcae548e1923d141c87d3 --- /dev/null +++ b/spanfinder/tools/framenet/retokenize_fn.py @@ -0,0 +1,188 @@ +import gzip +import json +import os +import logging +from argparse import ArgumentParser +from itertools import accumulate + +import nltk +import numpy as np +from tools.framenet.nltk_framenet import framenet, framenet15 +from tqdm import tqdm + +from tools.framenet.fn_util import framenet_split, Sentence + +logger = logging.getLogger('fn') + + +def _load_raw(version): + if version == '1.5': + nltk.download('framenet_v15') + return framenet15 + else: + nltk.download('framenet_v17') + return framenet + + +def one_frame(sentence, ann): + frame_info = {'label': ann.frame.name} + target_list = list() + for start, end in ann.Target: + start, end = sentence.span(start, end) + target_list.extend(list(range(start, end+1))) + assert len(target_list) > 0 + frame_info['span'] = [sorted(target_list)[0], sorted(target_list)[-1]] + frame_info['lu'] = ann.LU.name + frame_info['children'] = fes = list() + for start, end, fe_name in ann.FE[0]: + start, end = sentence.span(start, end) + fes.append({'span': [start, end], 'label': fe_name}) + return frame_info + + +def load_nltk_exemplars(version, exclude_ann_ids=None): + exclude_ann_ids = exclude_ann_ids or list() + fn = _load_raw(version) + egs = list() + bar = tqdm() + skipped = 0 + try: + for eg in fn.annotations(full_text=False): + if 'Target' not in eg.keys(): + # A bug of nltk + continue + if eg.ID in exclude_ann_ids: + skipped += 1 + continue + try: + sentence = Sentence(eg.text) + egs.append({ + 'tokens': list(map(str, sentence.tokens)), 'annotations': [one_frame(sentence, eg)], + 'meta': { + 'fully_annotated': False, + 'source': f'framenet_v{version}', + 'with_fe': True, + 'type': 'exemplar', + 'ann_ids': [eg.ID], + } + }) + bar.update() + except: + pass + except: + pass + bar.close() + logger.info(f'Loaded {len(egs)} sentences for framenet v{version} from exemplars. (skipped {skipped} sentences)') + return egs + + +def load_nltk_fully_annotated(version): + fn = _load_raw(version) + + splits = list(framenet_split.keys()) + all_containers = {split: [] for split in splits} + for doc in tqdm(fn.docs()): + container = all_containers['train'] + for sp in splits: + if doc.filename in framenet_split[sp]: + container = all_containers[sp] + + for sent in doc.sentence: + sentence = Sentence(sent.text) + all_frames = list() + ann_ids = [] + for ann in sent.annotationSet: + if ann._type == 'posannotationset': + continue + assert ann._type == 'fulltext_annotationset' + if 'Target' not in ann.keys(): + logger.warning('Target not found.') + continue + if 'ID' in ann: + ann_ids.append(ann['ID']) + frame_info = one_frame(sentence, ann) + all_frames.append(frame_info) + eg_dict = { + 'tokens': list(map(str, sentence.tokens)), 'annotations': all_frames, + 'meta': { + 'source': f'framenet_v{version}', + 'fully_annotated': True, + 'with_fe': True, + 'type': 'full text', + 'sentence ID': sent.ID, + 'doc': doc.filename, + 'ann_ids': ann_ids + } + } + container.append(eg_dict) + + for sp in splits: + logger.info(f'Load {len(all_containers[sp])} for {sp}.') + return all_containers + + +def load_expanded_fn(path): + raise NotImplementedError + with gzip.open(path, 'rb') as compressed: + lines = compressed.read().decode() + instances = list() + lines = lines.split('\n') + for line in tqdm(lines): + if len(line) != 0: + instances.append(json.loads(line)) + logger.info(f'{len(instances)} lines loaded.') + + dataset = list() + for instance in tqdm(instances, desc='Processing expanded framenet...'): + for output in instance['outputs']: + ins_dict = dict() + ins_dict['meta'] = { + 'source': 'expanded framenet', + 'type': 'paraphrase', + 'exemplar_id': instance['exemplar_id'], + 'annoset_id': instance['annoset_id'] + } + words = output['output_string'] + text = ' '.join(words) + length_offsets = [0] + list(accumulate(map(len, words))) + start_idx, end_idx = output['output_trigger_offset'] + start_idx = length_offsets[start_idx] + start_idx + end_idx = length_offsets[end_idx] + end_idx - 2 + sentence = Sentence(text) + ins_dict['text'] = sentence.tokens + ins_dict['pos'] = sentence.pos + ins_dict['tag'] = sentence.tag + ins_dict['frame'] = [{ + 'name': instance['frame_name'], + 'target': list(range(sentence.span(start_idx, end_idx)[0], sentence.span(start_idx, end_idx)[1]+1)), + 'lu': output['output_trigger'], + 'fe': [] + }] + ins_dict['score'] = { + 'pbr': np.exp(-output['pbr_score']), + 'aligner': output['aligner_score'], + } + ins_dict['with_fe'] = False + ins_dict['fully_annotated'] = False + dataset.append(ins_dict) + logger.info(f'{len(dataset)} sentences loaded.') + return dataset + + +if __name__ == '__main__': + logging.basicConfig(level='INFO') + arg_parser = ArgumentParser() + arg_parser.add_argument('output', type=str) + arg_parser.add_argument('-v', type=str, default='1.7') + cmd_args = arg_parser.parse_args() + full = load_nltk_fully_annotated(cmd_args.v) + full_ann_ids = list() + for split in ['train', 'dev', 'test']: + for sent in full[split]: + full_ann_ids.extend(sent['meta']['ann_ids']) + exe = load_nltk_exemplars(cmd_args.v, full_ann_ids) + os.makedirs(cmd_args.output, exist_ok=True) + with open(os.path.join(cmd_args.output, 'full.' + cmd_args.v.replace('.', '') + '.json'), 'w') as fp: + json.dump(full, fp) + with open(os.path.join(cmd_args.output, 'exe.' + cmd_args.v.replace('.', '') + '.json'), 'w') as fp: + json.dump(exe, fp) diff --git a/spanfinder/tools/framenet/run_fttosem.py b/spanfinder/tools/framenet/run_fttosem.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4ea4466925870e8d0ee009d225f2417ceda034 --- /dev/null +++ b/spanfinder/tools/framenet/run_fttosem.py @@ -0,0 +1,78 @@ +import argparse +import multiprocessing as mp +import os +import xml.etree.ElementTree as et +from io import StringIO + +import nltk + +buggy_annotations = [6538624, 6550700, 6547918, 6521702, 6541530, 6774318, 4531088, 4531238] +def iterate_docs(folder_path, tmp_path): + os.makedirs(tmp_path, exist_ok=True) + ft_path = os.path.join(folder_path, 'fulltext') + all_docs = list(filter(lambda x: x.endswith('.xml'), os.listdir(ft_path))) + for doc_name in all_docs: + it = et.iterparse(StringIO(open(os.path.join(ft_path, doc_name)).read())) + for _, el in it: + prefix, has_namespace, postfix = el.tag.partition('}') + if has_namespace: + el.tag = postfix + root = it.root + for sentence in root: + for annotation in sentence: + if "ID" in annotation.attrib and int(annotation.attrib['ID']) in buggy_annotations: + print('Delete one buggy annotation from', doc_name) + sentence.remove(annotation) + break + dump_path = os.path.join(tmp_path, doc_name) + et.ElementTree(root).write(dump_path) + # doc_xml.write(dump_path, default_namespace='') + yield dump_path + + +def process_doc(script_folder, doc_path): + os.chdir(script_folder) + print(f'processing {doc_path}...') + cmd = f'perl fttosem.pl {doc_path}' + print(cmd) + os.system(cmd) + print('Done') + return True + + +def main(): + parser = argparse.ArgumentParser('Run fttosem perl examples.') + parser.add_argument('-s', help='script folder', type=str, required=True) + parser.add_argument('-p', help='path to corpora', type=str) + parser.add_argument('-o', help='output path', type=str, default='/tmp/framenet') + args = parser.parse_args() + script_folder = args.s + corpora_folder = args.p + if corpora_folder is None: + nltk.download('framenet') + corpora_folder = os.path.join(nltk.data.path[0], 'corpora', 'framenet_v17') + os.chdir(script_folder) + fns = list(iterate_docs(corpora_folder, args.o)) + print(f'{len(fns)} documents detected.') + + processes = list() + for fn in fns: + print('Processing', fn) + process = mp.Process(target=process_doc, args=(script_folder, fn)) + process.start() + processes.append(process) + for process in processes: + process.join(timeout=480) + + rst = os.listdir(args.o) + rst = list(filter(lambda x: x.endswith('.sem'), rst)) + print(f'{len(rst)} Done.') + rst = [fn[:-4] for fn in rst] + fns = [fn[:-4] for fn in fns] + print('Unfinished docs:') + for fn in set(fns) - set(rst): + print(fn) + + +if __name__ == '__main__': + main() diff --git a/spanfinder/tools/ontology_mapping/__init__.py b/spanfinder/tools/ontology_mapping/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/spanfinder/tools/ontology_mapping/__pycache__/__init__.cpython-38.pyc b/spanfinder/tools/ontology_mapping/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3beabcb526750f8e5e55211e3e9f6f5d4f2f320 Binary files /dev/null and b/spanfinder/tools/ontology_mapping/__pycache__/__init__.cpython-38.pyc differ diff --git a/spanfinder/tools/ontology_mapping/__pycache__/__init__.cpython-39.pyc b/spanfinder/tools/ontology_mapping/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18a1b60eb2470f67edf05f9711ee2d128fce2220 Binary files /dev/null and b/spanfinder/tools/ontology_mapping/__pycache__/__init__.cpython-39.pyc differ diff --git a/spanfinder/tools/ontology_mapping/force_map.py b/spanfinder/tools/ontology_mapping/force_map.py new file mode 100644 index 0000000000000000000000000000000000000000..fb691b1d23c8a26d840fee02bec774637d278815 --- /dev/null +++ b/spanfinder/tools/ontology_mapping/force_map.py @@ -0,0 +1,143 @@ +import json +import os +from collections import defaultdict +from typing import * + +import numpy as np +from allennlp.data import Vocabulary +from tqdm import tqdm + +from sftp import SpanPredictor, Span +from sftp.utils import VIRTUAL_ROOT + + +def read_framenet(path: str): + ret = list() + for line in map(json.loads, open(path).readlines()): + ret.append((line['tokens'], Span.from_json(line['annotations']))) + return ret + + +def co_occur( + predictor: SpanPredictor, + sentences: List[Tuple[List[str], Span]], + event_list: List[str], + arg_list: List[str], +): + idx2label = predictor.vocab.get_index_to_token_vocabulary('span_label') + event_count = np.zeros([len(event_list), len(idx2label)], np.float64) + arg_count = np.zeros([len(arg_list), len(idx2label)], np.float64) + for sent, vr in tqdm(sentences): + # For events + _, _, event_dist = predictor.force_decode(sent, child_spans=[event.boundary for event in vr]) + for event, dist in zip(vr, event_dist): + event_count[event_list.index(event.label)] += dist + # For args + for event, one_event_dist in zip(vr, event_dist): + parent_label = idx2label[int(one_event_dist.argmax())] + arg_spans = [child.boundary for child in event] + _, _, arg_dist = predictor.force_decode( + sent, event.boundary, parent_label, arg_spans + ) + for arg, dist in zip(event, arg_dist): + arg_count[arg_list.index(arg.label)] += dist + return event_count, arg_count + + +def create_vocab(events, args): + vocab = Vocabulary() + vocab.add_token_to_namespace(VIRTUAL_ROOT, 'span_label') + for event in events: + vocab.add_token_to_namespace(event, 'span_label') + for arg in args: + vocab.add_token_to_namespace(arg, 'span_label') + return vocab + + +def count_data(annotations: Iterable[Span]): + event_cnt, arg_cnt = defaultdict(int), defaultdict(int) + for sent in annotations: + for event in sent: + event_cnt[event.label] += 1 + for arg in event: + arg_cnt[arg.label] += 1 + return dict(event_cnt), dict(arg_cnt) + + +def gen_mapping( + src_label: List[str], src_count: Dict[str, int], + tgt_onto: List[str], tgt_label: List[str], + cooccur_count: np.ndarray +): + """ + :param src_label: Src label list, including events and args. + :param src_count: Src label count, event or arg. + :param tgt_onto: Target label list, only event or arg. + :param tgt_label: Target label count, event or arg. + :param cooccur_count: Co-occurrence counting table. + :return: Mapping dict. + """ + onto2label = np.zeros([len(tgt_onto), len(tgt_label)], dtype=np.float) + for onto_idx, onto_tag in enumerate(tgt_onto): + onto2label[onto_idx, tgt_label.index(onto_tag)] = 1.0 + ret = dict() + for src_tag, src_freq in src_count.items(): + if src_tag in src_label: + src_idx = src_label.index(src_tag) + ret[src_tag] = list((cooccur_count[:, src_idx] / src_freq) @ onto2label) + return ret + + +def ontology_map( + model_path, + src_data: List[Tuple[List[str], Span]], + tgt_data: List[Tuple[List[str], Span]], + device: int, + dst_path: str, + meta: Optional[dict] = None, +) -> None: + ret = {'meta': meta or {}} + data = {'src': {}, 'tgt': {}} + for name, datasets in [['src', src_data], ['tgt', tgt_data]]: + d = data[name] + d['sentences'], d['annotations'] = zip(*datasets) + d['event_cnt'], d['arg_cnt'] = count_data(d['annotations']) + d['event'], d['arg'] = list(d['event_cnt']), list(d['arg_cnt']) + + predictor = SpanPredictor.from_path(model_path, cuda_device=device) + tgt_vocab = create_vocab(data['tgt']['event'], data['tgt']['arg']) + for name, vocab in [['src', predictor.vocab], ['tgt', tgt_vocab]]: + data[name]['label'] = [ + vocab.get_index_to_token_vocabulary('span_label')[i] for i in range(vocab.get_vocab_size('span_label')) + ] + + data['event'], data['arg'] = co_occur( + predictor, tgt_data, data['tgt']['event'], data['tgt']['arg'] + ) + mapping = {} + for layer in ['event', 'arg']: + mapping[layer] = gen_mapping( + data['src']['label'], data['src'][layer+'_cnt'], data['tgt'][layer], data['tgt']['label'], data[layer] + ) + + for key, name in [['source', 'src'], ['target', 'tgt']]: + ret[key] = { + 'label': data[name]['label'], + 'event': data[name]['event'], + 'argument': data[name]['arg'] + } + ret['mapping'] = { + 'event': mapping['event'], + 'argument': mapping['arg'] + } + + os.makedirs(dst_path, exist_ok=True) + with open(os.path.join(dst_path, 'ontology_mapping.json'), 'w') as fp: + json.dump(ret, fp) + with open(os.path.join(dst_path, 'ontology.tsv'), 'w') as fp: + to_dump = list() + to_dump.append('\t'.join([VIRTUAL_ROOT] + ret['target']['event'])) + for event in ret['target']['event']: + to_dump.append('\t'.join([event] + ret['target']['argument'])) + fp.write('\n'.join(to_dump)) + tgt_vocab.save_to_files(os.path.join(dst_path, 'vocabulary')) diff --git a/spanfinder/tools/ontology_mapping/framenet2ab.py b/spanfinder/tools/ontology_mapping/framenet2ab.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb24ed89deaef0cf1aebb44d47c99521b89d6c7 --- /dev/null +++ b/spanfinder/tools/ontology_mapping/framenet2ab.py @@ -0,0 +1,47 @@ +from argparse import ArgumentParser +import hashlib +import os + +from sftp.data_reader import BetterDatasetReader, ConcreteDatasetReader +from tools.ontology_mapping.force_map import ontology_map, read_framenet + + +def read_ace_better(reader, data_path): + sentences = list() + for ins in reader.read(data_path): + sentences.append(tuple(ins.fields['raw_inputs'].metadata[key] for key in ['sentence', 'spans'])) + return sentences + + +def run(model_path, src_data_path, tgt_data_path, device, dst_path): + if model_path.endswith('.tar.gz'): + model_md5 = hashlib.md5(open(model_path, 'rb').read()).hexdigest() + else: + model_md5 = hashlib.md5(open(os.path.join(model_path, 'model.tar.gz'), 'rb').read()).hexdigest() + print('model md5: ', model_md5) + if 'better' in tgt_data_path.lower(): + reader = BetterDatasetReader(eval_type='basic', pretrained_model='roberta-large', ignore_label=False) + elif 'ace' in tgt_data_path.lower(): + reader = ConcreteDatasetReader(ignore_unlabeled_sentence=True, pretrained_model='roberta-large') + else: + raise NotImplementedError + meta = { + 'model': {'path': model_path, 'md5': model_md5}, + 'src_data_path': src_data_path, + 'tgt_data_path': tgt_data_path + } + # event_list and arg_list are target ontology + # label_list is source ontology (i.e. FrameNet) + src_data, tgt_data = read_framenet(src_data_path), read_ace_better(reader, tgt_data_path) + ontology_map(model_path, src_data, tgt_data, device, dst_path, meta) + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('model', metavar='MODEL_PATH') + parser.add_argument('src', metavar='SRC_DATA_PATH') + parser.add_argument('tgt', metavar='TGT_DATA_PATH') + parser.add_argument('dst', metavar='DESTINATION_PATH') + parser.add_argument('-d', type=int, help='device', default=-1) + cmd_args = parser.parse_args() + run(cmd_args.model, cmd_args.src, cmd_args.tgt, cmd_args.d, cmd_args.dst)