gossminn commited on
Commit
6680682
·
0 Parent(s):

First version

Browse files
Files changed (49) hide show
  1. .gitattributes +2 -0
  2. .gitignore +210 -0
  3. deploy.py +3 -0
  4. fillmorle/app.py +524 -0
  5. model.mod.tar.gz +3 -0
  6. requirements.txt +19 -0
  7. setup.py +9 -0
  8. sftp/__init__.py +10 -0
  9. sftp/data_reader/__init__.py +6 -0
  10. sftp/data_reader/batch_sampler/__init__.py +1 -0
  11. sftp/data_reader/batch_sampler/mix_sampler.py +50 -0
  12. sftp/data_reader/better_reader.py +286 -0
  13. sftp/data_reader/concrete_reader.py +44 -0
  14. sftp/data_reader/concrete_srl.py +169 -0
  15. sftp/data_reader/span_reader.py +197 -0
  16. sftp/data_reader/srl_reader.py +107 -0
  17. sftp/metrics/__init__.py +4 -0
  18. sftp/metrics/base_f.py +27 -0
  19. sftp/metrics/exact_match.py +29 -0
  20. sftp/metrics/fbeta_mix_measure.py +34 -0
  21. sftp/metrics/srl_metrics.py +138 -0
  22. sftp/models/__init__.py +1 -0
  23. sftp/models/span_model.py +362 -0
  24. sftp/modules/__init__.py +4 -0
  25. sftp/modules/smooth_crf.py +77 -0
  26. sftp/modules/span_extractor/__init__.py +1 -0
  27. sftp/modules/span_extractor/combo.py +36 -0
  28. sftp/modules/span_finder/__init__.py +2 -0
  29. sftp/modules/span_finder/bio_span_finder.py +216 -0
  30. sftp/modules/span_finder/span_finder.py +87 -0
  31. sftp/modules/span_typing/__init__.py +2 -0
  32. sftp/modules/span_typing/mlp_span_typing.py +99 -0
  33. sftp/modules/span_typing/span_typing.py +64 -0
  34. sftp/predictor/__init__.py +1 -0
  35. sftp/predictor/span_predictor.orig.py +362 -0
  36. sftp/predictor/span_predictor.py +401 -0
  37. sftp/training/__init__.py +0 -0
  38. sftp/training/transformer_optimizer.py +121 -0
  39. sftp/utils/__init__.py +7 -0
  40. sftp/utils/bio_smoothing.py +62 -0
  41. sftp/utils/common.py +3 -0
  42. sftp/utils/db_storage.py +87 -0
  43. sftp/utils/functions.py +75 -0
  44. sftp/utils/label_smoothing.py +48 -0
  45. sftp/utils/span.py +420 -0
  46. sftp/utils/span_utils.py +57 -0
  47. sociolome/combine_models.py +130 -0
  48. sociolome/evalita_eval.py +319 -0
  49. sociolome/lome_wrapper.py +83 -0
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ spanfinder/model.mod.tar.gz filter=lfs diff=lfs merge=lfs -text
2
+ model.mod.tar.gz filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by .ignore support plugin (hsz.mobi)
2
+ ### JetBrains template
3
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
4
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
5
+
6
+ # User-specific stuff
7
+ .idea/
8
+ data
9
+ cache
10
+
11
+ # Gradle and Maven with auto-import
12
+ # When using Gradle or Maven with auto-import, you should exclude module files,
13
+ # since they will be recreated, and may cause churn. Uncomment if using
14
+ # auto-import.
15
+ # .idea/artifacts
16
+ # .idea/compiler.xml
17
+ # .idea/jarRepositories.xml
18
+ # .idea/modules.xml
19
+ # .idea/*.iml
20
+ # .idea/modules
21
+ # *.iml
22
+ # *.ipr
23
+
24
+ # CMake
25
+ cmake-build-*/
26
+
27
+ # Mongo Explorer plugin
28
+ .idea/**/mongoSettings.xml
29
+
30
+ # File-based project format
31
+ *.iws
32
+
33
+ # IntelliJ
34
+ out/
35
+
36
+ # mpeltonen/sbt-idea plugin
37
+ .idea_modules/
38
+
39
+ # JIRA plugin
40
+ atlassian-ide-plugin.xml
41
+
42
+ # Cursive Clojure plugin
43
+ .idea/replstate.xml
44
+
45
+ # Crashlytics plugin (for Android Studio and IntelliJ)
46
+ com_crashlytics_export_strings.xml
47
+ crashlytics.properties
48
+ crashlytics-build.properties
49
+ fabric.properties
50
+
51
+ # Editor-based Rest Client
52
+ .idea/httpRequests
53
+
54
+ # Android studio 3.1+ serialized cache file
55
+ .idea/caches/build_file_checksums.ser
56
+
57
+ ### JupyterNotebooks template
58
+ # gitignore template for Jupyter Notebooks
59
+ # website: http://jupyter.org/
60
+
61
+ .ipynb_checkpoints
62
+ */.ipynb_checkpoints/*
63
+
64
+ # IPython
65
+ profile_default/
66
+ ipython_config.py
67
+
68
+ # Remove previous ipynb_checkpoints
69
+ # git rm -r .ipynb_checkpoints/
70
+
71
+ ### Python template
72
+ # Byte-compiled / optimized / DLL files
73
+ __pycache__/
74
+ *.py[cod]
75
+ *.class
76
+
77
+ # C extensions
78
+ *.so
79
+
80
+ # Distribution / packaging
81
+ .Python
82
+ build/
83
+ develop-eggs/
84
+ dist/
85
+ downloads/
86
+ eggs/
87
+ .eggs/
88
+ lib/
89
+ lib64/
90
+ parts/
91
+ sdist/
92
+ var/
93
+ wheels/
94
+ share/python-wheels/
95
+ *.egg-info/
96
+ .installed.cfg
97
+ *.egg
98
+ MANIFEST
99
+
100
+ # PyInstaller
101
+ # Usually these files are written by a python script from a template
102
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
103
+ *.manifest
104
+ *.spec
105
+
106
+ # Installer logs
107
+ pip-log.txt
108
+ pip-delete-this-directory.txt
109
+
110
+ # Unit test / coverage reports
111
+ htmlcov/
112
+ .tox/
113
+ .nox/
114
+ .coverage
115
+ .coverage.*
116
+ .cache
117
+ nosetests.xml
118
+ coverage.xml
119
+ *.cover
120
+ *.py,cover
121
+ .hypothesis/
122
+ .pytest_cache/
123
+ cover/
124
+
125
+ # Translations
126
+ *.mo
127
+ *.pot
128
+
129
+ # Django stuff:
130
+ *.log
131
+ local_settings.py
132
+ db.sqlite3
133
+ db.sqlite3-journal
134
+
135
+ # Flask stuff:
136
+ instance/
137
+ .webassets-cache
138
+
139
+ # Scrapy stuff:
140
+ .scrapy
141
+
142
+ # Sphinx documentation
143
+ docs/_build/
144
+
145
+ # PyBuilder
146
+ .pybuilder/
147
+ target/
148
+
149
+ # Jupyter Notebook
150
+ .ipynb_checkpoints
151
+
152
+ # IPython
153
+ profile_default/
154
+ ipython_config.py
155
+
156
+ # pyenv
157
+ # For a library or package, you might want to ignore these files since the code is
158
+ # intended to run in multiple environments; otherwise, check them in:
159
+ # .python-version
160
+
161
+ # pipenv
162
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
163
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
164
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
165
+ # install all needed dependencies.
166
+ #Pipfile.lock
167
+
168
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
169
+ __pypackages__/
170
+
171
+ # Celery stuff
172
+ celerybeat-schedule
173
+ celerybeat.pid
174
+
175
+ # SageMath parsed files
176
+ *.sage.py
177
+
178
+ # Environments
179
+ .env
180
+ .venv
181
+ env/
182
+ venv/
183
+ ENV/
184
+ env.bak/
185
+ venv.bak/
186
+
187
+ # Spyder project settings
188
+ .spyderproject
189
+ .spyproject
190
+
191
+ # Rope project settings
192
+ .ropeproject
193
+
194
+ # mkdocs documentation
195
+ /site
196
+
197
+ # mypy
198
+ .mypy_cache/
199
+ .dmypy.json
200
+ dmypy.json
201
+
202
+ # Pyre type checker
203
+ .pyre/
204
+
205
+ # pytype static type analyzer
206
+ .pytype/
207
+
208
+ # Cython debug symbols
209
+ cython_debug/
210
+
deploy.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import runpy
2
+
3
+ runpy.run_module("fillmorle.app", run_name="__main__", alter_sys=True)
fillmorle/app.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import product
2
+ import random
3
+ from turtle import hideturtle
4
+ import requests
5
+ import json
6
+ import lxml.etree as ET
7
+
8
+ import gensim
9
+ import pandas as pd
10
+
11
+ import nltk
12
+ # from nltk.corpus import framenet as fn
13
+ # --- circumvent threading issues with FrameNet
14
+ fn_root = nltk.data.find("{}/{}".format("corpora", "framenet_v17"))
15
+ print(fn_root)
16
+ fn_files = ["frRelation.xml", "frameIndex.xml", "fulltextIndex.xml", "luIndex.xml", "semTypes.xml"]
17
+ fn = nltk.corpus.reader.framenet.FramenetCorpusReader(fn_root, fn_files)
18
+ # ---
19
+
20
+ import streamlit as st
21
+
22
+ from sociolome import lome_wrapper
23
+
24
+
25
+ def similarity(gensim_m, frame_1, frame_2):
26
+ if f"fn_{frame_1}" not in gensim_m or f"fn_{frame_2}" not in gensim_m:
27
+ return None
28
+ return 1 - gensim_m.distance(f"fn_{frame_1}", f"fn_{frame_2}")
29
+
30
+
31
+ def rank(gensim_m, frame_1, frame_2):
32
+ frame_1 = f"fn_{frame_1}"
33
+ frame_2 = f"fn_{frame_2}"
34
+
35
+ if frame_1 == frame_2:
36
+ return 0
37
+
38
+ for i, (word, _) in enumerate(gensim_m.most_similar(frame_1, topn=1200)):
39
+ if word == frame_2:
40
+ return i + 1
41
+ return -1
42
+
43
+
44
+ def format_frame_description(frame_def_xml):
45
+ frame_def_fmt = [frame_def_xml.text] if frame_def_xml.text else []
46
+ for elem in frame_def_xml:
47
+ if elem.tag == "ex":
48
+ break
49
+ elif elem.tag == "fen":
50
+ frame_def_fmt.append(elem.text.upper())
51
+ elif elem.text:
52
+ frame_def_fmt.append(elem.text)
53
+ if elem.tail:
54
+ frame_def_fmt.append(elem.tail)
55
+ return "".join(frame_def_fmt).replace("frames", "stories").replace("frame", "story")
56
+
57
+
58
+ def get_frame_definition(frame_info):
59
+ try:
60
+ # try extracting just the first sentence
61
+ definition_first_sent = nltk.sent_tokenize(frame_info.definitionMarkup)[0] + "</def-root>"
62
+ frame_def_xml = ET.fromstring(definition_first_sent)
63
+ except ET.XMLSyntaxError:
64
+ # otherwise, use the full definition
65
+ frame_def_xml = ET.fromstring(frame_info.definitionMarkup)
66
+ return format_frame_description(frame_def_xml)
67
+
68
+
69
+ def get_random_example(frame_info):
70
+ exemplars = [
71
+ {
72
+ "text": exemplar.text,
73
+ "target_lu": lu_name,
74
+ "target_idx": list(exemplar["Target"][0]),
75
+ "core_fes": {
76
+ role: exemplar.text[start_idx:end_idx]
77
+ for role, start_idx, end_idx in exemplar.FE[0]
78
+ if role in [fe for fe, fe_info in frame_info.FE.items() if fe_info.coreType == "Core"]
79
+ }
80
+ }
81
+ for lu_name, lu_info in frame_info["lexUnit"].items()
82
+ for exemplar in lu_info.exemplars if len(exemplar.text) > 30
83
+ ]
84
+ if exemplars:
85
+ return random.choice(exemplars)
86
+ return None
87
+
88
+ def make_hint(gensim_m, target, current_closest):
89
+
90
+ if target == current_closest:
91
+ return None
92
+
93
+ most_similar = gensim_m.most_similar(f"fn_{target}", topn=1200)
94
+ current_position = [word for word, _ in most_similar].index(f"fn_{current_closest}")
95
+
96
+ while current_position > 0:
97
+ next_closest, _ = most_similar[current_position - 1]
98
+ info = fn.frame(next_closest.replace("fn_", ""))
99
+ if len(info.lexUnit) > 10:
100
+ exemplar = get_random_example(info)
101
+ if exemplar:
102
+ return next_closest, exemplar
103
+ current_position -= 1
104
+
105
+ return None
106
+
107
+
108
+ def get_typical_exemplar(frame_info):
109
+ exemplars = [
110
+ {
111
+ "text": exemplar.text,
112
+ "target_lu": lu_name,
113
+ "target_idx": list(exemplar["Target"][0]),
114
+ "core_fes": {
115
+ role: exemplar.text[start_idx:end_idx]
116
+ for role, start_idx, end_idx in exemplar.FE[0]
117
+ if role in [fe for fe, fe_info in frame_info.FE.items() if fe_info.coreType == "Core"]
118
+ }
119
+ }
120
+ for lu_name, lu_info in frame_info["lexUnit"].items()
121
+ for exemplar in lu_info.exemplars
122
+ ]
123
+
124
+ # try to find a "typical" exemplar --- typical -> as short as possible, as many FEs as possible
125
+ exa_typicality_scores = [(exa, len(exa["text"]) - 25 * len(exa["core_fes"])) for exa in exemplars]
126
+ if exa_typicality_scores:
127
+ typical_exemplar = min(exa_typicality_scores, key=lambda t: t[1])[0]
128
+ else:
129
+ typical_exemplar = None
130
+ return typical_exemplar
131
+
132
+
133
+ def find_all_inheriting_frames(frame_name):
134
+ frame_info = fn.frame(frame_name)
135
+ inheritance_rels = [rel for rel in frame_info.frameRelations if rel.type.name == "Inheritance" and rel.superFrame.name == frame_name]
136
+ inheritors = [rel.subFrame.name for rel in inheritance_rels]
137
+ for inh in inheritors:
138
+ inheritors.extend(find_all_inheriting_frames(inh))
139
+ return inheritors
140
+
141
+
142
+ def has_enough_lus(frame, n=10):
143
+ return len(fn.frame(frame).lexUnit) > n
144
+
145
+
146
+ def choose_secret_frames():
147
+ event_frames = [frm for frm in find_all_inheriting_frames("Event") if has_enough_lus(frm)]
148
+ entity_frames = [frm for frm in find_all_inheriting_frames("Entity") if has_enough_lus(frm)]
149
+ return random.choice(list(product(event_frames, entity_frames)))
150
+
151
+
152
+ def get_frame_info(frames):
153
+ frames_and_info = []
154
+ for evoked_frame in frames:
155
+ try:
156
+ frame_info = fn.frame(evoked_frame)
157
+ typical_sentence = get_typical_exemplar(frame_info)
158
+ frames_and_info.append((evoked_frame, frame_info, typical_sentence))
159
+ except FileNotFoundError:
160
+ continue
161
+ return frames_and_info
162
+
163
+
164
+ def get_frame_feedback(frames_and_info, gensim_m, secret_event, secret_entity):
165
+ frame_feedback = []
166
+ for evoked_frame, frame_info, typical_sentence in frames_and_info:
167
+ lexunits = list(frame_info.lexUnit.keys())[:5]
168
+ similarity_score_1 = similarity(gensim_m, secret_event, evoked_frame)
169
+ similarity_rank_1 = rank(gensim_m, secret_event, evoked_frame)
170
+ similarity_score_2 = similarity(gensim_m, secret_entity, evoked_frame)
171
+ similarity_rank_2 = rank(gensim_m, secret_entity, evoked_frame)
172
+ if typical_sentence:
173
+ typical_sentence_txt = typical_sentence['text']
174
+ else:
175
+ typical_sentence_txt = None
176
+
177
+ frame_feedback.append({
178
+ "frame": evoked_frame,
179
+ "similarity_1": similarity_score_1 * 100 if similarity_score_1 else None,
180
+ "rank_1": similarity_rank_1 if similarity_rank_1 != -1 else "far away",
181
+ "similarity_2": similarity_score_2 * 100 if similarity_score_2 else None,
182
+ "rank_2": similarity_rank_2 if similarity_rank_2 != -1 else "far away",
183
+ "typical_words": lexunits,
184
+ "typical_sentence": typical_sentence_txt
185
+ })
186
+ return frame_feedback
187
+
188
+
189
+ def run_game_cli(debug=True):
190
+
191
+ secret_event, secret_entity = choose_secret_frames()
192
+
193
+ if debug:
194
+ print(f"Shhhhhh you're not supposed to know, but the secret frames are {secret_event} and {secret_entity}")
195
+ print("--------\n\n\n\n")
196
+
197
+ print("Welcome to FillmorLe!")
198
+ print("Words are not just words: behind every word, a story is hidden that appears in our imagination when we hear the word.")
199
+ print()
200
+ print("In this game, your job is to activate TWO SECRET STORIES by writing sentences.")
201
+ print("There will be new secret stories every day -- the first story is always about an EVENT (something that happens in the world) and the second one about an ENTITY (a thing or concept).")
202
+ print("Every time you write a sentence, I will tell you which stories are hidden below the surface, and how close these stories are to the secret stories.")
203
+ print("Once you write a sentence that has both of the secret stories in it, you win. Good luck and be creative!")
204
+
205
+ gensim_m = gensim.models.word2vec.KeyedVectors.load_word2vec_format("data/frame_embeddings.w2v.txt")
206
+
207
+ num_guesses = 0
208
+ guesses_event = []
209
+ guesses_entity = []
210
+
211
+ while True:
212
+ num_guesses += 1
213
+ closest_to_event = sorted(guesses_event, key=lambda g: g[1], reverse=True)[:5]
214
+ closest_to_entity = sorted(guesses_entity, key=lambda g: g[1], reverse=True)[:5]
215
+ closest_to_event_txt = ", ".join([f"{frm.upper()} ({sim:.2f})" for frm, sim in closest_to_event])
216
+ closest_to_entity_txt = ", ".join([f"{frm.upper()} ({sim:.2f})" for frm, sim in closest_to_entity])
217
+
218
+ print()
219
+ print(f"==== Guess #{num_guesses} ====")
220
+ if secret_event in guesses_event:
221
+ print("You already guessed SECRET STORY #1: ", secret_event.upper())
222
+ elif closest_to_event:
223
+ print(f"Best guesses (SECRET STORY #1):", closest_to_event_txt)
224
+
225
+ if secret_entity in guesses_entity:
226
+ print("You already guessed SECRET STORY #1: ", secret_entity.upper())
227
+ elif closest_to_entity:
228
+ print(f"Best guesses (SECRET STORY #2):", closest_to_entity_txt)
229
+
230
+ sentence = input("Enter a sentence or type 'HINT' if you're stuck >>>> ").strip()
231
+
232
+ if sentence == "HINT":
233
+ hint_target = None
234
+ while not hint_target:
235
+ hint_choice = input("For which story do you want a hint? Type '1' or '2' >>>> ").strip()
236
+ if hint_choice == "1":
237
+ hint_target = secret_event
238
+ hint_current = closest_to_event[0][0] if closest_to_event else "Event"
239
+ elif hint_choice == "2":
240
+ hint_target = secret_entity
241
+ hint_current = closest_to_entity[0][0] if closest_to_entity else "Entity"
242
+ else:
243
+ print("Please type '1' or '2'.")
244
+
245
+ if hint_current == hint_target:
246
+ print("You don't need a hint for this story! Maybe you want a hint for the other one?")
247
+ continue
248
+
249
+ hint = make_hint(gensim_m, hint_target, hint_current)
250
+ if hint is None:
251
+ print("Sorry, you're already too close to give you a hint!")
252
+ else:
253
+ _, hint_example = hint
254
+ hint_tgt_idx = hint_example["target_idx"]
255
+ hint_example_redacted = hint_example["text"][:hint_tgt_idx[0]] + "******" + hint_example["text"][hint_tgt_idx[1]:]
256
+ print(f"Your hint sentence is: «{hint_example_redacted}»")
257
+ print(f"PRO TIP 1: the '******' hide a secret word. Guess the word and you will find a story that takes your one step closer to find SECRET STORY #{hint_choice}")
258
+ print(f"PRO TIP 2: if you don't get the hint, just ask for a new one! You can do this as often as you want.")
259
+ print("\n\n")
260
+ continue
261
+
262
+ r = requests.get("http://127.0.0.1:9090/analyze", params={"text": sentence})
263
+ lome_data = json.loads(r.text)
264
+ frames = set()
265
+ for token_items in lome_data["analyses"][0]["frame_list"]:
266
+ for item in token_items:
267
+ if item.startswith("T:"):
268
+ evoked_frame = item.split("@")[0].replace("T:", "")
269
+ frames.add(evoked_frame)
270
+
271
+ frames_and_info = get_frame_info(frames)
272
+ frame_feedback = get_frame_feedback(frames_and_info)
273
+
274
+ for i, feedback in enumerate(frame_feedback):
275
+
276
+ print(f"STORY {i}: {feedback['frame'].upper()}")
277
+ if feedback["typical_sentence"]:
278
+ print(f"\ttypical context: «{feedback['typical_sentence']}»")
279
+ print("\ttypical words:", ", ".join(feedback["typical_words"]), "...")
280
+ if feedback["similarity_1"]:
281
+ guesses_event.append((evoked_frame, feedback["similarity_1"]))
282
+ guesses_entity.append((evoked_frame, feedback["similarity_2"]))
283
+ print(f"\tsimilarity to SECRET STORY #1: {feedback['similarity_1']:.2f}")
284
+ print(f"\tsimilarity to SECRET STORY #2: {feedback['similarity_2']:.2f}")
285
+ else:
286
+ print("similarity: unknown")
287
+ print()
288
+
289
+ if not frames_and_info:
290
+ print("I don't know any of the stories in your sentence. Try entering another sentence.")
291
+
292
+ elif secret_event in frames and secret_entity in frames:
293
+ print(f"YOU WIN! You made a sentence with both of the SECRET STORIES: {secret_event.upper()} and {secret_entity.upper()}.\nYou won the game in {num_guesses} guesses, great job!")
294
+ break
295
+
296
+ elif secret_event in frames:
297
+ print(f"Great, you guessed SECRET STORY #1! It was {secret_event.upper()}!")
298
+ print("To win, make a sentence with this story and SECRET STORY #2 hidden in it.")
299
+
300
+ elif secret_entity in frames:
301
+ print(f"Great, you guessed SECRET STORY #2! It was {secret_entity.upper()}!")
302
+ print("To win, make a sentence with this story and SECRET STORY #1 hidden in it.")
303
+
304
+
305
+ # dummy version
306
+ # def analyze_sentence(sentence):
307
+ # return sentence.split()
308
+
309
+ def analyze_sentence(sentence):
310
+ lome_data = lome_wrapper.analyze(sentence)
311
+ frames = set()
312
+ for token_items in lome_data["analyses"][0]["frame_list"]:
313
+ for item in token_items:
314
+ if item.startswith("T:"):
315
+ evoked_frame = item.split("@")[0].replace("T:", "")
316
+ frames.add(evoked_frame)
317
+ return frames
318
+
319
+
320
+
321
+ def make_frame_feedback_msg(frame_feedback):
322
+ feedback_msg = []
323
+ for i, feedback in enumerate(frame_feedback):
324
+ feedback_msg.append(f"* STORY {i}: *{feedback['frame'].upper()}*")
325
+ feedback_msg.append("\t* typical words: *" + " ".join(feedback["typical_words"]) + "* ...")
326
+ if feedback["typical_sentence"]:
327
+ feedback_msg.append(f"\t* typical context: «{feedback['typical_sentence']}»")
328
+
329
+ if feedback["similarity_1"]:
330
+ feedback_msg.append(f"\t* similarity to SECRET STORY #1: {feedback['similarity_1']:.2f}")
331
+ feedback_msg.append(f"\t* similarity to SECRET STORY #2: {feedback['similarity_2']:.2f}")
332
+ else:
333
+ feedback_msg.append(f"\t* similarity: unknown")
334
+ return "\n".join(feedback_msg)
335
+
336
+
337
+ def format_hint_sentence(hint_example):
338
+ hint_tgt_idx = hint_example["target_idx"]
339
+ hint_example_redacted = hint_example["text"][:hint_tgt_idx[0]] + "******" + hint_example["text"][hint_tgt_idx[1]:]
340
+ return hint_example_redacted.strip()
341
+
342
+
343
+ def play_turn():
344
+ # remove text from input
345
+ sentence = st.session_state["cur_sentence"]
346
+ st.session_state["cur_sentence"] = ""
347
+
348
+ # get previous game state
349
+ game_state = st.session_state["game_state"]
350
+ secret_event, secret_entity = game_state["secret_event"], game_state["secret_entity"]
351
+ guesses_event, guesses_entity = game_state["guesses_event"], game_state["guesses_entity"]
352
+
353
+ # reset hints
354
+ st.session_state["hints"] = [None, None]
355
+
356
+ # reveal correct frames
357
+ if sentence.strip().lower() == "show me the frames":
358
+ st.warning(f"The correct frames are: {secret_event.upper()} and {secret_entity.upper()}")
359
+
360
+ # process hints
361
+ elif sentence.strip() == "HINT":
362
+ guesses_event = sorted(game_state["guesses_event"], key=lambda t: t[1], reverse=True)
363
+ guesses_entity = sorted(game_state["guesses_entity"], key=lambda t: t[1], reverse=True)
364
+ best_guess_event = guesses_event[0][0] if guesses_event else "Event"
365
+ best_guess_entity = guesses_entity[0][0] if guesses_entity else "Entity"
366
+
367
+ event_hint = make_hint(st.session_state["gensim_model"], secret_event, best_guess_event)
368
+ entity_hint = make_hint(st.session_state["gensim_model"], secret_entity, best_guess_entity)
369
+
370
+ if event_hint:
371
+ st.session_state["hints"][0] = format_hint_sentence(event_hint[1])
372
+ if entity_hint:
373
+ st.session_state["hints"][1] = format_hint_sentence(entity_hint[1])
374
+
375
+
376
+ else:
377
+ frames = analyze_sentence(sentence)
378
+ frames_and_info = get_frame_info(frames)
379
+ frame_feedback = get_frame_feedback(frames_and_info, st.session_state["gensim_model"], secret_event, secret_entity)
380
+
381
+ # update game state post analysis
382
+ game_state["num_guesses"] += 1
383
+ for fdb in frame_feedback:
384
+ if fdb["similarity_1"]:
385
+ guesses_event.add((fdb["frame"], fdb["similarity_1"], fdb["rank_1"]))
386
+ guesses_entity.add((fdb["frame"], fdb["similarity_2"], fdb["rank_2"]))
387
+
388
+ st.session_state["frame_feedback"] = frame_feedback
389
+ if secret_event in frames and secret_entity in frames:
390
+ st.session_state["game_over"] = True
391
+ st.session_state["guesses_to_win"] = game_state["num_guesses"]
392
+
393
+ def display_guess_status():
394
+ game_state = st.session_state["game_state"]
395
+ guesses_entity = sorted(game_state["guesses_entity"], key=lambda t: t[1], reverse=True)
396
+ guesses_event = sorted(game_state["guesses_event"], key=lambda t: t[1], reverse=True)
397
+
398
+ if guesses_event or guesses_entity:
399
+ st.header("Best guesses")
400
+
401
+ event_col, entity_col = st.columns(2)
402
+ if guesses_event:
403
+ with event_col:
404
+ st.subheader("Secret Story #1")
405
+ st.table(pd.DataFrame(guesses_event, columns=["Story", "Similarity", "Steps To Go"]))
406
+ if game_state["secret_event"] in [g for g, _, _ in guesses_event]:
407
+ st.info("Great, you guessed the Event story! In order to win, make a sentence containing both the secret stories.")
408
+ if guesses_entity:
409
+ with entity_col:
410
+ st.subheader("Secret Story #2")
411
+ st.table(pd.DataFrame(guesses_entity, columns=["Story", "Similarity", "Steps To Go"]))
412
+ if game_state["secret_entity"] in [g for g, _, _ in guesses_entity]:
413
+ st.info("Great, you guessed the Thing story! In order to win, make a sentence containing both the secret stories.")
414
+
415
+
416
+ def format_feedback(frame_feedback):
417
+ out = []
418
+ for fdb in frame_feedback:
419
+ out.append({
420
+ "Story": fdb["frame"],
421
+ "Similarity (Event)": f"{fdb['similarity_1']:.2f}",
422
+ "Similarity (Thing)": f"{fdb['similarity_2']:.2f}",
423
+ "Typical Context": fdb["typical_sentence"],
424
+ "Typical Words": " ".join(fdb["typical_words"])
425
+ })
426
+ return out
427
+
428
+
429
+ def display_introduction():
430
+ st.subheader("Why this game?")
431
+ st.markdown(
432
+ """
433
+ Words are not just words: behind every word, a _mini-story_ (also known as "frame") is hidden
434
+ that appears in our imagination when we hear the word. For example, when we hear the word
435
+ "talking" we can imagine a mini-story that involves several people who are interacting
436
+ with each other. Or, if we hear the word "cookie", we might think of someone eating a cookie.
437
+ """.strip())
438
+
439
+ st.subheader("How does it work?")
440
+ st.markdown(
441
+ "* In this game, there are two secret mini-stories, and it's your job to figure out which ones!"
442
+ "\n"
443
+ "* The first mini-story is about an _Event_ (something that happens in the world, like a thunderstorm, "
444
+ "people talking, someone eating pasta), and the other one is a _Thing_ (a concrete thing like a tree"
445
+ "or something abstract like 'love')."
446
+ "\n"
447
+ "* How to guess the stories? Well, just type a sentence, and we'll tell you which mini-stories are "
448
+ "hidden in the sentence. For each of the stories, we'll tell you how close they are to the secret ones."
449
+ "\n"
450
+ "* Once you type a sentence with both of the secret mini-stories, you win!"
451
+ )
452
+
453
+
454
+
455
+ def display_hints():
456
+ event_hint, entity_hint = st.session_state["hints"]
457
+ if event_hint or entity_hint:
458
+ st.header("Hints")
459
+ st.info("So you need some help? Here you get your hint sentences! Guess the hidden word, use it in a sentence, and we'll help you get one step closer.")
460
+
461
+ if event_hint:
462
+ st.markdown(f"**Event Hint**:\n>_{event_hint}_")
463
+ if entity_hint:
464
+ st.markdown(f"**Thing Hint**:\n>_{entity_hint}_")
465
+
466
+ def display_frame_feedback():
467
+ frame_feedback = st.session_state["frame_feedback"]
468
+ if frame_feedback:
469
+ st.header("Feedback")
470
+ st.text("Your sentence contains the following stories: ")
471
+ feedback_df = format_feedback(frame_feedback)
472
+ st.table(pd.DataFrame(feedback_df))
473
+
474
+
475
+ def run_game_st(debug=True):
476
+
477
+ if not st.session_state.get("initialized", False):
478
+
479
+ secret_event, secret_entity = choose_secret_frames()
480
+ gensim_m = gensim.models.word2vec.KeyedVectors.load_word2vec_format("data/frame_embeddings.w2v.txt")
481
+
482
+ game_state = {
483
+ "secret_event": secret_event,
484
+ "secret_entity": secret_entity,
485
+ "num_guesses": 0,
486
+ "guesses_event": set(),
487
+ "guesses_entity": set(),
488
+ }
489
+
490
+ st.session_state["initialized"] = True
491
+ st.session_state["show_introduction"] = False
492
+ st.session_state["game_over"] = False
493
+ st.session_state["guesses_to_win"] = -1
494
+ st.session_state["game_state"] = game_state
495
+ st.session_state["gensim_model"] = gensim_m
496
+ st.session_state["frame_feedback"] = None
497
+ st.session_state["hints"] = [None, None]
498
+
499
+ else:
500
+ gensim_m = st.session_state["gensim_model"]
501
+ game_state = st.session_state["game_state"]
502
+
503
+ secret_event, secret_entity = game_state["secret_event"], game_state["secret_entity"]
504
+
505
+ header = st.container()
506
+ with header:
507
+ st.title("FillmorLe")
508
+ st.checkbox("Show explanation?", key="show_introduction")
509
+ if st.session_state["show_introduction"]:
510
+ display_introduction()
511
+
512
+ st.header(f"Guess #{st.session_state['game_state']['num_guesses'] + 1}")
513
+ st.text_input("Enter a sentence or type 'HINT' if you're stuck", key="cur_sentence", on_change=play_turn)
514
+
515
+ if st.session_state["game_over"]:
516
+ st.success(f"You won in {st.session_state['guesses_to_win']}!")
517
+
518
+ display_hints()
519
+ display_frame_feedback()
520
+ display_guess_status()
521
+
522
+
523
+ if __name__ == "__main__":
524
+ run_game_st()
model.mod.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f5be5aeef50b2f4840317b8196c51186f9f138a853dc1eb2da980b1947ceb23
3
+ size 1795605184
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ allennlp>=2.0.0
2
+ allennlp-models>=2.0.0
3
+ transformers>=4.0.0 # Why is huggingface so unstable?
4
+ numpy
5
+ torch>=1.7.0,<1.8.0
6
+ tqdm
7
+ nltk
8
+ overrides
9
+ concrete
10
+ flask
11
+ scipy
12
+ requests
13
+ lxml
14
+ gensim
15
+ streamlit
16
+ https://github.com/explosion/spacy-models/releases/download/it_core_news_md-3.0.0/it_core_news_md-3.0.0-py3-none-any.whl
17
+ https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.0.0/en_core_web_md-3.0.0-py3-none-any.whl
18
+ https://github.com/explosion/spacy-models/releases/download/nl_core_news_md-3.0.0/nl_core_news_md-3.0.0-py3-none-any.whl
19
+ https://github.com/explosion/spacy-models/releases/download/xx_sent_ud_sm-3.0.0/xx_sent_ud_sm-3.0.0-py3-none-any.whl
setup.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+
4
+ setup(
5
+ name='sftp',
6
+ version='0.0.2',
7
+ author='Guanghui Qin',
8
+ packages=find_packages(),
9
+ )
sftp/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .data_reader import (
2
+ BetterDatasetReader, SRLDatasetReader
3
+ )
4
+ from .metrics import SRLMetric, BaseF, ExactMatch, FBetaMixMeasure
5
+ from .models import SpanModel
6
+ from .modules import (
7
+ MLPSpanTyping, SpanTyping, SpanFinder, BIOSpanFinder
8
+ )
9
+ from .predictor import SpanPredictor
10
+ from .utils import Span
sftp/data_reader/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .batch_sampler import MixSampler
2
+ from .better_reader import BetterDatasetReader
3
+ from .span_reader import SpanReader
4
+ from .srl_reader import SRLDatasetReader
5
+ from .concrete_srl import concrete_doc, concrete_doc_tokenized, collect_concrete_srl
6
+ from .concrete_reader import ConcreteDatasetReader
sftp/data_reader/batch_sampler/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .mix_sampler import MixSampler
sftp/data_reader/batch_sampler/mix_sampler.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ from typing import *
4
+
5
+ from allennlp.data.samplers.batch_sampler import BatchSampler
6
+ from allennlp.data.samplers.max_tokens_batch_sampler import MaxTokensBatchSampler
7
+ from torch.utils import data
8
+
9
+ logger = logging.getLogger('mix_sampler')
10
+
11
+
12
+ @BatchSampler.register('mix_sampler')
13
+ class MixSampler(MaxTokensBatchSampler):
14
+ def __init__(
15
+ self,
16
+ max_tokens: int,
17
+ sorting_keys: List[str] = None,
18
+ padding_noise: float = 0.1,
19
+ sampling_ratios: Optional[Dict[str, float]] = None,
20
+ ):
21
+ super().__init__(max_tokens, sorting_keys, padding_noise)
22
+
23
+ self.sampling_ratios = sampling_ratios or dict()
24
+
25
+ def __iter__(self):
26
+ indices, lengths = self._argsort_by_padding(self.data_source)
27
+
28
+ original_num = len(indices)
29
+ instance_types = [
30
+ ins.fields['meta'].metadata.get('type', 'default') if 'meta' in ins.fields else 'default'
31
+ for ins in self.data_source
32
+ ]
33
+ instance_thresholds = [
34
+ self.sampling_ratios[ins_type] if ins_type in self.sampling_ratios else 1.0 for ins_type in instance_types
35
+ ]
36
+ for idx, threshold in enumerate(instance_thresholds):
37
+ if random.random() > threshold:
38
+ # Reject
39
+ list_idx = indices.index(idx)
40
+ del indices[list_idx], lengths[list_idx]
41
+ if original_num != len(indices):
42
+ logger.info(f'#instances reduced from {original_num} to {len(indices)}.')
43
+
44
+ max_lengths = [max(length) for length in lengths]
45
+ group_iterator = self._lazy_groups_of_max_size(indices, max_lengths)
46
+
47
+ batches = [list(group) for group in group_iterator]
48
+ random.shuffle(batches)
49
+ for batch in batches:
50
+ yield batch
sftp/data_reader/better_reader.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from collections import defaultdict, namedtuple
5
+ from typing import *
6
+
7
+ from allennlp.data.dataset_readers.dataset_reader import DatasetReader
8
+ from allennlp.data.instance import Instance
9
+
10
+ from .span_reader import SpanReader
11
+ from ..utils import Span
12
+
13
+ # logging.basicConfig(level=logging.DEBUG)
14
+
15
+ # for v in logging.Logger.manager.loggerDict.values():
16
+ # v.disabled = True
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ SpanTuple = namedtuple('Span', ['start', 'end'])
21
+
22
+
23
+ @DatasetReader.register('better')
24
+ class BetterDatasetReader(SpanReader):
25
+ def __init__(
26
+ self,
27
+ eval_type,
28
+ consolidation_strategy='first',
29
+ span_set_type='single',
30
+ max_argument_ss_size=1,
31
+ use_ref_events=False,
32
+ **extra
33
+ ):
34
+ super().__init__(**extra)
35
+ self.eval_type = eval_type
36
+ assert self.eval_type in ['abstract', 'basic']
37
+
38
+ self.consolidation_strategy = consolidation_strategy
39
+ self.unitary_spans = span_set_type == 'single'
40
+ # event anchors are always singleton spans
41
+ self.max_arg_spans = max_argument_ss_size
42
+ self.use_ref_events = use_ref_events
43
+
44
+ self.n_overlap_arg = 0
45
+ self.n_overlap_trigger = 0
46
+ self.n_skip = 0
47
+ self.n_too_long = 0
48
+
49
+ @staticmethod
50
+ def post_process_basic_span(predicted_span, basic_entry):
51
+ # Convert token offsets back to characters, also get the text spans as a sanity check
52
+
53
+ # !!!!!
54
+ # SF outputs inclusive idxs
55
+ # char offsets are inc-exc
56
+ # token offsets are inc-inc
57
+ # !!!!!
58
+
59
+ start_idx = predicted_span['start_idx'] # inc
60
+ end_idx = predicted_span['end_idx'] # inc
61
+
62
+ char_start_idx = basic_entry['tok2char'][predicted_span['start_idx']][0] # inc
63
+ char_end_idx = basic_entry['tok2char'][predicted_span['end_idx']][-1] + 1 # exc
64
+
65
+ span_text = basic_entry['segment-text'][char_start_idx:char_end_idx] # inc exc
66
+ span_text_tok = basic_entry['segment-text-tok'][start_idx:end_idx + 1] # inc exc
67
+
68
+ span = {'string': span_text,
69
+ 'start': char_start_idx,
70
+ 'end': char_end_idx,
71
+ 'start-token': start_idx,
72
+ 'end-token': end_idx,
73
+ 'string-tok': span_text_tok,
74
+ 'label': predicted_span['label'],
75
+ 'predicted': True}
76
+ return span
77
+
78
+ @staticmethod
79
+ def _get_shortest_span(spans):
80
+ # shortest_span_length = float('inf')
81
+ # shortest_span = None
82
+ # for span in spans:
83
+ # span_tokens = span['string-tok']
84
+ # span_length = len(span_tokens)
85
+ # if span_length < shortest_span_length:
86
+ # shortest_span_length = span_length
87
+ # shortest_span = span
88
+
89
+ # return shortest_span
90
+ return [s[-1] for s in sorted([(len(span['string']), ix, span) for ix, span in enumerate(spans)])]
91
+
92
+ @staticmethod
93
+ def _get_first_span(spans):
94
+ spans = [(span['start'], -len(span['string']), ix, span) for ix, span in enumerate(spans)]
95
+ try:
96
+ return [s[-1] for s in sorted(spans)]
97
+ except:
98
+ breakpoint()
99
+
100
+ @staticmethod
101
+ def _get_longest_span(spans):
102
+ return [s[-1] for s in sorted([(len(span['string']), ix, span) for ix, span in enumerate(spans)], reverse=True)]
103
+
104
+ @staticmethod
105
+ def _subfinder(text, pattern):
106
+ # https://stackoverflow.com/a/12576755
107
+ matches = []
108
+ pattern_length = len(pattern)
109
+ for i, token in enumerate(text):
110
+ try:
111
+ if token == pattern[0] and text[i:i + pattern_length] == pattern:
112
+ matches.append(SpanTuple(start=i, end=i + pattern_length - 1)) # inclusive boundaries
113
+ except:
114
+ continue
115
+ return matches
116
+
117
+ def consolidate_span_set(self, spans):
118
+ if self.consolidation_strategy == 'first':
119
+ spans = BetterDatasetReader._get_first_span(spans)
120
+ elif self.consolidation_strategy == 'shortest':
121
+ spans = BetterDatasetReader._get_shortest_span(spans)
122
+ elif self.consolidation_strategy == 'longest':
123
+ spans = BetterDatasetReader._get_longest_span(spans)
124
+ else:
125
+ raise NotImplementedError(f"{self.consolidation_strategy} does not exist")
126
+
127
+ if self.unitary_spans:
128
+ spans = [spans[0]]
129
+ else:
130
+ spans = spans[:self.max_arg_spans]
131
+
132
+ # TODO add some sanity checks here
133
+
134
+ return spans
135
+
136
+ def get_mention_spans(self, text: List[str], span_sets: Dict):
137
+ mention_spans = defaultdict(list)
138
+ for span_set_id in span_sets.keys():
139
+ spans = span_sets[span_set_id]['spans']
140
+ # span = BetterDatasetReader._get_shortest_span(spans)
141
+ # span = BetterDatasetReader._get_earliest_span(spans)
142
+ consolidated_spans = self.consolidate_span_set(spans)
143
+ # if len(spans) > 1:
144
+ # logging.info(f"Truncated a spanset from {len(spans)} spans to 1")
145
+
146
+ if self.eval_type == 'abstract':
147
+ span = consolidated_spans[0]
148
+ span_tokens = span['string-tok']
149
+
150
+ span_indices = BetterDatasetReader._subfinder(text=text, pattern=span_tokens)
151
+
152
+ if len(span_indices) > 1:
153
+ pass
154
+
155
+ if len(span_indices) == 0:
156
+ continue
157
+
158
+ mention_spans[span_set_id] = span_indices[0]
159
+ else:
160
+ # in basic, we already have token offsets in the right form
161
+
162
+ # if not span['string-tok'] == text[span['start-token']:span['end-token'] + 1]:
163
+ # print(span, text[span['start-token']:span['end-token'] + 1])
164
+
165
+ # we should use these token offsets only!
166
+ for span in consolidated_spans:
167
+ mention_spans[span_set_id].append(SpanTuple(start=span['start-token'], end=span['end-token']))
168
+
169
+ return mention_spans
170
+
171
+ def _read_single_file(self, file_path):
172
+ with open(file_path) as fp:
173
+ json_content = json.load(fp)
174
+ if 'entries' in json_content:
175
+ for doc_name, entry in json_content['entries'].items():
176
+ instance = self.text_to_instance(entry, 'train' in file_path)
177
+ yield instance
178
+ else: # TODO why is this split in 2 cases?
179
+ for doc_name, entry in json_content.items():
180
+ instance = self.text_to_instance(entry, True)
181
+ yield instance
182
+
183
+ logger.warning(f'{self.n_overlap_arg} overlapped args detected!')
184
+ logger.warning(f'{self.n_overlap_trigger} overlapped triggers detected!')
185
+ logger.warning(f'{self.n_skip} skipped detected!')
186
+ logger.warning(f'{self.n_too_long} were skipped because they are too long!')
187
+ self.n_overlap_arg = self.n_skip = self.n_too_long = self.n_overlap_trigger = 0
188
+
189
+ def _read(self, file_path: str) -> Iterable[Instance]:
190
+
191
+ if os.path.isdir(file_path):
192
+ for fn in os.listdir(file_path):
193
+ if not fn.endswith('.json'):
194
+ logger.info(f'Skipping {fn}')
195
+ continue
196
+ logger.info(f'Loading from {fn}')
197
+ yield from self._read_single_file(os.path.join(file_path, fn))
198
+ else:
199
+ yield from self._read_single_file(file_path)
200
+
201
+ def text_to_instance(self, entry, is_training=False):
202
+ word_tokens = entry['segment-text-tok']
203
+
204
+ # span sets have been trimmed to the earliest span mention
205
+ spans = self.get_mention_spans(
206
+ word_tokens, entry['annotation-sets'][f'{self.eval_type}-events']['span-sets']
207
+ )
208
+
209
+ # idx of every token that is a part of an event trigger/anchor span
210
+ all_trigger_idxs = set()
211
+
212
+ # actual inputs to the model
213
+ input_spans = []
214
+
215
+ self._local_child_overlap = 0
216
+ self._local_child_total = 0
217
+
218
+ better_events = entry['annotation-sets'][f'{self.eval_type}-events']['events']
219
+
220
+ skipped_events = set()
221
+ # check for events that overlap other event's anchors, skip them later
222
+ for event_id, event in better_events.items():
223
+ assert event['anchors'] in spans
224
+
225
+ # take the first consolidated span for anchors
226
+ anchor_start, anchor_end = spans[event['anchors']][0]
227
+
228
+ if any(ix in all_trigger_idxs for ix in range(anchor_start, anchor_end + 1)):
229
+ logger.warning(
230
+ f"Skipped {event_id} with anchor span {event['anchors']}, overlaps a previously found event trigger/anchor")
231
+ self.n_overlap_trigger += 1
232
+ skipped_events.add(event_id)
233
+ continue
234
+
235
+ all_trigger_idxs.update(range(anchor_start, anchor_end + 1)) # record the trigger
236
+
237
+ for event_id, event in better_events.items():
238
+ if event_id in skipped_events:
239
+ continue
240
+
241
+ # arguments for just this event
242
+ local_arg_idxs = set()
243
+ # take the first consolidated span for anchors
244
+ anchor_start, anchor_end = spans[event['anchors']][0]
245
+
246
+ event_span = Span(anchor_start, anchor_end, event['event-type'], True)
247
+ input_spans.append(event_span)
248
+
249
+ def add_a_child(span_id, label):
250
+ # TODO this is a bad way to do this
251
+ assert span_id in spans
252
+ for child_span in spans[span_id]:
253
+ self._local_child_total += 1
254
+ arg_start, arg_end = child_span
255
+
256
+ if any(ix in local_arg_idxs for ix in range(arg_start, arg_end + 1)):
257
+ # logger.warn(f"Skipped argument {span_id}, overlaps a previously found argument")
258
+ # print(entry['annotation-sets'][f'{self.eval_type}-events']['span-sets'][span_id])
259
+ self.n_overlap_arg += 1
260
+ self._local_child_overlap += 1
261
+ continue
262
+
263
+ local_arg_idxs.update(range(arg_start, arg_end + 1))
264
+ event_span.add_child(Span(arg_start, arg_end, label, False))
265
+
266
+ for agent in event['agents']:
267
+ add_a_child(agent, 'agent')
268
+ for patient in event['patients']:
269
+ add_a_child(patient, 'patient')
270
+
271
+ if self.use_ref_events:
272
+ for ref_event in event['ref-events']:
273
+ if ref_event in skipped_events:
274
+ continue
275
+ ref_event_anchor_id = better_events[ref_event]['anchors']
276
+ add_a_child(ref_event_anchor_id, 'ref-event')
277
+
278
+ # if len(event['ref-events']) > 0:
279
+ # breakpoint()
280
+
281
+ fields = self.prepare_inputs(word_tokens, spans=input_spans)
282
+ if self._local_child_overlap > 0:
283
+ logging.warning(
284
+ f"Skipped {self._local_child_overlap} / {self._local_child_total} argument spans due to overlaps")
285
+ return Instance(fields)
286
+
sftp/data_reader/concrete_reader.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections import defaultdict
3
+ from typing import *
4
+ import os
5
+
6
+ from allennlp.data.dataset_readers.dataset_reader import DatasetReader
7
+ from allennlp.data.instance import Instance
8
+ from concrete import SituationMention
9
+ from concrete.util import CommunicationReader
10
+
11
+ from .span_reader import SpanReader
12
+ from .srl_reader import SRLDatasetReader
13
+ from .concrete_srl import collect_concrete_srl
14
+ from ..utils import Span, BIOSmoothing
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @DatasetReader.register('concrete')
20
+ class ConcreteDatasetReader(SRLDatasetReader):
21
+ def __init__(
22
+ self,
23
+ event_only: bool = False,
24
+ event_smoothing_factor: float = 0.,
25
+ arg_smoothing_factor: float = 0.,
26
+ **extra
27
+ ):
28
+ super().__init__(**extra)
29
+ self.event_only = event_only
30
+ self.event_only = event_only
31
+ self.event_smooth_factor = event_smoothing_factor
32
+ self.arg_smooth_factor = arg_smoothing_factor
33
+
34
+ def _read(self, file_path: str) -> Iterable[Instance]:
35
+ if os.path.isdir(file_path):
36
+ for fn in os.listdir(file_path):
37
+ yield from self._read(os.path.join(file_path, fn))
38
+ all_files = CommunicationReader(file_path)
39
+ for comm, fn in all_files:
40
+ sentences = collect_concrete_srl(comm)
41
+ for tokens, vr in sentences:
42
+ yield self.text_to_instance(tokens, vr)
43
+ logger.warning(f'{self.n_span_removed} spans were removed')
44
+ self.n_span_removed = 0
sftp/data_reader/concrete_srl.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import time
2
+ from typing import *
3
+ from collections import defaultdict
4
+
5
+ from concrete import (
6
+ Token, TokenList, TextSpan, MentionArgument, SituationMentionSet, SituationMention, TokenRefSequence,
7
+ Communication, EntityMention, EntityMentionSet, Entity, EntitySet, AnnotationMetadata, Sentence
8
+ )
9
+ from concrete.util import create_comm, AnalyticUUIDGeneratorFactory
10
+ from concrete.validate import validate_communication
11
+
12
+ from ..utils import Span
13
+
14
+
15
+ def _process_sentence(sent, comm_sent, aug, char_idx_offset: int):
16
+ token_list = list()
17
+ for tok_idx, (start_idx, end_idx) in enumerate(sent['tokenization']):
18
+ token_list.append(Token(
19
+ tokenIndex=tok_idx,
20
+ text=sent['sentence'][start_idx:end_idx + 1],
21
+ textSpan=TextSpan(
22
+ start=start_idx + char_idx_offset,
23
+ ending=end_idx + char_idx_offset + 1
24
+ ),
25
+ ))
26
+ comm_sent.tokenization.tokenList = TokenList(tokenList=token_list)
27
+
28
+ sm_list, em_dict, entity_list = list(), dict(), list()
29
+
30
+ annotation = sent['annotations'] if isinstance(sent['annotations'], Span) else Span.from_json(sent['annotations'])
31
+ for event in annotation:
32
+ char_start_idx = sent['tokenization'][event.start_idx][0]
33
+ char_end_idx = sent['tokenization'][event.end_idx][1]
34
+ sm = SituationMention(
35
+ uuid=next(aug),
36
+ text=sent['sentence'][char_start_idx: char_end_idx + 1],
37
+ situationType='EVENT',
38
+ situationKind=event.label,
39
+ argumentList=list(),
40
+ tokens=TokenRefSequence(
41
+ tokenIndexList=list(range(event.start_idx, event.end_idx + 1)),
42
+ tokenizationId=comm_sent.tokenization.uuid
43
+ ),
44
+ )
45
+
46
+ for arg in event:
47
+ em = em_dict.get((arg.start_idx, arg.end_idx + 1))
48
+ if em is None:
49
+ char_start_idx = sent['tokenization'][arg.start_idx][0]
50
+ char_end_idx = sent['tokenization'][arg.end_idx][1]
51
+ em = EntityMention(next(aug), TokenRefSequence(
52
+ tokenIndexList=list(range(arg.start_idx, arg.end_idx + 1)),
53
+ tokenizationId=comm_sent.tokenization.uuid,
54
+ ), text=sent['sentence'][char_start_idx: char_end_idx + 1])
55
+ entity_list.append(Entity(next(aug), id=em.text, mentionIdList=[em.uuid]))
56
+ em_dict[(arg.start_idx, arg.end_idx + 1)] = em
57
+ sm.argumentList.append(MentionArgument(
58
+ role=arg.label,
59
+ entityMentionId=em.uuid,
60
+ ))
61
+
62
+ sm_list.append(sm)
63
+
64
+ return sm_list, list(em_dict.values()), entity_list
65
+
66
+
67
+ def concrete_doc(
68
+ sentences: List[Dict[str, Any]],
69
+ doc_name: str = 'document',
70
+ ) -> Communication:
71
+ """
72
+ Data format: A list of sentences. Each sentence should be a dict of the following format:
73
+ {
74
+ "sentence": String.
75
+ "tokenization": A list of Tuple[int, int] for start and end indices. Both inclusive.
76
+ "annotations": A list of event dict, or Span object.
77
+ }
78
+ If it is dict, its format should be:
79
+
80
+ Each event should be a dict of the following format:
81
+ {
82
+ "span": [start_idx, end_idx]: Integer. Both inclusive.
83
+ "label": String.
84
+ "children": A list of arguments.
85
+ }
86
+ Each argument should be a dict of the following format:
87
+ {
88
+ "span": [start_idx, end_idx]: Integer. Both inclusive.
89
+ "label": String.
90
+ }
91
+
92
+ Note the "indices" above all refer to the indices of tokens, instead of characters.
93
+ """
94
+ comm = create_comm(
95
+ doc_name,
96
+ '\n'.join([sent['sentence'] for sent in sentences]),
97
+ )
98
+ aug = AnalyticUUIDGeneratorFactory(comm).create()
99
+ situation_mention_set = SituationMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list())
100
+ comm.situationMentionSetList = [situation_mention_set]
101
+ entity_mention_set = EntityMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list())
102
+ comm.entityMentionSetList = [entity_mention_set]
103
+ entity_set = EntitySet(
104
+ next(aug), AnnotationMetadata('O(0) Coref Paser.', time()), list(), None, entity_mention_set.uuid
105
+ )
106
+ comm.entitySetList = [entity_set]
107
+ assert len(sentences) == len(comm.sectionList[0].sentenceList)
108
+
109
+ char_idx_offset = 0
110
+ for sent, comm_sent in zip(sentences, comm.sectionList[0].sentenceList):
111
+ sm_list, em_list, entity_list = _process_sentence(sent, comm_sent, aug, char_idx_offset)
112
+ entity_set.entityList.extend(entity_list)
113
+ situation_mention_set.mentionList.extend(sm_list)
114
+ entity_mention_set.mentionList.extend(em_list)
115
+ char_idx_offset += len(sent['sentence']) + 1
116
+
117
+ validate_communication(comm)
118
+ return comm
119
+
120
+
121
+ def concrete_doc_tokenized(
122
+ sentences: List[List[str]],
123
+ spans: List[Span],
124
+ doc_name: str = "document",
125
+ ):
126
+ """
127
+ Similar to concrete_doc, but with tokenized words and spans.
128
+ """
129
+ inputs = list()
130
+ for sent, vr in zip(sentences, spans):
131
+ cur_start = 0
132
+ tokenization = list()
133
+ for token in sent:
134
+ tokenization.append((cur_start, cur_start + len(token) - 1))
135
+ cur_start += len(token) + 1
136
+ inputs.append({
137
+ "sentence": " ".join(sent),
138
+ "tokenization": tokenization,
139
+ "annotations": vr
140
+ })
141
+ return concrete_doc(inputs, doc_name)
142
+
143
+
144
+ def collect_concrete_srl(comm: Communication) -> List[Tuple[List[str], Span]]:
145
+ # Mapping from <sentence uuid> to [<ConcreteSentence>, <Associated situation mentions>]
146
+ sentences = defaultdict(lambda: [None, list()])
147
+ for sec in comm.sectionList:
148
+ for sen in sec.sentenceList:
149
+ sentences[sen.uuid.uuidString][0] = sen
150
+ # Assume there's only ONE situation mention set
151
+ assert len(comm.situationMentionSetList) == 1
152
+ # Assign each situation mention to the corresponding sentence
153
+ for men in comm.situationMentionSetList[0].mentionList:
154
+ if men.tokens is None: continue # For ACE relations
155
+ sentences[men.tokens.tokenization.sentence.uuid.uuidString][1].append(men)
156
+ ret = list()
157
+ for sen, mention_list in sentences.values():
158
+ tokens = [t.text for t in sen.tokenization.tokenList.tokenList]
159
+ spans = list()
160
+ for mention in mention_list:
161
+ mention_tokens = sorted(mention.tokens.tokenIndexList)
162
+ event = Span(mention_tokens[0], mention_tokens[-1], mention.situationKind, True)
163
+ for men_arg in mention.argumentList:
164
+ arg_tokens = sorted(men_arg.entityMention.tokens.tokenIndexList)
165
+ event.add_child(Span(arg_tokens[0], arg_tokens[-1], men_arg.role, False))
166
+ spans.append(event)
167
+ vr = Span.virtual_root(spans)
168
+ ret.append((tokens, vr))
169
+ return ret
sftp/data_reader/span_reader.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from abc import ABC
3
+ from typing import *
4
+
5
+ import numpy as np
6
+ from allennlp.common.util import END_SYMBOL
7
+ from allennlp.data.dataset_readers.dataset_reader import DatasetReader
8
+ from allennlp.data.dataset_readers.dataset_utils.span_utils import bio_tags_to_spans
9
+ from allennlp.data.fields import *
10
+ from allennlp.data.token_indexers import PretrainedTransformerIndexer
11
+ from allennlp.data.tokenizers import PretrainedTransformerTokenizer, Token
12
+
13
+ from ..utils import Span, BIOSmoothing, apply_bio_smoothing
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @DatasetReader.register('span')
19
+ class SpanReader(DatasetReader, ABC):
20
+ def __init__(
21
+ self,
22
+ pretrained_model: str,
23
+ max_length: int = 512,
24
+ ignore_label: bool = False,
25
+ debug: bool = False,
26
+ **extras
27
+ ) -> None:
28
+ """
29
+ :param pretrained_model: The name of the pretrained model. E.g. xlm-roberta-large
30
+ :param max_length: Sequences longer than this limit will be truncated.
31
+ :param ignore_label: If True, label on spans will be anonymized.
32
+ :param debug: True to turn on debugging mode.
33
+ :param span_proposals: Needed for "enumeration" scheme, but not needed for "BIO".
34
+ If True, it will try to enumerate candidate spans in the sentence, which will then be fed into
35
+ a binary classifier (EnumSpanFinder).
36
+ Note: It might take time to propose spans. And better to use SpacyTokenizer if you want to call
37
+ constituency parser or dependency parser.
38
+ :param maximum_negative_spans: Necessary for EnumSpanFinder.
39
+ :param extras: Args to DatasetReader.
40
+ """
41
+ super().__init__(**extras)
42
+ self.word_indexer = {
43
+ 'pieces': PretrainedTransformerIndexer(pretrained_model, namespace='pieces')
44
+ }
45
+
46
+ self._pretrained_model_name = pretrained_model
47
+ self.debug = debug
48
+ self.ignore_label = ignore_label
49
+
50
+ self._pretrained_tokenizer = PretrainedTransformerTokenizer(pretrained_model)
51
+ self.max_length = max_length
52
+ self.n_span_removed = 0
53
+
54
+ def retokenize(
55
+ self, sentence: List[str], truncate: bool = True
56
+ ) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
57
+ pieces, offsets = self._pretrained_tokenizer.intra_word_tokenize(sentence)
58
+ pieces = list(map(str, pieces))
59
+ if truncate:
60
+ pieces = pieces[:self.max_length]
61
+ pieces[-1] = END_SYMBOL
62
+ return pieces, offsets
63
+
64
+ def prepare_inputs(
65
+ self,
66
+ sentence: List[str],
67
+ spans: Optional[Union[List[Span], Span]] = None,
68
+ truncate: bool = True,
69
+ label_type: str = 'string',
70
+ ) -> Dict[str, Field]:
71
+ """
72
+ Prepare inputs and auxiliary variables for span model.
73
+ :param sentence: A list of tokens. Do not pass in any special tokens, like BOS or EOS.
74
+ Necessary for both training and testing.
75
+ :param spans: Optional. For training, spans passed in will be considered as positive examples; the spans
76
+ that are automatically proposed and not in the positive set will be considered as negative examples.
77
+ Necessary for training.
78
+ :param truncate: If True, sequence will be truncated if it's longer than `self.max_training_length`
79
+ :param label_type: One of [string, list].
80
+
81
+ :return: Dict of AllenNLP fields. For detailed of explanation of every field, refer to the comments
82
+ below. For the shape of every field, check the module doc.
83
+ Fields list:
84
+ - words
85
+ - span_labels
86
+ - span_boundary
87
+ - parent_indices
88
+ - parent_mask
89
+ - bio_seqs
90
+ - raw_sentence
91
+ - raw_spans
92
+ - proposed_spans
93
+ """
94
+ fields = dict()
95
+
96
+ pieces, offsets = self.retokenize(sentence, truncate)
97
+ fields['tokens'] = TextField(list(map(Token, pieces)), self.word_indexer)
98
+ raw_inputs = {'sentence': sentence, "pieces": pieces, 'offsets': offsets}
99
+ fields['raw_inputs'] = MetadataField(raw_inputs)
100
+
101
+ if spans is None:
102
+ return fields
103
+
104
+ vr = spans if isinstance(spans, Span) else Span.virtual_root(spans)
105
+ self.n_span_removed = vr.remove_overlapping()
106
+ raw_inputs['spans'] = vr
107
+
108
+ vr = vr.re_index(offsets)
109
+ if truncate:
110
+ vr.truncate(self.max_length)
111
+ if self.ignore_label:
112
+ vr.ignore_labels()
113
+
114
+ # (start_idx, end_idx) pairs. Left and right inclusive.
115
+ # The first span is the Virtual Root node. Shape [span, 2]
116
+ span_boundary = list()
117
+ # label on span. Shape [span]
118
+ span_labels = list()
119
+ # parent idx (span indexing space). Shape [span]
120
+ span_parent_indices = list()
121
+ # True for parents. Shape [span]
122
+ parent_mask = [False] * vr.n_nodes
123
+ # Key: parent idx (span indexing space). Value: child span idx
124
+ flatten_spans = list(vr.bfs())
125
+ for span_idx, span in enumerate(vr.bfs()):
126
+ if span.is_parent:
127
+ parent_mask[span_idx] = True
128
+ # 0 is the virtual root
129
+ parent_idx = flatten_spans.index(span.parent) if span.parent else 0
130
+ span_parent_indices.append(parent_idx)
131
+ span_boundary.append(span.boundary)
132
+ span_labels.append(span.label)
133
+
134
+ bio_tag_list: List[List[str]] = list()
135
+ bio_configs: List[List[BIOSmoothing]] = list()
136
+ # Shape: [#parent, #token, 3]
137
+ bio_seqs: List[np.ndarray] = list()
138
+ # Parent index for every BIO seq
139
+ for parent_idx, parent in filter(lambda node: node[1].is_parent, enumerate(flatten_spans)):
140
+ bio_tags = ['O'] * len(pieces)
141
+ bio_tag_list.append(bio_tags)
142
+ bio_smooth: List[BIOSmoothing] = [parent.child_smooth.clone() for _ in pieces]
143
+ bio_configs.append(bio_smooth)
144
+ for child in parent:
145
+ assert all(bio_tags[bio_idx] == 'O' for bio_idx in range(child.start_idx, child.end_idx + 1))
146
+ if child.smooth_weight is not None:
147
+ for i in range(child.start_idx, child.end_idx+1):
148
+ bio_smooth[i].weight = child.smooth_weight
149
+ bio_tags[child.start_idx] = 'B'
150
+ for word_idx in range(child.start_idx + 1, child.end_idx + 1):
151
+ bio_tags[word_idx] = 'I'
152
+ bio_seqs.append(apply_bio_smoothing(bio_smooth, bio_tags))
153
+
154
+ fields['span_boundary'] = ArrayField(
155
+ np.array(span_boundary), padding_value=0, dtype=np.int
156
+ )
157
+ fields['parent_indices'] = ArrayField(np.array(span_parent_indices), 0, np.int)
158
+ if label_type == 'string':
159
+ fields['span_labels'] = ListField([LabelField(label, 'span_label') for label in span_labels])
160
+ elif label_type == 'list':
161
+ fields['span_labels'] = ArrayField(np.array(span_labels))
162
+ else:
163
+ raise NotImplementedError
164
+ fields['parent_mask'] = ArrayField(np.array(parent_mask), False, np.bool)
165
+ fields['bio_seqs'] = ArrayField(np.stack(bio_seqs))
166
+
167
+ self._sanity_check(
168
+ flatten_spans, pieces, bio_tag_list, parent_mask, span_boundary, span_labels, span_parent_indices
169
+ )
170
+
171
+ return fields
172
+
173
+ @staticmethod
174
+ def _sanity_check(
175
+ flatten_spans, words, bio_tag_list, parent_mask, span_boundary, span_labels, parent_indices, verbose=False
176
+ ):
177
+ # For debugging use.
178
+ assert len(parent_mask) == len(span_boundary) == len(span_labels) == len(parent_indices)
179
+ for (parent_idx, parent_span), bio_tags in zip(
180
+ filter(lambda x: x[1].is_parent, enumerate(flatten_spans)), bio_tag_list
181
+ ):
182
+ assert parent_mask[parent_idx]
183
+ parent_s, parent_e = span_boundary[parent_idx]
184
+ if verbose:
185
+ print('Parent: ', span_labels[parent_idx], 'Text: ', ' '.join(words[parent_s:parent_e+1]))
186
+ print(f'It contains {len(parent_span)} children.')
187
+ for child in parent_span:
188
+ child_idx = flatten_spans.index(child)
189
+ assert parent_indices[child_idx] == flatten_spans.index(parent_span)
190
+ if verbose:
191
+ child_s, child_e = span_boundary[child_idx]
192
+ print(' ', span_labels[child_idx], 'Text', words[child_s:child_e+1])
193
+
194
+ if verbose:
195
+ print(f'Child derived from BIO tags:')
196
+ for _, (start, end) in bio_tags_to_spans(bio_tags):
197
+ print(words[start:end+1])
sftp/data_reader/srl_reader.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import random
4
+ from typing import *
5
+
6
+ import numpy as np
7
+ from allennlp.data.dataset_readers.dataset_reader import DatasetReader
8
+ from allennlp.data.fields import MetadataField
9
+ from allennlp.data.instance import Instance
10
+
11
+ from .span_reader import SpanReader
12
+ from ..utils import Span, VIRTUAL_ROOT, BIOSmoothing
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @DatasetReader.register('semantic_role_labeling')
18
+ class SRLDatasetReader(SpanReader):
19
+ def __init__(
20
+ self,
21
+ min_negative: int = 5,
22
+ negative_ratio: float = 1.,
23
+ event_only: bool = False,
24
+ event_smoothing_factor: float = 0.,
25
+ arg_smoothing_factor: float = 0.,
26
+ # For Ontology Mapping
27
+ ontology_mapping_path: Optional[str] = None,
28
+ min_weight: float = 1e-2,
29
+ max_weight: float = 1.0,
30
+ **extra
31
+ ):
32
+ super().__init__(**extra)
33
+ self.min_negative = min_negative
34
+ self.negative_ratio = negative_ratio
35
+ self.event_only = event_only
36
+ self.event_smooth_factor = event_smoothing_factor
37
+ self.arg_smooth_factor = arg_smoothing_factor
38
+ self.ontology_mapping = None
39
+ if ontology_mapping_path is not None:
40
+ self.ontology_mapping = json.load(open(ontology_mapping_path))
41
+ for k1 in ['event', 'argument']:
42
+ for k2, weights in self.ontology_mapping['mapping'][k1].items():
43
+ weights = np.array(weights)
44
+ weights[weights < min_weight] = 0.0
45
+ weights[weights > max_weight] = max_weight
46
+ self.ontology_mapping['mapping'][k1][k2] = weights
47
+ self.ontology_mapping['mapping'][k1] = {
48
+ k2: weights for k2, weights in self.ontology_mapping['mapping'][k1].items() if weights.sum() > 1e-5
49
+ }
50
+ vr_label = [0.] * len(self.ontology_mapping['target']['label'])
51
+ vr_label[self.ontology_mapping['target']['label'].index(VIRTUAL_ROOT)] = 1.0
52
+ self.ontology_mapping['mapping']['event'][VIRTUAL_ROOT] = np.array(vr_label)
53
+
54
+ def _read(self, file_path: str) -> Iterable[Instance]:
55
+ all_lines = list(map(json.loads, open(file_path).readlines()))
56
+ if self.debug:
57
+ random.seed(1); random.shuffle(all_lines)
58
+ for line in all_lines:
59
+ ins = self.text_to_instance(**line)
60
+ if ins is not None:
61
+ yield ins
62
+ if self.n_span_removed > 0:
63
+ logger.warning(f'{self.n_span_removed} spans are removed.')
64
+ self.n_span_removed = 0
65
+
66
+ def apply_ontology_mapping(self, vr):
67
+ new_events = list()
68
+ event_map, arg_map = self.ontology_mapping['mapping']['event'], self.ontology_mapping['mapping']['argument']
69
+ for event in vr:
70
+ if event.label not in event_map: continue
71
+ event.child_smooth.weight = event.smooth_weight = event_map[event.label].sum()
72
+ event = event.map_ontology(event_map, False, False)
73
+ new_events.append(event)
74
+ new_children = list()
75
+ for child in event:
76
+ if child.label not in arg_map: continue
77
+ child.child_smooth.weight = child.smooth_weight = arg_map[child.label].sum()
78
+ child = child.map_ontology(arg_map, False, False)
79
+ new_children.append(child)
80
+ event.remove_child()
81
+ for child in new_children: event.add_child(child)
82
+ new_vr = Span.virtual_root(new_events)
83
+ # For Virtual Root itself.
84
+ new_vr.map_ontology(self.ontology_mapping['mapping']['event'], True, False)
85
+ return new_vr
86
+
87
+ def text_to_instance(self, tokens, annotations=None, meta=None) -> Optional[Instance]:
88
+ meta = meta or {'fully_annotated': True}
89
+ meta['fully_annotated'] = meta.get('fully_annotated', True)
90
+ vr = None
91
+ if annotations is not None:
92
+ vr = annotations if isinstance(annotations, Span) else Span.from_json(annotations)
93
+ vr = self.apply_ontology_mapping(vr) if self.ontology_mapping is not None else vr
94
+ # if len(vr) == 0: return # Ignore sentence with empty annotation
95
+ if self.event_smooth_factor != 0.0:
96
+ vr.child_smooth = BIOSmoothing(o_smooth=self.event_smooth_factor if meta['fully_annotated'] else -1)
97
+ if self.arg_smooth_factor != 0.0:
98
+ for event in vr:
99
+ event.child_smooth = BIOSmoothing(o_smooth=self.arg_smooth_factor)
100
+ if self.event_only:
101
+ for event in vr:
102
+ event.remove_child()
103
+ event.is_parent = False
104
+
105
+ fields = self.prepare_inputs(tokens, vr, True, 'string' if self.ontology_mapping is None else 'list')
106
+ fields['meta'] = MetadataField(meta)
107
+ return Instance(fields)
sftp/metrics/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from sftp.metrics.base_f import BaseF
2
+ from sftp.metrics.exact_match import ExactMatch
3
+ from sftp.metrics.fbeta_mix_measure import FBetaMixMeasure
4
+ from sftp.metrics.srl_metrics import SRLMetric
sftp/metrics/base_f.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from typing import *
3
+
4
+ from allennlp.training.metrics import Metric
5
+
6
+
7
+ class BaseF(Metric, ABC):
8
+ def __init__(self, prefix: str):
9
+ self.tp = self.fp = self.fn = 0
10
+ self.prefix = prefix
11
+
12
+ def reset(self) -> None:
13
+ self.tp = self.fp = self.fn = 0
14
+
15
+ def get_metric(
16
+ self, reset: bool
17
+ ) -> Union[float, Tuple[float, ...], Dict[str, float], Dict[str, List[float]]]:
18
+ precision = self.tp * 100 / (self.tp + self.fp) if self.tp > 0 else 0.
19
+ recall = self.tp * 100 / (self.tp + self.fn) if self.tp > 0 else 0.
20
+ rst = {
21
+ f'{self.prefix}_p': precision,
22
+ f'{self.prefix}_r': recall,
23
+ f'{self.prefix}_f': 2 / (1 / precision + 1 / recall) if self.tp > 0 else 0.
24
+ }
25
+ if reset:
26
+ self.reset()
27
+ return rst
sftp/metrics/exact_match.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from allennlp.training.metrics import Metric
2
+ from overrides import overrides
3
+
4
+ from .base_f import BaseF
5
+ from ..utils import Span
6
+
7
+
8
+ @Metric.register('exact_match')
9
+ class ExactMatch(BaseF):
10
+ def __init__(self, check_type: bool):
11
+ self.check_type = check_type
12
+ if check_type:
13
+ super(ExactMatch, self).__init__('em')
14
+ else:
15
+ super(ExactMatch, self).__init__('sm')
16
+
17
+ @overrides
18
+ def __call__(
19
+ self,
20
+ prediction: Span,
21
+ gold: Span,
22
+ ):
23
+ tp = prediction.match(gold, self.check_type) - 1
24
+ fp = prediction.n_nodes - tp - 1
25
+ fn = gold.n_nodes - tp - 1
26
+ assert tp >= 0 and fp >= 0 and fn >= 0
27
+ self.tp += tp
28
+ self.fp += fp
29
+ self.fn += fn
sftp/metrics/fbeta_mix_measure.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from allennlp.training.metrics import FBetaMeasure, Metric
2
+
3
+
4
+ @Metric.register('fbeta_mix')
5
+ class FBetaMixMeasure(FBetaMeasure):
6
+ def __init__(self, null_idx, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.null_idx = null_idx
9
+
10
+ def get_metric(self, reset: bool = False):
11
+
12
+ tp = float(self._true_positive_sum.sum() - self._true_positive_sum[self.null_idx])
13
+ total_pred = float(self._pred_sum.sum() - self._pred_sum[self.null_idx])
14
+ total_gold = float(self._true_sum.sum() - self._true_sum[self.null_idx])
15
+
16
+ beta2 = self._beta ** 2
17
+ p = 0. if total_pred == 0 else tp / total_pred
18
+ r = 0. if total_pred == 0 else tp / total_gold
19
+ f = 0. if p == 0. or r == 0. else ((1 + beta2) * p * r / (p * beta2 + r))
20
+
21
+ mix_f = {
22
+ 'p': p * 100,
23
+ 'r': r * 100,
24
+ 'f': f * 100
25
+ }
26
+
27
+ if reset:
28
+ self.reset()
29
+
30
+ return mix_f
31
+
32
+ def add_false_negative(self, labels):
33
+ for lab in labels:
34
+ self._true_sum[lab] += 1
sftp/metrics/srl_metrics.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ from allennlp.training.metrics import Metric
4
+ from overrides import overrides
5
+ import numpy as np
6
+ import logging
7
+
8
+ from .base_f import BaseF
9
+ from ..utils import Span, max_match
10
+
11
+ logger = logging.getLogger('srl_metric')
12
+
13
+
14
+ @Metric.register('srl')
15
+ class SRLMetric(Metric):
16
+ def __init__(self, check_type: Optional[bool] = None):
17
+ self.tri_i = BaseF('tri-i')
18
+ self.tri_c = BaseF('tri-c')
19
+ self.arg_i = BaseF('arg-i')
20
+ self.arg_c = BaseF('arg-c')
21
+ if check_type is not None:
22
+ logger.warning('Check type argument is deprecated.')
23
+
24
+ def reset(self) -> None:
25
+ for metric in [self.tri_i, self.tri_c, self.arg_i, self.arg_c]:
26
+ metric.reset()
27
+
28
+ def get_metric(self, reset: bool) -> Dict[str, Any]:
29
+ ret = dict()
30
+ for metric in [self.tri_i, self.tri_c, self.arg_i, self.arg_c]:
31
+ ret.update(metric.get_metric(reset))
32
+ return ret
33
+
34
+ @overrides
35
+ def __call__(self, prediction: Span, gold: Span):
36
+ self.with_label_event(prediction, gold)
37
+ self.without_label_event(prediction, gold)
38
+ self.tuple_eval(prediction, gold)
39
+ # self.with_label_arg(prediction, gold)
40
+ # self.without_label_arg(prediction, gold)
41
+
42
+ def tuple_eval(self, prediction: Span, gold: Span):
43
+ def extract_tuples(vr: Span, parent_boundary: bool):
44
+ labeled, unlabeled = list(), list()
45
+ for event in vr:
46
+ for arg in event:
47
+ if parent_boundary:
48
+ labeled.append((event.boundary, event.label, arg.boundary, arg.label))
49
+ unlabeled.append((event.boundary, event.label, arg.boundary))
50
+ else:
51
+ labeled.append((event.label, arg.boundary, arg.label))
52
+ unlabeled.append((event.label, arg.boundary))
53
+ return labeled, unlabeled
54
+
55
+ def equal_matrix(l1, l2): return np.array([[e1 == e2 for e2 in l2] for e1 in l1], dtype=np.int)
56
+
57
+ pred_label, pred_unlabel = extract_tuples(prediction, False)
58
+ gold_label, gold_unlabel = extract_tuples(gold, False)
59
+
60
+ if len(pred_label) == 0 or len(gold_label) == 0:
61
+ arg_c_tp = arg_i_tp = 0
62
+ else:
63
+ label_bipartite = equal_matrix(pred_label, gold_label)
64
+ unlabel_bipartite = equal_matrix(pred_unlabel, gold_unlabel)
65
+ arg_c_tp, arg_i_tp = max_match(label_bipartite), max_match(unlabel_bipartite)
66
+
67
+ arg_c_fp = prediction.n_nodes - len(prediction) - 1 - arg_c_tp
68
+ arg_c_fn = gold.n_nodes - len(gold) - 1 - arg_c_tp
69
+ arg_i_fp = prediction.n_nodes - len(prediction) - 1 - arg_i_tp
70
+ arg_i_fn = gold.n_nodes - len(gold) - 1 - arg_i_tp
71
+
72
+ assert arg_i_tp >= 0 and arg_i_fn >= 0 and arg_i_fp >= 0
73
+ self.arg_i.tp += arg_i_tp
74
+ self.arg_i.fp += arg_i_fp
75
+ self.arg_i.fn += arg_i_fn
76
+
77
+ assert arg_c_tp >= 0 and arg_c_fn >= 0 and arg_c_fp >= 0
78
+ self.arg_c.tp += arg_c_tp
79
+ self.arg_c.fp += arg_c_fp
80
+ self.arg_c.fn += arg_c_fn
81
+
82
+ def with_label_event(self, prediction: Span, gold: Span):
83
+ trigger_tp = prediction.match(gold, True, 2) - 1
84
+ trigger_fp = len(prediction) - trigger_tp
85
+ trigger_fn = len(gold) - trigger_tp
86
+ assert trigger_fp >= 0 and trigger_fn >= 0 and trigger_tp >= 0
87
+ self.tri_c.tp += trigger_tp
88
+ self.tri_c.fp += trigger_fp
89
+ self.tri_c.fn += trigger_fn
90
+
91
+ def with_label_arg(self, prediction: Span, gold: Span):
92
+ trigger_tp = prediction.match(gold, True, 2) - 1
93
+ role_tp = prediction.match(gold, True, ignore_parent_boundary=True) - 1 - trigger_tp
94
+ role_fp = (prediction.n_nodes - 1 - len(prediction)) - role_tp
95
+ role_fn = (gold.n_nodes - 1 - len(gold)) - role_tp
96
+ assert role_fp >= 0 and role_fn >= 0 and role_tp >= 0
97
+ self.arg_c.tp += role_tp
98
+ self.arg_c.fp += role_fp
99
+ self.arg_c.fn += role_fn
100
+
101
+ def without_label_event(self, prediction: Span, gold: Span):
102
+ tri_i_tp = prediction.match(gold, False, 2) - 1
103
+ tri_i_fp = len(prediction) - tri_i_tp
104
+ tri_i_fn = len(gold) - tri_i_tp
105
+ assert tri_i_tp >= 0 and tri_i_fp >= 0 and tri_i_fn >= 0
106
+ self.tri_i.tp += tri_i_tp
107
+ self.tri_i.fp += tri_i_fp
108
+ self.tri_i.fn += tri_i_fn
109
+
110
+ def without_label_arg(self, prediction: Span, gold: Span):
111
+ arg_i_tp = 0
112
+ matched_pairs: List[Tuple[Span, Span]] = list()
113
+ n_gold_arg, n_pred_arg = gold.n_nodes - len(gold) - 1, prediction.n_nodes - len(prediction) - 1
114
+ prediction, gold = prediction.clone(), gold.clone()
115
+ for p in prediction:
116
+ for g in gold:
117
+ if p.match(g, True, 1) == 1:
118
+ arg_i_tp += (p.match(g, False) - 1)
119
+ matched_pairs.append((p, g))
120
+ break
121
+ for p, g in matched_pairs:
122
+ prediction.remove_child(p)
123
+ gold.remove_child(g)
124
+
125
+ sub_matches = np.zeros([len(prediction), len(gold)], np.int)
126
+ for p_idx, p in enumerate(prediction):
127
+ for g_idx, g in enumerate(gold):
128
+ if p.label == g.label:
129
+ sub_matches[p_idx, g_idx] = p.match(g, False, -1, True)
130
+ arg_i_tp += max_match(sub_matches)
131
+
132
+ arg_i_fp = n_pred_arg - arg_i_tp
133
+ arg_i_fn = n_gold_arg - arg_i_tp
134
+ assert arg_i_tp >= 0 and arg_i_fn >= 0 and arg_i_fp >= 0
135
+
136
+ self.arg_i.tp += arg_i_tp
137
+ self.arg_i.fp += arg_i_fp
138
+ self.arg_i.fn += arg_i_fn
sftp/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from sftp.models.span_model import SpanModel
sftp/models/span_model.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import *
3
+
4
+ import torch
5
+ from allennlp.common.from_params import Params, T, pop_and_construct_arg
6
+ from allennlp.data.vocabulary import Vocabulary, DEFAULT_PADDING_TOKEN, DEFAULT_OOV_TOKEN
7
+ from allennlp.models.model import Model
8
+ from allennlp.modules import TextFieldEmbedder
9
+ from allennlp.modules.seq2seq_encoders.pytorch_seq2seq_wrapper import Seq2SeqEncoder
10
+ from allennlp.modules.span_extractors import SpanExtractor
11
+ from allennlp.training.metrics import Metric
12
+
13
+ from ..metrics import ExactMatch
14
+ from ..modules import SpanFinder, SpanTyping
15
+ from ..utils import num2mask, VIRTUAL_ROOT, Span, tensor2span
16
+
17
+
18
+ @Model.register("span")
19
+ class SpanModel(Model):
20
+ """
21
+ Identify/Find spans; link them as a tree; label them.
22
+ """
23
+ default_predictor = 'span'
24
+
25
+ def __init__(
26
+ self,
27
+ vocab: Vocabulary,
28
+
29
+ # Modules
30
+ word_embedding: TextFieldEmbedder,
31
+ span_extractor: SpanExtractor,
32
+ span_finder: SpanFinder,
33
+ span_typing: SpanTyping,
34
+
35
+ # Config
36
+ typing_loss_factor: float = 1.,
37
+ max_recursion_depth: int = -1,
38
+ max_decoding_spans: int = -1,
39
+ debug: bool = False,
40
+
41
+ # Ontology Constraints
42
+ ontology_path: Optional[str] = None,
43
+
44
+ # Metrics
45
+ metrics: Optional[List[Metric]] = None,
46
+ ) -> None:
47
+ """
48
+ Note for jsonnet file: it doesn't strictly follow the init examples of every module for that we override
49
+ the from_params method.
50
+ You can either check the SpanModel.from_params or the example jsonnet file.
51
+ :param vocab: No need to specify.
52
+ ## Modules
53
+ :param word_embedding: Refer to the module doc.
54
+ :param span_extractor: Refer to the module doc.
55
+ :param span_finder: Refer to the module doc.
56
+ :param span_typing: Refer to the module doc.
57
+ ## Configs
58
+ :param typing_loss_factor: loss = span_finder_loss + span_typing_loss * typing_loss_factor
59
+ :param max_recursion_depth: Maximum tree depth for inference. E.g., 1 for shallow event typing, 2 for SRL,
60
+ -1 (unlimited) for dependency parsing.
61
+ :param max_decoding_spans: Maximum spans for inference. -1 for unlimited.
62
+ :param debug: Useless now.
63
+ """
64
+ self._pad_idx = vocab.get_token_index(DEFAULT_PADDING_TOKEN, 'token')
65
+ self._null_idx = vocab.get_token_index(DEFAULT_OOV_TOKEN, 'span_label')
66
+ super().__init__(vocab)
67
+
68
+ self.word_embedding = word_embedding
69
+ self._span_finder = span_finder
70
+ self._span_extractor = span_extractor
71
+ self._span_typing = span_typing
72
+
73
+ self.metrics = [ExactMatch(True), ExactMatch(False)]
74
+ if metrics is not None:
75
+ self.metrics.extend(metrics)
76
+
77
+ if ontology_path is not None and os.path.exists(ontology_path):
78
+ self._span_typing.load_ontology(ontology_path, self.vocab)
79
+
80
+ self._max_decoding_spans = max_decoding_spans
81
+ self._typing_loss_factor = typing_loss_factor
82
+ self._max_recursion_depth = max_recursion_depth
83
+ self.debug = debug
84
+
85
+ def forward(
86
+ self,
87
+ tokens: Dict[str, Dict[str, torch.Tensor]],
88
+
89
+ span_boundary: Optional[torch.Tensor] = None,
90
+ span_labels: Optional[torch.Tensor] = None,
91
+ parent_indices: Optional[torch.Tensor] = None,
92
+ parent_mask: Optional[torch.Tensor] = None,
93
+
94
+ bio_seqs: Optional[torch.Tensor] = None,
95
+ raw_inputs: Optional[dict] = None,
96
+ meta: Optional[dict] = None,
97
+
98
+ **extra
99
+ ) -> Dict[str, torch.Tensor]:
100
+ """
101
+ For training, provide all blow.
102
+ For inference, it's enough to only provide words.
103
+
104
+ :param tokens: Indexed input sentence. Shape: [batch, token]
105
+
106
+ :param span_boundary: Start and end indices for every span. Note this includes both parent and
107
+ non-parent spans. Shape: [batch, span, 2]. For the last dim, [0] is start idx and [1] is end idx.
108
+ :param span_labels: Indexed label for spans, including parent and non-parent ones. Shape: [batch, span]
109
+ :param parent_indices: The parent span idx of every span. Shape: [batch, span]
110
+ :param parent_mask: True if this span is a parent. Shape: [batch, span]
111
+
112
+ :param bio_seqs: Shape [batch, parent, token, 3]
113
+ :param raw_inputs
114
+
115
+ :param meta: Meta information. Will be copied to the outputs.
116
+
117
+ :return:
118
+ - loss: training loss
119
+ - prediction: Predicted spans
120
+ - meta: Meta info copied from input
121
+ - inputs: Input sentences and spans (if exist)
122
+ """
123
+ ret = {'inputs': raw_inputs, 'meta': meta or dict()}
124
+
125
+ is_eval = span_labels is not None and not self.training # evaluation on dev set
126
+ is_test = span_labels is None # test on test set
127
+ # Shape [batch]
128
+ num_spans = (span_labels != -1).sum(1) if span_labels is not None else None
129
+ num_words = tokens['pieces']['mask'].sum(1)
130
+ # Shape [batch, word, token_dim]
131
+ token_vec = self.word_embedding(tokens)
132
+
133
+ if span_labels is not None:
134
+ # Revise the padding value from -1 to 0
135
+ span_labels[span_labels == -1] = 0
136
+
137
+ # Calculate Loss
138
+ if self.training or is_eval:
139
+ # Shape [batch, word, token_dim]
140
+ span_vec = self._span_extractor(token_vec, span_boundary)
141
+ finder_rst = self._span_finder(
142
+ token_vec, num2mask(num_words), span_vec, num2mask(num_spans), span_labels, parent_indices,
143
+ parent_mask, bio_seqs
144
+ )
145
+ typing_rst = self._span_typing(span_vec, parent_indices, span_labels)
146
+ ret['loss'] = finder_rst['loss'] + typing_rst['loss'] * self._typing_loss_factor
147
+
148
+ # Decoding
149
+ if is_eval or is_test:
150
+ pred_span_boundary, pred_span_labels, pred_parent_indices, pred_cursor, pred_label_confidence \
151
+ = self.inference(num_words, token_vec, **extra)
152
+ prediction = self.post_process_pred(
153
+ pred_span_boundary, pred_span_labels, pred_parent_indices, pred_cursor, pred_label_confidence
154
+ )
155
+ for pred, raw_in in zip(prediction, raw_inputs):
156
+ pred.re_index(raw_in['offsets'], True, True, True)
157
+ pred.remove_overlapping()
158
+ ret['prediction'] = prediction
159
+ if 'spans' in raw_inputs[0]:
160
+ for pred, raw_in in zip(prediction, raw_inputs):
161
+ gold = raw_in['spans']
162
+ for metric in self.metrics:
163
+ metric(pred, gold)
164
+
165
+ return ret
166
+
167
+ def inference(
168
+ self,
169
+ num_words: torch.Tensor,
170
+ token_vec: torch.Tensor,
171
+ **auxiliaries
172
+ ):
173
+ n_batch = num_words.shape[0]
174
+ # The decoding results are preserved in the following tensors starting with `pred`
175
+ # During inference, we completely ignore the arguments defaulted None in the forward method.
176
+ # The span indexing space is shift to the decoding span space. (since we do not have gold span now)
177
+ # boundary indices of every predicted span
178
+ pred_span_boundary = num_words.new_zeros([n_batch, self._max_decoding_spans, 2])
179
+ # labels (and corresponding confidence) for predicted spans
180
+ pred_span_labels = num_words.new_full(
181
+ [n_batch, self._max_decoding_spans], self.vocab.get_token_index(VIRTUAL_ROOT, 'span_label')
182
+ )
183
+ pred_label_confidence = num_words.new_zeros([n_batch, self._max_decoding_spans])
184
+ # label masked as True will be treated as parent in the next round
185
+ pred_parent_mask = num_words.new_zeros([n_batch, self._max_decoding_spans], dtype=torch.bool)
186
+ pred_parent_mask[:, 0] = True
187
+ # parent index (in the span indexing space) for every span
188
+ pred_parent_indices = num_words.new_zeros([n_batch, self._max_decoding_spans])
189
+ # what index have we reached for every batch?
190
+ pred_cursor = num_words.new_ones([n_batch])
191
+
192
+ # Pass environment variables to handler. Extra variables will be ignored.
193
+ # So pass the union of variables that are needed by different modules.
194
+ span_find_handler = self._span_finder.inference_forward_handler(
195
+ token_vec, num2mask(num_words), self._span_extractor, **auxiliaries
196
+ )
197
+
198
+ # Every step here is one layer of the tree. It deals with all the parents for the last layer
199
+ # so there might be 0 to multiple parents for a batch for a single step.
200
+ for _ in range(self._max_recursion_depth):
201
+ cursor_before_find = pred_cursor.clone()
202
+ span_find_handler(
203
+ pred_span_boundary, pred_span_labels, pred_parent_mask, pred_parent_indices, pred_cursor
204
+ )
205
+ # Labels of old spans are re-predicted. It doesn't matter since their results shouldn't change
206
+ # in theory.
207
+ span_typing_ret = self._span_typing(
208
+ self._span_extractor(token_vec, pred_span_boundary), pred_parent_indices, pred_span_labels, True
209
+ )
210
+ pred_span_labels = span_typing_ret['prediction']
211
+ pred_label_confidence = span_typing_ret['label_confidence']
212
+ pred_span_labels[:, 0] = self.vocab.get_token_index(VIRTUAL_ROOT, 'span_label')
213
+ pred_parent_mask = (
214
+ num2mask(cursor_before_find, self._max_decoding_spans) ^ num2mask(pred_cursor,
215
+ self._max_decoding_spans)
216
+ )
217
+
218
+ # Break the inference loop if 1) all batches reach max span limit OR 2) no parent is predicted
219
+ # at last step OR 3) max recursion limit is reached (for loop condition)
220
+ if (pred_cursor == self._max_decoding_spans).all() or pred_parent_mask.sum() == 0:
221
+ break
222
+
223
+ return pred_span_boundary, pred_span_labels, pred_parent_indices, pred_cursor, pred_label_confidence
224
+
225
+ def one_step_prediction(
226
+ self,
227
+ tokens: Dict[str, Dict[str, torch.Tensor]],
228
+ parent_boundary: torch.Tensor,
229
+ parent_labels: torch.Tensor,
230
+ ):
231
+ """
232
+ Single step prediction. Given parent span boundary indices, return the corresponding children spans
233
+ and their labels.
234
+ Restriction: Each sentence contain exactly 1 parent.
235
+ For efficient multi-layer prediction, i.e. given a root, predict the whole tree,
236
+ refer to the `forward' method.
237
+ :param tokens: See forward.
238
+ :param parent_boundary: Pairs of (start_idx, end_idx) for parents. Shape [batch, 2]
239
+ :param parent_labels: Labels for parents. Shape [batch]
240
+ Note: If `no_label' is on in span_finder module, this will be ignored.
241
+ :return:
242
+ children_boundary: (start_idx, end_idx) for every child span. Padded with (0, 0).
243
+ Shape [batch, children, 2]
244
+ children_labels: Label for every child span. Padded with null_idx. Shape [batch, children]
245
+ num_children: The number of children predicted for parent/batch. Shape [batch]
246
+ Tips: You can use num2mask method to convert this to bool tensor mask.
247
+ """
248
+ num_words = tokens['pieces']['mask'].sum(1)
249
+ # Shape [batch, word, token_dim]
250
+ token_vec = self.word_embedding(tokens)
251
+ n_batch = token_vec.shape[0]
252
+
253
+ # The following variables assumes the parent is the 0-th span, and we let the model
254
+ # to extend the span list.
255
+ pred_span_boundary = num_words.new_zeros([n_batch, self._max_decoding_spans, 2])
256
+ pred_span_boundary[:, 0] = parent_boundary
257
+ pred_span_labels = num_words.new_full([n_batch, self._max_decoding_spans], self._null_idx)
258
+ pred_span_labels[:, 0] = parent_labels
259
+ pred_parent_mask = num_words.new_zeros(pred_span_labels.shape, dtype=torch.bool)
260
+ pred_parent_mask[:, 0] = True
261
+ pred_parent_indices = num_words.new_zeros([n_batch, self._max_decoding_spans])
262
+ # We start from idx 1 since 0 is the parents.
263
+ pred_cursor = num_words.new_ones([n_batch])
264
+
265
+ span_find_handler = self._span_finder.inference_forward_handler(
266
+ token_vec, num2mask(num_words), self._span_extractor
267
+ )
268
+ span_find_handler(
269
+ pred_span_boundary, pred_span_labels, pred_parent_mask, pred_parent_indices, pred_cursor
270
+ )
271
+ typing_out = self._span_typing(
272
+ self._span_extractor(token_vec, pred_span_boundary), pred_parent_indices, pred_span_labels, True
273
+ )
274
+ pred_span_labels = typing_out['prediction']
275
+
276
+ # Now remove the parent
277
+ num_children = pred_cursor - 1
278
+ max_children = int(num_children.max())
279
+ children_boundary = pred_span_boundary[:, 1:max_children + 1]
280
+ children_labels = pred_span_labels[:, 1:max_children + 1]
281
+ children_distribution = typing_out['distribution'][:, 1:max_children + 1]
282
+ return children_boundary, children_labels, num_children, children_distribution
283
+
284
+ def post_process_pred(
285
+ self, span_boundary, span_labels, parent_indices, num_spans, label_confidence
286
+ ) -> List[Span]:
287
+ pred_spans = tensor2span(
288
+ span_boundary, span_labels, parent_indices, num_spans, label_confidence,
289
+ self.vocab.get_index_to_token_vocabulary('span_label'),
290
+ label_ignore=[self._null_idx],
291
+ )
292
+ return pred_spans
293
+
294
+ def get_metrics(self, reset: bool = False) -> Dict[str, float]:
295
+ ret = dict()
296
+ if reset:
297
+ for metric in self.metrics:
298
+ ret.update(metric.get_metric(reset))
299
+ ret.update(self._span_finder.get_metrics(reset))
300
+ ret.update(self._span_typing.get_metric(reset))
301
+ return ret
302
+
303
+ @classmethod
304
+ def from_params(
305
+ cls: Type[T],
306
+ params: Params,
307
+ constructor_to_call: Callable[..., T] = None,
308
+ constructor_to_inspect: Callable[..., T] = None,
309
+ **extras,
310
+ ) -> T:
311
+ """
312
+ Specify the dependency between modules. E.g. the input dim of a module might depend on the output dim
313
+ of another module.
314
+ """
315
+ vocab = extras['vocab']
316
+ word_embedding = pop_and_construct_arg('SpanModel', 'word_embedding', TextFieldEmbedder, None, params, **extras)
317
+ label_dim, token_emb_dim = params.pop('label_dim'), word_embedding.get_output_dim()
318
+ span_extractor = pop_and_construct_arg(
319
+ 'SpanModel', 'span_extractor', SpanExtractor, None, params, input_dim=token_emb_dim, **extras
320
+ )
321
+ label_embedding = torch.nn.Embedding(vocab.get_vocab_size('span_label'), label_dim)
322
+ extras['label_emb'] = label_embedding
323
+
324
+ if params.get('span_finder').get('type') == 'bio':
325
+ bio_encoder = Seq2SeqEncoder.from_params(
326
+ params['span_finder'].pop('bio_encoder'),
327
+ input_size=span_extractor.get_output_dim() + token_emb_dim + label_dim,
328
+ input_dim=span_extractor.get_output_dim() + token_emb_dim + label_dim,
329
+ **extras
330
+ )
331
+ extras['span_finder'] = SpanFinder.from_params(
332
+ params.pop('span_finder'), bio_encoder=bio_encoder, **extras
333
+ )
334
+ else:
335
+ extras['span_finder'] = pop_and_construct_arg(
336
+ 'SpanModel', 'span_finder', SpanFinder, None, params, **extras
337
+ )
338
+ extras['span_finder'].label_emb = label_embedding
339
+
340
+ if params.get('span_typing').get('type') == 'mlp':
341
+ extras['span_typing'] = SpanTyping.from_params(
342
+ params.pop('span_typing'),
343
+ input_dim=span_extractor.get_output_dim() * 2 + label_dim,
344
+ n_category=vocab.get_vocab_size('span_label'),
345
+ label_to_ignore=[
346
+ vocab.get_token_index(lti, 'span_label')
347
+ for lti in [DEFAULT_OOV_TOKEN, DEFAULT_PADDING_TOKEN]
348
+ ],
349
+ **extras
350
+ )
351
+ else:
352
+ extras['span_typing'] = pop_and_construct_arg(
353
+ 'SpanModel', 'span_typing', SpanTyping, None, params, **extras
354
+ )
355
+ extras['span_typing'].label_emb = label_embedding
356
+
357
+ return super().from_params(
358
+ params,
359
+ word_embedding=word_embedding,
360
+ span_extractor=span_extractor,
361
+ **extras
362
+ )
sftp/modules/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .span_extractor import ComboSpanExtractor
2
+ from .span_finder import SpanFinder, BIOSpanFinder
3
+ from .span_typing import MLPSpanTyping, SpanTyping
4
+ from .smooth_crf import SmoothCRF
sftp/modules/smooth_crf.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from allennlp.modules.conditional_random_field import ConditionalRandomField
3
+ from allennlp.nn.util import logsumexp
4
+ from overrides import overrides
5
+
6
+
7
+ class SmoothCRF(ConditionalRandomField):
8
+ @overrides
9
+ def forward(self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.Tensor = None):
10
+ """
11
+
12
+ :param inputs: Shape [batch, token, tag]
13
+ :param tags: Shape [batch, token] or [batch, token, tag]
14
+ :param mask: Shape [batch, token]
15
+ :return:
16
+ """
17
+ if mask is None:
18
+ mask = tags.new_ones(tags.shape, dtype=torch.bool)
19
+ mask = mask.to(dtype=torch.bool)
20
+ if tags.dim() == 2:
21
+ return super(SmoothCRF, self).forward(inputs, tags, mask)
22
+
23
+ # smooth mode
24
+ log_denominator = self._input_likelihood(inputs, mask)
25
+ log_numerator = self._smooth_joint_likelihood(inputs, tags, mask)
26
+
27
+ return torch.sum(log_numerator - log_denominator)
28
+
29
+ def _smooth_joint_likelihood(
30
+ self, logits: torch.Tensor, soft_tags: torch.Tensor, mask: torch.Tensor
31
+ ) -> torch.Tensor:
32
+ batch_size, sequence_length, num_tags = logits.size()
33
+
34
+ epsilon = 1e-30
35
+ soft_tags = soft_tags.clone()
36
+ soft_tags[soft_tags < epsilon] = epsilon
37
+
38
+ # Transpose batch size and sequence dimensions
39
+ mask = mask.transpose(0, 1).contiguous()
40
+ logits = logits.transpose(0, 1).contiguous()
41
+ soft_tags = soft_tags.transpose(0, 1).contiguous()
42
+
43
+ # Initial alpha is the (batch_size, num_tags) tensor of likelihoods combining the
44
+ # transitions to the initial states and the logits for the first timestep.
45
+ if self.include_start_end_transitions:
46
+ alpha = self.start_transitions.view(1, num_tags) + logits[0] + soft_tags[0].log()
47
+ else:
48
+ alpha = logits[0] * soft_tags[0]
49
+
50
+ # For each i we compute logits for the transitions from timestep i-1 to timestep i.
51
+ # We do so in a (batch_size, num_tags, num_tags) tensor where the axes are
52
+ # (instance, current_tag, next_tag)
53
+ for i in range(1, sequence_length):
54
+ # The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis.
55
+ emit_scores = logits[i].view(batch_size, 1, num_tags)
56
+ # Transition scores are (current_tag, next_tag) so we broadcast along the instance axis.
57
+ transition_scores = self.transitions.view(1, num_tags, num_tags)
58
+ # Alpha is for the current_tag, so we broadcast along the next_tag axis.
59
+ broadcast_alpha = alpha.view(batch_size, num_tags, 1)
60
+
61
+ # Add all the scores together and logexp over the current_tag axis.
62
+ inner = broadcast_alpha + emit_scores + transition_scores + soft_tags[i].log().unsqueeze(1)
63
+
64
+ # In valid positions (mask == True) we want to take the logsumexp over the current_tag dimension
65
+ # of `inner`. Otherwise (mask == False) we want to retain the previous alpha.
66
+ alpha = logsumexp(inner, 1) * mask[i].view(batch_size, 1) + alpha * (
67
+ ~mask[i]
68
+ ).view(batch_size, 1)
69
+
70
+ # Every sequence needs to end with a transition to the stop_tag.
71
+ if self.include_start_end_transitions:
72
+ stops = alpha + self.end_transitions.view(1, num_tags)
73
+ else:
74
+ stops = alpha
75
+
76
+ # Finally we log_sum_exp along the num_tags dim, result is (batch_size,)
77
+ return logsumexp(stops)
sftp/modules/span_extractor/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .combo import ComboSpanExtractor
sftp/modules/span_extractor/combo.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import torch
4
+ from allennlp.modules.span_extractors import SpanExtractor
5
+
6
+
7
+ @SpanExtractor.register('combo')
8
+ class ComboSpanExtractor(SpanExtractor):
9
+ def __init__(self, input_dim: int, sub_extractors: List[SpanExtractor]):
10
+ super().__init__()
11
+ self.sub_extractors = sub_extractors
12
+ for i, sub in enumerate(sub_extractors):
13
+ self.add_module(f'SpanExtractor-{i+1}', sub)
14
+ self.input_dim = input_dim
15
+
16
+ def get_input_dim(self) -> int:
17
+ return self.input_dim
18
+
19
+ def get_output_dim(self) -> int:
20
+ return sum([sub.get_output_dim() for sub in self.sub_extractors])
21
+
22
+ def forward(
23
+ self,
24
+ sequence_tensor: torch.FloatTensor,
25
+ span_indices: torch.LongTensor,
26
+ sequence_mask: torch.BoolTensor = None,
27
+ span_indices_mask: torch.BoolTensor = None,
28
+ ):
29
+ outputs = [
30
+ sub(
31
+ sequence_tensor=sequence_tensor,
32
+ span_indices=span_indices,
33
+ span_indices_mask=span_indices_mask
34
+ ) for sub in self.sub_extractors
35
+ ]
36
+ return torch.cat(outputs, dim=2)
sftp/modules/span_finder/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .bio_span_finder import BIOSpanFinder
2
+ from .span_finder import SpanFinder
sftp/modules/span_finder/bio_span_finder.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import torch
4
+ from allennlp.data.dataset_readers.dataset_utils.span_utils import bio_tags_to_spans
5
+ from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder
6
+ from allennlp.modules.span_extractors import SpanExtractor
7
+ from allennlp.training.metrics import FBetaMeasure
8
+
9
+ from ..smooth_crf import SmoothCRF
10
+ from .span_finder import SpanFinder
11
+ from ...utils import num2mask, mask2idx, BIO
12
+
13
+
14
+ @SpanFinder.register("bio")
15
+ class BIOSpanFinder(SpanFinder):
16
+ """
17
+ Train BIO representations for span finding.
18
+ """
19
+ def __init__(
20
+ self,
21
+ bio_encoder: Seq2SeqEncoder,
22
+ label_emb: torch.nn.Embedding,
23
+ no_label: bool = True,
24
+ ):
25
+ super().__init__(no_label)
26
+ self.bio_encoder = bio_encoder
27
+ self.label_emb = label_emb
28
+
29
+ self.classifier = torch.nn.Linear(bio_encoder.get_output_dim(), 3)
30
+ self.crf = SmoothCRF(3)
31
+
32
+ self.fb_measure = FBetaMeasure(1., 'micro', [BIO.index('B'), BIO.index('I')])
33
+
34
+ def forward(
35
+ self,
36
+ token_vec: torch.Tensor,
37
+ token_mask: torch.Tensor,
38
+ span_vec: torch.Tensor,
39
+ span_mask: Optional[torch.Tensor] = None, # Do not need to provide
40
+ span_labels: Optional[torch.Tensor] = None, # Do not need to provide
41
+ parent_indices: Optional[torch.Tensor] = None, # Do not need to provide
42
+ parent_mask: Optional[torch.Tensor] = None,
43
+ bio_seqs: Optional[torch.Tensor] = None,
44
+ prediction: bool = False,
45
+ **extra
46
+ ) -> Dict[str, torch.Tensor]:
47
+ """
48
+ See doc of SpanFinder.
49
+ Possible extra variables:
50
+ smoothing_factor
51
+ :return:
52
+ - loss
53
+ - prediction
54
+ """
55
+ ret = dict()
56
+ is_soft = span_labels.dtype != torch.int64
57
+
58
+ distinct_parent_indices, num_parents = mask2idx(parent_mask)
59
+ n_batch, n_parent = distinct_parent_indices.shape
60
+ n_token = token_vec.shape[1]
61
+ # Shape [batch, parent, token_dim]
62
+ parent_span_features = span_vec.gather(
63
+ 1, distinct_parent_indices.unsqueeze(2).expand(-1, -1, span_vec.shape[2])
64
+ )
65
+ label_features = span_labels @ self.label_emb.weight if is_soft else self.label_emb(span_labels)
66
+ if self._no_label:
67
+ label_features = label_features.zero_()
68
+ # Shape [batch, span, label_dim]
69
+ parent_label_features = label_features.gather(
70
+ 1, distinct_parent_indices.unsqueeze(2).expand(-1, -1, label_features.shape[2])
71
+ )
72
+ # Shape [batch, parent, token, token_dim*2]
73
+ encoder_inputs = torch.cat([
74
+ parent_span_features.unsqueeze(2).expand(-1, -1, n_token, -1),
75
+ token_vec.unsqueeze(1).expand(-1, n_parent, -1, -1),
76
+ parent_label_features.unsqueeze(2).expand(-1, -1, n_token, -1),
77
+ ], dim=3)
78
+ encoder_inputs = encoder_inputs.reshape(n_batch * n_parent, n_token, -1)
79
+
80
+ # Shape [batch, parent]. Considers batches may have fewer seqs.
81
+ seq_mask = num2mask(num_parents)
82
+ # Shape [batch, parent, token]. Also considers batches may have fewer tokens.
83
+ token_mask = seq_mask.unsqueeze(2).expand(-1, -1, n_token) & token_mask.unsqueeze(1).expand(-1, n_parent, -1)
84
+
85
+ class_in = self.bio_encoder(encoder_inputs, token_mask.flatten(0, 1))
86
+ class_out = self.classifier(class_in).reshape(n_batch, n_parent, n_token, 3)
87
+
88
+ if not prediction:
89
+ # For training
90
+ # We use `seq_mask` here because seq with length 0 is not acceptable.
91
+ ret['loss'] = -self.crf(class_out[seq_mask], bio_seqs[seq_mask], token_mask[seq_mask])
92
+ self.fb_measure(class_out[seq_mask], bio_seqs[seq_mask].max(2).indices, token_mask[seq_mask])
93
+ else:
94
+ # For prediction
95
+ features_for_decode = class_out.clone().detach()
96
+ decoded = self.crf.viterbi_tags(features_for_decode.flatten(0, 1), token_mask.flatten(0, 1))
97
+ pred_tag = torch.tensor(
98
+ [path + [BIO.index('O')] * (n_token - len(path)) for path, _ in decoded]
99
+ )
100
+ pred_tag = pred_tag.reshape(n_batch, n_parent, n_token)
101
+ ret['prediction'] = pred_tag
102
+
103
+ return ret
104
+
105
+ @staticmethod
106
+ def bio2boundary(seqs) -> Tuple[torch.Tensor, torch.Tensor]:
107
+ def recursive_construct_spans(seqs_):
108
+ """
109
+ Helper function for bio2boundary
110
+ Recursively convert seqs of integers to boundary indices.
111
+ Return boundary indices and corresponding lens
112
+ """
113
+ if isinstance(seqs_, torch.Tensor):
114
+ if seqs_.device.type == 'cuda':
115
+ seqs_ = seqs_.to(device='cpu')
116
+ seqs_ = seqs_.tolist()
117
+ if isinstance(seqs_[0], int):
118
+ seqs_ = [BIO[i] for i in seqs_]
119
+ span_boundary_list = bio_tags_to_spans(seqs_)
120
+ return torch.tensor([item[1] for item in span_boundary_list]), len(span_boundary_list)
121
+ span_boundary = list()
122
+ lens_ = list()
123
+ for seq in seqs_:
124
+ one_bou, one_len = recursive_construct_spans(seq)
125
+ span_boundary.append(one_bou)
126
+ lens_.append(one_len)
127
+ if isinstance(lens_[0], int):
128
+ lens_ = torch.tensor(lens_)
129
+ else:
130
+ lens_ = torch.stack(lens_)
131
+ return span_boundary, lens_
132
+
133
+ boundary_list, lens = recursive_construct_spans(seqs)
134
+ max_span = int(lens.max())
135
+ boundary = torch.zeros((*lens.shape, max_span, 2), dtype=torch.long)
136
+
137
+ def recursive_copy(list_var, tensor_var):
138
+ if len(list_var) == 0:
139
+ return
140
+ if isinstance(list_var, torch.Tensor):
141
+ tensor_var[:len(list_var)] = list_var
142
+ return
143
+ assert len(list_var) == len(tensor_var)
144
+ for list_var_, tensor_var_ in zip(list_var, tensor_var):
145
+ recursive_copy(list_var_, tensor_var_)
146
+
147
+ recursive_copy(boundary_list, boundary)
148
+
149
+ return boundary, lens
150
+
151
+ def inference_forward_handler(
152
+ self,
153
+ token_vec: torch.Tensor,
154
+ token_mask: torch.Tensor,
155
+ span_extractor: SpanExtractor,
156
+ **auxiliaries,
157
+ ) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], None]:
158
+ """
159
+ Refer to the doc of the SpanFinder for definition of this function.
160
+ """
161
+
162
+ def handler(
163
+ span_boundary: torch.Tensor,
164
+ span_labels: torch.Tensor,
165
+ parent_mask: torch.Tensor,
166
+ parent_indices: torch.Tensor,
167
+ cursor: torch.tensor,
168
+ ):
169
+ """
170
+ Refer to the doc of the SpanFinder for definition of this function.
171
+ """
172
+ max_decoding_span = span_boundary.shape[1]
173
+ # Shape [batch, span, token_dim]
174
+ span_vec = span_extractor(token_vec, span_boundary)
175
+ # Shape [batch, parent]
176
+ parent_indices_at_span, _ = mask2idx(parent_mask)
177
+ pred_bio = self(
178
+ token_vec, token_mask, span_vec, None, span_labels, None, parent_mask, prediction=True
179
+ )['prediction']
180
+ # Shape [batch, parent, span, 2]; Shape [batch, parent]
181
+ pred_boundary, pred_num = self.bio2boundary(pred_bio)
182
+ if pred_boundary.device != span_boundary.device:
183
+ pred_boundary = pred_boundary.to(device=span_boundary.device)
184
+ pred_num = pred_num.to(device=span_boundary.device)
185
+ # Shape [batch, parent, span]
186
+ pred_mask = num2mask(pred_num)
187
+
188
+ # Parent Loop
189
+ for pred_boundary_parent, pred_mask_parent, parent_indices_parent \
190
+ in zip(pred_boundary.unbind(1), pred_mask.unbind(1), parent_indices_at_span.unbind(1)):
191
+ for pred_boundary_step, step_mask in zip(pred_boundary_parent.unbind(1), pred_mask_parent.unbind(1)):
192
+ step_mask &= cursor < max_decoding_span
193
+ parent_indices[step_mask] = parent_indices[step_mask].scatter(
194
+ 1,
195
+ cursor[step_mask].unsqueeze(1),
196
+ parent_indices_parent[step_mask].unsqueeze(1)
197
+ )
198
+ span_boundary[step_mask] = span_boundary[step_mask].scatter(
199
+ 1,
200
+ cursor[step_mask].reshape(-1, 1, 1).expand(-1, -1, 2),
201
+ pred_boundary_step[step_mask].unsqueeze(1)
202
+ )
203
+ cursor[step_mask] += 1
204
+
205
+ return handler
206
+
207
+ def get_metrics(self, reset: bool = False) -> Dict[str, float]:
208
+ score = self.fb_measure.get_metric(reset)
209
+ if reset:
210
+ return {
211
+ 'finder_p': score['precision'] * 100,
212
+ 'finder_r': score['recall'] * 100,
213
+ 'finder_f': score['fscore'] * 100,
214
+ }
215
+ else:
216
+ return {'finder_f': score['fscore'] * 100}
sftp/modules/span_finder/span_finder.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import *
3
+
4
+ import torch
5
+ from allennlp.common import Registrable
6
+ from allennlp.modules.span_extractors import SpanExtractor
7
+
8
+
9
+ class SpanFinder(Registrable, ABC, torch.nn.Module):
10
+ """
11
+ Model the probability p(child_span | parent_span [, parent_label])
12
+ It's optional to model parent_label, since in some cases we may want the parameters to be shared across
13
+ different tasks, where we may have similar span semantics but different label space.
14
+ """
15
+ def __init__(
16
+ self,
17
+ no_label: bool = True,
18
+ ):
19
+ """
20
+ :param no_label: If True, will not use input labels as features and use all 0 vector instead.
21
+ """
22
+ super().__init__()
23
+ self._no_label = no_label
24
+
25
+ @abstractmethod
26
+ def forward(
27
+ self,
28
+ token_vec: torch.Tensor,
29
+ token_mask: torch.Tensor,
30
+ span_vec: torch.Tensor,
31
+ span_mask: Optional[torch.Tensor] = None, # Do not need to provide
32
+ span_labels: Optional[torch.Tensor] = None, # Do not need to provide
33
+ parent_indices: Optional[torch.Tensor] = None, # Do not need to provide
34
+ parent_mask: Optional[torch.Tensor] = None,
35
+ bio_seqs: Optional[torch.Tensor] = None,
36
+ prediction: bool = False,
37
+ **extra
38
+ ) -> Dict[str, torch.Tensor]:
39
+ """
40
+ Return training loss and predictions.
41
+ :param token_vec: Vector representation of tokens. Shape [batch, token ,token_dim]
42
+ :param token_mask: True for non-padding tokens.
43
+ :param span_vec: Vector representation of spans. Shape [batch, span, token_dim]
44
+ :param span_mask: True for non-padding spans. Shape [batch, span]
45
+ :param span_labels: The labels of spans. Shape [batch, span]
46
+ :param parent_indices: Parent indices of spans. Shape [batch, span]
47
+ :param parent_mask: True for parent spans. Shape [batch, span]
48
+ :param prediction: If True, no loss will be return & no metrics will be updated.
49
+ :param bio_seqs: BIO sequences. Shape [batch, parent, token, 3]
50
+ :return:
51
+ loss: Training loss
52
+ prediction: Shape [batch, span]. True for positive predictions.
53
+ """
54
+ raise NotImplementedError
55
+
56
+ @abstractmethod
57
+ def inference_forward_handler(
58
+ self,
59
+ token_vec: torch.Tensor,
60
+ token_mask: torch.Tensor,
61
+ span_extractor: SpanExtractor,
62
+ **auxiliaries,
63
+ ) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], None]:
64
+ """
65
+ Pre-process some information and return a callable module for p(child_span | parent_span [,parent_label])
66
+ :param token_vec: Vector representation of tokens. Shape [batch, token ,token_dim]
67
+ :param token_mask: True for non-padding tokens.
68
+ :param span_extractor: The same module in model.
69
+ :param auxiliaries: Environment variables. You can pass extra environment variables
70
+ since the extras will be ignored.
71
+ :return:
72
+ A callable function in a closure.
73
+ The arguments for the callable object are:
74
+ - span_boundary: Shape [batch, span, 2]
75
+ - span_labels: Shape [batch, span]
76
+ - parent_mask: Shape [batch, span]
77
+ - parent_indices: Shape [batch, span]
78
+ - cursor: Shape [batch]
79
+ No return values. Everything should be done inplace.
80
+ Note the span indexing space has different meaning from training process. We don't have gold span list,
81
+ so span here refers to the predicted spans.
82
+ """
83
+ raise NotImplementedError
84
+
85
+ @abstractmethod
86
+ def get_metrics(self, reset: bool = False) -> Dict[str, float]:
87
+ raise NotImplementedError
sftp/modules/span_typing/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .mlp_span_typing import MLPSpanTyping
2
+ from .span_typing import SpanTyping
sftp/modules/span_typing/mlp_span_typing.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import torch
4
+ from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax
5
+
6
+ from .span_typing import SpanTyping
7
+
8
+
9
+ @SpanTyping.register('mlp')
10
+ class MLPSpanTyping(SpanTyping):
11
+ """
12
+ An MLP implementation for Span Typing.
13
+ """
14
+ def __init__(
15
+ self,
16
+ input_dim: int,
17
+ hidden_dims: List[int],
18
+ label_emb: torch.nn.Embedding,
19
+ n_category: int,
20
+ label_to_ignore: Optional[List[int]] = None
21
+ ):
22
+ """
23
+ :param input_dim: dim(parent_span) + dim(child_span) + dim(label_dim)
24
+ :param hidden_dims: The dim of hidden layers of MLP.
25
+ :param n_category: #labels
26
+ :param label_emb: Embeds labels to vectors.
27
+ """
28
+ super().__init__(label_emb.num_embeddings, label_to_ignore, )
29
+ self.MLPs: List[torch.nn.Linear] = list()
30
+ for i_mlp, output_dim in enumerate(hidden_dims + [n_category]):
31
+ mlp = torch.nn.Linear(input_dim, output_dim, bias=True)
32
+ self.MLPs.append(mlp)
33
+ self.add_module(f'MLP-{i_mlp}', mlp)
34
+ input_dim = output_dim
35
+
36
+ # Embeds labels as features.
37
+ self.label_emb = label_emb
38
+
39
+ def forward(
40
+ self,
41
+ span_vec: torch.Tensor,
42
+ parent_at_span: torch.Tensor,
43
+ span_labels: Optional[torch.Tensor],
44
+ prediction_only: bool = False,
45
+ ) -> Dict[str, torch.Tensor]:
46
+ """
47
+ Inputs: All features for typing a child span.
48
+ Process: Update the metric.
49
+ Output: The loss of typing and predictions.
50
+ :return:
51
+ loss: Loss for label prediction.
52
+ prediction: Predicted labels.
53
+ """
54
+ is_soft = span_labels.dtype != torch.int64
55
+ # Shape [batch, span, label_dim]
56
+ label_vec = span_labels @ self.label_emb.weight if is_soft else self.label_emb(span_labels)
57
+ n_batch, n_span, _ = label_vec.shape
58
+ n_label, _ = self.ontology.shape
59
+ # Shape [batch, span, label_dim]
60
+ parent_label_features = label_vec.gather(1, parent_at_span.unsqueeze(2).expand_as(label_vec))
61
+ # Shape [batch, span, token_dim]
62
+ parent_span_features = span_vec.gather(1, parent_at_span.unsqueeze(2).expand_as(span_vec))
63
+ # Shape [batch, span, token_dim]
64
+ child_span_features = span_vec
65
+
66
+ features = torch.cat([parent_label_features, parent_span_features, child_span_features], dim=2)
67
+ # Shape [batch, span, label]
68
+ for mlp in self.MLPs[:-1]:
69
+ features = torch.relu(mlp(features))
70
+ logits = self.MLPs[-1](features)
71
+
72
+ logits_for_prediction = logits.clone()
73
+
74
+ if not is_soft:
75
+ # Shape [batch, span]
76
+ parent_labels = span_labels.gather(1, parent_at_span)
77
+ onto_mask = self.ontology.unsqueeze(0).expand(n_batch, -1, -1).gather(
78
+ 1, parent_labels.unsqueeze(2).expand(-1, -1, n_label)
79
+ )
80
+ logits_for_prediction[~onto_mask] = float('-inf')
81
+
82
+ label_dist = torch.softmax(logits_for_prediction, 2)
83
+ label_confidence, predictions = label_dist.max(2)
84
+ ret = {'prediction': predictions, 'label_confidence': label_confidence, 'distribution': label_dist}
85
+ if prediction_only:
86
+ return ret
87
+
88
+ span_labels = span_labels.clone()
89
+
90
+ if is_soft:
91
+ self.acc_metric(logits_for_prediction, span_labels.max(2)[1], ~span_labels.sum(2).isclose(torch.tensor(0.)))
92
+ ret['loss'] = KLDivLoss(reduction='sum')(LogSoftmax(dim=2)(logits), span_labels)
93
+ else:
94
+ for label_idx in self.label_to_ignore:
95
+ span_labels[span_labels == label_idx] = -100
96
+ self.acc_metric(logits_for_prediction, span_labels, span_labels != -100)
97
+ ret['loss'] = CrossEntropyLoss(reduction='sum')(logits.flatten(0, 1), span_labels.flatten())
98
+
99
+ return ret
sftp/modules/span_typing/span_typing.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from typing import *
3
+
4
+ import torch
5
+ from allennlp.common import Registrable
6
+ from allennlp.data.vocabulary import DEFAULT_OOV_TOKEN, Vocabulary
7
+ from allennlp.training.metrics import CategoricalAccuracy
8
+
9
+
10
+ class SpanTyping(Registrable, torch.nn.Module, ABC):
11
+ """
12
+ Models the probability p(child_label | child_span, parent_span, parent_label).
13
+ """
14
+ def __init__(
15
+ self,
16
+ n_label: int,
17
+ label_to_ignore: Optional[List[int]] = None,
18
+ ):
19
+ """
20
+ :param label_to_ignore: Label indexes in this list will be ignored.
21
+ Usually this should include NULL, PADDING and UNKNOWN.
22
+ """
23
+ super().__init__()
24
+ self.label_to_ignore = label_to_ignore or list()
25
+ self.acc_metric = CategoricalAccuracy()
26
+ self.onto = torch.ones([n_label, n_label], dtype=torch.bool)
27
+ self.register_buffer('ontology', self.onto)
28
+
29
+ def load_ontology(self, path: str, vocab: Vocabulary):
30
+ unk_id = vocab.get_token_index(DEFAULT_OOV_TOKEN, 'span_label')
31
+ for line in open(path).readlines():
32
+ entities = [vocab.get_token_index(ent, 'span_label') for ent in line.replace('\n', '').split('\t')]
33
+ parent, children = entities[0], entities[1:]
34
+ if parent == unk_id:
35
+ continue
36
+ self.onto[parent, :] = False
37
+ children = list(filter(lambda x: x != unk_id, children))
38
+ self.onto[parent, children] = True
39
+ self.register_buffer('ontology', self.onto)
40
+
41
+ def forward(
42
+ self,
43
+ span_vec: torch.Tensor,
44
+ parent_at_span: torch.Tensor,
45
+ span_labels: Optional[torch.Tensor],
46
+ prediction_only: bool = False,
47
+ ) -> Dict[str, torch.Tensor]:
48
+ """
49
+ Inputs: All features for typing a child span.
50
+ Output: The loss of typing and predictions.
51
+ :param span_vec: Shape [batch, span, token_dim]
52
+ :param parent_at_span: Shape [batch, span]
53
+ :param span_labels: Shape [batch, span]
54
+ :param prediction_only: If True, no loss returned & metric will not be updated
55
+ :return:
56
+ loss: Loss for label prediction. (absent of pred_only = True)
57
+ prediction: Predicted labels.
58
+ """
59
+ raise NotImplementedError
60
+
61
+ def get_metric(self, reset):
62
+ return{
63
+ "typing_acc": self.acc_metric.get_metric(reset) * 100
64
+ }
sftp/predictor/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .span_predictor import SpanPredictor
sftp/predictor/span_predictor.orig.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from time import time
3
+ from typing import *
4
+ import json
5
+
6
+ import numpy as np
7
+ import torch
8
+ from allennlp.common.util import JsonDict, sanitize
9
+ from allennlp.data import DatasetReader, Instance
10
+ from allennlp.data.data_loaders import SimpleDataLoader
11
+ from allennlp.data.samplers import MaxTokensBatchSampler
12
+ from allennlp.data.tokenizers import SpacyTokenizer
13
+ from allennlp.models import Model
14
+ from allennlp.nn import util as nn_util
15
+ from allennlp.predictors import Predictor
16
+ from concrete import (
17
+ MentionArgument, SituationMentionSet, SituationMention, TokenRefSequence,
18
+ EntityMention, EntityMentionSet, Entity, EntitySet, AnnotationMetadata, Communication
19
+ )
20
+ from concrete.util import CommunicationReader, AnalyticUUIDGeneratorFactory, CommunicationWriterZip
21
+ from concrete.validate import validate_communication
22
+
23
+ from ..data_reader import concrete_doc, concrete_doc_tokenized
24
+ from ..utils import Span, re_index_span, VIRTUAL_ROOT
25
+
26
+
27
+ class PredictionReturn(NamedTuple):
28
+ span: Union[Span, dict, Communication]
29
+ sentence: List[str]
30
+ meta: Dict[str, Any]
31
+
32
+
33
+ class ForceDecodingReturn(NamedTuple):
34
+ span: np.ndarray
35
+ label: List[str]
36
+ distribution: np.ndarray
37
+
38
+
39
+ @Predictor.register('span')
40
+ class SpanPredictor(Predictor):
41
+ @staticmethod
42
+ def format_convert(
43
+ sentence: Union[List[str], List[List[str]]],
44
+ prediction: Union[Span, List[Span]],
45
+ output_format: str
46
+ ):
47
+ if output_format == 'span':
48
+ return prediction
49
+ elif output_format == 'json':
50
+ if isinstance(prediction, list):
51
+ return [SpanPredictor.format_convert(sent, pred, 'json') for sent, pred in zip(sentence, prediction)]
52
+ return prediction.to_json()
53
+ elif output_format == 'concrete':
54
+ if isinstance(prediction, Span):
55
+ sentence, prediction = [sentence], [prediction]
56
+ return concrete_doc_tokenized(sentence, prediction)
57
+
58
+ def predict_concrete(
59
+ self,
60
+ concrete_path: str,
61
+ output_path: Optional[str] = None,
62
+ max_tokens: int = 2048,
63
+ ontology_mapping: Optional[Dict[str, str]] = None,
64
+ ):
65
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
66
+ writer = CommunicationWriterZip(output_path)
67
+
68
+ for comm, fn in CommunicationReader(concrete_path):
69
+ assert len(comm.sectionList) == 1
70
+ concrete_sentences = comm.sectionList[0].sentenceList
71
+ json_sentences = list()
72
+ for con_sent in concrete_sentences:
73
+ json_sentences.append(
74
+ [t.text for t in con_sent.tokenization.tokenList.tokenList]
75
+ )
76
+ predictions = self.predict_batch_sentences(json_sentences, max_tokens, ontology_mapping=ontology_mapping)
77
+
78
+ # Merge predictions into concrete
79
+ aug = AnalyticUUIDGeneratorFactory(comm).create()
80
+ situation_mention_set = SituationMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list())
81
+ comm.situationMentionSetList = [situation_mention_set]
82
+ situation_mention_set.mentionList = sm_list = list()
83
+ entity_mention_set = EntityMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list())
84
+ comm.entityMentionSetList = [entity_mention_set]
85
+ entity_mention_set.mentionList = em_list = list()
86
+ entity_set = EntitySet(
87
+ next(aug), AnnotationMetadata('Span Finder', time()), list(), None, entity_mention_set.uuid
88
+ )
89
+ comm.entitySetList = [entity_set]
90
+
91
+ em_dict = dict()
92
+ for con_sent, pred in zip(concrete_sentences, predictions):
93
+ for event in pred.span:
94
+ def raw_text_span(start_idx, end_idx, **_):
95
+ si_char = con_sent.tokenization.tokenList.tokenList[start_idx].textSpan.start
96
+ ei_char = con_sent.tokenization.tokenList.tokenList[end_idx].textSpan.ending
97
+ return comm.text[si_char:ei_char]
98
+ sm = SituationMention(
99
+ next(aug),
100
+ text=raw_text_span(event.start_idx, event.end_idx),
101
+ situationKind=event.label,
102
+ situationType='EVENT',
103
+ confidence=event.confidence,
104
+ argumentList=list(),
105
+ tokens=TokenRefSequence(
106
+ tokenIndexList=list(range(event.start_idx, event.end_idx+1)),
107
+ tokenizationId=con_sent.tokenization.uuid
108
+ )
109
+ )
110
+
111
+ for arg in event:
112
+ em = em_dict.get((arg.start_idx, arg.end_idx + 1))
113
+ if em is None:
114
+ em = EntityMention(
115
+ next(aug),
116
+ tokens=TokenRefSequence(
117
+ tokenIndexList=list(range(arg.start_idx, arg.end_idx+1)),
118
+ tokenizationId=con_sent.tokenization.uuid,
119
+ ),
120
+ text=raw_text_span(arg.start_idx, arg.end_idx)
121
+ )
122
+ em_list.append(em)
123
+ entity_set.entityList.append(Entity(next(aug), id=em.text, mentionIdList=[em.uuid]))
124
+ em_dict[(arg.start_idx, arg.end_idx+1)] = em
125
+ sm.argumentList.append(MentionArgument(
126
+ role=arg.label,
127
+ entityMentionId=em.uuid,
128
+ confidence=arg.confidence
129
+ ))
130
+ sm_list.append(sm)
131
+ validate_communication(comm)
132
+ writer.write(comm, fn)
133
+ writer.close()
134
+
135
+ def predict_sentence(
136
+ self,
137
+ sentence: Union[str, List[str]],
138
+ ontology_mapping: Optional[Dict[str, str]] = None,
139
+ output_format: str = 'span',
140
+ ) -> PredictionReturn:
141
+ """
142
+ Predict spans on a single sentence (no batch). If not tokenized, will tokenize it with SpacyTokenizer.
143
+ :param sentence: If tokenized, should be a list of tokens in string. If not, should be a string.
144
+ :param ontology_mapping:
145
+ :param output_format: span, json or concrete.
146
+ """
147
+ prediction = self.predict_json(self._prepare_sentence(sentence))
148
+ prediction['prediction'] = self.format_convert(
149
+ prediction['sentence'],
150
+ Span.from_json(prediction['prediction']).map_ontology(ontology_mapping),
151
+ output_format
152
+ )
153
+ return PredictionReturn(prediction['prediction'], prediction['sentence'], prediction.get('meta', dict()))
154
+
155
+ def predict_batch_sentences(
156
+ self,
157
+ sentences: List[Union[List[str], str]],
158
+ max_tokens: int = 512,
159
+ ontology_mapping: Optional[Dict[str, str]] = None,
160
+ output_format: str = 'span',
161
+ ) -> List[PredictionReturn]:
162
+ """
163
+ Predict spans on a batch of sentences. If not tokenized, will tokenize it with SpacyTokenizer.
164
+ :param sentences: A list of sentences. Refer to `predict_sentence`.
165
+ :param max_tokens: Maximum tokens in a batch.
166
+ :param ontology_mapping: If not None, will try to map the output from one ontology to another.
167
+ If the predicted frame is not in the mapping, the prediction will be ignored.
168
+ :param output_format: span, json or concrete.
169
+ :return: A list of predictions.
170
+ """
171
+ sentences = list(map(self._prepare_sentence, sentences))
172
+ for i_sent, sent in enumerate(sentences):
173
+ sent['meta'] = {"idx": i_sent}
174
+ instances = list(map(self._json_to_instance, sentences))
175
+ outputs = list()
176
+ for ins_indices in MaxTokensBatchSampler(max_tokens, ["tokens"], 0.0).get_batch_indices(instances):
177
+ batch_ins = list(
178
+ SimpleDataLoader([instances[ins_idx] for ins_idx in ins_indices], len(ins_indices), vocab=self.vocab)
179
+ )[0]
180
+ batch_inputs = nn_util.move_to_device(batch_ins, device=self.cuda_device)
181
+ batch_outputs = self._model(**batch_inputs)
182
+ for meta, prediction, inputs in zip(
183
+ batch_outputs['meta'], batch_outputs['prediction'], batch_outputs['inputs']
184
+ ):
185
+ prediction.map_ontology(ontology_mapping)
186
+ prediction = self.format_convert(inputs['sentence'], prediction, output_format)
187
+ outputs.append(PredictionReturn(prediction, inputs['sentence'], {"input_idx": meta['idx']}))
188
+
189
+ outputs.sort(key=lambda x: x.meta['input_idx'])
190
+ return outputs
191
+
192
+ def predict_instance(self, instance: Instance) -> JsonDict:
193
+ outputs = self._model.forward_on_instance(instance)
194
+ outputs = sanitize(outputs)
195
+ return {
196
+ 'prediction': outputs['prediction'],
197
+ 'sentence': outputs['inputs']['sentence'],
198
+ 'meta': outputs.get('meta', {})
199
+ }
200
+
201
+ def __init__(
202
+ self,
203
+ model: Model,
204
+ dataset_reader: DatasetReader,
205
+ frozen: bool = True,
206
+ ):
207
+ super(SpanPredictor, self).__init__(model=model, dataset_reader=dataset_reader, frozen=frozen)
208
+ self.spacy_tokenizer = SpacyTokenizer(language='en_core_web_sm')
209
+
210
+ def economize(
211
+ self,
212
+ max_decoding_spans: Optional[int] = None,
213
+ max_recursion_depth: Optional[int] = None,
214
+ ):
215
+ if max_decoding_spans:
216
+ self._model._max_decoding_spans = max_decoding_spans
217
+ if max_recursion_depth:
218
+ self._model._max_recursion_depth = max_recursion_depth
219
+
220
+ def _json_to_instance(self, json_dict: JsonDict) -> Instance:
221
+ return self._dataset_reader.text_to_instance(**json_dict)
222
+
223
+ @staticmethod
224
+ def to_nested(prediction: List[dict]):
225
+ first_layer, idx2children = list(), dict()
226
+ for idx, pred in enumerate(prediction):
227
+ children = list()
228
+ pred['children'] = idx2children[idx+1] = children
229
+ if pred['parent'] == 0:
230
+ first_layer.append(pred)
231
+ else:
232
+ idx2children[pred['parent']].append(pred)
233
+ del pred['parent']
234
+ return first_layer
235
+
236
+ def _prepare_sentence(self, sentence: Union[str, List[str]]) -> Dict[str, List[str]]:
237
+ if isinstance(sentence, str):
238
+ while ' ' in sentence:
239
+ sentence = sentence.replace(' ', ' ')
240
+ sentence = sentence.replace(chr(65533), '')
241
+ if sentence == '':
242
+ sentence = [""]
243
+ sentence = list(map(str, self.spacy_tokenizer.tokenize(sentence)))
244
+ return {"tokens": sentence}
245
+
246
+ @staticmethod
247
+ def json_to_concrete(
248
+ predictions: List[dict],
249
+ ):
250
+ sentences = list()
251
+ for pred in predictions:
252
+ tokenization, event = list(), list()
253
+ sent = {'text': ' '.join(pred['inputs']), 'tokenization': tokenization, 'event': event}
254
+ sentences.append(sent)
255
+ start_idx = 0
256
+ for token in pred['inputs']:
257
+ tokenization.append((start_idx, len(token)-1+start_idx))
258
+ start_idx += len(token) + 1
259
+ for pred_event in pred['prediction']:
260
+ arg_list = list()
261
+ one_event = {'argument': arg_list}
262
+ event.append(one_event)
263
+ for key in ['start_idx', 'end_idx', 'label']:
264
+ one_event[key] = pred_event[key]
265
+ for pred_arg in pred_event['children']:
266
+ arg_list.append({key: pred_arg[key] for key in ['start_idx', 'end_idx', 'label']})
267
+
268
+ concrete_comm = concrete_doc(sentences)
269
+ return concrete_comm
270
+
271
+ def force_decode(
272
+ self,
273
+ sentence: List[str],
274
+ parent_span: Tuple[int, int] = (-1, -1),
275
+ parent_label: str = VIRTUAL_ROOT,
276
+ child_spans: Optional[List[Tuple[int, int]]] = None,
277
+ ) -> ForceDecodingReturn:
278
+ """
279
+ Force decoding. There are 2 modes:
280
+ 1. Given parent span and its label, find all it children (direct children, not including other descendents)
281
+ and type them.
282
+ 2. Given parent span, parent label, and children spans, type all children.
283
+ :param sentence: Tokens.
284
+ :param parent_span: [start_idx, end_idx], both inclusive.
285
+ :param parent_label: Parent label in string.
286
+ :param child_spans: Optional. If provided, will turn to mode 2; else mode 1.
287
+ :return:
288
+ - span: children spans.
289
+ - label: most probable labels of children.
290
+ - distribution: distribution over children labels.
291
+ """
292
+ instance = self._dataset_reader.text_to_instance(self._prepare_sentence(sentence)['tokens'])
293
+ model_input = nn_util.move_to_device(
294
+ list(SimpleDataLoader([instance], 1, vocab=self.vocab))[0], device=self.cuda_device
295
+ )
296
+ offsets = instance.fields['raw_inputs'].metadata['offsets']
297
+
298
+ with torch.no_grad():
299
+ tokens = model_input['tokens']
300
+ parent_span = re_index_span(parent_span, offsets)
301
+ if parent_span[1] >= self._dataset_reader.max_length:
302
+ return ForceDecodingReturn(
303
+ np.zeros([0, 2], dtype=np.int),
304
+ [],
305
+ np.zeros([0, self.vocab.get_vocab_size('span_label')], dtype=np.float64)
306
+ )
307
+ if child_spans is not None:
308
+ token_vec = self._model.word_embedding(tokens)
309
+ child_pieces = [re_index_span(bdr, offsets) for bdr in child_spans]
310
+ child_pieces = list(filter(lambda x: x[1] < self._dataset_reader.max_length-1, child_pieces))
311
+ span_tensor = torch.tensor(
312
+ [parent_span] + child_pieces, dtype=torch.int64, device=self.device
313
+ ).unsqueeze(0)
314
+ parent_indices = span_tensor.new_zeros(span_tensor.shape[0:2])
315
+ span_labels = parent_indices.new_full(
316
+ parent_indices.shape, self._model.vocab.get_token_index(parent_label, 'span_label')
317
+ )
318
+ span_vec = self._model._span_extractor(token_vec, span_tensor)
319
+ typing_out = self._model._span_typing(span_vec, parent_indices, span_labels)
320
+ distribution = typing_out['distribution'][0, 1:].cpu().numpy()
321
+ boundary = np.array(child_spans)
322
+ else:
323
+ parent_label_tensor = torch.tensor(
324
+ [self._model.vocab.get_token_index(parent_label, 'span_label')], device=self.device
325
+ )
326
+ parent_boundary_tensor = torch.tensor([parent_span], device=self.device)
327
+ boundary, _, num_children, distribution = self._model.one_step_prediction(
328
+ tokens, parent_boundary_tensor, parent_label_tensor
329
+ )
330
+ boundary, distribution = boundary[0].cpu().tolist(), distribution[0].cpu().numpy()
331
+ boundary = np.array([re_index_span(bdr, offsets, True) for bdr in boundary])
332
+
333
+ labels = [
334
+ self.vocab.get_token_from_index(label_idx, 'span_label') for label_idx in distribution.argmax(1)
335
+ ]
336
+ return ForceDecodingReturn(boundary, labels, distribution)
337
+
338
+ @property
339
+ def vocab(self):
340
+ return self._model.vocab
341
+
342
+ @property
343
+ def device(self):
344
+ return self.cuda_device if self.cuda_device > -1 else 'cpu'
345
+
346
+ @staticmethod
347
+ def read_ontology_mapping(file_path: str):
348
+ """
349
+ Read the ontology mapping file. The file format can be read in docs.
350
+ """
351
+ if file_path is None:
352
+ return None
353
+ if file_path.endswith('.json'):
354
+ return json.load(open(file_path))
355
+ mapping = dict()
356
+ for line in open(file_path).readlines():
357
+ parent_label, original_label, new_label = line.replace('\n', '').split('\t')
358
+ if parent_label == '*':
359
+ mapping[original_label] = new_label
360
+ else:
361
+ mapping[(parent_label, original_label)] = new_label
362
+ return mapping
sftp/predictor/span_predictor.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from time import time
3
+ from typing import *
4
+ import json
5
+
6
+ # # ---GFM add debugger
7
+ # import pdb
8
+ # # end---
9
+
10
+ import numpy as np
11
+ import torch
12
+ from allennlp.common.util import JsonDict, sanitize
13
+ from allennlp.data import DatasetReader, Instance
14
+ from allennlp.data.data_loaders import SimpleDataLoader
15
+ from allennlp.data.samplers import MaxTokensBatchSampler
16
+ from allennlp.data.tokenizers import SpacyTokenizer
17
+ from allennlp.models import Model
18
+ from allennlp.nn import util as nn_util
19
+ from allennlp.predictors import Predictor
20
+ from concrete import (
21
+ MentionArgument, SituationMentionSet, SituationMention, TokenRefSequence,
22
+ EntityMention, EntityMentionSet, Entity, EntitySet, AnnotationMetadata, Communication
23
+ )
24
+ from concrete.util import CommunicationReader, AnalyticUUIDGeneratorFactory, CommunicationWriterZip
25
+ from concrete.validate import validate_communication
26
+
27
+ from ..data_reader import concrete_doc, concrete_doc_tokenized
28
+ from ..utils import Span, re_index_span, VIRTUAL_ROOT
29
+
30
+
31
+ class PredictionReturn(NamedTuple):
32
+ span: Union[Span, dict, Communication]
33
+ sentence: List[str]
34
+ meta: Dict[str, Any]
35
+
36
+
37
+ class ForceDecodingReturn(NamedTuple):
38
+ span: np.ndarray
39
+ label: List[str]
40
+ distribution: np.ndarray
41
+
42
+
43
+ @Predictor.register('span')
44
+ class SpanPredictor(Predictor):
45
+ @staticmethod
46
+ def format_convert(
47
+ sentence: Union[List[str], List[List[str]]],
48
+ prediction: Union[Span, List[Span]],
49
+ output_format: str
50
+ ):
51
+ if output_format == 'span':
52
+ return prediction
53
+ elif output_format == 'json':
54
+ if isinstance(prediction, list):
55
+ return [SpanPredictor.format_convert(sent, pred, 'json') for sent, pred in zip(sentence, prediction)]
56
+ return prediction.to_json()
57
+ elif output_format == 'concrete':
58
+ if isinstance(prediction, Span):
59
+ sentence, prediction = [sentence], [prediction]
60
+ return concrete_doc_tokenized(sentence, prediction)
61
+
62
+ def predict_concrete(
63
+ self,
64
+ concrete_path: str,
65
+ output_path: Optional[str] = None,
66
+ max_tokens: int = 2048,
67
+ ontology_mapping: Optional[Dict[str, str]] = None,
68
+ ):
69
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
70
+ writer = CommunicationWriterZip(output_path)
71
+
72
+ print(concrete_path)
73
+ for comm, fn in CommunicationReader(concrete_path):
74
+ print(fn)
75
+ assert len(comm.sectionList) == 1
76
+ concrete_sentences = comm.sectionList[0].sentenceList
77
+ json_sentences = list()
78
+ for con_sent in concrete_sentences:
79
+ json_sentences.append(
80
+ [t.text for t in con_sent.tokenization.tokenList.tokenList]
81
+ )
82
+ predictions = self.predict_batch_sentences(json_sentences, max_tokens, ontology_mapping=ontology_mapping)
83
+
84
+ # Merge predictions into concrete
85
+ aug = AnalyticUUIDGeneratorFactory(comm).create()
86
+ situation_mention_set = SituationMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list())
87
+ comm.situationMentionSetList = [situation_mention_set]
88
+ situation_mention_set.mentionList = sm_list = list()
89
+ entity_mention_set = EntityMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list())
90
+ comm.entityMentionSetList = [entity_mention_set]
91
+ entity_mention_set.mentionList = em_list = list()
92
+ entity_set = EntitySet(
93
+ next(aug), AnnotationMetadata('Span Finder', time()), list(), None, entity_mention_set.uuid
94
+ )
95
+ comm.entitySetList = [entity_set]
96
+
97
+ em_dict = dict()
98
+ for con_sent, pred in zip(concrete_sentences, predictions):
99
+ for event in pred.span:
100
+ def raw_text_span(start_idx, end_idx, **_):
101
+ si_char = con_sent.tokenization.tokenList.tokenList[start_idx].textSpan.start
102
+ ei_char = con_sent.tokenization.tokenList.tokenList[end_idx].textSpan.ending
103
+ return comm.text[si_char:ei_char]
104
+
105
+ # ---GFM: added this to get around off-by-one errors (unclear why these arise)
106
+ event_start_idx = event.start_idx
107
+ event_end_idx = event.end_idx
108
+ if event_end_idx > len(con_sent.tokenization.tokenList.tokenList) - 1:
109
+ print("WARNING: invalid `event_end_idx` passed for sentence, adjusting to final token")
110
+ print("\tsentence:", con_sent.tokenization.tokenList)
111
+ print("event_end_idx:", event_end_idx)
112
+ print("length:", len(con_sent.tokenization.tokenList.tokenList))
113
+ event_end_idx = len(con_sent.tokenization.tokenList.tokenList) - 1
114
+ print("new event_end_idx:", event_end_idx)
115
+ print()
116
+ # end---
117
+
118
+ sm = SituationMention(
119
+ next(aug),
120
+ # ---GFM: added this to get around off-by-one errors (unclear why these arise)
121
+ text=raw_text_span(event_start_idx, event_end_idx),
122
+ # end---
123
+ situationKind=event.label,
124
+ situationType='EVENT',
125
+ confidence=event.confidence,
126
+ argumentList=list(),
127
+ tokens=TokenRefSequence(
128
+ # ---GFM: added this to get around off-by-one errors (unclear why these arise)
129
+ tokenIndexList=list(range(event_start_idx, event_end_idx+1)),
130
+ # end---
131
+ tokenizationId=con_sent.tokenization.uuid
132
+ )
133
+ )
134
+
135
+ for arg in event:
136
+ # ---GFM: added this to get around off-by-one errors (unclear why these arise)
137
+ arg_start_idx = arg.start_idx
138
+ arg_end_idx = arg.end_idx
139
+ if arg_end_idx > len(con_sent.tokenization.tokenList.tokenList) - 1:
140
+ print("WARNING: invalid `arg_end_idx` passed for sentence, adjusting to final token")
141
+ print("\tsentence:", con_sent.tokenization.tokenList)
142
+ print("arg_end_idx:", arg_end_idx)
143
+ print("length:", len(con_sent.tokenization.tokenList.tokenList))
144
+ arg_end_idx = len(con_sent.tokenization.tokenList.tokenList) - 1
145
+ print("new arg_end_idx:", arg_end_idx)
146
+ print()
147
+ # end---
148
+
149
+ # ---GFM: replaced all arg.*_idx to arg_*_idx
150
+ em = em_dict.get((arg_start_idx, arg_end_idx + 1))
151
+ if em is None:
152
+ em = EntityMention(
153
+ next(aug),
154
+ tokens=TokenRefSequence(
155
+ tokenIndexList=list(range(arg_start_idx, arg_end_idx+1)),
156
+ tokenizationId=con_sent.tokenization.uuid,
157
+ ),
158
+ text=raw_text_span(arg_start_idx, arg_end_idx)
159
+ )
160
+ em_list.append(em)
161
+ entity_set.entityList.append(Entity(next(aug), id=em.text, mentionIdList=[em.uuid]))
162
+ em_dict[(arg_start_idx, arg_end_idx+1)] = em
163
+ sm.argumentList.append(MentionArgument(
164
+ role=arg.label,
165
+ entityMentionId=em.uuid,
166
+ confidence=arg.confidence
167
+ ))
168
+ # end---
169
+ sm_list.append(sm)
170
+ validate_communication(comm)
171
+ writer.write(comm, fn)
172
+ writer.close()
173
+
174
+ def predict_sentence(
175
+ self,
176
+ sentence: Union[str, List[str]],
177
+ ontology_mapping: Optional[Dict[str, str]] = None,
178
+ output_format: str = 'span',
179
+ ) -> PredictionReturn:
180
+ """
181
+ Predict spans on a single sentence (no batch). If not tokenized, will tokenize it with SpacyTokenizer.
182
+ :param sentence: If tokenized, should be a list of tokens in string. If not, should be a string.
183
+ :param ontology_mapping:
184
+ :param output_format: span, json or concrete.
185
+ """
186
+ prediction = self.predict_json(self._prepare_sentence(sentence))
187
+ prediction['prediction'] = self.format_convert(
188
+ prediction['sentence'],
189
+ Span.from_json(prediction['prediction']).map_ontology(ontology_mapping),
190
+ output_format
191
+ )
192
+ return PredictionReturn(prediction['prediction'], prediction['sentence'], prediction.get('meta', dict()))
193
+
194
+ def predict_batch_sentences(
195
+ self,
196
+ sentences: List[Union[List[str], str]],
197
+ max_tokens: int = 512,
198
+ ontology_mapping: Optional[Dict[str, str]] = None,
199
+ output_format: str = 'span',
200
+ ) -> List[PredictionReturn]:
201
+ """
202
+ Predict spans on a batch of sentences. If not tokenized, will tokenize it with SpacyTokenizer.
203
+ :param sentences: A list of sentences. Refer to `predict_sentence`.
204
+ :param max_tokens: Maximum tokens in a batch.
205
+ :param ontology_mapping: If not None, will try to map the output from one ontology to another.
206
+ If the predicted frame is not in the mapping, the prediction will be ignored.
207
+ :param output_format: span, json or concrete.
208
+ :return: A list of predictions.
209
+ """
210
+ sentences = list(map(self._prepare_sentence, sentences))
211
+ for i_sent, sent in enumerate(sentences):
212
+ sent['meta'] = {"idx": i_sent}
213
+ instances = list(map(self._json_to_instance, sentences))
214
+ outputs = list()
215
+ for ins_indices in MaxTokensBatchSampler(max_tokens, ["tokens"], 0.0).get_batch_indices(instances):
216
+ batch_ins = list(
217
+ SimpleDataLoader([instances[ins_idx] for ins_idx in ins_indices], len(ins_indices), vocab=self.vocab)
218
+ )[0]
219
+ batch_inputs = nn_util.move_to_device(batch_ins, device=self.cuda_device)
220
+ batch_outputs = self._model(**batch_inputs)
221
+ for meta, prediction, inputs in zip(
222
+ batch_outputs['meta'], batch_outputs['prediction'], batch_outputs['inputs']
223
+ ):
224
+ prediction.map_ontology(ontology_mapping)
225
+ prediction = self.format_convert(inputs['sentence'], prediction, output_format)
226
+ outputs.append(PredictionReturn(prediction, inputs['sentence'], {"input_idx": meta['idx']}))
227
+
228
+ outputs.sort(key=lambda x: x.meta['input_idx'])
229
+ return outputs
230
+
231
+ def predict_instance(self, instance: Instance) -> JsonDict:
232
+ outputs = self._model.forward_on_instance(instance)
233
+ outputs = sanitize(outputs)
234
+ return {
235
+ 'prediction': outputs['prediction'],
236
+ 'sentence': outputs['inputs']['sentence'],
237
+ 'meta': outputs.get('meta', {})
238
+ }
239
+
240
+ def __init__(
241
+ self,
242
+ model: Model,
243
+ dataset_reader: DatasetReader,
244
+ frozen: bool = True,
245
+ ):
246
+ super(SpanPredictor, self).__init__(model=model, dataset_reader=dataset_reader, frozen=frozen)
247
+ self.spacy_tokenizer = SpacyTokenizer(language='en_core_web_sm')
248
+
249
+ def economize(
250
+ self,
251
+ max_decoding_spans: Optional[int] = None,
252
+ max_recursion_depth: Optional[int] = None,
253
+ ):
254
+ if max_decoding_spans:
255
+ self._model._max_decoding_spans = max_decoding_spans
256
+ if max_recursion_depth:
257
+ self._model._max_recursion_depth = max_recursion_depth
258
+
259
+ def _json_to_instance(self, json_dict: JsonDict) -> Instance:
260
+ return self._dataset_reader.text_to_instance(**json_dict)
261
+
262
+ @staticmethod
263
+ def to_nested(prediction: List[dict]):
264
+ first_layer, idx2children = list(), dict()
265
+ for idx, pred in enumerate(prediction):
266
+ children = list()
267
+ pred['children'] = idx2children[idx+1] = children
268
+ if pred['parent'] == 0:
269
+ first_layer.append(pred)
270
+ else:
271
+ idx2children[pred['parent']].append(pred)
272
+ del pred['parent']
273
+ return first_layer
274
+
275
+ def _prepare_sentence(self, sentence: Union[str, List[str]]) -> Dict[str, List[str]]:
276
+ if isinstance(sentence, str):
277
+ while ' ' in sentence:
278
+ sentence = sentence.replace(' ', ' ')
279
+ sentence = sentence.replace(chr(65533), '')
280
+ if sentence == '':
281
+ sentence = [""]
282
+ sentence = list(map(str, self.spacy_tokenizer.tokenize(sentence)))
283
+ return {"tokens": sentence}
284
+
285
+ @staticmethod
286
+ def json_to_concrete(
287
+ predictions: List[dict],
288
+ ):
289
+ sentences = list()
290
+ for pred in predictions:
291
+ tokenization, event = list(), list()
292
+ sent = {'text': ' '.join(pred['inputs']), 'tokenization': tokenization, 'event': event}
293
+ sentences.append(sent)
294
+ start_idx = 0
295
+ for token in pred['inputs']:
296
+ tokenization.append((start_idx, len(token)-1+start_idx))
297
+ start_idx += len(token) + 1
298
+ for pred_event in pred['prediction']:
299
+ arg_list = list()
300
+ one_event = {'argument': arg_list}
301
+ event.append(one_event)
302
+ for key in ['start_idx', 'end_idx', 'label']:
303
+ one_event[key] = pred_event[key]
304
+ for pred_arg in pred_event['children']:
305
+ arg_list.append({key: pred_arg[key] for key in ['start_idx', 'end_idx', 'label']})
306
+
307
+ concrete_comm = concrete_doc(sentences)
308
+ return concrete_comm
309
+
310
+ def force_decode(
311
+ self,
312
+ sentence: List[str],
313
+ parent_span: Tuple[int, int] = (-1, -1),
314
+ parent_label: str = VIRTUAL_ROOT,
315
+ child_spans: Optional[List[Tuple[int, int]]] = None,
316
+ ) -> ForceDecodingReturn:
317
+ """
318
+ Force decoding. There are 2 modes:
319
+ 1. Given parent span and its label, find all it children (direct children, not including other descendents)
320
+ and type them.
321
+ 2. Given parent span, parent label, and children spans, type all children.
322
+ :param sentence: Tokens.
323
+ :param parent_span: [start_idx, end_idx], both inclusive.
324
+ :param parent_label: Parent label in string.
325
+ :param child_spans: Optional. If provided, will turn to mode 2; else mode 1.
326
+ :return:
327
+ - span: children spans.
328
+ - label: most probable labels of children.
329
+ - distribution: distribution over children labels.
330
+ """
331
+ instance = self._dataset_reader.text_to_instance(self._prepare_sentence(sentence)['tokens'])
332
+ model_input = nn_util.move_to_device(
333
+ list(SimpleDataLoader([instance], 1, vocab=self.vocab))[0], device=self.cuda_device
334
+ )
335
+ offsets = instance.fields['raw_inputs'].metadata['offsets']
336
+
337
+ with torch.no_grad():
338
+ tokens = model_input['tokens']
339
+ parent_span = re_index_span(parent_span, offsets)
340
+ if parent_span[1] >= self._dataset_reader.max_length:
341
+ return ForceDecodingReturn(
342
+ np.zeros([0, 2], dtype=np.int),
343
+ [],
344
+ np.zeros([0, self.vocab.get_vocab_size('span_label')], dtype=np.float64)
345
+ )
346
+ if child_spans is not None:
347
+ token_vec = self._model.word_embedding(tokens)
348
+ child_pieces = [re_index_span(bdr, offsets) for bdr in child_spans]
349
+ child_pieces = list(filter(lambda x: x[1] < self._dataset_reader.max_length-1, child_pieces))
350
+ span_tensor = torch.tensor(
351
+ [parent_span] + child_pieces, dtype=torch.int64, device=self.device
352
+ ).unsqueeze(0)
353
+ parent_indices = span_tensor.new_zeros(span_tensor.shape[0:2])
354
+ span_labels = parent_indices.new_full(
355
+ parent_indices.shape, self._model.vocab.get_token_index(parent_label, 'span_label')
356
+ )
357
+ span_vec = self._model._span_extractor(token_vec, span_tensor)
358
+ typing_out = self._model._span_typing(span_vec, parent_indices, span_labels)
359
+ distribution = typing_out['distribution'][0, 1:].cpu().numpy()
360
+ boundary = np.array(child_spans)
361
+ else:
362
+ parent_label_tensor = torch.tensor(
363
+ [self._model.vocab.get_token_index(parent_label, 'span_label')], device=self.device
364
+ )
365
+ parent_boundary_tensor = torch.tensor([parent_span], device=self.device)
366
+ boundary, _, num_children, distribution = self._model.one_step_prediction(
367
+ tokens, parent_boundary_tensor, parent_label_tensor
368
+ )
369
+ boundary, distribution = boundary[0].cpu().tolist(), distribution[0].cpu().numpy()
370
+ boundary = np.array([re_index_span(bdr, offsets, True) for bdr in boundary])
371
+
372
+ labels = [
373
+ self.vocab.get_token_from_index(label_idx, 'span_label') for label_idx in distribution.argmax(1)
374
+ ]
375
+ return ForceDecodingReturn(boundary, labels, distribution)
376
+
377
+ @property
378
+ def vocab(self):
379
+ return self._model.vocab
380
+
381
+ @property
382
+ def device(self):
383
+ return self.cuda_device if self.cuda_device > -1 else 'cpu'
384
+
385
+ @staticmethod
386
+ def read_ontology_mapping(file_path: str):
387
+ """
388
+ Read the ontology mapping file. The file format can be read in docs.
389
+ """
390
+ if file_path is None:
391
+ return None
392
+ if file_path.endswith('.json'):
393
+ return json.load(open(file_path))
394
+ mapping = dict()
395
+ for line in open(file_path).readlines():
396
+ parent_label, original_label, new_label = line.replace('\n', '').split('\t')
397
+ if parent_label == '*':
398
+ mapping[original_label] = new_label
399
+ else:
400
+ mapping[(parent_label, original_label)] = new_label
401
+ return mapping
sftp/training/__init__.py ADDED
File without changes
sftp/training/transformer_optimizer.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ from typing import *
4
+
5
+ import torch
6
+ from allennlp.common.from_params import Params, T
7
+ from allennlp.training.optimizers import Optimizer
8
+
9
+ logger = logging.getLogger('optim')
10
+
11
+
12
+ @Optimizer.register('transformer')
13
+ class TransformerOptimizer:
14
+ """
15
+ Wrapper for AllenNLP optimizer.
16
+ This is used to fine-tune the pretrained transformer with some layers fixed and different learning rate.
17
+ When some layers are fixed, the wrapper will set the `require_grad` flag as False, which could save
18
+ training time and optimize memory usage.
19
+ Plz contact Guanghui Qin for bugs.
20
+ Params:
21
+ base: base optimizer.
22
+ embeddings_lr: learning rate for embedding layer. Set as 0.0 to fix it.
23
+ encoder_lr: learning rate for encoder layer. Set as 0.0 to fix it.
24
+ pooler_lr: learning rate for pooler layer. Set as 0.0 to fix it.
25
+ layer_fix: the number of encoder layers that should be fixed.
26
+
27
+ Example json config:
28
+
29
+ 1. No-op. Do nothing (why do you use me?)
30
+ optimizer: {
31
+ type: "transformer",
32
+ base: {
33
+ type: "adam",
34
+ lr: 0.001
35
+ }
36
+ }
37
+
38
+ 2. Fix everything in the transformer.
39
+ optimizer: {
40
+ type: "transformer",
41
+ base: {
42
+ type: "adam",
43
+ lr: 0.001
44
+ },
45
+ embeddings_lr: 0.0,
46
+ encoder_lr: 0.0,
47
+ pooler_lr: 0.0
48
+ }
49
+
50
+ Or equivalently (suppose we have 24 layers)
51
+
52
+ optimizer: {
53
+ type: "transformer",
54
+ base: {
55
+ type: "adam",
56
+ lr: 0.001
57
+ },
58
+ embeddings_lr: 0.0,
59
+ layer_fix: 24,
60
+ pooler_lr: 0.0
61
+ }
62
+
63
+ 3. Fix embeddings and the lower 12 encoder layers, set a small learning rate
64
+ for the other parts of the transformer
65
+
66
+ optimizer: {
67
+ type: "transformer",
68
+ base: {
69
+ type: "adam",
70
+ lr: 0.001
71
+ },
72
+ embeddings_lr: 0.0,
73
+ layer_fix: 12,
74
+ encoder_lr: 1e-5,
75
+ pooler_lr: 1e-5
76
+ }
77
+ """
78
+ @classmethod
79
+ def from_params(
80
+ cls: Type[T],
81
+ params: Params,
82
+ model_parameters: List[Tuple[str, torch.nn.Parameter]],
83
+ **_
84
+ ):
85
+ param_groups = list()
86
+
87
+ def remove_param(keyword_):
88
+ nonlocal model_parameters
89
+ logger.info(f'Fix param with name matching {keyword_}.')
90
+ for name, param in model_parameters:
91
+ if keyword_ in name:
92
+ logger.debug(f'Fix param {name}.')
93
+ param.requires_grad_(False)
94
+ model_parameters = list(filter(lambda x: keyword_ not in x[0], model_parameters))
95
+
96
+ for i_layer in range(params.pop('layer_fix')):
97
+ remove_param('transformer_model.encoder.layer.{}.'.format(i_layer))
98
+
99
+ for specific_lr, keyword in (
100
+ (params.pop('embeddings_lr', None), 'transformer_model.embeddings'),
101
+ (params.pop('encoder_lr', None), 'transformer_model.encoder.layer'),
102
+ (params.pop('pooler_lr', None), 'transformer_model.pooler'),
103
+ ):
104
+ if specific_lr is not None:
105
+ if specific_lr > 0.:
106
+ pattern = '.*' + keyword.replace('.', r'\.') + '.*'
107
+ if len([name for name, _ in model_parameters if re.match(pattern, name)]) > 0:
108
+ param_groups.append([[pattern], {'lr': specific_lr}])
109
+ else:
110
+ logger.warning(f'{pattern} is set to use lr {specific_lr} but no param matches.')
111
+ else:
112
+ remove_param(keyword)
113
+
114
+ if 'parameter_groups' in params:
115
+ for pg in params.pop('parameter_groups'):
116
+ param_groups.append([pg[0], pg[1].as_dict()])
117
+
118
+ return Optimizer.by_name(params.get('base').pop('type'))(
119
+ model_parameters=model_parameters, parameter_groups=param_groups,
120
+ **params.pop('base').as_flat_dict()
121
+ )
sftp/utils/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import sftp.utils.label_smoothing
2
+ from sftp.utils.common import VIRTUAL_ROOT, DEFAULT_SPAN, BIO
3
+ from sftp.utils.db_storage import Cache
4
+ from sftp.utils.functions import num2mask, mask2idx, numpy2torch, one_hot, max_match
5
+ from sftp.utils.span import Span, re_index_span
6
+ from sftp.utils.span_utils import tensor2span
7
+ from sftp.utils.bio_smoothing import BIOSmoothing, apply_bio_smoothing
sftp/utils/bio_smoothing.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import numpy as np
4
+ from .common import BIO
5
+
6
+
7
+ class BIOSmoothing:
8
+ def __init__(
9
+ self,
10
+ b_smooth: float = 0.0,
11
+ i_smooth: float = 0.0,
12
+ o_smooth: float = 0.0,
13
+ weight: float = 1.0
14
+ ):
15
+ self.smooth = [b_smooth, i_smooth, o_smooth]
16
+ self.weight = weight
17
+
18
+ def apply_sequence(self, sequence: List[str]):
19
+ bio_tags = np.zeros([len(sequence), 3], np.float32)
20
+ for i, tag in enumerate(sequence):
21
+ bio_tags[i] = self.apply_tag(tag)
22
+ return bio_tags
23
+
24
+ def apply_tag(self, tag: str):
25
+ j = BIO.index(tag)
26
+ ret = np.zeros([3], np.float32)
27
+ if self.smooth[j] >= 0.0:
28
+ # Smooth
29
+ ret[j] = 1.0 - self.smooth[j]
30
+ for j_ in set(range(3)) - {j}:
31
+ ret[j_] = self.smooth[j] / 2
32
+ else:
33
+ # Marginalize
34
+ ret[:] = 1.0
35
+
36
+ return ret * self.weight
37
+
38
+ def __repr__(self):
39
+ ret = f'<W={self.weight:.2f}'
40
+ for j, tag in enumerate(BIO):
41
+ if self.smooth[j] != 0.0:
42
+ if self.smooth[j] < 0:
43
+ ret += f' [marginalize {tag}]'
44
+ else:
45
+ ret += f' [smooth {tag} by {self.smooth[j]:.2f}]'
46
+ return ret + '>'
47
+
48
+ def clone(self):
49
+ return BIOSmoothing(*self.smooth, self.weight)
50
+
51
+
52
+ def apply_bio_smoothing(
53
+ config: Optional[Union[BIOSmoothing, List[BIOSmoothing]]],
54
+ bio_seq: List[str]
55
+ ) -> np.ndarray:
56
+ if config is None:
57
+ config = BIOSmoothing()
58
+ if isinstance(config, BIOSmoothing):
59
+ return config.apply_sequence(bio_seq)
60
+ else:
61
+ assert len(bio_seq) == len(config)
62
+ return np.stack([cfg.apply_tag(tag) for cfg, tag in zip(config, bio_seq)])
sftp/utils/common.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ DEFAULT_SPAN = '@@SPAN@@'
2
+ VIRTUAL_ROOT = '@@VIRTUAL_ROOT@@'
3
+ BIO = 'BIO'
sftp/utils/db_storage.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import warnings
3
+
4
+ import h5py
5
+ import numpy as np
6
+
7
+
8
+ class Cache:
9
+ def __init__(self, file: str, mode: str = 'a', overwrite=False):
10
+ self.db_file = h5py.File(file, mode=mode)
11
+ self.overwrite = overwrite
12
+
13
+ @staticmethod
14
+ def _key(key):
15
+ if isinstance(key, str):
16
+ return key
17
+ elif isinstance(key, list):
18
+ ret = []
19
+ for k in key:
20
+ ret.append(Cache._key(k))
21
+ return ' '.join(ret)
22
+ else:
23
+ return str(key)
24
+
25
+ @staticmethod
26
+ def _value(value: np.ndarray):
27
+ if isinstance(value, h5py.Dataset):
28
+ value: np.ndarray = value[()]
29
+ if value.dtype.name.startswith('bytes'):
30
+ value = pickle.loads(value)
31
+ return value
32
+
33
+ def __getitem__(self, key):
34
+ key = self._key(key)
35
+ if key not in self:
36
+ raise KeyError
37
+ return self._value(self.db_file[key])
38
+
39
+ def __setitem__(self, key, value) -> None:
40
+ key = self._key(key)
41
+ if key in self:
42
+ del self.db_file[key]
43
+ if not isinstance(value, np.ndarray):
44
+ value = np.array(pickle.dumps(value))
45
+ self.db_file[key] = value
46
+
47
+ def __delitem__(self, key) -> None:
48
+ key = self._key(key)
49
+ if key in self:
50
+ del self.db_file[key]
51
+
52
+ def __len__(self) -> int:
53
+ return len(self.db_file)
54
+
55
+ def close(self) -> None:
56
+ self.db_file.close()
57
+
58
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
59
+ self.close()
60
+
61
+ def __contains__(self, item):
62
+ item = self._key(item)
63
+ return item in self.db_file
64
+
65
+ def __enter__(self):
66
+ return self
67
+
68
+ def __call__(self, function):
69
+ """
70
+ The object of the class could also be used as a decorator. Provide an additional
71
+ argument `cache_id' when calling the function, and the results will be cached.
72
+ """
73
+
74
+ def wrapper(*args, **kwargs):
75
+ if 'cache_id' in kwargs:
76
+ cache_id = kwargs['cache_id']
77
+ del kwargs['cache_id']
78
+ if cache_id in self and not self.overwrite:
79
+ return self[cache_id]
80
+ rst = function(*args, **kwargs)
81
+ self[cache_id] = rst
82
+ return rst
83
+ else:
84
+ warnings.warn("`cache_id' argument not found. Cache is disabled.")
85
+ return function(*args, **kwargs)
86
+
87
+ return wrapper
sftp/utils/functions.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import numpy as np
4
+ import torch
5
+ from scipy.optimize import linear_sum_assignment
6
+ from torch.nn.utils.rnn import pad_sequence
7
+
8
+
9
+ def num2mask(
10
+ nums: torch.Tensor,
11
+ max_length: Optional[int] = None
12
+ ) -> torch.Tensor:
13
+ """
14
+ E.g. input a tensor [2, 3, 4], return [[T T F F], [T T T F], [T T T T]]
15
+ :param nums: Shape [batch]
16
+ :param max_length: maximum length. if not provided, will choose the largest number from nums.
17
+ :return: 2D binary mask.
18
+ """
19
+ shape_backup = nums.shape
20
+ nums = nums.flatten()
21
+ max_length = max_length or int(nums.max())
22
+ batch_size = len(nums)
23
+ range_nums = torch.arange(0, max_length, device=nums.device).unsqueeze(0).expand([batch_size, max_length])
24
+ ret = (range_nums.T < nums).T
25
+ return ret.reshape(*shape_backup, max_length)
26
+
27
+
28
+ def mask2idx(
29
+ mask: torch.Tensor,
30
+ max_length: Optional[int] = None,
31
+ padding_value: int = 0,
32
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
33
+ """
34
+ E.g. input a tensor [[T T F F], [T T T F], [F F F T]] with padding value -1,
35
+ return [[0, 1, -1], [0, 1, 2], [3, -1, -1]]
36
+ :param mask: Mask tensor. Boolean. Not necessarily to be 2D.
37
+ :param max_length: If provided, will truncate.
38
+ :param padding_value: Padding value. Default to 0.
39
+ :return: Index tensor.
40
+ """
41
+ shape_prefix, mask_length = mask.shape[:-1], mask.shape[-1]
42
+ flat_mask = mask.flatten(0, -2)
43
+ index_list = [torch.arange(mask_length, device=mask.device)[one_mask] for one_mask in flat_mask.unbind(0)]
44
+ index_tensor = pad_sequence(index_list, batch_first=True, padding_value=padding_value)
45
+ if max_length is not None:
46
+ index_tensor = index_tensor[:, :max_length]
47
+ index_tensor = index_tensor.reshape(*shape_prefix, -1)
48
+ return index_tensor, mask.sum(-1)
49
+
50
+
51
+ def one_hot(tags: torch.Tensor, num_tags: Optional[int] = None) -> torch.Tensor:
52
+ num_tags = num_tags or int(tags.max())
53
+ ret = tags.new_zeros(size=[*tags.shape, num_tags], dtype=torch.bool)
54
+ ret.scatter_(2, tags.unsqueeze(2), tags.new_ones([*tags.shape, 1], dtype=torch.bool))
55
+ return ret
56
+
57
+
58
+ def numpy2torch(
59
+ dict_obj: dict
60
+ ) -> dict:
61
+ """
62
+ Convert list/np.ndarray data to torch.Tensor and add add a batch dim.
63
+ """
64
+ ret = dict()
65
+ for k, v in dict_obj.items():
66
+ if isinstance(v, list) or isinstance(v, np.ndarray):
67
+ ret[k] = torch.tensor(v).unsqueeze(0)
68
+ else:
69
+ ret[k] = v
70
+ return ret
71
+
72
+
73
+ def max_match(mat: np.ndarray):
74
+ row_idx, col_idx = linear_sum_assignment(mat, True)
75
+ return mat[row_idx, col_idx].sum()
sftp/utils/label_smoothing.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import KLDivLoss
4
+ from torch.nn import LogSoftmax
5
+
6
+
7
+ class LabelSmoothingLoss(nn.Module):
8
+ def __init__(self, label_smoothing=0.0, unreliable_label=None, ignore_index=-100):
9
+ """
10
+ If label_smoothing == 0.0, it is equivalent to xentropy
11
+ """
12
+ assert 0.0 <= label_smoothing <= 1.0
13
+ super(LabelSmoothingLoss, self).__init__()
14
+
15
+ self.ignore_index = ignore_index
16
+ self.label_smoothing = label_smoothing
17
+
18
+ self.loss_fn = KLDivLoss(reduction='batchmean')
19
+ self.unreliable_label = unreliable_label
20
+ self.max_gap = 100.
21
+ self.log_softmax = LogSoftmax(1)
22
+
23
+ def forward(self, output, target):
24
+ """
25
+ output: logits
26
+ target: labels
27
+ """
28
+ vocab_size = output.shape[1]
29
+ mask = (target != self.ignore_index)
30
+ output, target = output[mask], target[mask]
31
+ output = self.log_softmax(output)
32
+
33
+ def get_smooth_prob(ls):
34
+ smoothing_value = ls / (vocab_size - 1)
35
+ prob = output.new_full((target.size(0), vocab_size), smoothing_value)
36
+ prob.scatter_(1, target.unsqueeze(1), 1 - ls)
37
+ return prob
38
+
39
+ if self.unreliable_label is not None:
40
+ smoothed_prob = get_smooth_prob(self.label_smoothing)
41
+ hard_prob = get_smooth_prob(0.0)
42
+ unreliable_mask = (target == self.unreliable_label).to(torch.float)
43
+ model_prob = ((smoothed_prob.T * unreliable_mask) + (hard_prob.T * (1 - unreliable_mask))).T
44
+ else:
45
+ model_prob = get_smooth_prob(self.label_smoothing)
46
+
47
+ loss = self.loss_fn(output, model_prob)
48
+ return loss
sftp/utils/span.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import numpy as np
4
+
5
+ from .common import VIRTUAL_ROOT, DEFAULT_SPAN
6
+ from .bio_smoothing import BIOSmoothing
7
+ from .functions import max_match
8
+
9
+
10
+ class Span:
11
+ """
12
+ Span is a simple data structure for a span (not necessarily associated with text), along with its label,
13
+ children and possibly its parent and a confidence score.
14
+
15
+ Basic usages (suppose span is a Span object):
16
+ 1. len(span) -- #children.
17
+ 2. span[i] -- i-th child.
18
+ 3. for s in span: ... -- iterate its children.
19
+ 4. for s in span.bfs: ... -- iterate its descendents.
20
+ 5. print(span) -- show its description.
21
+ 6. span.tree() -- print the whole tree.
22
+
23
+ It provides some utilities:
24
+ 1. Re-indexing. BPE will change token indices, and the `re_index` method can convert normal tokens
25
+ BPE word piece indices, or vice versa.
26
+ 2. Span object and span dict (JSON format) are mutually convertible (by `to_json` and `from_json` methods).
27
+ 3. Recursively truncate spans up to a given length. (see `truncate` method)
28
+ 4. Recursively replace all labels with the default label. (see `ignore_labels` method)
29
+ 5. Recursively solve the span overlapping problem by removing children overlapped with others.
30
+ (see `remove_overlapping` method)
31
+ """
32
+ def __init__(
33
+ self,
34
+ start_idx: int,
35
+ end_idx: int,
36
+ label: Union[str, int, list] = DEFAULT_SPAN,
37
+ is_parent: bool = False,
38
+ parent: Optional["Span"] = None,
39
+ confidence: Optional[float] = None,
40
+ ):
41
+ """
42
+ Init function. Children should be added using the `add_children` method.
43
+ :param start_idx: Start index in a seq of tokens, inclusive.
44
+ :param end_idx: End index in a seq of tokens, inclusive.
45
+ :param label: Label. If not provided, will assign a default label.
46
+ Can be of various types: String, integer, or list of something.
47
+ :param is_parent: If True, will be treated as parent. This is important because in the training process of BIO
48
+ tagger, when a span has no children, we need to know if it's a parent with no children (so we should have
49
+ an training example with all O tags) or not (then the above example doesn't exist).
50
+ We follow a convention where if a span is not parent, then the key `children` shouldn't appear in its
51
+ JSON dict; if a span is parent but has no children, the key `children` in its JSON dict should appear
52
+ and be an empty list.
53
+ :param parent: A pointer to its parent.
54
+ :param confidence: Confidence value.
55
+ """
56
+ self.start_idx, self.end_idx = start_idx, end_idx
57
+ self.label: Union[int, str, list] = label
58
+ self.is_parent = is_parent
59
+ self.parent = parent
60
+ self._children: List[Span] = list()
61
+ self.confidence = confidence
62
+
63
+ # Following are for label smoothing. Leave default is you don't need smoothing.
64
+ # Logic:
65
+ # The label smoothing factors of (i.e. b_smooth, i_smooth, o_smooth) depend on the `child_span` of its parent.
66
+ # The re-weighting factor of a span also depends on the `child_span` of its parent, but can be overridden
67
+ # by its own `smoothing_weight` field if it's not None.
68
+ self.child_smooth: BIOSmoothing = BIOSmoothing()
69
+ self.smooth_weight: Optional[float] = None
70
+
71
+ def add_child(self, span: "Span") -> "Span":
72
+ """
73
+ Add a span to children list. Will link current span to child's parent pointer.
74
+ :param span: Child span.
75
+ """
76
+ assert self.is_parent
77
+ span.parent = self
78
+ self._children.append(span)
79
+ return self
80
+
81
+ def re_index(
82
+ self,
83
+ offsets: List[Optional[Tuple[int, int]]],
84
+ reverse: bool = False,
85
+ recursive: bool = True,
86
+ inplace: bool = False,
87
+ ) -> "Span":
88
+ """
89
+ BPE will change token indices, and the `re_index` method can convert normal tokens BPE word piece indices,
90
+ or vice versa.
91
+ We assume Virtual Root has a boundary [-1, -1] before being mapped to the BPE space, and a boundary [0, 0]
92
+ after the re-indexing. We use [0, 0] because it's always the BOS token in BPE.
93
+ Mapping to BPE space is straight forward. The reverse mapping has special cases where the span might
94
+ contain BOS or EOS. Usually this is a parsing bug. We will map the BOS index to 0, and EOS index to -1.
95
+ :param offsets: Offsets. Defined by BPE tokenizer and resides in the SpanFinder outputs.
96
+ :param reverse: If True, map from the BPE space to original token space.
97
+ :param recursive: If True, will apply the re-indexing to its children.
98
+ :param inplace: Inplace?
99
+ :return: Re-indexed span.
100
+ """
101
+ span = self if inplace else self.clone()
102
+
103
+ span.start_idx, span.end_idx = re_index_span(span.boundary, offsets, reverse)
104
+ if recursive:
105
+ new_children = list()
106
+ for child in span._children:
107
+ new_children.append(child.re_index(offsets, reverse, recursive, True))
108
+ span._children = new_children
109
+ return span
110
+
111
+ def truncate(self, max_length: int) -> bool:
112
+ """
113
+ Discard spans whose end_idx exceeds the max_length (inclusive).
114
+ This is done recursively.
115
+ This is useful for some encoder like XLMR that has a limit on input length. (512 for XLMR large)
116
+ :param max_length: Max length.
117
+ :return: You don't need to care return value.
118
+ """
119
+ if self.end_idx >= max_length:
120
+ return False
121
+ else:
122
+ self._children = list(filter(lambda x: x.truncate(max_length), self._children))
123
+ return True
124
+
125
+ @classmethod
126
+ def virtual_root(cls: "Span", spans: Optional[List["Span"]] = None) -> "Span":
127
+ """
128
+ An official method to create a tree: Generate the first layer of spans by yourself, and pass them into this
129
+ method.
130
+ E.g., for SRL style task, generate a list of events, assign arguments to them as children. Then pass the
131
+ events to this method to have a virtual root which serves as a parent of events.
132
+ :param spans: 1st layer spans.
133
+ :return: Virtual root.
134
+ """
135
+ vr = Span(-1, -1, VIRTUAL_ROOT, True)
136
+ if spans is not None:
137
+ vr._children = spans
138
+ for child in vr._children:
139
+ child.parent = vr
140
+ return vr
141
+
142
+ def ignore_labels(self) -> None:
143
+ """
144
+ Remove all labels. Make them placeholders. Inplace.
145
+ """
146
+ self.label = DEFAULT_SPAN
147
+ for child in self._children:
148
+ child.ignore_labels()
149
+
150
+ def clone(self) -> "Span":
151
+ """
152
+ Clone a tree.
153
+ :return: Cloned tree.
154
+ """
155
+ span = Span(self.start_idx, self.end_idx, self.label, self.is_parent, self.parent, self.confidence)
156
+ span.child_smooth, span.smooth_weight = self.child_smooth, self.smooth_weight
157
+ for child in self._children:
158
+ span.add_child(child.clone())
159
+ return span
160
+
161
+ def bfs(self) -> Iterable["Span"]:
162
+ """
163
+ Iterate over all descendents with BFS, including self.
164
+ :return: Spans.
165
+ """
166
+ yield self
167
+ yield from self._bfs()
168
+
169
+ def _bfs(self) -> List["Span"]:
170
+ """
171
+ Helper function.
172
+ """
173
+ for child in self._children:
174
+ yield child
175
+ for child in self._children:
176
+ yield from child._bfs()
177
+
178
+ def remove_overlapping(self, recursive=True) -> int:
179
+ """
180
+ Remove overlapped spans. If spans overlap, will pick the first one and discard the others, judged by start_idx.
181
+ :param recursive: Apply to all of the descendents?
182
+ :return: The number of spans that are removed.
183
+ """
184
+ indices = set()
185
+ new_children = list()
186
+ removing = 0
187
+ for child in self._children:
188
+ if len(set(range(child.start_idx, child.end_idx + 1)) & indices) > 0:
189
+ removing += 1
190
+ continue
191
+ indices.update(set(range(child.start_idx, child.end_idx + 1)))
192
+ new_children.append(child)
193
+ if recursive:
194
+ removing += child.remove_overlapping(True)
195
+ self._children = new_children
196
+ return removing
197
+
198
+ def describe(self, sentence: Optional[List[str]] = None) -> str:
199
+ """
200
+ :param sentence: If provided, will replace the indices with real tokens for presentation.
201
+ :return: The description in a single line.
202
+ """
203
+ if self.start_idx >= 0:
204
+ if sentence is None:
205
+ span = f'({self.start_idx}, {self.end_idx})'
206
+ else:
207
+ span = '(' + ' '.join(sentence[self.start_idx: self.end_idx + 1]) + ')'
208
+ if self.is_parent:
209
+ return f'<Span: {span}, {self.label}, {len(self._children)} children>'
210
+ else:
211
+ return f'[Span: {span}, {self.label}]'
212
+ else:
213
+ return f'<Span Annotation: {self.n_nodes - 1} descendents>'
214
+
215
+ def __repr__(self) -> str:
216
+ return self.describe()
217
+
218
+ @property
219
+ def n_nodes(self) -> int:
220
+ """
221
+ :return: Number of descendents + self.
222
+ """
223
+ return sum([child.n_nodes for child in self._children], 1)
224
+
225
+ @property
226
+ def boundary(self):
227
+ """
228
+ :return: (start_idx, end_idx), both inclusive.
229
+ """
230
+ return self.start_idx, self.end_idx
231
+
232
+ def __iter__(self) -> Iterable["Span"]:
233
+ """
234
+ Iterate over children.
235
+ """
236
+ yield from self._children
237
+
238
+ def __len__(self):
239
+ """
240
+ :return: #children.
241
+ """
242
+ return len(self._children)
243
+
244
+ def __getitem__(self, idx: int):
245
+ """
246
+ :return: The indexed child.
247
+ """
248
+ return self._children[idx]
249
+
250
+ def tree(self, sentence: Optional[List[str]] = None, printing: bool = True) -> str:
251
+ """
252
+ A tree description of all descendents. Human readable.
253
+ :param sentence: If provided, will replace the indices with real tokens for presentation.
254
+ :param printing: If True, will print out.
255
+ :return: The description.
256
+ """
257
+ ret = list()
258
+ ret.append(self.describe(sentence))
259
+ for child in self._children:
260
+ child_lines = child.tree(sentence, False).split('\n')
261
+ for line in child_lines:
262
+ ret.append(' ' + line)
263
+ desc = '\n'.join(ret)
264
+ if printing: print(desc)
265
+ else: return desc
266
+
267
+ def match(
268
+ self,
269
+ other: "Span",
270
+ match_label: bool = True,
271
+ depth: int = -1,
272
+ ignore_parent_boundary: bool = False,
273
+ ) -> int:
274
+ """
275
+ Used for evaluation. Count how many spans two trees share. Two spans are considered to be identical
276
+ if their boundary, label, and parent match.
277
+ :param other: The other tree to compare.
278
+ :param match_label: If False, will ignore label.
279
+ :param depth: If specified as non-negative, will only search thru certain depth.
280
+ :param ignore_parent_boundary: If True, two children can be matched ignoring parent boundaries.
281
+ :return: #spans two tree share.
282
+ """
283
+ if depth == 0:
284
+ return 0
285
+ if self.label != other.label and match_label:
286
+ return 0
287
+ if self.boundary == other.boundary:
288
+ n_match = 1
289
+ elif ignore_parent_boundary:
290
+ # Parents fail, Children might match!
291
+ n_match = 0
292
+ else:
293
+ return 0
294
+
295
+ sub_matches = np.zeros([len(self), len(other)], dtype=np.int)
296
+ for self_idx, my_child in enumerate(self):
297
+ for other_idx, other_child in enumerate(other):
298
+ sub_matches[self_idx, other_idx] = my_child.match(
299
+ other_child, match_label, depth-1, ignore_parent_boundary
300
+ )
301
+ if not ignore_parent_boundary:
302
+ for m in [sub_matches, sub_matches.T]:
303
+ for line in m:
304
+ assert (line > 0).sum() <= 1
305
+ n_match += max_match(sub_matches)
306
+ return n_match
307
+
308
+ def to_json(self) -> dict:
309
+ """
310
+ To JSON dict format. See init.
311
+ """
312
+ ret = {
313
+ "label": self.label,
314
+ "span": list(self.boundary),
315
+ }
316
+ if self.confidence is not None:
317
+ ret['confidence'] = self.confidence
318
+ if self.is_parent:
319
+ children = list()
320
+ for child in self._children:
321
+ children.append(child.to_json())
322
+ ret['children'] = children
323
+ return ret
324
+
325
+ @classmethod
326
+ def from_json(cls, span_json: Union[list, dict]) -> "Span":
327
+ """
328
+ Load from JSON. See init.
329
+ """
330
+ if isinstance(span_json, dict):
331
+ span = Span(
332
+ span_json['span'][0], span_json['span'][1], span_json.get('label', None), 'children' in span_json,
333
+ confidence=span_json.get('confidence', None)
334
+ )
335
+ for child_dict in span_json.get('children', []):
336
+ span.add_child(Span.from_json(child_dict))
337
+ else:
338
+ spans = [Span.from_json(child) for child in span_json]
339
+ span = Span.virtual_root(spans)
340
+ return span
341
+
342
+ def map_ontology(
343
+ self,
344
+ ontology_mapping: Optional[dict] = None,
345
+ inplace: bool = True,
346
+ recursive: bool = True,
347
+ ) -> Optional["Span"]:
348
+ """
349
+ Map labels to other things, like another ontology of soft labels.
350
+ :param ontology_mapping: Mapping dict. The key should be labels, and values can be anything.
351
+ Labels not in the dict will not be deleted. So be careful.
352
+ :param inplace: Inplace?
353
+ :param recursive: Apply to all descendents if True.
354
+ :return: The mapped tree.
355
+ """
356
+ span = self if inplace else self.clone()
357
+ if ontology_mapping is None:
358
+ # Do nothing if mapping not provided.
359
+ return span
360
+
361
+ if recursive:
362
+ new_children = list()
363
+ for child in span:
364
+ new_child = child.map_ontology(ontology_mapping, False, True)
365
+ if new_child is not None:
366
+ new_children.append(new_child)
367
+ span._children = new_children
368
+
369
+ if span.label != VIRTUAL_ROOT:
370
+ if span.parent is not None and (span.parent.label, span.label) in ontology_mapping:
371
+ span.label = ontology_mapping[(span.parent.label, span.label)]
372
+ elif span.label in ontology_mapping:
373
+ span.label = ontology_mapping[span.label]
374
+ else:
375
+ return
376
+
377
+ return span
378
+
379
+ def isolate(self) -> "Span":
380
+ """
381
+ Generate a span that is identical to self but has no children or parent.
382
+ """
383
+ return Span(self.start_idx, self.end_idx, self.label, self.is_parent, None, self.confidence)
384
+
385
+ def remove_child(self, span: Optional["Span"] = None):
386
+ """
387
+ Remove a child. If pass None, will reset the children list.
388
+ """
389
+ if span is None:
390
+ self._children = list()
391
+ else:
392
+ del self._children[self._children.index(span)]
393
+
394
+
395
+ def re_index_span(
396
+ boundary: Tuple[int, int], offsets: List[Tuple[int, int]], reverse: bool = False
397
+ ) -> Tuple[int, int]:
398
+ """
399
+ Helper function.
400
+ """
401
+ if not reverse:
402
+ if boundary[0] == boundary[1] == -1:
403
+ # Virtual Root
404
+ start_idx = end_idx = 0
405
+ else:
406
+ start_idx = offsets[boundary[0]][0]
407
+ end_idx = offsets[boundary[1]][1]
408
+ else:
409
+ if boundary[0] == boundary[1] == 0:
410
+ # Virtual Root
411
+ start_idx = end_idx = -1
412
+ else:
413
+ start_within = [bo[0] <= boundary[0] <= bo[1] if bo is not None else False for bo in offsets]
414
+ end_within = [bo[0] <= boundary[1] <= bo[1] if bo is not None else False for bo in offsets]
415
+ assert sum(start_within) <= 1 and sum(end_within) <= 1
416
+ start_idx = start_within.index(True) if sum(start_within) == 1 else 0
417
+ end_idx = end_within.index(True) if sum(end_within) == 1 else len(offsets)
418
+ if start_idx > end_idx:
419
+ raise IndexError
420
+ return start_idx, end_idx
sftp/utils/span_utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import torch
4
+
5
+ from .span import Span
6
+
7
+
8
+ def _tensor2span_batch(
9
+ span_boundary: torch.Tensor,
10
+ span_labels: torch.Tensor,
11
+ parent_indices: torch.Tensor,
12
+ num_spans: torch.Tensor,
13
+ label_confidence: torch.Tensor,
14
+ idx2label: Dict[int, str],
15
+ label_ignore: List[int],
16
+ ) -> Span:
17
+ spans = list()
18
+ for (start_idx, end_idx), parent_idx, label, label_conf in \
19
+ list(zip(span_boundary, parent_indices, span_labels, label_confidence))[:int(num_spans)]:
20
+ if label not in label_ignore:
21
+ span = Span(int(start_idx), int(end_idx), idx2label[int(label)], True, confidence=float(label_conf))
22
+ if int(parent_idx) < len(spans):
23
+ spans[int(parent_idx)].add_child(span)
24
+ spans.append(span)
25
+ return spans[0]
26
+
27
+
28
+ def tensor2span(
29
+ span_boundary: torch.Tensor,
30
+ span_labels: torch.Tensor,
31
+ parent_indices: torch.Tensor,
32
+ num_spans: torch.Tensor,
33
+ label_confidence: torch.Tensor,
34
+ idx2label: Dict[int, str],
35
+ label_ignore: Optional[List[int]] = None,
36
+ ) -> List[Span]:
37
+ """
38
+ Generate spans in dict from vectors. Refer to the model part for the meaning of these variables.
39
+ If idx_ignore is provided, some labels will be ignored.
40
+ :return:
41
+ """
42
+ label_ignore = label_ignore or []
43
+ if span_boundary.device.type != 'cpu':
44
+ span_boundary = span_boundary.to(device='cpu')
45
+ parent_indices = parent_indices.to(device='cpu')
46
+ span_labels = span_labels.to(device='cpu')
47
+ num_spans = num_spans.to(device='cpu')
48
+ label_confidence = label_confidence.to(device='cpu')
49
+
50
+ ret = list()
51
+ for args in zip(
52
+ span_boundary.unbind(0), span_labels.unbind(0), parent_indices.unbind(0), num_spans.unbind(0),
53
+ label_confidence.unbind(0),
54
+ ):
55
+ ret.append(_tensor2span_batch(*args, label_ignore=label_ignore, idx2label=idx2label))
56
+
57
+ return ret
sociolome/combine_models.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+ import dataclasses
3
+ import glob
4
+ import os
5
+ import sys
6
+ import json
7
+
8
+ import spacy
9
+ from spacy.language import Language
10
+
11
+ from sftp import SpanPredictor
12
+
13
+
14
+ @dataclasses.dataclass
15
+ class FrameAnnotation:
16
+ tokens: List[str] = dataclasses.field(default_factory=list)
17
+ pos: List[str] = dataclasses.field(default_factory=list)
18
+
19
+
20
+ @dataclasses.dataclass
21
+ class MultiLabelAnnotation(FrameAnnotation):
22
+ frame_list: List[List[str]] = dataclasses.field(default_factory=list)
23
+ lu_list: List[Optional[str]] = dataclasses.field(default_factory=list)
24
+
25
+ def to_txt(self):
26
+ for i, tok in enumerate(self.tokens):
27
+ yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}"
28
+
29
+
30
+ def convert_to_seq_labels(sentence: List[str], structures: Dict[int, Dict[str, Any]]) -> List[List[str]]:
31
+ labels = [[] for _ in sentence]
32
+
33
+ for struct_id, struct in structures.items():
34
+ tgt_span = struct["target"]
35
+ frame = struct["frame"]
36
+
37
+ for i in range(tgt_span[0], tgt_span[1] + 1):
38
+ labels[i].append(f"T:{frame}@{struct_id:02}")
39
+ for role in struct["roles"]:
40
+ role_span = role["boundary"]
41
+ role_label = role["label"]
42
+ for i in range(role_span[0], role_span[1] + 1):
43
+ prefix = "B" if i == role_span[0] else "I"
44
+ labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}")
45
+ return labels
46
+
47
+
48
+ def predict_combined(
49
+ spacy_model: Language,
50
+ sentences: List[str],
51
+ tgt_predictor: SpanPredictor,
52
+ frm_predictor: SpanPredictor,
53
+ bnd_predictor: SpanPredictor,
54
+ arg_predictor: SpanPredictor,
55
+ ) -> List[MultiLabelAnnotation]:
56
+
57
+ annotations_out = []
58
+
59
+ for sent_idx, sent in enumerate(sentences):
60
+
61
+ sent = sent.strip()
62
+
63
+ print(f"Processing sent with idx={sent_idx}: {sent}")
64
+
65
+ doc = spacy_model(sent)
66
+ sent_tokens = [t.text for t in doc]
67
+
68
+ tgt_spans, _, _ = tgt_predictor.force_decode(sent_tokens)
69
+
70
+ frame_structures = {}
71
+
72
+ for i, span in enumerate(tgt_spans):
73
+ span = tuple(span)
74
+ _, fr_labels, _ = frm_predictor.force_decode(sent_tokens, child_spans=[span])
75
+ frame = fr_labels[0]
76
+ if frame == "@@VIRTUAL_ROOT@@@":
77
+ continue
78
+
79
+ boundaries, _, _ = bnd_predictor.force_decode(sent_tokens, parent_span=span, parent_label=frame)
80
+ _, arg_labels, _ = arg_predictor.force_decode(sent_tokens, parent_span=span, parent_label=frame, child_spans=boundaries)
81
+
82
+ frame_structures[i] = {
83
+ "target": span,
84
+ "frame": frame,
85
+ "roles": [
86
+ {"boundary": bnd, "label": label}
87
+ for bnd, label in zip(boundaries, arg_labels)
88
+ if label != "Target"
89
+ ]
90
+ }
91
+ annotations_out.append(MultiLabelAnnotation(
92
+ tokens=sent_tokens,
93
+ pos=[t.pos_ for t in doc],
94
+ frame_list=convert_to_seq_labels(sent_tokens, frame_structures),
95
+ lu_list=[None for _ in sent_tokens]
96
+ ))
97
+ return annotations_out
98
+
99
+
100
+ def main(input_folder):
101
+
102
+ print("Loading spaCy model ...")
103
+ nlp = spacy.load("it_core_news_md")
104
+
105
+ print("Loading predictors ...")
106
+ zs_predictor = SpanPredictor.from_path("/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz", cuda_device=0)
107
+ ev_predictor = SpanPredictor.from_path("/scratch/p289731/lome-training-files/train-evalita-plus-fn-vanilla/model.tar.gz", cuda_device=0)
108
+
109
+
110
+ print("Reading input files ...")
111
+ for file in glob.glob(os.path.join(input_folder, "*.txt")):
112
+ print(file)
113
+ with open(file, encoding="utf-8") as f:
114
+ sentences = list(f)
115
+
116
+ annotations = predict_combined(nlp, sentences, zs_predictor, ev_predictor, ev_predictor, ev_predictor)
117
+
118
+ out_name = os.path.splitext(os.path.basename(file))[0]
119
+ with open(f"../../data-out/{out_name}.combined_zs_ev.tc_bilstm.txt", "w", encoding="utf-8") as f_out:
120
+ for ann in annotations:
121
+ for line in ann.to_txt():
122
+ f_out.write(line + os.linesep)
123
+ f_out.write(os.linesep)
124
+
125
+ with open(f"../../data-out/{out_name}.combined_zs_ev.tc_bilstm.json", "w", encoding="utf-8") as f_out:
126
+ json.dump([dataclasses.asdict(ann) for ann in annotations], f_out)
127
+
128
+
129
+ if __name__ == "__main__":
130
+ main(sys.argv[1])
sociolome/evalita_eval.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List, Tuple
3
+
4
+ import pandas as pd
5
+
6
+ from sftp import SpanPredictor
7
+
8
+
9
+ def main():
10
+ # data_file = "/home/p289731/cloned/lome/preproc/evalita_jsonl/evalita_dev.jsonl"
11
+ # data_file = "/home/p289731/cloned/lome/preproc/svm_challenge.jsonl"
12
+ data_file = "/home/p289731/cloned/lome/preproc/evalita_jsonl/evalita_test.jsonl"
13
+ models = [
14
+ (
15
+ "lome-en",
16
+ "/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz",
17
+ ),
18
+ (
19
+ "lome-it-best",
20
+ "/scratch/p289731/lome-training-files/train-evalita-plus-fn-vanilla/model.tar.gz",
21
+ ),
22
+ # (
23
+ # "lome-it-freeze",
24
+ # "/data/p289731/cloned/lome/train-evalita-plus-fn-freeze/model.tar.gz",
25
+ # ),
26
+ # (
27
+ # "lome-it-mono",
28
+ # "/data/p289731/cloned/lome/train-evalita-it_mono/model.tar.gz",
29
+ # ),
30
+ ]
31
+
32
+ for (model_name, model_path) in models:
33
+ print("testing model: ", model_name)
34
+ predictor = SpanPredictor.from_path(model_path)
35
+
36
+ print("=== FD (run 1) ===")
37
+ eval_frame_detection(data_file, predictor, model_name=model_name)
38
+
39
+ for run in [1, 2]:
40
+ print(f"=== BD (run {run}) ===")
41
+ eval_boundary_detection(data_file, predictor, run=run)
42
+
43
+ for run in [1, 2, 3]:
44
+ print(f"=== AC (run {run}) ===")
45
+ eval_argument_classification(data_file, predictor, run=run)
46
+
47
+
48
+ def predict_frame(
49
+ predictor: SpanPredictor, tokens: List[str], predicate_span: Tuple[int, int]
50
+ ):
51
+ _, labels, _ = predictor.force_decode(tokens, child_spans=[predicate_span])
52
+ return labels[0]
53
+
54
+
55
+ def eval_frame_detection(data_file, predictor, verbose=False, model_name="_"):
56
+
57
+ true_pos = 0
58
+ false_pos = 0
59
+
60
+ out = []
61
+
62
+ with open(data_file, encoding="utf-8") as f:
63
+ for sent_id, sent in enumerate(f):
64
+ sent_data = json.loads(sent)
65
+
66
+ tokens = sent_data["tokens"]
67
+ annotation = sent_data["annotations"][0]
68
+
69
+ predicate_span = tuple(annotation["span"])
70
+ predicate = tokens[predicate_span[0] : predicate_span[1] + 1]
71
+
72
+ frame_gold = annotation["label"]
73
+ frame_pred = predict_frame(predictor, tokens, predicate_span)
74
+
75
+ if frame_pred == frame_gold:
76
+ true_pos += 1
77
+ else:
78
+ false_pos += 1
79
+
80
+ out.append({
81
+ "sentence": " ".join(tokens),
82
+ "predicate": predicate,
83
+ "frame_gold": frame_gold,
84
+ "frame_pred": frame_pred
85
+ })
86
+
87
+ if verbose:
88
+ print(f"Sentence #{sent_id:03}: {' '.join(tokens)}")
89
+ print(f"\tpredicate: {predicate}")
90
+ print(f"\t gold: {frame_gold}")
91
+ print(f"\tpredicted: {frame_pred}")
92
+ print()
93
+
94
+ acc_score = true_pos / (true_pos + false_pos)
95
+ print("ACC =", acc_score)
96
+
97
+ data_sect = "rai" if "svm_challenge" in data_file else "dev" if "dev" in data_file else "test"
98
+
99
+ df_out = pd.DataFrame(out)
100
+ df_out.to_csv(f"frame_prediction_output_{model_name}_{data_sect}.csv")
101
+
102
+
103
+ def predict_boundaries(predictor: SpanPredictor, tokens, predicate_span, frame):
104
+ boundaries, labels, _ = predictor.force_decode(
105
+ tokens, parent_span=predicate_span, parent_label=frame
106
+ )
107
+ out = []
108
+ for bnd, lab in zip(boundaries, labels):
109
+ bnd = tuple(bnd)
110
+ if bnd == predicate_span and lab == "Target":
111
+ continue
112
+ out.append(bnd)
113
+ return out
114
+
115
+
116
+ def get_gold_boundaries(annotation, predicate_span):
117
+ return {
118
+ tuple(c["span"])
119
+ for c in annotation["children"]
120
+ if not (tuple(c["span"]) == predicate_span and c["label"] == "Target")
121
+ }
122
+
123
+
124
+ def eval_boundary_detection(data_file, predictor, run=1, verbose=False):
125
+
126
+ assert run in [1, 2]
127
+
128
+ true_pos = 0
129
+ false_pos = 0
130
+ false_neg = 0
131
+
132
+ true_pos_tok = 0
133
+ false_pos_tok = 0
134
+ false_neg_tok = 0
135
+
136
+ with open(data_file, encoding="utf-8") as f:
137
+ for sent_id, sent in enumerate(f):
138
+ sent_data = json.loads(sent)
139
+
140
+ tokens = sent_data["tokens"]
141
+ annotation = sent_data["annotations"][0]
142
+
143
+ predicate_span = tuple(annotation["span"])
144
+ predicate = tokens[predicate_span[0] : predicate_span[1] + 1]
145
+
146
+ if run == 1:
147
+ frame = predict_frame(predictor, tokens, predicate_span)
148
+ else:
149
+ frame = annotation["label"]
150
+
151
+ boundaries_gold = get_gold_boundaries(annotation, predicate_span)
152
+ boundaries_pred = set(
153
+ predict_boundaries(predictor, tokens, predicate_span, frame)
154
+ )
155
+
156
+ sent_true_pos = len(boundaries_gold & boundaries_pred)
157
+ sent_false_pos = len(boundaries_pred - boundaries_gold)
158
+ sent_false_neg = len(boundaries_gold - boundaries_pred)
159
+ true_pos += sent_true_pos
160
+ false_pos += sent_false_pos
161
+ false_neg += sent_false_neg
162
+
163
+ boundary_toks_gold = {
164
+ tok_idx
165
+ for (start, stop) in boundaries_gold
166
+ for tok_idx in range(start, stop + 1)
167
+ }
168
+ boundary_toks_pred = {
169
+ tok_idx
170
+ for (start, stop) in boundaries_pred
171
+ for tok_idx in range(start, stop + 1)
172
+ }
173
+ sent_tok_true_pos = len(boundary_toks_gold & boundary_toks_pred)
174
+ sent_tok_false_pos = len(boundary_toks_pred - boundary_toks_gold)
175
+ sent_tok_false_neg = len(boundary_toks_gold - boundary_toks_pred)
176
+ true_pos_tok += sent_tok_true_pos
177
+ false_pos_tok += sent_tok_false_pos
178
+ false_neg_tok += sent_tok_false_neg
179
+
180
+ if verbose:
181
+ print(f"Sentence #{sent_id:03}: {' '.join(tokens)}")
182
+ print(f"\tpredicate: {predicate}")
183
+ print(f"\t frame: {frame}")
184
+ print(f"\t gold: {boundaries_gold}")
185
+ print(f"\tpredicted: {boundaries_pred}")
186
+ print(f"\ttp={sent_true_pos}\tfp={sent_false_pos}\tfn={sent_false_neg}")
187
+ print(
188
+ f"\ttp_t={sent_tok_true_pos}\tfp_t={sent_tok_false_pos}\tfn_t={sent_tok_false_neg}"
189
+ )
190
+ print()
191
+
192
+ prec = true_pos / (true_pos + false_pos)
193
+ rec = true_pos / (true_pos + false_neg)
194
+ f1_score = 2 * ((prec * rec) / (prec + rec))
195
+
196
+ print(f"P/R/F=\n{prec}\t{rec}\t{f1_score}")
197
+
198
+ tok_prec = true_pos_tok / (true_pos_tok + false_pos_tok)
199
+ tok_rec = true_pos_tok / (true_pos_tok + false_neg_tok)
200
+ tok_f1 = 2 * ((tok_prec * tok_rec) / (tok_prec + tok_rec))
201
+
202
+ print(f"Pt/Rt/Ft=\n{tok_prec}\t{tok_rec}\t{tok_f1}")
203
+
204
+
205
+ def predict_arguments(
206
+ predictor: SpanPredictor, tokens, predicate_span, frame, boundaries
207
+ ):
208
+ boundaries = list(sorted(boundaries, key=lambda t: t[0]))
209
+ _, labels, _ = predictor.force_decode(
210
+ tokens, parent_span=predicate_span, parent_label=frame, child_spans=boundaries
211
+ )
212
+ out = []
213
+ for bnd, lab in zip(boundaries, labels):
214
+ if bnd == predicate_span and lab == "Target":
215
+ continue
216
+ out.append((bnd, lab))
217
+ return out
218
+
219
+
220
+ def eval_argument_classification(data_file, predictor, run=1, verbose=False):
221
+ assert run in [1, 2, 3]
222
+
223
+ true_pos = 0
224
+ false_pos = 0
225
+ false_neg = 0
226
+
227
+ true_pos_tok = 0
228
+ false_pos_tok = 0
229
+ false_neg_tok = 0
230
+
231
+ with open(data_file, encoding="utf-8") as f:
232
+ for sent_id, sent in enumerate(f):
233
+ sent_data = json.loads(sent)
234
+
235
+ tokens = sent_data["tokens"]
236
+ annotation = sent_data["annotations"][0]
237
+
238
+ predicate_span = tuple(annotation["span"])
239
+ predicate = tokens[predicate_span[0] : predicate_span[1] + 1]
240
+
241
+ # gold or predicted frames?
242
+ if run == 1:
243
+ frame = predict_frame(predictor, tokens, predicate_span)
244
+ else:
245
+ frame = annotation["label"]
246
+
247
+ # gold or predicted argument boundaries?
248
+ if run in [1, 2]:
249
+ boundaries = set(
250
+ predict_boundaries(predictor, tokens, predicate_span, frame)
251
+ )
252
+ else:
253
+ boundaries = get_gold_boundaries(annotation, predicate_span)
254
+
255
+ pred_arguments = predict_arguments(
256
+ predictor, tokens, predicate_span, frame, boundaries
257
+ )
258
+ gold_arguments = {
259
+ (tuple(c["span"]), c["label"])
260
+ for c in annotation["children"]
261
+ if not (tuple(c["span"]) == predicate_span and c["label"] == "Target")
262
+ }
263
+
264
+ if verbose:
265
+ print(f"Sentence #{sent_id:03}: {' '.join(tokens)}")
266
+ print(f"\tpredicate: {predicate}")
267
+ print(f"\t frame: {frame}")
268
+ print(f"\t gold: {gold_arguments}")
269
+ print(f"\tpredicted: {pred_arguments}")
270
+ print()
271
+
272
+ # -- full spans version
273
+ for g_bnd, g_label in gold_arguments:
274
+ # true positive: found the span and labeled it correctly
275
+ if (g_bnd, g_label) in pred_arguments:
276
+ true_pos += 1
277
+ # false negative: missed this argument
278
+ else:
279
+ false_neg += 1
280
+ for p_bnd, p_label in pred_arguments:
281
+ # all predictions that are not true positives are false positives
282
+ if (p_bnd, p_label) not in gold_arguments:
283
+ false_pos += 1
284
+
285
+ # -- token based
286
+ tok_gold_labels = {
287
+ (token, label)
288
+ for ((bnd_start, bnd_end), label) in gold_arguments
289
+ for token in range(bnd_start, bnd_end + 1)
290
+ }
291
+ tok_pred_labels = {
292
+ (token, label)
293
+ for ((bnd_start, bnd_end), label) in pred_arguments
294
+ for token in range(bnd_start, bnd_end + 1)
295
+ }
296
+ for g_tok, g_tok_label in tok_gold_labels:
297
+ if (g_tok, g_tok_label) in tok_pred_labels:
298
+ true_pos_tok += 1
299
+ else:
300
+ false_neg_tok += 1
301
+ for p_tok, p_tok_label in tok_pred_labels:
302
+ if (p_tok, p_tok_label) not in tok_gold_labels:
303
+ false_pos_tok += 1
304
+
305
+ prec = true_pos / (true_pos + false_pos)
306
+ rec = true_pos / (true_pos + false_neg)
307
+ f1_score = 2 * ((prec * rec) / (prec + rec))
308
+
309
+ print(f"P/R/F=\n{prec}\t{rec}\t{f1_score}")
310
+
311
+ tok_prec = true_pos_tok / (true_pos_tok + false_pos_tok)
312
+ tok_rec = true_pos_tok / (true_pos_tok + false_neg_tok)
313
+ tok_f1 = 2 * ((tok_prec * tok_rec) / (tok_prec + tok_rec))
314
+
315
+ print(f"Pt/Rt/Ft=\n{tok_prec}\t{tok_rec}\t{tok_f1}")
316
+
317
+
318
+ if __name__ == "__main__":
319
+ main()
sociolome/lome_wrapper.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sftp import SpanPredictor
2
+ import spacy
3
+
4
+ import sys
5
+ import dataclasses
6
+ from typing import List, Optional, Dict, Any
7
+
8
+
9
+ predictor = SpanPredictor.from_path("model.mod.tar.gz")
10
+ nlp = spacy.load("xx_sent_ud_sm")
11
+
12
+
13
+ @dataclasses.dataclass
14
+ class FrameAnnotation:
15
+ tokens: List[str] = dataclasses.field(default_factory=list)
16
+ pos: List[str] = dataclasses.field(default_factory=list)
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class MultiLabelAnnotation(FrameAnnotation):
21
+ frame_list: List[List[str]] = dataclasses.field(default_factory=list)
22
+ lu_list: List[Optional[str]] = dataclasses.field(default_factory=list)
23
+
24
+ def to_txt(self):
25
+ for i, tok in enumerate(self.tokens):
26
+ yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}"
27
+
28
+
29
+ # reused from "combine_predictions.py" (cloned/lome/src/spanfinder/sociolome)
30
+ def convert_to_seq_labels(sentence: List[str], structures: Dict[int, Dict[str, Any]]) -> List[List[str]]:
31
+ labels = [[] for _ in sentence]
32
+
33
+ for struct_id, struct in structures.items():
34
+ tgt_span = struct["target"]
35
+ frame = struct["frame"]
36
+
37
+ for i in range(tgt_span[0], tgt_span[1] + 1):
38
+ labels[i].append(f"T:{frame}@{struct_id:02}")
39
+ for role in struct["roles"]:
40
+ role_span = role["boundary"]
41
+ role_label = role["label"]
42
+ for i in range(role_span[0], role_span[1] + 1):
43
+ prefix = "B" if i == role_span[0] else "I"
44
+ labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}")
45
+ return labels
46
+
47
+ def make_prediction(sentence, spacy_model, predictor):
48
+ spacy_doc = spacy_model(sentence)
49
+ tokens = [t.text for t in spacy_doc]
50
+ tgt_spans, fr_labels, _ = predictor.force_decode(tokens)
51
+
52
+ frame_structures = {}
53
+
54
+ for i, (tgt, frm) in enumerate(sorted(zip(tgt_spans, fr_labels), key=lambda t: t[0][0])):
55
+ arg_spans, arg_labels, _ = predictor.force_decode(tokens, parent_span=tgt, parent_label=frm)
56
+
57
+ frame_structures[i] = {
58
+ "target": tgt,
59
+ "frame": frm,
60
+ "roles": [
61
+ {"boundary": bnd, "label": label}
62
+ for bnd, label in zip(arg_spans, arg_labels)
63
+ if label != "Target"
64
+ ]
65
+ }
66
+
67
+ return MultiLabelAnnotation(
68
+ tokens=tokens,
69
+ pos=[t.pos_ for t in spacy_doc],
70
+ frame_list=convert_to_seq_labels(tokens, frame_structures),
71
+ lu_list=[None for _ in tokens]
72
+ )
73
+
74
+
75
+ def analyze(text):
76
+ analyses = []
77
+ for sentence in text.split("\n"):
78
+ analyses.append(make_prediction(sentence, nlp, predictor))
79
+
80
+ return {
81
+ "result": "OK",
82
+ "analyses": [dataclasses.asdict(an) for an in analyses]
83
+ }