Spaces:
Build error
Build error
gossminn
commited on
Commit
·
6680682
0
Parent(s):
First version
Browse files- .gitattributes +2 -0
- .gitignore +210 -0
- deploy.py +3 -0
- fillmorle/app.py +524 -0
- model.mod.tar.gz +3 -0
- requirements.txt +19 -0
- setup.py +9 -0
- sftp/__init__.py +10 -0
- sftp/data_reader/__init__.py +6 -0
- sftp/data_reader/batch_sampler/__init__.py +1 -0
- sftp/data_reader/batch_sampler/mix_sampler.py +50 -0
- sftp/data_reader/better_reader.py +286 -0
- sftp/data_reader/concrete_reader.py +44 -0
- sftp/data_reader/concrete_srl.py +169 -0
- sftp/data_reader/span_reader.py +197 -0
- sftp/data_reader/srl_reader.py +107 -0
- sftp/metrics/__init__.py +4 -0
- sftp/metrics/base_f.py +27 -0
- sftp/metrics/exact_match.py +29 -0
- sftp/metrics/fbeta_mix_measure.py +34 -0
- sftp/metrics/srl_metrics.py +138 -0
- sftp/models/__init__.py +1 -0
- sftp/models/span_model.py +362 -0
- sftp/modules/__init__.py +4 -0
- sftp/modules/smooth_crf.py +77 -0
- sftp/modules/span_extractor/__init__.py +1 -0
- sftp/modules/span_extractor/combo.py +36 -0
- sftp/modules/span_finder/__init__.py +2 -0
- sftp/modules/span_finder/bio_span_finder.py +216 -0
- sftp/modules/span_finder/span_finder.py +87 -0
- sftp/modules/span_typing/__init__.py +2 -0
- sftp/modules/span_typing/mlp_span_typing.py +99 -0
- sftp/modules/span_typing/span_typing.py +64 -0
- sftp/predictor/__init__.py +1 -0
- sftp/predictor/span_predictor.orig.py +362 -0
- sftp/predictor/span_predictor.py +401 -0
- sftp/training/__init__.py +0 -0
- sftp/training/transformer_optimizer.py +121 -0
- sftp/utils/__init__.py +7 -0
- sftp/utils/bio_smoothing.py +62 -0
- sftp/utils/common.py +3 -0
- sftp/utils/db_storage.py +87 -0
- sftp/utils/functions.py +75 -0
- sftp/utils/label_smoothing.py +48 -0
- sftp/utils/span.py +420 -0
- sftp/utils/span_utils.py +57 -0
- sociolome/combine_models.py +130 -0
- sociolome/evalita_eval.py +319 -0
- sociolome/lome_wrapper.py +83 -0
.gitattributes
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
spanfinder/model.mod.tar.gz filter=lfs diff=lfs merge=lfs -text
|
2 |
+
model.mod.tar.gz filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Created by .ignore support plugin (hsz.mobi)
|
2 |
+
### JetBrains template
|
3 |
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
4 |
+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
5 |
+
|
6 |
+
# User-specific stuff
|
7 |
+
.idea/
|
8 |
+
data
|
9 |
+
cache
|
10 |
+
|
11 |
+
# Gradle and Maven with auto-import
|
12 |
+
# When using Gradle or Maven with auto-import, you should exclude module files,
|
13 |
+
# since they will be recreated, and may cause churn. Uncomment if using
|
14 |
+
# auto-import.
|
15 |
+
# .idea/artifacts
|
16 |
+
# .idea/compiler.xml
|
17 |
+
# .idea/jarRepositories.xml
|
18 |
+
# .idea/modules.xml
|
19 |
+
# .idea/*.iml
|
20 |
+
# .idea/modules
|
21 |
+
# *.iml
|
22 |
+
# *.ipr
|
23 |
+
|
24 |
+
# CMake
|
25 |
+
cmake-build-*/
|
26 |
+
|
27 |
+
# Mongo Explorer plugin
|
28 |
+
.idea/**/mongoSettings.xml
|
29 |
+
|
30 |
+
# File-based project format
|
31 |
+
*.iws
|
32 |
+
|
33 |
+
# IntelliJ
|
34 |
+
out/
|
35 |
+
|
36 |
+
# mpeltonen/sbt-idea plugin
|
37 |
+
.idea_modules/
|
38 |
+
|
39 |
+
# JIRA plugin
|
40 |
+
atlassian-ide-plugin.xml
|
41 |
+
|
42 |
+
# Cursive Clojure plugin
|
43 |
+
.idea/replstate.xml
|
44 |
+
|
45 |
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
46 |
+
com_crashlytics_export_strings.xml
|
47 |
+
crashlytics.properties
|
48 |
+
crashlytics-build.properties
|
49 |
+
fabric.properties
|
50 |
+
|
51 |
+
# Editor-based Rest Client
|
52 |
+
.idea/httpRequests
|
53 |
+
|
54 |
+
# Android studio 3.1+ serialized cache file
|
55 |
+
.idea/caches/build_file_checksums.ser
|
56 |
+
|
57 |
+
### JupyterNotebooks template
|
58 |
+
# gitignore template for Jupyter Notebooks
|
59 |
+
# website: http://jupyter.org/
|
60 |
+
|
61 |
+
.ipynb_checkpoints
|
62 |
+
*/.ipynb_checkpoints/*
|
63 |
+
|
64 |
+
# IPython
|
65 |
+
profile_default/
|
66 |
+
ipython_config.py
|
67 |
+
|
68 |
+
# Remove previous ipynb_checkpoints
|
69 |
+
# git rm -r .ipynb_checkpoints/
|
70 |
+
|
71 |
+
### Python template
|
72 |
+
# Byte-compiled / optimized / DLL files
|
73 |
+
__pycache__/
|
74 |
+
*.py[cod]
|
75 |
+
*.class
|
76 |
+
|
77 |
+
# C extensions
|
78 |
+
*.so
|
79 |
+
|
80 |
+
# Distribution / packaging
|
81 |
+
.Python
|
82 |
+
build/
|
83 |
+
develop-eggs/
|
84 |
+
dist/
|
85 |
+
downloads/
|
86 |
+
eggs/
|
87 |
+
.eggs/
|
88 |
+
lib/
|
89 |
+
lib64/
|
90 |
+
parts/
|
91 |
+
sdist/
|
92 |
+
var/
|
93 |
+
wheels/
|
94 |
+
share/python-wheels/
|
95 |
+
*.egg-info/
|
96 |
+
.installed.cfg
|
97 |
+
*.egg
|
98 |
+
MANIFEST
|
99 |
+
|
100 |
+
# PyInstaller
|
101 |
+
# Usually these files are written by a python script from a template
|
102 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
103 |
+
*.manifest
|
104 |
+
*.spec
|
105 |
+
|
106 |
+
# Installer logs
|
107 |
+
pip-log.txt
|
108 |
+
pip-delete-this-directory.txt
|
109 |
+
|
110 |
+
# Unit test / coverage reports
|
111 |
+
htmlcov/
|
112 |
+
.tox/
|
113 |
+
.nox/
|
114 |
+
.coverage
|
115 |
+
.coverage.*
|
116 |
+
.cache
|
117 |
+
nosetests.xml
|
118 |
+
coverage.xml
|
119 |
+
*.cover
|
120 |
+
*.py,cover
|
121 |
+
.hypothesis/
|
122 |
+
.pytest_cache/
|
123 |
+
cover/
|
124 |
+
|
125 |
+
# Translations
|
126 |
+
*.mo
|
127 |
+
*.pot
|
128 |
+
|
129 |
+
# Django stuff:
|
130 |
+
*.log
|
131 |
+
local_settings.py
|
132 |
+
db.sqlite3
|
133 |
+
db.sqlite3-journal
|
134 |
+
|
135 |
+
# Flask stuff:
|
136 |
+
instance/
|
137 |
+
.webassets-cache
|
138 |
+
|
139 |
+
# Scrapy stuff:
|
140 |
+
.scrapy
|
141 |
+
|
142 |
+
# Sphinx documentation
|
143 |
+
docs/_build/
|
144 |
+
|
145 |
+
# PyBuilder
|
146 |
+
.pybuilder/
|
147 |
+
target/
|
148 |
+
|
149 |
+
# Jupyter Notebook
|
150 |
+
.ipynb_checkpoints
|
151 |
+
|
152 |
+
# IPython
|
153 |
+
profile_default/
|
154 |
+
ipython_config.py
|
155 |
+
|
156 |
+
# pyenv
|
157 |
+
# For a library or package, you might want to ignore these files since the code is
|
158 |
+
# intended to run in multiple environments; otherwise, check them in:
|
159 |
+
# .python-version
|
160 |
+
|
161 |
+
# pipenv
|
162 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
163 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
164 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
165 |
+
# install all needed dependencies.
|
166 |
+
#Pipfile.lock
|
167 |
+
|
168 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
169 |
+
__pypackages__/
|
170 |
+
|
171 |
+
# Celery stuff
|
172 |
+
celerybeat-schedule
|
173 |
+
celerybeat.pid
|
174 |
+
|
175 |
+
# SageMath parsed files
|
176 |
+
*.sage.py
|
177 |
+
|
178 |
+
# Environments
|
179 |
+
.env
|
180 |
+
.venv
|
181 |
+
env/
|
182 |
+
venv/
|
183 |
+
ENV/
|
184 |
+
env.bak/
|
185 |
+
venv.bak/
|
186 |
+
|
187 |
+
# Spyder project settings
|
188 |
+
.spyderproject
|
189 |
+
.spyproject
|
190 |
+
|
191 |
+
# Rope project settings
|
192 |
+
.ropeproject
|
193 |
+
|
194 |
+
# mkdocs documentation
|
195 |
+
/site
|
196 |
+
|
197 |
+
# mypy
|
198 |
+
.mypy_cache/
|
199 |
+
.dmypy.json
|
200 |
+
dmypy.json
|
201 |
+
|
202 |
+
# Pyre type checker
|
203 |
+
.pyre/
|
204 |
+
|
205 |
+
# pytype static type analyzer
|
206 |
+
.pytype/
|
207 |
+
|
208 |
+
# Cython debug symbols
|
209 |
+
cython_debug/
|
210 |
+
|
deploy.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
import runpy
|
2 |
+
|
3 |
+
runpy.run_module("fillmorle.app", run_name="__main__", alter_sys=True)
|
fillmorle/app.py
ADDED
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from itertools import product
|
2 |
+
import random
|
3 |
+
from turtle import hideturtle
|
4 |
+
import requests
|
5 |
+
import json
|
6 |
+
import lxml.etree as ET
|
7 |
+
|
8 |
+
import gensim
|
9 |
+
import pandas as pd
|
10 |
+
|
11 |
+
import nltk
|
12 |
+
# from nltk.corpus import framenet as fn
|
13 |
+
# --- circumvent threading issues with FrameNet
|
14 |
+
fn_root = nltk.data.find("{}/{}".format("corpora", "framenet_v17"))
|
15 |
+
print(fn_root)
|
16 |
+
fn_files = ["frRelation.xml", "frameIndex.xml", "fulltextIndex.xml", "luIndex.xml", "semTypes.xml"]
|
17 |
+
fn = nltk.corpus.reader.framenet.FramenetCorpusReader(fn_root, fn_files)
|
18 |
+
# ---
|
19 |
+
|
20 |
+
import streamlit as st
|
21 |
+
|
22 |
+
from sociolome import lome_wrapper
|
23 |
+
|
24 |
+
|
25 |
+
def similarity(gensim_m, frame_1, frame_2):
|
26 |
+
if f"fn_{frame_1}" not in gensim_m or f"fn_{frame_2}" not in gensim_m:
|
27 |
+
return None
|
28 |
+
return 1 - gensim_m.distance(f"fn_{frame_1}", f"fn_{frame_2}")
|
29 |
+
|
30 |
+
|
31 |
+
def rank(gensim_m, frame_1, frame_2):
|
32 |
+
frame_1 = f"fn_{frame_1}"
|
33 |
+
frame_2 = f"fn_{frame_2}"
|
34 |
+
|
35 |
+
if frame_1 == frame_2:
|
36 |
+
return 0
|
37 |
+
|
38 |
+
for i, (word, _) in enumerate(gensim_m.most_similar(frame_1, topn=1200)):
|
39 |
+
if word == frame_2:
|
40 |
+
return i + 1
|
41 |
+
return -1
|
42 |
+
|
43 |
+
|
44 |
+
def format_frame_description(frame_def_xml):
|
45 |
+
frame_def_fmt = [frame_def_xml.text] if frame_def_xml.text else []
|
46 |
+
for elem in frame_def_xml:
|
47 |
+
if elem.tag == "ex":
|
48 |
+
break
|
49 |
+
elif elem.tag == "fen":
|
50 |
+
frame_def_fmt.append(elem.text.upper())
|
51 |
+
elif elem.text:
|
52 |
+
frame_def_fmt.append(elem.text)
|
53 |
+
if elem.tail:
|
54 |
+
frame_def_fmt.append(elem.tail)
|
55 |
+
return "".join(frame_def_fmt).replace("frames", "stories").replace("frame", "story")
|
56 |
+
|
57 |
+
|
58 |
+
def get_frame_definition(frame_info):
|
59 |
+
try:
|
60 |
+
# try extracting just the first sentence
|
61 |
+
definition_first_sent = nltk.sent_tokenize(frame_info.definitionMarkup)[0] + "</def-root>"
|
62 |
+
frame_def_xml = ET.fromstring(definition_first_sent)
|
63 |
+
except ET.XMLSyntaxError:
|
64 |
+
# otherwise, use the full definition
|
65 |
+
frame_def_xml = ET.fromstring(frame_info.definitionMarkup)
|
66 |
+
return format_frame_description(frame_def_xml)
|
67 |
+
|
68 |
+
|
69 |
+
def get_random_example(frame_info):
|
70 |
+
exemplars = [
|
71 |
+
{
|
72 |
+
"text": exemplar.text,
|
73 |
+
"target_lu": lu_name,
|
74 |
+
"target_idx": list(exemplar["Target"][0]),
|
75 |
+
"core_fes": {
|
76 |
+
role: exemplar.text[start_idx:end_idx]
|
77 |
+
for role, start_idx, end_idx in exemplar.FE[0]
|
78 |
+
if role in [fe for fe, fe_info in frame_info.FE.items() if fe_info.coreType == "Core"]
|
79 |
+
}
|
80 |
+
}
|
81 |
+
for lu_name, lu_info in frame_info["lexUnit"].items()
|
82 |
+
for exemplar in lu_info.exemplars if len(exemplar.text) > 30
|
83 |
+
]
|
84 |
+
if exemplars:
|
85 |
+
return random.choice(exemplars)
|
86 |
+
return None
|
87 |
+
|
88 |
+
def make_hint(gensim_m, target, current_closest):
|
89 |
+
|
90 |
+
if target == current_closest:
|
91 |
+
return None
|
92 |
+
|
93 |
+
most_similar = gensim_m.most_similar(f"fn_{target}", topn=1200)
|
94 |
+
current_position = [word for word, _ in most_similar].index(f"fn_{current_closest}")
|
95 |
+
|
96 |
+
while current_position > 0:
|
97 |
+
next_closest, _ = most_similar[current_position - 1]
|
98 |
+
info = fn.frame(next_closest.replace("fn_", ""))
|
99 |
+
if len(info.lexUnit) > 10:
|
100 |
+
exemplar = get_random_example(info)
|
101 |
+
if exemplar:
|
102 |
+
return next_closest, exemplar
|
103 |
+
current_position -= 1
|
104 |
+
|
105 |
+
return None
|
106 |
+
|
107 |
+
|
108 |
+
def get_typical_exemplar(frame_info):
|
109 |
+
exemplars = [
|
110 |
+
{
|
111 |
+
"text": exemplar.text,
|
112 |
+
"target_lu": lu_name,
|
113 |
+
"target_idx": list(exemplar["Target"][0]),
|
114 |
+
"core_fes": {
|
115 |
+
role: exemplar.text[start_idx:end_idx]
|
116 |
+
for role, start_idx, end_idx in exemplar.FE[0]
|
117 |
+
if role in [fe for fe, fe_info in frame_info.FE.items() if fe_info.coreType == "Core"]
|
118 |
+
}
|
119 |
+
}
|
120 |
+
for lu_name, lu_info in frame_info["lexUnit"].items()
|
121 |
+
for exemplar in lu_info.exemplars
|
122 |
+
]
|
123 |
+
|
124 |
+
# try to find a "typical" exemplar --- typical -> as short as possible, as many FEs as possible
|
125 |
+
exa_typicality_scores = [(exa, len(exa["text"]) - 25 * len(exa["core_fes"])) for exa in exemplars]
|
126 |
+
if exa_typicality_scores:
|
127 |
+
typical_exemplar = min(exa_typicality_scores, key=lambda t: t[1])[0]
|
128 |
+
else:
|
129 |
+
typical_exemplar = None
|
130 |
+
return typical_exemplar
|
131 |
+
|
132 |
+
|
133 |
+
def find_all_inheriting_frames(frame_name):
|
134 |
+
frame_info = fn.frame(frame_name)
|
135 |
+
inheritance_rels = [rel for rel in frame_info.frameRelations if rel.type.name == "Inheritance" and rel.superFrame.name == frame_name]
|
136 |
+
inheritors = [rel.subFrame.name for rel in inheritance_rels]
|
137 |
+
for inh in inheritors:
|
138 |
+
inheritors.extend(find_all_inheriting_frames(inh))
|
139 |
+
return inheritors
|
140 |
+
|
141 |
+
|
142 |
+
def has_enough_lus(frame, n=10):
|
143 |
+
return len(fn.frame(frame).lexUnit) > n
|
144 |
+
|
145 |
+
|
146 |
+
def choose_secret_frames():
|
147 |
+
event_frames = [frm for frm in find_all_inheriting_frames("Event") if has_enough_lus(frm)]
|
148 |
+
entity_frames = [frm for frm in find_all_inheriting_frames("Entity") if has_enough_lus(frm)]
|
149 |
+
return random.choice(list(product(event_frames, entity_frames)))
|
150 |
+
|
151 |
+
|
152 |
+
def get_frame_info(frames):
|
153 |
+
frames_and_info = []
|
154 |
+
for evoked_frame in frames:
|
155 |
+
try:
|
156 |
+
frame_info = fn.frame(evoked_frame)
|
157 |
+
typical_sentence = get_typical_exemplar(frame_info)
|
158 |
+
frames_and_info.append((evoked_frame, frame_info, typical_sentence))
|
159 |
+
except FileNotFoundError:
|
160 |
+
continue
|
161 |
+
return frames_and_info
|
162 |
+
|
163 |
+
|
164 |
+
def get_frame_feedback(frames_and_info, gensim_m, secret_event, secret_entity):
|
165 |
+
frame_feedback = []
|
166 |
+
for evoked_frame, frame_info, typical_sentence in frames_and_info:
|
167 |
+
lexunits = list(frame_info.lexUnit.keys())[:5]
|
168 |
+
similarity_score_1 = similarity(gensim_m, secret_event, evoked_frame)
|
169 |
+
similarity_rank_1 = rank(gensim_m, secret_event, evoked_frame)
|
170 |
+
similarity_score_2 = similarity(gensim_m, secret_entity, evoked_frame)
|
171 |
+
similarity_rank_2 = rank(gensim_m, secret_entity, evoked_frame)
|
172 |
+
if typical_sentence:
|
173 |
+
typical_sentence_txt = typical_sentence['text']
|
174 |
+
else:
|
175 |
+
typical_sentence_txt = None
|
176 |
+
|
177 |
+
frame_feedback.append({
|
178 |
+
"frame": evoked_frame,
|
179 |
+
"similarity_1": similarity_score_1 * 100 if similarity_score_1 else None,
|
180 |
+
"rank_1": similarity_rank_1 if similarity_rank_1 != -1 else "far away",
|
181 |
+
"similarity_2": similarity_score_2 * 100 if similarity_score_2 else None,
|
182 |
+
"rank_2": similarity_rank_2 if similarity_rank_2 != -1 else "far away",
|
183 |
+
"typical_words": lexunits,
|
184 |
+
"typical_sentence": typical_sentence_txt
|
185 |
+
})
|
186 |
+
return frame_feedback
|
187 |
+
|
188 |
+
|
189 |
+
def run_game_cli(debug=True):
|
190 |
+
|
191 |
+
secret_event, secret_entity = choose_secret_frames()
|
192 |
+
|
193 |
+
if debug:
|
194 |
+
print(f"Shhhhhh you're not supposed to know, but the secret frames are {secret_event} and {secret_entity}")
|
195 |
+
print("--------\n\n\n\n")
|
196 |
+
|
197 |
+
print("Welcome to FillmorLe!")
|
198 |
+
print("Words are not just words: behind every word, a story is hidden that appears in our imagination when we hear the word.")
|
199 |
+
print()
|
200 |
+
print("In this game, your job is to activate TWO SECRET STORIES by writing sentences.")
|
201 |
+
print("There will be new secret stories every day -- the first story is always about an EVENT (something that happens in the world) and the second one about an ENTITY (a thing or concept).")
|
202 |
+
print("Every time you write a sentence, I will tell you which stories are hidden below the surface, and how close these stories are to the secret stories.")
|
203 |
+
print("Once you write a sentence that has both of the secret stories in it, you win. Good luck and be creative!")
|
204 |
+
|
205 |
+
gensim_m = gensim.models.word2vec.KeyedVectors.load_word2vec_format("data/frame_embeddings.w2v.txt")
|
206 |
+
|
207 |
+
num_guesses = 0
|
208 |
+
guesses_event = []
|
209 |
+
guesses_entity = []
|
210 |
+
|
211 |
+
while True:
|
212 |
+
num_guesses += 1
|
213 |
+
closest_to_event = sorted(guesses_event, key=lambda g: g[1], reverse=True)[:5]
|
214 |
+
closest_to_entity = sorted(guesses_entity, key=lambda g: g[1], reverse=True)[:5]
|
215 |
+
closest_to_event_txt = ", ".join([f"{frm.upper()} ({sim:.2f})" for frm, sim in closest_to_event])
|
216 |
+
closest_to_entity_txt = ", ".join([f"{frm.upper()} ({sim:.2f})" for frm, sim in closest_to_entity])
|
217 |
+
|
218 |
+
print()
|
219 |
+
print(f"==== Guess #{num_guesses} ====")
|
220 |
+
if secret_event in guesses_event:
|
221 |
+
print("You already guessed SECRET STORY #1: ", secret_event.upper())
|
222 |
+
elif closest_to_event:
|
223 |
+
print(f"Best guesses (SECRET STORY #1):", closest_to_event_txt)
|
224 |
+
|
225 |
+
if secret_entity in guesses_entity:
|
226 |
+
print("You already guessed SECRET STORY #1: ", secret_entity.upper())
|
227 |
+
elif closest_to_entity:
|
228 |
+
print(f"Best guesses (SECRET STORY #2):", closest_to_entity_txt)
|
229 |
+
|
230 |
+
sentence = input("Enter a sentence or type 'HINT' if you're stuck >>>> ").strip()
|
231 |
+
|
232 |
+
if sentence == "HINT":
|
233 |
+
hint_target = None
|
234 |
+
while not hint_target:
|
235 |
+
hint_choice = input("For which story do you want a hint? Type '1' or '2' >>>> ").strip()
|
236 |
+
if hint_choice == "1":
|
237 |
+
hint_target = secret_event
|
238 |
+
hint_current = closest_to_event[0][0] if closest_to_event else "Event"
|
239 |
+
elif hint_choice == "2":
|
240 |
+
hint_target = secret_entity
|
241 |
+
hint_current = closest_to_entity[0][0] if closest_to_entity else "Entity"
|
242 |
+
else:
|
243 |
+
print("Please type '1' or '2'.")
|
244 |
+
|
245 |
+
if hint_current == hint_target:
|
246 |
+
print("You don't need a hint for this story! Maybe you want a hint for the other one?")
|
247 |
+
continue
|
248 |
+
|
249 |
+
hint = make_hint(gensim_m, hint_target, hint_current)
|
250 |
+
if hint is None:
|
251 |
+
print("Sorry, you're already too close to give you a hint!")
|
252 |
+
else:
|
253 |
+
_, hint_example = hint
|
254 |
+
hint_tgt_idx = hint_example["target_idx"]
|
255 |
+
hint_example_redacted = hint_example["text"][:hint_tgt_idx[0]] + "******" + hint_example["text"][hint_tgt_idx[1]:]
|
256 |
+
print(f"Your hint sentence is: «{hint_example_redacted}»")
|
257 |
+
print(f"PRO TIP 1: the '******' hide a secret word. Guess the word and you will find a story that takes your one step closer to find SECRET STORY #{hint_choice}")
|
258 |
+
print(f"PRO TIP 2: if you don't get the hint, just ask for a new one! You can do this as often as you want.")
|
259 |
+
print("\n\n")
|
260 |
+
continue
|
261 |
+
|
262 |
+
r = requests.get("http://127.0.0.1:9090/analyze", params={"text": sentence})
|
263 |
+
lome_data = json.loads(r.text)
|
264 |
+
frames = set()
|
265 |
+
for token_items in lome_data["analyses"][0]["frame_list"]:
|
266 |
+
for item in token_items:
|
267 |
+
if item.startswith("T:"):
|
268 |
+
evoked_frame = item.split("@")[0].replace("T:", "")
|
269 |
+
frames.add(evoked_frame)
|
270 |
+
|
271 |
+
frames_and_info = get_frame_info(frames)
|
272 |
+
frame_feedback = get_frame_feedback(frames_and_info)
|
273 |
+
|
274 |
+
for i, feedback in enumerate(frame_feedback):
|
275 |
+
|
276 |
+
print(f"STORY {i}: {feedback['frame'].upper()}")
|
277 |
+
if feedback["typical_sentence"]:
|
278 |
+
print(f"\ttypical context: «{feedback['typical_sentence']}»")
|
279 |
+
print("\ttypical words:", ", ".join(feedback["typical_words"]), "...")
|
280 |
+
if feedback["similarity_1"]:
|
281 |
+
guesses_event.append((evoked_frame, feedback["similarity_1"]))
|
282 |
+
guesses_entity.append((evoked_frame, feedback["similarity_2"]))
|
283 |
+
print(f"\tsimilarity to SECRET STORY #1: {feedback['similarity_1']:.2f}")
|
284 |
+
print(f"\tsimilarity to SECRET STORY #2: {feedback['similarity_2']:.2f}")
|
285 |
+
else:
|
286 |
+
print("similarity: unknown")
|
287 |
+
print()
|
288 |
+
|
289 |
+
if not frames_and_info:
|
290 |
+
print("I don't know any of the stories in your sentence. Try entering another sentence.")
|
291 |
+
|
292 |
+
elif secret_event in frames and secret_entity in frames:
|
293 |
+
print(f"YOU WIN! You made a sentence with both of the SECRET STORIES: {secret_event.upper()} and {secret_entity.upper()}.\nYou won the game in {num_guesses} guesses, great job!")
|
294 |
+
break
|
295 |
+
|
296 |
+
elif secret_event in frames:
|
297 |
+
print(f"Great, you guessed SECRET STORY #1! It was {secret_event.upper()}!")
|
298 |
+
print("To win, make a sentence with this story and SECRET STORY #2 hidden in it.")
|
299 |
+
|
300 |
+
elif secret_entity in frames:
|
301 |
+
print(f"Great, you guessed SECRET STORY #2! It was {secret_entity.upper()}!")
|
302 |
+
print("To win, make a sentence with this story and SECRET STORY #1 hidden in it.")
|
303 |
+
|
304 |
+
|
305 |
+
# dummy version
|
306 |
+
# def analyze_sentence(sentence):
|
307 |
+
# return sentence.split()
|
308 |
+
|
309 |
+
def analyze_sentence(sentence):
|
310 |
+
lome_data = lome_wrapper.analyze(sentence)
|
311 |
+
frames = set()
|
312 |
+
for token_items in lome_data["analyses"][0]["frame_list"]:
|
313 |
+
for item in token_items:
|
314 |
+
if item.startswith("T:"):
|
315 |
+
evoked_frame = item.split("@")[0].replace("T:", "")
|
316 |
+
frames.add(evoked_frame)
|
317 |
+
return frames
|
318 |
+
|
319 |
+
|
320 |
+
|
321 |
+
def make_frame_feedback_msg(frame_feedback):
|
322 |
+
feedback_msg = []
|
323 |
+
for i, feedback in enumerate(frame_feedback):
|
324 |
+
feedback_msg.append(f"* STORY {i}: *{feedback['frame'].upper()}*")
|
325 |
+
feedback_msg.append("\t* typical words: *" + " ".join(feedback["typical_words"]) + "* ...")
|
326 |
+
if feedback["typical_sentence"]:
|
327 |
+
feedback_msg.append(f"\t* typical context: «{feedback['typical_sentence']}»")
|
328 |
+
|
329 |
+
if feedback["similarity_1"]:
|
330 |
+
feedback_msg.append(f"\t* similarity to SECRET STORY #1: {feedback['similarity_1']:.2f}")
|
331 |
+
feedback_msg.append(f"\t* similarity to SECRET STORY #2: {feedback['similarity_2']:.2f}")
|
332 |
+
else:
|
333 |
+
feedback_msg.append(f"\t* similarity: unknown")
|
334 |
+
return "\n".join(feedback_msg)
|
335 |
+
|
336 |
+
|
337 |
+
def format_hint_sentence(hint_example):
|
338 |
+
hint_tgt_idx = hint_example["target_idx"]
|
339 |
+
hint_example_redacted = hint_example["text"][:hint_tgt_idx[0]] + "******" + hint_example["text"][hint_tgt_idx[1]:]
|
340 |
+
return hint_example_redacted.strip()
|
341 |
+
|
342 |
+
|
343 |
+
def play_turn():
|
344 |
+
# remove text from input
|
345 |
+
sentence = st.session_state["cur_sentence"]
|
346 |
+
st.session_state["cur_sentence"] = ""
|
347 |
+
|
348 |
+
# get previous game state
|
349 |
+
game_state = st.session_state["game_state"]
|
350 |
+
secret_event, secret_entity = game_state["secret_event"], game_state["secret_entity"]
|
351 |
+
guesses_event, guesses_entity = game_state["guesses_event"], game_state["guesses_entity"]
|
352 |
+
|
353 |
+
# reset hints
|
354 |
+
st.session_state["hints"] = [None, None]
|
355 |
+
|
356 |
+
# reveal correct frames
|
357 |
+
if sentence.strip().lower() == "show me the frames":
|
358 |
+
st.warning(f"The correct frames are: {secret_event.upper()} and {secret_entity.upper()}")
|
359 |
+
|
360 |
+
# process hints
|
361 |
+
elif sentence.strip() == "HINT":
|
362 |
+
guesses_event = sorted(game_state["guesses_event"], key=lambda t: t[1], reverse=True)
|
363 |
+
guesses_entity = sorted(game_state["guesses_entity"], key=lambda t: t[1], reverse=True)
|
364 |
+
best_guess_event = guesses_event[0][0] if guesses_event else "Event"
|
365 |
+
best_guess_entity = guesses_entity[0][0] if guesses_entity else "Entity"
|
366 |
+
|
367 |
+
event_hint = make_hint(st.session_state["gensim_model"], secret_event, best_guess_event)
|
368 |
+
entity_hint = make_hint(st.session_state["gensim_model"], secret_entity, best_guess_entity)
|
369 |
+
|
370 |
+
if event_hint:
|
371 |
+
st.session_state["hints"][0] = format_hint_sentence(event_hint[1])
|
372 |
+
if entity_hint:
|
373 |
+
st.session_state["hints"][1] = format_hint_sentence(entity_hint[1])
|
374 |
+
|
375 |
+
|
376 |
+
else:
|
377 |
+
frames = analyze_sentence(sentence)
|
378 |
+
frames_and_info = get_frame_info(frames)
|
379 |
+
frame_feedback = get_frame_feedback(frames_and_info, st.session_state["gensim_model"], secret_event, secret_entity)
|
380 |
+
|
381 |
+
# update game state post analysis
|
382 |
+
game_state["num_guesses"] += 1
|
383 |
+
for fdb in frame_feedback:
|
384 |
+
if fdb["similarity_1"]:
|
385 |
+
guesses_event.add((fdb["frame"], fdb["similarity_1"], fdb["rank_1"]))
|
386 |
+
guesses_entity.add((fdb["frame"], fdb["similarity_2"], fdb["rank_2"]))
|
387 |
+
|
388 |
+
st.session_state["frame_feedback"] = frame_feedback
|
389 |
+
if secret_event in frames and secret_entity in frames:
|
390 |
+
st.session_state["game_over"] = True
|
391 |
+
st.session_state["guesses_to_win"] = game_state["num_guesses"]
|
392 |
+
|
393 |
+
def display_guess_status():
|
394 |
+
game_state = st.session_state["game_state"]
|
395 |
+
guesses_entity = sorted(game_state["guesses_entity"], key=lambda t: t[1], reverse=True)
|
396 |
+
guesses_event = sorted(game_state["guesses_event"], key=lambda t: t[1], reverse=True)
|
397 |
+
|
398 |
+
if guesses_event or guesses_entity:
|
399 |
+
st.header("Best guesses")
|
400 |
+
|
401 |
+
event_col, entity_col = st.columns(2)
|
402 |
+
if guesses_event:
|
403 |
+
with event_col:
|
404 |
+
st.subheader("Secret Story #1")
|
405 |
+
st.table(pd.DataFrame(guesses_event, columns=["Story", "Similarity", "Steps To Go"]))
|
406 |
+
if game_state["secret_event"] in [g for g, _, _ in guesses_event]:
|
407 |
+
st.info("Great, you guessed the Event story! In order to win, make a sentence containing both the secret stories.")
|
408 |
+
if guesses_entity:
|
409 |
+
with entity_col:
|
410 |
+
st.subheader("Secret Story #2")
|
411 |
+
st.table(pd.DataFrame(guesses_entity, columns=["Story", "Similarity", "Steps To Go"]))
|
412 |
+
if game_state["secret_entity"] in [g for g, _, _ in guesses_entity]:
|
413 |
+
st.info("Great, you guessed the Thing story! In order to win, make a sentence containing both the secret stories.")
|
414 |
+
|
415 |
+
|
416 |
+
def format_feedback(frame_feedback):
|
417 |
+
out = []
|
418 |
+
for fdb in frame_feedback:
|
419 |
+
out.append({
|
420 |
+
"Story": fdb["frame"],
|
421 |
+
"Similarity (Event)": f"{fdb['similarity_1']:.2f}",
|
422 |
+
"Similarity (Thing)": f"{fdb['similarity_2']:.2f}",
|
423 |
+
"Typical Context": fdb["typical_sentence"],
|
424 |
+
"Typical Words": " ".join(fdb["typical_words"])
|
425 |
+
})
|
426 |
+
return out
|
427 |
+
|
428 |
+
|
429 |
+
def display_introduction():
|
430 |
+
st.subheader("Why this game?")
|
431 |
+
st.markdown(
|
432 |
+
"""
|
433 |
+
Words are not just words: behind every word, a _mini-story_ (also known as "frame") is hidden
|
434 |
+
that appears in our imagination when we hear the word. For example, when we hear the word
|
435 |
+
"talking" we can imagine a mini-story that involves several people who are interacting
|
436 |
+
with each other. Or, if we hear the word "cookie", we might think of someone eating a cookie.
|
437 |
+
""".strip())
|
438 |
+
|
439 |
+
st.subheader("How does it work?")
|
440 |
+
st.markdown(
|
441 |
+
"* In this game, there are two secret mini-stories, and it's your job to figure out which ones!"
|
442 |
+
"\n"
|
443 |
+
"* The first mini-story is about an _Event_ (something that happens in the world, like a thunderstorm, "
|
444 |
+
"people talking, someone eating pasta), and the other one is a _Thing_ (a concrete thing like a tree"
|
445 |
+
"or something abstract like 'love')."
|
446 |
+
"\n"
|
447 |
+
"* How to guess the stories? Well, just type a sentence, and we'll tell you which mini-stories are "
|
448 |
+
"hidden in the sentence. For each of the stories, we'll tell you how close they are to the secret ones."
|
449 |
+
"\n"
|
450 |
+
"* Once you type a sentence with both of the secret mini-stories, you win!"
|
451 |
+
)
|
452 |
+
|
453 |
+
|
454 |
+
|
455 |
+
def display_hints():
|
456 |
+
event_hint, entity_hint = st.session_state["hints"]
|
457 |
+
if event_hint or entity_hint:
|
458 |
+
st.header("Hints")
|
459 |
+
st.info("So you need some help? Here you get your hint sentences! Guess the hidden word, use it in a sentence, and we'll help you get one step closer.")
|
460 |
+
|
461 |
+
if event_hint:
|
462 |
+
st.markdown(f"**Event Hint**:\n>_{event_hint}_")
|
463 |
+
if entity_hint:
|
464 |
+
st.markdown(f"**Thing Hint**:\n>_{entity_hint}_")
|
465 |
+
|
466 |
+
def display_frame_feedback():
|
467 |
+
frame_feedback = st.session_state["frame_feedback"]
|
468 |
+
if frame_feedback:
|
469 |
+
st.header("Feedback")
|
470 |
+
st.text("Your sentence contains the following stories: ")
|
471 |
+
feedback_df = format_feedback(frame_feedback)
|
472 |
+
st.table(pd.DataFrame(feedback_df))
|
473 |
+
|
474 |
+
|
475 |
+
def run_game_st(debug=True):
|
476 |
+
|
477 |
+
if not st.session_state.get("initialized", False):
|
478 |
+
|
479 |
+
secret_event, secret_entity = choose_secret_frames()
|
480 |
+
gensim_m = gensim.models.word2vec.KeyedVectors.load_word2vec_format("data/frame_embeddings.w2v.txt")
|
481 |
+
|
482 |
+
game_state = {
|
483 |
+
"secret_event": secret_event,
|
484 |
+
"secret_entity": secret_entity,
|
485 |
+
"num_guesses": 0,
|
486 |
+
"guesses_event": set(),
|
487 |
+
"guesses_entity": set(),
|
488 |
+
}
|
489 |
+
|
490 |
+
st.session_state["initialized"] = True
|
491 |
+
st.session_state["show_introduction"] = False
|
492 |
+
st.session_state["game_over"] = False
|
493 |
+
st.session_state["guesses_to_win"] = -1
|
494 |
+
st.session_state["game_state"] = game_state
|
495 |
+
st.session_state["gensim_model"] = gensim_m
|
496 |
+
st.session_state["frame_feedback"] = None
|
497 |
+
st.session_state["hints"] = [None, None]
|
498 |
+
|
499 |
+
else:
|
500 |
+
gensim_m = st.session_state["gensim_model"]
|
501 |
+
game_state = st.session_state["game_state"]
|
502 |
+
|
503 |
+
secret_event, secret_entity = game_state["secret_event"], game_state["secret_entity"]
|
504 |
+
|
505 |
+
header = st.container()
|
506 |
+
with header:
|
507 |
+
st.title("FillmorLe")
|
508 |
+
st.checkbox("Show explanation?", key="show_introduction")
|
509 |
+
if st.session_state["show_introduction"]:
|
510 |
+
display_introduction()
|
511 |
+
|
512 |
+
st.header(f"Guess #{st.session_state['game_state']['num_guesses'] + 1}")
|
513 |
+
st.text_input("Enter a sentence or type 'HINT' if you're stuck", key="cur_sentence", on_change=play_turn)
|
514 |
+
|
515 |
+
if st.session_state["game_over"]:
|
516 |
+
st.success(f"You won in {st.session_state['guesses_to_win']}!")
|
517 |
+
|
518 |
+
display_hints()
|
519 |
+
display_frame_feedback()
|
520 |
+
display_guess_status()
|
521 |
+
|
522 |
+
|
523 |
+
if __name__ == "__main__":
|
524 |
+
run_game_st()
|
model.mod.tar.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3f5be5aeef50b2f4840317b8196c51186f9f138a853dc1eb2da980b1947ceb23
|
3 |
+
size 1795605184
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
allennlp>=2.0.0
|
2 |
+
allennlp-models>=2.0.0
|
3 |
+
transformers>=4.0.0 # Why is huggingface so unstable?
|
4 |
+
numpy
|
5 |
+
torch>=1.7.0,<1.8.0
|
6 |
+
tqdm
|
7 |
+
nltk
|
8 |
+
overrides
|
9 |
+
concrete
|
10 |
+
flask
|
11 |
+
scipy
|
12 |
+
requests
|
13 |
+
lxml
|
14 |
+
gensim
|
15 |
+
streamlit
|
16 |
+
https://github.com/explosion/spacy-models/releases/download/it_core_news_md-3.0.0/it_core_news_md-3.0.0-py3-none-any.whl
|
17 |
+
https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.0.0/en_core_web_md-3.0.0-py3-none-any.whl
|
18 |
+
https://github.com/explosion/spacy-models/releases/download/nl_core_news_md-3.0.0/nl_core_news_md-3.0.0-py3-none-any.whl
|
19 |
+
https://github.com/explosion/spacy-models/releases/download/xx_sent_ud_sm-3.0.0/xx_sent_ud_sm-3.0.0-py3-none-any.whl
|
setup.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
|
4 |
+
setup(
|
5 |
+
name='sftp',
|
6 |
+
version='0.0.2',
|
7 |
+
author='Guanghui Qin',
|
8 |
+
packages=find_packages(),
|
9 |
+
)
|
sftp/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .data_reader import (
|
2 |
+
BetterDatasetReader, SRLDatasetReader
|
3 |
+
)
|
4 |
+
from .metrics import SRLMetric, BaseF, ExactMatch, FBetaMixMeasure
|
5 |
+
from .models import SpanModel
|
6 |
+
from .modules import (
|
7 |
+
MLPSpanTyping, SpanTyping, SpanFinder, BIOSpanFinder
|
8 |
+
)
|
9 |
+
from .predictor import SpanPredictor
|
10 |
+
from .utils import Span
|
sftp/data_reader/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .batch_sampler import MixSampler
|
2 |
+
from .better_reader import BetterDatasetReader
|
3 |
+
from .span_reader import SpanReader
|
4 |
+
from .srl_reader import SRLDatasetReader
|
5 |
+
from .concrete_srl import concrete_doc, concrete_doc_tokenized, collect_concrete_srl
|
6 |
+
from .concrete_reader import ConcreteDatasetReader
|
sftp/data_reader/batch_sampler/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .mix_sampler import MixSampler
|
sftp/data_reader/batch_sampler/mix_sampler.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import random
|
3 |
+
from typing import *
|
4 |
+
|
5 |
+
from allennlp.data.samplers.batch_sampler import BatchSampler
|
6 |
+
from allennlp.data.samplers.max_tokens_batch_sampler import MaxTokensBatchSampler
|
7 |
+
from torch.utils import data
|
8 |
+
|
9 |
+
logger = logging.getLogger('mix_sampler')
|
10 |
+
|
11 |
+
|
12 |
+
@BatchSampler.register('mix_sampler')
|
13 |
+
class MixSampler(MaxTokensBatchSampler):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
max_tokens: int,
|
17 |
+
sorting_keys: List[str] = None,
|
18 |
+
padding_noise: float = 0.1,
|
19 |
+
sampling_ratios: Optional[Dict[str, float]] = None,
|
20 |
+
):
|
21 |
+
super().__init__(max_tokens, sorting_keys, padding_noise)
|
22 |
+
|
23 |
+
self.sampling_ratios = sampling_ratios or dict()
|
24 |
+
|
25 |
+
def __iter__(self):
|
26 |
+
indices, lengths = self._argsort_by_padding(self.data_source)
|
27 |
+
|
28 |
+
original_num = len(indices)
|
29 |
+
instance_types = [
|
30 |
+
ins.fields['meta'].metadata.get('type', 'default') if 'meta' in ins.fields else 'default'
|
31 |
+
for ins in self.data_source
|
32 |
+
]
|
33 |
+
instance_thresholds = [
|
34 |
+
self.sampling_ratios[ins_type] if ins_type in self.sampling_ratios else 1.0 for ins_type in instance_types
|
35 |
+
]
|
36 |
+
for idx, threshold in enumerate(instance_thresholds):
|
37 |
+
if random.random() > threshold:
|
38 |
+
# Reject
|
39 |
+
list_idx = indices.index(idx)
|
40 |
+
del indices[list_idx], lengths[list_idx]
|
41 |
+
if original_num != len(indices):
|
42 |
+
logger.info(f'#instances reduced from {original_num} to {len(indices)}.')
|
43 |
+
|
44 |
+
max_lengths = [max(length) for length in lengths]
|
45 |
+
group_iterator = self._lazy_groups_of_max_size(indices, max_lengths)
|
46 |
+
|
47 |
+
batches = [list(group) for group in group_iterator]
|
48 |
+
random.shuffle(batches)
|
49 |
+
for batch in batches:
|
50 |
+
yield batch
|
sftp/data_reader/better_reader.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from collections import defaultdict, namedtuple
|
5 |
+
from typing import *
|
6 |
+
|
7 |
+
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
|
8 |
+
from allennlp.data.instance import Instance
|
9 |
+
|
10 |
+
from .span_reader import SpanReader
|
11 |
+
from ..utils import Span
|
12 |
+
|
13 |
+
# logging.basicConfig(level=logging.DEBUG)
|
14 |
+
|
15 |
+
# for v in logging.Logger.manager.loggerDict.values():
|
16 |
+
# v.disabled = True
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
SpanTuple = namedtuple('Span', ['start', 'end'])
|
21 |
+
|
22 |
+
|
23 |
+
@DatasetReader.register('better')
|
24 |
+
class BetterDatasetReader(SpanReader):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
eval_type,
|
28 |
+
consolidation_strategy='first',
|
29 |
+
span_set_type='single',
|
30 |
+
max_argument_ss_size=1,
|
31 |
+
use_ref_events=False,
|
32 |
+
**extra
|
33 |
+
):
|
34 |
+
super().__init__(**extra)
|
35 |
+
self.eval_type = eval_type
|
36 |
+
assert self.eval_type in ['abstract', 'basic']
|
37 |
+
|
38 |
+
self.consolidation_strategy = consolidation_strategy
|
39 |
+
self.unitary_spans = span_set_type == 'single'
|
40 |
+
# event anchors are always singleton spans
|
41 |
+
self.max_arg_spans = max_argument_ss_size
|
42 |
+
self.use_ref_events = use_ref_events
|
43 |
+
|
44 |
+
self.n_overlap_arg = 0
|
45 |
+
self.n_overlap_trigger = 0
|
46 |
+
self.n_skip = 0
|
47 |
+
self.n_too_long = 0
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def post_process_basic_span(predicted_span, basic_entry):
|
51 |
+
# Convert token offsets back to characters, also get the text spans as a sanity check
|
52 |
+
|
53 |
+
# !!!!!
|
54 |
+
# SF outputs inclusive idxs
|
55 |
+
# char offsets are inc-exc
|
56 |
+
# token offsets are inc-inc
|
57 |
+
# !!!!!
|
58 |
+
|
59 |
+
start_idx = predicted_span['start_idx'] # inc
|
60 |
+
end_idx = predicted_span['end_idx'] # inc
|
61 |
+
|
62 |
+
char_start_idx = basic_entry['tok2char'][predicted_span['start_idx']][0] # inc
|
63 |
+
char_end_idx = basic_entry['tok2char'][predicted_span['end_idx']][-1] + 1 # exc
|
64 |
+
|
65 |
+
span_text = basic_entry['segment-text'][char_start_idx:char_end_idx] # inc exc
|
66 |
+
span_text_tok = basic_entry['segment-text-tok'][start_idx:end_idx + 1] # inc exc
|
67 |
+
|
68 |
+
span = {'string': span_text,
|
69 |
+
'start': char_start_idx,
|
70 |
+
'end': char_end_idx,
|
71 |
+
'start-token': start_idx,
|
72 |
+
'end-token': end_idx,
|
73 |
+
'string-tok': span_text_tok,
|
74 |
+
'label': predicted_span['label'],
|
75 |
+
'predicted': True}
|
76 |
+
return span
|
77 |
+
|
78 |
+
@staticmethod
|
79 |
+
def _get_shortest_span(spans):
|
80 |
+
# shortest_span_length = float('inf')
|
81 |
+
# shortest_span = None
|
82 |
+
# for span in spans:
|
83 |
+
# span_tokens = span['string-tok']
|
84 |
+
# span_length = len(span_tokens)
|
85 |
+
# if span_length < shortest_span_length:
|
86 |
+
# shortest_span_length = span_length
|
87 |
+
# shortest_span = span
|
88 |
+
|
89 |
+
# return shortest_span
|
90 |
+
return [s[-1] for s in sorted([(len(span['string']), ix, span) for ix, span in enumerate(spans)])]
|
91 |
+
|
92 |
+
@staticmethod
|
93 |
+
def _get_first_span(spans):
|
94 |
+
spans = [(span['start'], -len(span['string']), ix, span) for ix, span in enumerate(spans)]
|
95 |
+
try:
|
96 |
+
return [s[-1] for s in sorted(spans)]
|
97 |
+
except:
|
98 |
+
breakpoint()
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def _get_longest_span(spans):
|
102 |
+
return [s[-1] for s in sorted([(len(span['string']), ix, span) for ix, span in enumerate(spans)], reverse=True)]
|
103 |
+
|
104 |
+
@staticmethod
|
105 |
+
def _subfinder(text, pattern):
|
106 |
+
# https://stackoverflow.com/a/12576755
|
107 |
+
matches = []
|
108 |
+
pattern_length = len(pattern)
|
109 |
+
for i, token in enumerate(text):
|
110 |
+
try:
|
111 |
+
if token == pattern[0] and text[i:i + pattern_length] == pattern:
|
112 |
+
matches.append(SpanTuple(start=i, end=i + pattern_length - 1)) # inclusive boundaries
|
113 |
+
except:
|
114 |
+
continue
|
115 |
+
return matches
|
116 |
+
|
117 |
+
def consolidate_span_set(self, spans):
|
118 |
+
if self.consolidation_strategy == 'first':
|
119 |
+
spans = BetterDatasetReader._get_first_span(spans)
|
120 |
+
elif self.consolidation_strategy == 'shortest':
|
121 |
+
spans = BetterDatasetReader._get_shortest_span(spans)
|
122 |
+
elif self.consolidation_strategy == 'longest':
|
123 |
+
spans = BetterDatasetReader._get_longest_span(spans)
|
124 |
+
else:
|
125 |
+
raise NotImplementedError(f"{self.consolidation_strategy} does not exist")
|
126 |
+
|
127 |
+
if self.unitary_spans:
|
128 |
+
spans = [spans[0]]
|
129 |
+
else:
|
130 |
+
spans = spans[:self.max_arg_spans]
|
131 |
+
|
132 |
+
# TODO add some sanity checks here
|
133 |
+
|
134 |
+
return spans
|
135 |
+
|
136 |
+
def get_mention_spans(self, text: List[str], span_sets: Dict):
|
137 |
+
mention_spans = defaultdict(list)
|
138 |
+
for span_set_id in span_sets.keys():
|
139 |
+
spans = span_sets[span_set_id]['spans']
|
140 |
+
# span = BetterDatasetReader._get_shortest_span(spans)
|
141 |
+
# span = BetterDatasetReader._get_earliest_span(spans)
|
142 |
+
consolidated_spans = self.consolidate_span_set(spans)
|
143 |
+
# if len(spans) > 1:
|
144 |
+
# logging.info(f"Truncated a spanset from {len(spans)} spans to 1")
|
145 |
+
|
146 |
+
if self.eval_type == 'abstract':
|
147 |
+
span = consolidated_spans[0]
|
148 |
+
span_tokens = span['string-tok']
|
149 |
+
|
150 |
+
span_indices = BetterDatasetReader._subfinder(text=text, pattern=span_tokens)
|
151 |
+
|
152 |
+
if len(span_indices) > 1:
|
153 |
+
pass
|
154 |
+
|
155 |
+
if len(span_indices) == 0:
|
156 |
+
continue
|
157 |
+
|
158 |
+
mention_spans[span_set_id] = span_indices[0]
|
159 |
+
else:
|
160 |
+
# in basic, we already have token offsets in the right form
|
161 |
+
|
162 |
+
# if not span['string-tok'] == text[span['start-token']:span['end-token'] + 1]:
|
163 |
+
# print(span, text[span['start-token']:span['end-token'] + 1])
|
164 |
+
|
165 |
+
# we should use these token offsets only!
|
166 |
+
for span in consolidated_spans:
|
167 |
+
mention_spans[span_set_id].append(SpanTuple(start=span['start-token'], end=span['end-token']))
|
168 |
+
|
169 |
+
return mention_spans
|
170 |
+
|
171 |
+
def _read_single_file(self, file_path):
|
172 |
+
with open(file_path) as fp:
|
173 |
+
json_content = json.load(fp)
|
174 |
+
if 'entries' in json_content:
|
175 |
+
for doc_name, entry in json_content['entries'].items():
|
176 |
+
instance = self.text_to_instance(entry, 'train' in file_path)
|
177 |
+
yield instance
|
178 |
+
else: # TODO why is this split in 2 cases?
|
179 |
+
for doc_name, entry in json_content.items():
|
180 |
+
instance = self.text_to_instance(entry, True)
|
181 |
+
yield instance
|
182 |
+
|
183 |
+
logger.warning(f'{self.n_overlap_arg} overlapped args detected!')
|
184 |
+
logger.warning(f'{self.n_overlap_trigger} overlapped triggers detected!')
|
185 |
+
logger.warning(f'{self.n_skip} skipped detected!')
|
186 |
+
logger.warning(f'{self.n_too_long} were skipped because they are too long!')
|
187 |
+
self.n_overlap_arg = self.n_skip = self.n_too_long = self.n_overlap_trigger = 0
|
188 |
+
|
189 |
+
def _read(self, file_path: str) -> Iterable[Instance]:
|
190 |
+
|
191 |
+
if os.path.isdir(file_path):
|
192 |
+
for fn in os.listdir(file_path):
|
193 |
+
if not fn.endswith('.json'):
|
194 |
+
logger.info(f'Skipping {fn}')
|
195 |
+
continue
|
196 |
+
logger.info(f'Loading from {fn}')
|
197 |
+
yield from self._read_single_file(os.path.join(file_path, fn))
|
198 |
+
else:
|
199 |
+
yield from self._read_single_file(file_path)
|
200 |
+
|
201 |
+
def text_to_instance(self, entry, is_training=False):
|
202 |
+
word_tokens = entry['segment-text-tok']
|
203 |
+
|
204 |
+
# span sets have been trimmed to the earliest span mention
|
205 |
+
spans = self.get_mention_spans(
|
206 |
+
word_tokens, entry['annotation-sets'][f'{self.eval_type}-events']['span-sets']
|
207 |
+
)
|
208 |
+
|
209 |
+
# idx of every token that is a part of an event trigger/anchor span
|
210 |
+
all_trigger_idxs = set()
|
211 |
+
|
212 |
+
# actual inputs to the model
|
213 |
+
input_spans = []
|
214 |
+
|
215 |
+
self._local_child_overlap = 0
|
216 |
+
self._local_child_total = 0
|
217 |
+
|
218 |
+
better_events = entry['annotation-sets'][f'{self.eval_type}-events']['events']
|
219 |
+
|
220 |
+
skipped_events = set()
|
221 |
+
# check for events that overlap other event's anchors, skip them later
|
222 |
+
for event_id, event in better_events.items():
|
223 |
+
assert event['anchors'] in spans
|
224 |
+
|
225 |
+
# take the first consolidated span for anchors
|
226 |
+
anchor_start, anchor_end = spans[event['anchors']][0]
|
227 |
+
|
228 |
+
if any(ix in all_trigger_idxs for ix in range(anchor_start, anchor_end + 1)):
|
229 |
+
logger.warning(
|
230 |
+
f"Skipped {event_id} with anchor span {event['anchors']}, overlaps a previously found event trigger/anchor")
|
231 |
+
self.n_overlap_trigger += 1
|
232 |
+
skipped_events.add(event_id)
|
233 |
+
continue
|
234 |
+
|
235 |
+
all_trigger_idxs.update(range(anchor_start, anchor_end + 1)) # record the trigger
|
236 |
+
|
237 |
+
for event_id, event in better_events.items():
|
238 |
+
if event_id in skipped_events:
|
239 |
+
continue
|
240 |
+
|
241 |
+
# arguments for just this event
|
242 |
+
local_arg_idxs = set()
|
243 |
+
# take the first consolidated span for anchors
|
244 |
+
anchor_start, anchor_end = spans[event['anchors']][0]
|
245 |
+
|
246 |
+
event_span = Span(anchor_start, anchor_end, event['event-type'], True)
|
247 |
+
input_spans.append(event_span)
|
248 |
+
|
249 |
+
def add_a_child(span_id, label):
|
250 |
+
# TODO this is a bad way to do this
|
251 |
+
assert span_id in spans
|
252 |
+
for child_span in spans[span_id]:
|
253 |
+
self._local_child_total += 1
|
254 |
+
arg_start, arg_end = child_span
|
255 |
+
|
256 |
+
if any(ix in local_arg_idxs for ix in range(arg_start, arg_end + 1)):
|
257 |
+
# logger.warn(f"Skipped argument {span_id}, overlaps a previously found argument")
|
258 |
+
# print(entry['annotation-sets'][f'{self.eval_type}-events']['span-sets'][span_id])
|
259 |
+
self.n_overlap_arg += 1
|
260 |
+
self._local_child_overlap += 1
|
261 |
+
continue
|
262 |
+
|
263 |
+
local_arg_idxs.update(range(arg_start, arg_end + 1))
|
264 |
+
event_span.add_child(Span(arg_start, arg_end, label, False))
|
265 |
+
|
266 |
+
for agent in event['agents']:
|
267 |
+
add_a_child(agent, 'agent')
|
268 |
+
for patient in event['patients']:
|
269 |
+
add_a_child(patient, 'patient')
|
270 |
+
|
271 |
+
if self.use_ref_events:
|
272 |
+
for ref_event in event['ref-events']:
|
273 |
+
if ref_event in skipped_events:
|
274 |
+
continue
|
275 |
+
ref_event_anchor_id = better_events[ref_event]['anchors']
|
276 |
+
add_a_child(ref_event_anchor_id, 'ref-event')
|
277 |
+
|
278 |
+
# if len(event['ref-events']) > 0:
|
279 |
+
# breakpoint()
|
280 |
+
|
281 |
+
fields = self.prepare_inputs(word_tokens, spans=input_spans)
|
282 |
+
if self._local_child_overlap > 0:
|
283 |
+
logging.warning(
|
284 |
+
f"Skipped {self._local_child_overlap} / {self._local_child_total} argument spans due to overlaps")
|
285 |
+
return Instance(fields)
|
286 |
+
|
sftp/data_reader/concrete_reader.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from collections import defaultdict
|
3 |
+
from typing import *
|
4 |
+
import os
|
5 |
+
|
6 |
+
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
|
7 |
+
from allennlp.data.instance import Instance
|
8 |
+
from concrete import SituationMention
|
9 |
+
from concrete.util import CommunicationReader
|
10 |
+
|
11 |
+
from .span_reader import SpanReader
|
12 |
+
from .srl_reader import SRLDatasetReader
|
13 |
+
from .concrete_srl import collect_concrete_srl
|
14 |
+
from ..utils import Span, BIOSmoothing
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
@DatasetReader.register('concrete')
|
20 |
+
class ConcreteDatasetReader(SRLDatasetReader):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
event_only: bool = False,
|
24 |
+
event_smoothing_factor: float = 0.,
|
25 |
+
arg_smoothing_factor: float = 0.,
|
26 |
+
**extra
|
27 |
+
):
|
28 |
+
super().__init__(**extra)
|
29 |
+
self.event_only = event_only
|
30 |
+
self.event_only = event_only
|
31 |
+
self.event_smooth_factor = event_smoothing_factor
|
32 |
+
self.arg_smooth_factor = arg_smoothing_factor
|
33 |
+
|
34 |
+
def _read(self, file_path: str) -> Iterable[Instance]:
|
35 |
+
if os.path.isdir(file_path):
|
36 |
+
for fn in os.listdir(file_path):
|
37 |
+
yield from self._read(os.path.join(file_path, fn))
|
38 |
+
all_files = CommunicationReader(file_path)
|
39 |
+
for comm, fn in all_files:
|
40 |
+
sentences = collect_concrete_srl(comm)
|
41 |
+
for tokens, vr in sentences:
|
42 |
+
yield self.text_to_instance(tokens, vr)
|
43 |
+
logger.warning(f'{self.n_span_removed} spans were removed')
|
44 |
+
self.n_span_removed = 0
|
sftp/data_reader/concrete_srl.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from time import time
|
2 |
+
from typing import *
|
3 |
+
from collections import defaultdict
|
4 |
+
|
5 |
+
from concrete import (
|
6 |
+
Token, TokenList, TextSpan, MentionArgument, SituationMentionSet, SituationMention, TokenRefSequence,
|
7 |
+
Communication, EntityMention, EntityMentionSet, Entity, EntitySet, AnnotationMetadata, Sentence
|
8 |
+
)
|
9 |
+
from concrete.util import create_comm, AnalyticUUIDGeneratorFactory
|
10 |
+
from concrete.validate import validate_communication
|
11 |
+
|
12 |
+
from ..utils import Span
|
13 |
+
|
14 |
+
|
15 |
+
def _process_sentence(sent, comm_sent, aug, char_idx_offset: int):
|
16 |
+
token_list = list()
|
17 |
+
for tok_idx, (start_idx, end_idx) in enumerate(sent['tokenization']):
|
18 |
+
token_list.append(Token(
|
19 |
+
tokenIndex=tok_idx,
|
20 |
+
text=sent['sentence'][start_idx:end_idx + 1],
|
21 |
+
textSpan=TextSpan(
|
22 |
+
start=start_idx + char_idx_offset,
|
23 |
+
ending=end_idx + char_idx_offset + 1
|
24 |
+
),
|
25 |
+
))
|
26 |
+
comm_sent.tokenization.tokenList = TokenList(tokenList=token_list)
|
27 |
+
|
28 |
+
sm_list, em_dict, entity_list = list(), dict(), list()
|
29 |
+
|
30 |
+
annotation = sent['annotations'] if isinstance(sent['annotations'], Span) else Span.from_json(sent['annotations'])
|
31 |
+
for event in annotation:
|
32 |
+
char_start_idx = sent['tokenization'][event.start_idx][0]
|
33 |
+
char_end_idx = sent['tokenization'][event.end_idx][1]
|
34 |
+
sm = SituationMention(
|
35 |
+
uuid=next(aug),
|
36 |
+
text=sent['sentence'][char_start_idx: char_end_idx + 1],
|
37 |
+
situationType='EVENT',
|
38 |
+
situationKind=event.label,
|
39 |
+
argumentList=list(),
|
40 |
+
tokens=TokenRefSequence(
|
41 |
+
tokenIndexList=list(range(event.start_idx, event.end_idx + 1)),
|
42 |
+
tokenizationId=comm_sent.tokenization.uuid
|
43 |
+
),
|
44 |
+
)
|
45 |
+
|
46 |
+
for arg in event:
|
47 |
+
em = em_dict.get((arg.start_idx, arg.end_idx + 1))
|
48 |
+
if em is None:
|
49 |
+
char_start_idx = sent['tokenization'][arg.start_idx][0]
|
50 |
+
char_end_idx = sent['tokenization'][arg.end_idx][1]
|
51 |
+
em = EntityMention(next(aug), TokenRefSequence(
|
52 |
+
tokenIndexList=list(range(arg.start_idx, arg.end_idx + 1)),
|
53 |
+
tokenizationId=comm_sent.tokenization.uuid,
|
54 |
+
), text=sent['sentence'][char_start_idx: char_end_idx + 1])
|
55 |
+
entity_list.append(Entity(next(aug), id=em.text, mentionIdList=[em.uuid]))
|
56 |
+
em_dict[(arg.start_idx, arg.end_idx + 1)] = em
|
57 |
+
sm.argumentList.append(MentionArgument(
|
58 |
+
role=arg.label,
|
59 |
+
entityMentionId=em.uuid,
|
60 |
+
))
|
61 |
+
|
62 |
+
sm_list.append(sm)
|
63 |
+
|
64 |
+
return sm_list, list(em_dict.values()), entity_list
|
65 |
+
|
66 |
+
|
67 |
+
def concrete_doc(
|
68 |
+
sentences: List[Dict[str, Any]],
|
69 |
+
doc_name: str = 'document',
|
70 |
+
) -> Communication:
|
71 |
+
"""
|
72 |
+
Data format: A list of sentences. Each sentence should be a dict of the following format:
|
73 |
+
{
|
74 |
+
"sentence": String.
|
75 |
+
"tokenization": A list of Tuple[int, int] for start and end indices. Both inclusive.
|
76 |
+
"annotations": A list of event dict, or Span object.
|
77 |
+
}
|
78 |
+
If it is dict, its format should be:
|
79 |
+
|
80 |
+
Each event should be a dict of the following format:
|
81 |
+
{
|
82 |
+
"span": [start_idx, end_idx]: Integer. Both inclusive.
|
83 |
+
"label": String.
|
84 |
+
"children": A list of arguments.
|
85 |
+
}
|
86 |
+
Each argument should be a dict of the following format:
|
87 |
+
{
|
88 |
+
"span": [start_idx, end_idx]: Integer. Both inclusive.
|
89 |
+
"label": String.
|
90 |
+
}
|
91 |
+
|
92 |
+
Note the "indices" above all refer to the indices of tokens, instead of characters.
|
93 |
+
"""
|
94 |
+
comm = create_comm(
|
95 |
+
doc_name,
|
96 |
+
'\n'.join([sent['sentence'] for sent in sentences]),
|
97 |
+
)
|
98 |
+
aug = AnalyticUUIDGeneratorFactory(comm).create()
|
99 |
+
situation_mention_set = SituationMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list())
|
100 |
+
comm.situationMentionSetList = [situation_mention_set]
|
101 |
+
entity_mention_set = EntityMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list())
|
102 |
+
comm.entityMentionSetList = [entity_mention_set]
|
103 |
+
entity_set = EntitySet(
|
104 |
+
next(aug), AnnotationMetadata('O(0) Coref Paser.', time()), list(), None, entity_mention_set.uuid
|
105 |
+
)
|
106 |
+
comm.entitySetList = [entity_set]
|
107 |
+
assert len(sentences) == len(comm.sectionList[0].sentenceList)
|
108 |
+
|
109 |
+
char_idx_offset = 0
|
110 |
+
for sent, comm_sent in zip(sentences, comm.sectionList[0].sentenceList):
|
111 |
+
sm_list, em_list, entity_list = _process_sentence(sent, comm_sent, aug, char_idx_offset)
|
112 |
+
entity_set.entityList.extend(entity_list)
|
113 |
+
situation_mention_set.mentionList.extend(sm_list)
|
114 |
+
entity_mention_set.mentionList.extend(em_list)
|
115 |
+
char_idx_offset += len(sent['sentence']) + 1
|
116 |
+
|
117 |
+
validate_communication(comm)
|
118 |
+
return comm
|
119 |
+
|
120 |
+
|
121 |
+
def concrete_doc_tokenized(
|
122 |
+
sentences: List[List[str]],
|
123 |
+
spans: List[Span],
|
124 |
+
doc_name: str = "document",
|
125 |
+
):
|
126 |
+
"""
|
127 |
+
Similar to concrete_doc, but with tokenized words and spans.
|
128 |
+
"""
|
129 |
+
inputs = list()
|
130 |
+
for sent, vr in zip(sentences, spans):
|
131 |
+
cur_start = 0
|
132 |
+
tokenization = list()
|
133 |
+
for token in sent:
|
134 |
+
tokenization.append((cur_start, cur_start + len(token) - 1))
|
135 |
+
cur_start += len(token) + 1
|
136 |
+
inputs.append({
|
137 |
+
"sentence": " ".join(sent),
|
138 |
+
"tokenization": tokenization,
|
139 |
+
"annotations": vr
|
140 |
+
})
|
141 |
+
return concrete_doc(inputs, doc_name)
|
142 |
+
|
143 |
+
|
144 |
+
def collect_concrete_srl(comm: Communication) -> List[Tuple[List[str], Span]]:
|
145 |
+
# Mapping from <sentence uuid> to [<ConcreteSentence>, <Associated situation mentions>]
|
146 |
+
sentences = defaultdict(lambda: [None, list()])
|
147 |
+
for sec in comm.sectionList:
|
148 |
+
for sen in sec.sentenceList:
|
149 |
+
sentences[sen.uuid.uuidString][0] = sen
|
150 |
+
# Assume there's only ONE situation mention set
|
151 |
+
assert len(comm.situationMentionSetList) == 1
|
152 |
+
# Assign each situation mention to the corresponding sentence
|
153 |
+
for men in comm.situationMentionSetList[0].mentionList:
|
154 |
+
if men.tokens is None: continue # For ACE relations
|
155 |
+
sentences[men.tokens.tokenization.sentence.uuid.uuidString][1].append(men)
|
156 |
+
ret = list()
|
157 |
+
for sen, mention_list in sentences.values():
|
158 |
+
tokens = [t.text for t in sen.tokenization.tokenList.tokenList]
|
159 |
+
spans = list()
|
160 |
+
for mention in mention_list:
|
161 |
+
mention_tokens = sorted(mention.tokens.tokenIndexList)
|
162 |
+
event = Span(mention_tokens[0], mention_tokens[-1], mention.situationKind, True)
|
163 |
+
for men_arg in mention.argumentList:
|
164 |
+
arg_tokens = sorted(men_arg.entityMention.tokens.tokenIndexList)
|
165 |
+
event.add_child(Span(arg_tokens[0], arg_tokens[-1], men_arg.role, False))
|
166 |
+
spans.append(event)
|
167 |
+
vr = Span.virtual_root(spans)
|
168 |
+
ret.append((tokens, vr))
|
169 |
+
return ret
|
sftp/data_reader/span_reader.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from abc import ABC
|
3 |
+
from typing import *
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from allennlp.common.util import END_SYMBOL
|
7 |
+
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
|
8 |
+
from allennlp.data.dataset_readers.dataset_utils.span_utils import bio_tags_to_spans
|
9 |
+
from allennlp.data.fields import *
|
10 |
+
from allennlp.data.token_indexers import PretrainedTransformerIndexer
|
11 |
+
from allennlp.data.tokenizers import PretrainedTransformerTokenizer, Token
|
12 |
+
|
13 |
+
from ..utils import Span, BIOSmoothing, apply_bio_smoothing
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
@DatasetReader.register('span')
|
19 |
+
class SpanReader(DatasetReader, ABC):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
pretrained_model: str,
|
23 |
+
max_length: int = 512,
|
24 |
+
ignore_label: bool = False,
|
25 |
+
debug: bool = False,
|
26 |
+
**extras
|
27 |
+
) -> None:
|
28 |
+
"""
|
29 |
+
:param pretrained_model: The name of the pretrained model. E.g. xlm-roberta-large
|
30 |
+
:param max_length: Sequences longer than this limit will be truncated.
|
31 |
+
:param ignore_label: If True, label on spans will be anonymized.
|
32 |
+
:param debug: True to turn on debugging mode.
|
33 |
+
:param span_proposals: Needed for "enumeration" scheme, but not needed for "BIO".
|
34 |
+
If True, it will try to enumerate candidate spans in the sentence, which will then be fed into
|
35 |
+
a binary classifier (EnumSpanFinder).
|
36 |
+
Note: It might take time to propose spans. And better to use SpacyTokenizer if you want to call
|
37 |
+
constituency parser or dependency parser.
|
38 |
+
:param maximum_negative_spans: Necessary for EnumSpanFinder.
|
39 |
+
:param extras: Args to DatasetReader.
|
40 |
+
"""
|
41 |
+
super().__init__(**extras)
|
42 |
+
self.word_indexer = {
|
43 |
+
'pieces': PretrainedTransformerIndexer(pretrained_model, namespace='pieces')
|
44 |
+
}
|
45 |
+
|
46 |
+
self._pretrained_model_name = pretrained_model
|
47 |
+
self.debug = debug
|
48 |
+
self.ignore_label = ignore_label
|
49 |
+
|
50 |
+
self._pretrained_tokenizer = PretrainedTransformerTokenizer(pretrained_model)
|
51 |
+
self.max_length = max_length
|
52 |
+
self.n_span_removed = 0
|
53 |
+
|
54 |
+
def retokenize(
|
55 |
+
self, sentence: List[str], truncate: bool = True
|
56 |
+
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
|
57 |
+
pieces, offsets = self._pretrained_tokenizer.intra_word_tokenize(sentence)
|
58 |
+
pieces = list(map(str, pieces))
|
59 |
+
if truncate:
|
60 |
+
pieces = pieces[:self.max_length]
|
61 |
+
pieces[-1] = END_SYMBOL
|
62 |
+
return pieces, offsets
|
63 |
+
|
64 |
+
def prepare_inputs(
|
65 |
+
self,
|
66 |
+
sentence: List[str],
|
67 |
+
spans: Optional[Union[List[Span], Span]] = None,
|
68 |
+
truncate: bool = True,
|
69 |
+
label_type: str = 'string',
|
70 |
+
) -> Dict[str, Field]:
|
71 |
+
"""
|
72 |
+
Prepare inputs and auxiliary variables for span model.
|
73 |
+
:param sentence: A list of tokens. Do not pass in any special tokens, like BOS or EOS.
|
74 |
+
Necessary for both training and testing.
|
75 |
+
:param spans: Optional. For training, spans passed in will be considered as positive examples; the spans
|
76 |
+
that are automatically proposed and not in the positive set will be considered as negative examples.
|
77 |
+
Necessary for training.
|
78 |
+
:param truncate: If True, sequence will be truncated if it's longer than `self.max_training_length`
|
79 |
+
:param label_type: One of [string, list].
|
80 |
+
|
81 |
+
:return: Dict of AllenNLP fields. For detailed of explanation of every field, refer to the comments
|
82 |
+
below. For the shape of every field, check the module doc.
|
83 |
+
Fields list:
|
84 |
+
- words
|
85 |
+
- span_labels
|
86 |
+
- span_boundary
|
87 |
+
- parent_indices
|
88 |
+
- parent_mask
|
89 |
+
- bio_seqs
|
90 |
+
- raw_sentence
|
91 |
+
- raw_spans
|
92 |
+
- proposed_spans
|
93 |
+
"""
|
94 |
+
fields = dict()
|
95 |
+
|
96 |
+
pieces, offsets = self.retokenize(sentence, truncate)
|
97 |
+
fields['tokens'] = TextField(list(map(Token, pieces)), self.word_indexer)
|
98 |
+
raw_inputs = {'sentence': sentence, "pieces": pieces, 'offsets': offsets}
|
99 |
+
fields['raw_inputs'] = MetadataField(raw_inputs)
|
100 |
+
|
101 |
+
if spans is None:
|
102 |
+
return fields
|
103 |
+
|
104 |
+
vr = spans if isinstance(spans, Span) else Span.virtual_root(spans)
|
105 |
+
self.n_span_removed = vr.remove_overlapping()
|
106 |
+
raw_inputs['spans'] = vr
|
107 |
+
|
108 |
+
vr = vr.re_index(offsets)
|
109 |
+
if truncate:
|
110 |
+
vr.truncate(self.max_length)
|
111 |
+
if self.ignore_label:
|
112 |
+
vr.ignore_labels()
|
113 |
+
|
114 |
+
# (start_idx, end_idx) pairs. Left and right inclusive.
|
115 |
+
# The first span is the Virtual Root node. Shape [span, 2]
|
116 |
+
span_boundary = list()
|
117 |
+
# label on span. Shape [span]
|
118 |
+
span_labels = list()
|
119 |
+
# parent idx (span indexing space). Shape [span]
|
120 |
+
span_parent_indices = list()
|
121 |
+
# True for parents. Shape [span]
|
122 |
+
parent_mask = [False] * vr.n_nodes
|
123 |
+
# Key: parent idx (span indexing space). Value: child span idx
|
124 |
+
flatten_spans = list(vr.bfs())
|
125 |
+
for span_idx, span in enumerate(vr.bfs()):
|
126 |
+
if span.is_parent:
|
127 |
+
parent_mask[span_idx] = True
|
128 |
+
# 0 is the virtual root
|
129 |
+
parent_idx = flatten_spans.index(span.parent) if span.parent else 0
|
130 |
+
span_parent_indices.append(parent_idx)
|
131 |
+
span_boundary.append(span.boundary)
|
132 |
+
span_labels.append(span.label)
|
133 |
+
|
134 |
+
bio_tag_list: List[List[str]] = list()
|
135 |
+
bio_configs: List[List[BIOSmoothing]] = list()
|
136 |
+
# Shape: [#parent, #token, 3]
|
137 |
+
bio_seqs: List[np.ndarray] = list()
|
138 |
+
# Parent index for every BIO seq
|
139 |
+
for parent_idx, parent in filter(lambda node: node[1].is_parent, enumerate(flatten_spans)):
|
140 |
+
bio_tags = ['O'] * len(pieces)
|
141 |
+
bio_tag_list.append(bio_tags)
|
142 |
+
bio_smooth: List[BIOSmoothing] = [parent.child_smooth.clone() for _ in pieces]
|
143 |
+
bio_configs.append(bio_smooth)
|
144 |
+
for child in parent:
|
145 |
+
assert all(bio_tags[bio_idx] == 'O' for bio_idx in range(child.start_idx, child.end_idx + 1))
|
146 |
+
if child.smooth_weight is not None:
|
147 |
+
for i in range(child.start_idx, child.end_idx+1):
|
148 |
+
bio_smooth[i].weight = child.smooth_weight
|
149 |
+
bio_tags[child.start_idx] = 'B'
|
150 |
+
for word_idx in range(child.start_idx + 1, child.end_idx + 1):
|
151 |
+
bio_tags[word_idx] = 'I'
|
152 |
+
bio_seqs.append(apply_bio_smoothing(bio_smooth, bio_tags))
|
153 |
+
|
154 |
+
fields['span_boundary'] = ArrayField(
|
155 |
+
np.array(span_boundary), padding_value=0, dtype=np.int
|
156 |
+
)
|
157 |
+
fields['parent_indices'] = ArrayField(np.array(span_parent_indices), 0, np.int)
|
158 |
+
if label_type == 'string':
|
159 |
+
fields['span_labels'] = ListField([LabelField(label, 'span_label') for label in span_labels])
|
160 |
+
elif label_type == 'list':
|
161 |
+
fields['span_labels'] = ArrayField(np.array(span_labels))
|
162 |
+
else:
|
163 |
+
raise NotImplementedError
|
164 |
+
fields['parent_mask'] = ArrayField(np.array(parent_mask), False, np.bool)
|
165 |
+
fields['bio_seqs'] = ArrayField(np.stack(bio_seqs))
|
166 |
+
|
167 |
+
self._sanity_check(
|
168 |
+
flatten_spans, pieces, bio_tag_list, parent_mask, span_boundary, span_labels, span_parent_indices
|
169 |
+
)
|
170 |
+
|
171 |
+
return fields
|
172 |
+
|
173 |
+
@staticmethod
|
174 |
+
def _sanity_check(
|
175 |
+
flatten_spans, words, bio_tag_list, parent_mask, span_boundary, span_labels, parent_indices, verbose=False
|
176 |
+
):
|
177 |
+
# For debugging use.
|
178 |
+
assert len(parent_mask) == len(span_boundary) == len(span_labels) == len(parent_indices)
|
179 |
+
for (parent_idx, parent_span), bio_tags in zip(
|
180 |
+
filter(lambda x: x[1].is_parent, enumerate(flatten_spans)), bio_tag_list
|
181 |
+
):
|
182 |
+
assert parent_mask[parent_idx]
|
183 |
+
parent_s, parent_e = span_boundary[parent_idx]
|
184 |
+
if verbose:
|
185 |
+
print('Parent: ', span_labels[parent_idx], 'Text: ', ' '.join(words[parent_s:parent_e+1]))
|
186 |
+
print(f'It contains {len(parent_span)} children.')
|
187 |
+
for child in parent_span:
|
188 |
+
child_idx = flatten_spans.index(child)
|
189 |
+
assert parent_indices[child_idx] == flatten_spans.index(parent_span)
|
190 |
+
if verbose:
|
191 |
+
child_s, child_e = span_boundary[child_idx]
|
192 |
+
print(' ', span_labels[child_idx], 'Text', words[child_s:child_e+1])
|
193 |
+
|
194 |
+
if verbose:
|
195 |
+
print(f'Child derived from BIO tags:')
|
196 |
+
for _, (start, end) in bio_tags_to_spans(bio_tags):
|
197 |
+
print(words[start:end+1])
|
sftp/data_reader/srl_reader.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import random
|
4 |
+
from typing import *
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
|
8 |
+
from allennlp.data.fields import MetadataField
|
9 |
+
from allennlp.data.instance import Instance
|
10 |
+
|
11 |
+
from .span_reader import SpanReader
|
12 |
+
from ..utils import Span, VIRTUAL_ROOT, BIOSmoothing
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
@DatasetReader.register('semantic_role_labeling')
|
18 |
+
class SRLDatasetReader(SpanReader):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
min_negative: int = 5,
|
22 |
+
negative_ratio: float = 1.,
|
23 |
+
event_only: bool = False,
|
24 |
+
event_smoothing_factor: float = 0.,
|
25 |
+
arg_smoothing_factor: float = 0.,
|
26 |
+
# For Ontology Mapping
|
27 |
+
ontology_mapping_path: Optional[str] = None,
|
28 |
+
min_weight: float = 1e-2,
|
29 |
+
max_weight: float = 1.0,
|
30 |
+
**extra
|
31 |
+
):
|
32 |
+
super().__init__(**extra)
|
33 |
+
self.min_negative = min_negative
|
34 |
+
self.negative_ratio = negative_ratio
|
35 |
+
self.event_only = event_only
|
36 |
+
self.event_smooth_factor = event_smoothing_factor
|
37 |
+
self.arg_smooth_factor = arg_smoothing_factor
|
38 |
+
self.ontology_mapping = None
|
39 |
+
if ontology_mapping_path is not None:
|
40 |
+
self.ontology_mapping = json.load(open(ontology_mapping_path))
|
41 |
+
for k1 in ['event', 'argument']:
|
42 |
+
for k2, weights in self.ontology_mapping['mapping'][k1].items():
|
43 |
+
weights = np.array(weights)
|
44 |
+
weights[weights < min_weight] = 0.0
|
45 |
+
weights[weights > max_weight] = max_weight
|
46 |
+
self.ontology_mapping['mapping'][k1][k2] = weights
|
47 |
+
self.ontology_mapping['mapping'][k1] = {
|
48 |
+
k2: weights for k2, weights in self.ontology_mapping['mapping'][k1].items() if weights.sum() > 1e-5
|
49 |
+
}
|
50 |
+
vr_label = [0.] * len(self.ontology_mapping['target']['label'])
|
51 |
+
vr_label[self.ontology_mapping['target']['label'].index(VIRTUAL_ROOT)] = 1.0
|
52 |
+
self.ontology_mapping['mapping']['event'][VIRTUAL_ROOT] = np.array(vr_label)
|
53 |
+
|
54 |
+
def _read(self, file_path: str) -> Iterable[Instance]:
|
55 |
+
all_lines = list(map(json.loads, open(file_path).readlines()))
|
56 |
+
if self.debug:
|
57 |
+
random.seed(1); random.shuffle(all_lines)
|
58 |
+
for line in all_lines:
|
59 |
+
ins = self.text_to_instance(**line)
|
60 |
+
if ins is not None:
|
61 |
+
yield ins
|
62 |
+
if self.n_span_removed > 0:
|
63 |
+
logger.warning(f'{self.n_span_removed} spans are removed.')
|
64 |
+
self.n_span_removed = 0
|
65 |
+
|
66 |
+
def apply_ontology_mapping(self, vr):
|
67 |
+
new_events = list()
|
68 |
+
event_map, arg_map = self.ontology_mapping['mapping']['event'], self.ontology_mapping['mapping']['argument']
|
69 |
+
for event in vr:
|
70 |
+
if event.label not in event_map: continue
|
71 |
+
event.child_smooth.weight = event.smooth_weight = event_map[event.label].sum()
|
72 |
+
event = event.map_ontology(event_map, False, False)
|
73 |
+
new_events.append(event)
|
74 |
+
new_children = list()
|
75 |
+
for child in event:
|
76 |
+
if child.label not in arg_map: continue
|
77 |
+
child.child_smooth.weight = child.smooth_weight = arg_map[child.label].sum()
|
78 |
+
child = child.map_ontology(arg_map, False, False)
|
79 |
+
new_children.append(child)
|
80 |
+
event.remove_child()
|
81 |
+
for child in new_children: event.add_child(child)
|
82 |
+
new_vr = Span.virtual_root(new_events)
|
83 |
+
# For Virtual Root itself.
|
84 |
+
new_vr.map_ontology(self.ontology_mapping['mapping']['event'], True, False)
|
85 |
+
return new_vr
|
86 |
+
|
87 |
+
def text_to_instance(self, tokens, annotations=None, meta=None) -> Optional[Instance]:
|
88 |
+
meta = meta or {'fully_annotated': True}
|
89 |
+
meta['fully_annotated'] = meta.get('fully_annotated', True)
|
90 |
+
vr = None
|
91 |
+
if annotations is not None:
|
92 |
+
vr = annotations if isinstance(annotations, Span) else Span.from_json(annotations)
|
93 |
+
vr = self.apply_ontology_mapping(vr) if self.ontology_mapping is not None else vr
|
94 |
+
# if len(vr) == 0: return # Ignore sentence with empty annotation
|
95 |
+
if self.event_smooth_factor != 0.0:
|
96 |
+
vr.child_smooth = BIOSmoothing(o_smooth=self.event_smooth_factor if meta['fully_annotated'] else -1)
|
97 |
+
if self.arg_smooth_factor != 0.0:
|
98 |
+
for event in vr:
|
99 |
+
event.child_smooth = BIOSmoothing(o_smooth=self.arg_smooth_factor)
|
100 |
+
if self.event_only:
|
101 |
+
for event in vr:
|
102 |
+
event.remove_child()
|
103 |
+
event.is_parent = False
|
104 |
+
|
105 |
+
fields = self.prepare_inputs(tokens, vr, True, 'string' if self.ontology_mapping is None else 'list')
|
106 |
+
fields['meta'] = MetadataField(meta)
|
107 |
+
return Instance(fields)
|
sftp/metrics/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sftp.metrics.base_f import BaseF
|
2 |
+
from sftp.metrics.exact_match import ExactMatch
|
3 |
+
from sftp.metrics.fbeta_mix_measure import FBetaMixMeasure
|
4 |
+
from sftp.metrics.srl_metrics import SRLMetric
|
sftp/metrics/base_f.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
from typing import *
|
3 |
+
|
4 |
+
from allennlp.training.metrics import Metric
|
5 |
+
|
6 |
+
|
7 |
+
class BaseF(Metric, ABC):
|
8 |
+
def __init__(self, prefix: str):
|
9 |
+
self.tp = self.fp = self.fn = 0
|
10 |
+
self.prefix = prefix
|
11 |
+
|
12 |
+
def reset(self) -> None:
|
13 |
+
self.tp = self.fp = self.fn = 0
|
14 |
+
|
15 |
+
def get_metric(
|
16 |
+
self, reset: bool
|
17 |
+
) -> Union[float, Tuple[float, ...], Dict[str, float], Dict[str, List[float]]]:
|
18 |
+
precision = self.tp * 100 / (self.tp + self.fp) if self.tp > 0 else 0.
|
19 |
+
recall = self.tp * 100 / (self.tp + self.fn) if self.tp > 0 else 0.
|
20 |
+
rst = {
|
21 |
+
f'{self.prefix}_p': precision,
|
22 |
+
f'{self.prefix}_r': recall,
|
23 |
+
f'{self.prefix}_f': 2 / (1 / precision + 1 / recall) if self.tp > 0 else 0.
|
24 |
+
}
|
25 |
+
if reset:
|
26 |
+
self.reset()
|
27 |
+
return rst
|
sftp/metrics/exact_match.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from allennlp.training.metrics import Metric
|
2 |
+
from overrides import overrides
|
3 |
+
|
4 |
+
from .base_f import BaseF
|
5 |
+
from ..utils import Span
|
6 |
+
|
7 |
+
|
8 |
+
@Metric.register('exact_match')
|
9 |
+
class ExactMatch(BaseF):
|
10 |
+
def __init__(self, check_type: bool):
|
11 |
+
self.check_type = check_type
|
12 |
+
if check_type:
|
13 |
+
super(ExactMatch, self).__init__('em')
|
14 |
+
else:
|
15 |
+
super(ExactMatch, self).__init__('sm')
|
16 |
+
|
17 |
+
@overrides
|
18 |
+
def __call__(
|
19 |
+
self,
|
20 |
+
prediction: Span,
|
21 |
+
gold: Span,
|
22 |
+
):
|
23 |
+
tp = prediction.match(gold, self.check_type) - 1
|
24 |
+
fp = prediction.n_nodes - tp - 1
|
25 |
+
fn = gold.n_nodes - tp - 1
|
26 |
+
assert tp >= 0 and fp >= 0 and fn >= 0
|
27 |
+
self.tp += tp
|
28 |
+
self.fp += fp
|
29 |
+
self.fn += fn
|
sftp/metrics/fbeta_mix_measure.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from allennlp.training.metrics import FBetaMeasure, Metric
|
2 |
+
|
3 |
+
|
4 |
+
@Metric.register('fbeta_mix')
|
5 |
+
class FBetaMixMeasure(FBetaMeasure):
|
6 |
+
def __init__(self, null_idx, **kwargs):
|
7 |
+
super().__init__(**kwargs)
|
8 |
+
self.null_idx = null_idx
|
9 |
+
|
10 |
+
def get_metric(self, reset: bool = False):
|
11 |
+
|
12 |
+
tp = float(self._true_positive_sum.sum() - self._true_positive_sum[self.null_idx])
|
13 |
+
total_pred = float(self._pred_sum.sum() - self._pred_sum[self.null_idx])
|
14 |
+
total_gold = float(self._true_sum.sum() - self._true_sum[self.null_idx])
|
15 |
+
|
16 |
+
beta2 = self._beta ** 2
|
17 |
+
p = 0. if total_pred == 0 else tp / total_pred
|
18 |
+
r = 0. if total_pred == 0 else tp / total_gold
|
19 |
+
f = 0. if p == 0. or r == 0. else ((1 + beta2) * p * r / (p * beta2 + r))
|
20 |
+
|
21 |
+
mix_f = {
|
22 |
+
'p': p * 100,
|
23 |
+
'r': r * 100,
|
24 |
+
'f': f * 100
|
25 |
+
}
|
26 |
+
|
27 |
+
if reset:
|
28 |
+
self.reset()
|
29 |
+
|
30 |
+
return mix_f
|
31 |
+
|
32 |
+
def add_false_negative(self, labels):
|
33 |
+
for lab in labels:
|
34 |
+
self._true_sum[lab] += 1
|
sftp/metrics/srl_metrics.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
|
3 |
+
from allennlp.training.metrics import Metric
|
4 |
+
from overrides import overrides
|
5 |
+
import numpy as np
|
6 |
+
import logging
|
7 |
+
|
8 |
+
from .base_f import BaseF
|
9 |
+
from ..utils import Span, max_match
|
10 |
+
|
11 |
+
logger = logging.getLogger('srl_metric')
|
12 |
+
|
13 |
+
|
14 |
+
@Metric.register('srl')
|
15 |
+
class SRLMetric(Metric):
|
16 |
+
def __init__(self, check_type: Optional[bool] = None):
|
17 |
+
self.tri_i = BaseF('tri-i')
|
18 |
+
self.tri_c = BaseF('tri-c')
|
19 |
+
self.arg_i = BaseF('arg-i')
|
20 |
+
self.arg_c = BaseF('arg-c')
|
21 |
+
if check_type is not None:
|
22 |
+
logger.warning('Check type argument is deprecated.')
|
23 |
+
|
24 |
+
def reset(self) -> None:
|
25 |
+
for metric in [self.tri_i, self.tri_c, self.arg_i, self.arg_c]:
|
26 |
+
metric.reset()
|
27 |
+
|
28 |
+
def get_metric(self, reset: bool) -> Dict[str, Any]:
|
29 |
+
ret = dict()
|
30 |
+
for metric in [self.tri_i, self.tri_c, self.arg_i, self.arg_c]:
|
31 |
+
ret.update(metric.get_metric(reset))
|
32 |
+
return ret
|
33 |
+
|
34 |
+
@overrides
|
35 |
+
def __call__(self, prediction: Span, gold: Span):
|
36 |
+
self.with_label_event(prediction, gold)
|
37 |
+
self.without_label_event(prediction, gold)
|
38 |
+
self.tuple_eval(prediction, gold)
|
39 |
+
# self.with_label_arg(prediction, gold)
|
40 |
+
# self.without_label_arg(prediction, gold)
|
41 |
+
|
42 |
+
def tuple_eval(self, prediction: Span, gold: Span):
|
43 |
+
def extract_tuples(vr: Span, parent_boundary: bool):
|
44 |
+
labeled, unlabeled = list(), list()
|
45 |
+
for event in vr:
|
46 |
+
for arg in event:
|
47 |
+
if parent_boundary:
|
48 |
+
labeled.append((event.boundary, event.label, arg.boundary, arg.label))
|
49 |
+
unlabeled.append((event.boundary, event.label, arg.boundary))
|
50 |
+
else:
|
51 |
+
labeled.append((event.label, arg.boundary, arg.label))
|
52 |
+
unlabeled.append((event.label, arg.boundary))
|
53 |
+
return labeled, unlabeled
|
54 |
+
|
55 |
+
def equal_matrix(l1, l2): return np.array([[e1 == e2 for e2 in l2] for e1 in l1], dtype=np.int)
|
56 |
+
|
57 |
+
pred_label, pred_unlabel = extract_tuples(prediction, False)
|
58 |
+
gold_label, gold_unlabel = extract_tuples(gold, False)
|
59 |
+
|
60 |
+
if len(pred_label) == 0 or len(gold_label) == 0:
|
61 |
+
arg_c_tp = arg_i_tp = 0
|
62 |
+
else:
|
63 |
+
label_bipartite = equal_matrix(pred_label, gold_label)
|
64 |
+
unlabel_bipartite = equal_matrix(pred_unlabel, gold_unlabel)
|
65 |
+
arg_c_tp, arg_i_tp = max_match(label_bipartite), max_match(unlabel_bipartite)
|
66 |
+
|
67 |
+
arg_c_fp = prediction.n_nodes - len(prediction) - 1 - arg_c_tp
|
68 |
+
arg_c_fn = gold.n_nodes - len(gold) - 1 - arg_c_tp
|
69 |
+
arg_i_fp = prediction.n_nodes - len(prediction) - 1 - arg_i_tp
|
70 |
+
arg_i_fn = gold.n_nodes - len(gold) - 1 - arg_i_tp
|
71 |
+
|
72 |
+
assert arg_i_tp >= 0 and arg_i_fn >= 0 and arg_i_fp >= 0
|
73 |
+
self.arg_i.tp += arg_i_tp
|
74 |
+
self.arg_i.fp += arg_i_fp
|
75 |
+
self.arg_i.fn += arg_i_fn
|
76 |
+
|
77 |
+
assert arg_c_tp >= 0 and arg_c_fn >= 0 and arg_c_fp >= 0
|
78 |
+
self.arg_c.tp += arg_c_tp
|
79 |
+
self.arg_c.fp += arg_c_fp
|
80 |
+
self.arg_c.fn += arg_c_fn
|
81 |
+
|
82 |
+
def with_label_event(self, prediction: Span, gold: Span):
|
83 |
+
trigger_tp = prediction.match(gold, True, 2) - 1
|
84 |
+
trigger_fp = len(prediction) - trigger_tp
|
85 |
+
trigger_fn = len(gold) - trigger_tp
|
86 |
+
assert trigger_fp >= 0 and trigger_fn >= 0 and trigger_tp >= 0
|
87 |
+
self.tri_c.tp += trigger_tp
|
88 |
+
self.tri_c.fp += trigger_fp
|
89 |
+
self.tri_c.fn += trigger_fn
|
90 |
+
|
91 |
+
def with_label_arg(self, prediction: Span, gold: Span):
|
92 |
+
trigger_tp = prediction.match(gold, True, 2) - 1
|
93 |
+
role_tp = prediction.match(gold, True, ignore_parent_boundary=True) - 1 - trigger_tp
|
94 |
+
role_fp = (prediction.n_nodes - 1 - len(prediction)) - role_tp
|
95 |
+
role_fn = (gold.n_nodes - 1 - len(gold)) - role_tp
|
96 |
+
assert role_fp >= 0 and role_fn >= 0 and role_tp >= 0
|
97 |
+
self.arg_c.tp += role_tp
|
98 |
+
self.arg_c.fp += role_fp
|
99 |
+
self.arg_c.fn += role_fn
|
100 |
+
|
101 |
+
def without_label_event(self, prediction: Span, gold: Span):
|
102 |
+
tri_i_tp = prediction.match(gold, False, 2) - 1
|
103 |
+
tri_i_fp = len(prediction) - tri_i_tp
|
104 |
+
tri_i_fn = len(gold) - tri_i_tp
|
105 |
+
assert tri_i_tp >= 0 and tri_i_fp >= 0 and tri_i_fn >= 0
|
106 |
+
self.tri_i.tp += tri_i_tp
|
107 |
+
self.tri_i.fp += tri_i_fp
|
108 |
+
self.tri_i.fn += tri_i_fn
|
109 |
+
|
110 |
+
def without_label_arg(self, prediction: Span, gold: Span):
|
111 |
+
arg_i_tp = 0
|
112 |
+
matched_pairs: List[Tuple[Span, Span]] = list()
|
113 |
+
n_gold_arg, n_pred_arg = gold.n_nodes - len(gold) - 1, prediction.n_nodes - len(prediction) - 1
|
114 |
+
prediction, gold = prediction.clone(), gold.clone()
|
115 |
+
for p in prediction:
|
116 |
+
for g in gold:
|
117 |
+
if p.match(g, True, 1) == 1:
|
118 |
+
arg_i_tp += (p.match(g, False) - 1)
|
119 |
+
matched_pairs.append((p, g))
|
120 |
+
break
|
121 |
+
for p, g in matched_pairs:
|
122 |
+
prediction.remove_child(p)
|
123 |
+
gold.remove_child(g)
|
124 |
+
|
125 |
+
sub_matches = np.zeros([len(prediction), len(gold)], np.int)
|
126 |
+
for p_idx, p in enumerate(prediction):
|
127 |
+
for g_idx, g in enumerate(gold):
|
128 |
+
if p.label == g.label:
|
129 |
+
sub_matches[p_idx, g_idx] = p.match(g, False, -1, True)
|
130 |
+
arg_i_tp += max_match(sub_matches)
|
131 |
+
|
132 |
+
arg_i_fp = n_pred_arg - arg_i_tp
|
133 |
+
arg_i_fn = n_gold_arg - arg_i_tp
|
134 |
+
assert arg_i_tp >= 0 and arg_i_fn >= 0 and arg_i_fp >= 0
|
135 |
+
|
136 |
+
self.arg_i.tp += arg_i_tp
|
137 |
+
self.arg_i.fp += arg_i_fp
|
138 |
+
self.arg_i.fn += arg_i_fn
|
sftp/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from sftp.models.span_model import SpanModel
|
sftp/models/span_model.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import *
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from allennlp.common.from_params import Params, T, pop_and_construct_arg
|
6 |
+
from allennlp.data.vocabulary import Vocabulary, DEFAULT_PADDING_TOKEN, DEFAULT_OOV_TOKEN
|
7 |
+
from allennlp.models.model import Model
|
8 |
+
from allennlp.modules import TextFieldEmbedder
|
9 |
+
from allennlp.modules.seq2seq_encoders.pytorch_seq2seq_wrapper import Seq2SeqEncoder
|
10 |
+
from allennlp.modules.span_extractors import SpanExtractor
|
11 |
+
from allennlp.training.metrics import Metric
|
12 |
+
|
13 |
+
from ..metrics import ExactMatch
|
14 |
+
from ..modules import SpanFinder, SpanTyping
|
15 |
+
from ..utils import num2mask, VIRTUAL_ROOT, Span, tensor2span
|
16 |
+
|
17 |
+
|
18 |
+
@Model.register("span")
|
19 |
+
class SpanModel(Model):
|
20 |
+
"""
|
21 |
+
Identify/Find spans; link them as a tree; label them.
|
22 |
+
"""
|
23 |
+
default_predictor = 'span'
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
vocab: Vocabulary,
|
28 |
+
|
29 |
+
# Modules
|
30 |
+
word_embedding: TextFieldEmbedder,
|
31 |
+
span_extractor: SpanExtractor,
|
32 |
+
span_finder: SpanFinder,
|
33 |
+
span_typing: SpanTyping,
|
34 |
+
|
35 |
+
# Config
|
36 |
+
typing_loss_factor: float = 1.,
|
37 |
+
max_recursion_depth: int = -1,
|
38 |
+
max_decoding_spans: int = -1,
|
39 |
+
debug: bool = False,
|
40 |
+
|
41 |
+
# Ontology Constraints
|
42 |
+
ontology_path: Optional[str] = None,
|
43 |
+
|
44 |
+
# Metrics
|
45 |
+
metrics: Optional[List[Metric]] = None,
|
46 |
+
) -> None:
|
47 |
+
"""
|
48 |
+
Note for jsonnet file: it doesn't strictly follow the init examples of every module for that we override
|
49 |
+
the from_params method.
|
50 |
+
You can either check the SpanModel.from_params or the example jsonnet file.
|
51 |
+
:param vocab: No need to specify.
|
52 |
+
## Modules
|
53 |
+
:param word_embedding: Refer to the module doc.
|
54 |
+
:param span_extractor: Refer to the module doc.
|
55 |
+
:param span_finder: Refer to the module doc.
|
56 |
+
:param span_typing: Refer to the module doc.
|
57 |
+
## Configs
|
58 |
+
:param typing_loss_factor: loss = span_finder_loss + span_typing_loss * typing_loss_factor
|
59 |
+
:param max_recursion_depth: Maximum tree depth for inference. E.g., 1 for shallow event typing, 2 for SRL,
|
60 |
+
-1 (unlimited) for dependency parsing.
|
61 |
+
:param max_decoding_spans: Maximum spans for inference. -1 for unlimited.
|
62 |
+
:param debug: Useless now.
|
63 |
+
"""
|
64 |
+
self._pad_idx = vocab.get_token_index(DEFAULT_PADDING_TOKEN, 'token')
|
65 |
+
self._null_idx = vocab.get_token_index(DEFAULT_OOV_TOKEN, 'span_label')
|
66 |
+
super().__init__(vocab)
|
67 |
+
|
68 |
+
self.word_embedding = word_embedding
|
69 |
+
self._span_finder = span_finder
|
70 |
+
self._span_extractor = span_extractor
|
71 |
+
self._span_typing = span_typing
|
72 |
+
|
73 |
+
self.metrics = [ExactMatch(True), ExactMatch(False)]
|
74 |
+
if metrics is not None:
|
75 |
+
self.metrics.extend(metrics)
|
76 |
+
|
77 |
+
if ontology_path is not None and os.path.exists(ontology_path):
|
78 |
+
self._span_typing.load_ontology(ontology_path, self.vocab)
|
79 |
+
|
80 |
+
self._max_decoding_spans = max_decoding_spans
|
81 |
+
self._typing_loss_factor = typing_loss_factor
|
82 |
+
self._max_recursion_depth = max_recursion_depth
|
83 |
+
self.debug = debug
|
84 |
+
|
85 |
+
def forward(
|
86 |
+
self,
|
87 |
+
tokens: Dict[str, Dict[str, torch.Tensor]],
|
88 |
+
|
89 |
+
span_boundary: Optional[torch.Tensor] = None,
|
90 |
+
span_labels: Optional[torch.Tensor] = None,
|
91 |
+
parent_indices: Optional[torch.Tensor] = None,
|
92 |
+
parent_mask: Optional[torch.Tensor] = None,
|
93 |
+
|
94 |
+
bio_seqs: Optional[torch.Tensor] = None,
|
95 |
+
raw_inputs: Optional[dict] = None,
|
96 |
+
meta: Optional[dict] = None,
|
97 |
+
|
98 |
+
**extra
|
99 |
+
) -> Dict[str, torch.Tensor]:
|
100 |
+
"""
|
101 |
+
For training, provide all blow.
|
102 |
+
For inference, it's enough to only provide words.
|
103 |
+
|
104 |
+
:param tokens: Indexed input sentence. Shape: [batch, token]
|
105 |
+
|
106 |
+
:param span_boundary: Start and end indices for every span. Note this includes both parent and
|
107 |
+
non-parent spans. Shape: [batch, span, 2]. For the last dim, [0] is start idx and [1] is end idx.
|
108 |
+
:param span_labels: Indexed label for spans, including parent and non-parent ones. Shape: [batch, span]
|
109 |
+
:param parent_indices: The parent span idx of every span. Shape: [batch, span]
|
110 |
+
:param parent_mask: True if this span is a parent. Shape: [batch, span]
|
111 |
+
|
112 |
+
:param bio_seqs: Shape [batch, parent, token, 3]
|
113 |
+
:param raw_inputs
|
114 |
+
|
115 |
+
:param meta: Meta information. Will be copied to the outputs.
|
116 |
+
|
117 |
+
:return:
|
118 |
+
- loss: training loss
|
119 |
+
- prediction: Predicted spans
|
120 |
+
- meta: Meta info copied from input
|
121 |
+
- inputs: Input sentences and spans (if exist)
|
122 |
+
"""
|
123 |
+
ret = {'inputs': raw_inputs, 'meta': meta or dict()}
|
124 |
+
|
125 |
+
is_eval = span_labels is not None and not self.training # evaluation on dev set
|
126 |
+
is_test = span_labels is None # test on test set
|
127 |
+
# Shape [batch]
|
128 |
+
num_spans = (span_labels != -1).sum(1) if span_labels is not None else None
|
129 |
+
num_words = tokens['pieces']['mask'].sum(1)
|
130 |
+
# Shape [batch, word, token_dim]
|
131 |
+
token_vec = self.word_embedding(tokens)
|
132 |
+
|
133 |
+
if span_labels is not None:
|
134 |
+
# Revise the padding value from -1 to 0
|
135 |
+
span_labels[span_labels == -1] = 0
|
136 |
+
|
137 |
+
# Calculate Loss
|
138 |
+
if self.training or is_eval:
|
139 |
+
# Shape [batch, word, token_dim]
|
140 |
+
span_vec = self._span_extractor(token_vec, span_boundary)
|
141 |
+
finder_rst = self._span_finder(
|
142 |
+
token_vec, num2mask(num_words), span_vec, num2mask(num_spans), span_labels, parent_indices,
|
143 |
+
parent_mask, bio_seqs
|
144 |
+
)
|
145 |
+
typing_rst = self._span_typing(span_vec, parent_indices, span_labels)
|
146 |
+
ret['loss'] = finder_rst['loss'] + typing_rst['loss'] * self._typing_loss_factor
|
147 |
+
|
148 |
+
# Decoding
|
149 |
+
if is_eval or is_test:
|
150 |
+
pred_span_boundary, pred_span_labels, pred_parent_indices, pred_cursor, pred_label_confidence \
|
151 |
+
= self.inference(num_words, token_vec, **extra)
|
152 |
+
prediction = self.post_process_pred(
|
153 |
+
pred_span_boundary, pred_span_labels, pred_parent_indices, pred_cursor, pred_label_confidence
|
154 |
+
)
|
155 |
+
for pred, raw_in in zip(prediction, raw_inputs):
|
156 |
+
pred.re_index(raw_in['offsets'], True, True, True)
|
157 |
+
pred.remove_overlapping()
|
158 |
+
ret['prediction'] = prediction
|
159 |
+
if 'spans' in raw_inputs[0]:
|
160 |
+
for pred, raw_in in zip(prediction, raw_inputs):
|
161 |
+
gold = raw_in['spans']
|
162 |
+
for metric in self.metrics:
|
163 |
+
metric(pred, gold)
|
164 |
+
|
165 |
+
return ret
|
166 |
+
|
167 |
+
def inference(
|
168 |
+
self,
|
169 |
+
num_words: torch.Tensor,
|
170 |
+
token_vec: torch.Tensor,
|
171 |
+
**auxiliaries
|
172 |
+
):
|
173 |
+
n_batch = num_words.shape[0]
|
174 |
+
# The decoding results are preserved in the following tensors starting with `pred`
|
175 |
+
# During inference, we completely ignore the arguments defaulted None in the forward method.
|
176 |
+
# The span indexing space is shift to the decoding span space. (since we do not have gold span now)
|
177 |
+
# boundary indices of every predicted span
|
178 |
+
pred_span_boundary = num_words.new_zeros([n_batch, self._max_decoding_spans, 2])
|
179 |
+
# labels (and corresponding confidence) for predicted spans
|
180 |
+
pred_span_labels = num_words.new_full(
|
181 |
+
[n_batch, self._max_decoding_spans], self.vocab.get_token_index(VIRTUAL_ROOT, 'span_label')
|
182 |
+
)
|
183 |
+
pred_label_confidence = num_words.new_zeros([n_batch, self._max_decoding_spans])
|
184 |
+
# label masked as True will be treated as parent in the next round
|
185 |
+
pred_parent_mask = num_words.new_zeros([n_batch, self._max_decoding_spans], dtype=torch.bool)
|
186 |
+
pred_parent_mask[:, 0] = True
|
187 |
+
# parent index (in the span indexing space) for every span
|
188 |
+
pred_parent_indices = num_words.new_zeros([n_batch, self._max_decoding_spans])
|
189 |
+
# what index have we reached for every batch?
|
190 |
+
pred_cursor = num_words.new_ones([n_batch])
|
191 |
+
|
192 |
+
# Pass environment variables to handler. Extra variables will be ignored.
|
193 |
+
# So pass the union of variables that are needed by different modules.
|
194 |
+
span_find_handler = self._span_finder.inference_forward_handler(
|
195 |
+
token_vec, num2mask(num_words), self._span_extractor, **auxiliaries
|
196 |
+
)
|
197 |
+
|
198 |
+
# Every step here is one layer of the tree. It deals with all the parents for the last layer
|
199 |
+
# so there might be 0 to multiple parents for a batch for a single step.
|
200 |
+
for _ in range(self._max_recursion_depth):
|
201 |
+
cursor_before_find = pred_cursor.clone()
|
202 |
+
span_find_handler(
|
203 |
+
pred_span_boundary, pred_span_labels, pred_parent_mask, pred_parent_indices, pred_cursor
|
204 |
+
)
|
205 |
+
# Labels of old spans are re-predicted. It doesn't matter since their results shouldn't change
|
206 |
+
# in theory.
|
207 |
+
span_typing_ret = self._span_typing(
|
208 |
+
self._span_extractor(token_vec, pred_span_boundary), pred_parent_indices, pred_span_labels, True
|
209 |
+
)
|
210 |
+
pred_span_labels = span_typing_ret['prediction']
|
211 |
+
pred_label_confidence = span_typing_ret['label_confidence']
|
212 |
+
pred_span_labels[:, 0] = self.vocab.get_token_index(VIRTUAL_ROOT, 'span_label')
|
213 |
+
pred_parent_mask = (
|
214 |
+
num2mask(cursor_before_find, self._max_decoding_spans) ^ num2mask(pred_cursor,
|
215 |
+
self._max_decoding_spans)
|
216 |
+
)
|
217 |
+
|
218 |
+
# Break the inference loop if 1) all batches reach max span limit OR 2) no parent is predicted
|
219 |
+
# at last step OR 3) max recursion limit is reached (for loop condition)
|
220 |
+
if (pred_cursor == self._max_decoding_spans).all() or pred_parent_mask.sum() == 0:
|
221 |
+
break
|
222 |
+
|
223 |
+
return pred_span_boundary, pred_span_labels, pred_parent_indices, pred_cursor, pred_label_confidence
|
224 |
+
|
225 |
+
def one_step_prediction(
|
226 |
+
self,
|
227 |
+
tokens: Dict[str, Dict[str, torch.Tensor]],
|
228 |
+
parent_boundary: torch.Tensor,
|
229 |
+
parent_labels: torch.Tensor,
|
230 |
+
):
|
231 |
+
"""
|
232 |
+
Single step prediction. Given parent span boundary indices, return the corresponding children spans
|
233 |
+
and their labels.
|
234 |
+
Restriction: Each sentence contain exactly 1 parent.
|
235 |
+
For efficient multi-layer prediction, i.e. given a root, predict the whole tree,
|
236 |
+
refer to the `forward' method.
|
237 |
+
:param tokens: See forward.
|
238 |
+
:param parent_boundary: Pairs of (start_idx, end_idx) for parents. Shape [batch, 2]
|
239 |
+
:param parent_labels: Labels for parents. Shape [batch]
|
240 |
+
Note: If `no_label' is on in span_finder module, this will be ignored.
|
241 |
+
:return:
|
242 |
+
children_boundary: (start_idx, end_idx) for every child span. Padded with (0, 0).
|
243 |
+
Shape [batch, children, 2]
|
244 |
+
children_labels: Label for every child span. Padded with null_idx. Shape [batch, children]
|
245 |
+
num_children: The number of children predicted for parent/batch. Shape [batch]
|
246 |
+
Tips: You can use num2mask method to convert this to bool tensor mask.
|
247 |
+
"""
|
248 |
+
num_words = tokens['pieces']['mask'].sum(1)
|
249 |
+
# Shape [batch, word, token_dim]
|
250 |
+
token_vec = self.word_embedding(tokens)
|
251 |
+
n_batch = token_vec.shape[0]
|
252 |
+
|
253 |
+
# The following variables assumes the parent is the 0-th span, and we let the model
|
254 |
+
# to extend the span list.
|
255 |
+
pred_span_boundary = num_words.new_zeros([n_batch, self._max_decoding_spans, 2])
|
256 |
+
pred_span_boundary[:, 0] = parent_boundary
|
257 |
+
pred_span_labels = num_words.new_full([n_batch, self._max_decoding_spans], self._null_idx)
|
258 |
+
pred_span_labels[:, 0] = parent_labels
|
259 |
+
pred_parent_mask = num_words.new_zeros(pred_span_labels.shape, dtype=torch.bool)
|
260 |
+
pred_parent_mask[:, 0] = True
|
261 |
+
pred_parent_indices = num_words.new_zeros([n_batch, self._max_decoding_spans])
|
262 |
+
# We start from idx 1 since 0 is the parents.
|
263 |
+
pred_cursor = num_words.new_ones([n_batch])
|
264 |
+
|
265 |
+
span_find_handler = self._span_finder.inference_forward_handler(
|
266 |
+
token_vec, num2mask(num_words), self._span_extractor
|
267 |
+
)
|
268 |
+
span_find_handler(
|
269 |
+
pred_span_boundary, pred_span_labels, pred_parent_mask, pred_parent_indices, pred_cursor
|
270 |
+
)
|
271 |
+
typing_out = self._span_typing(
|
272 |
+
self._span_extractor(token_vec, pred_span_boundary), pred_parent_indices, pred_span_labels, True
|
273 |
+
)
|
274 |
+
pred_span_labels = typing_out['prediction']
|
275 |
+
|
276 |
+
# Now remove the parent
|
277 |
+
num_children = pred_cursor - 1
|
278 |
+
max_children = int(num_children.max())
|
279 |
+
children_boundary = pred_span_boundary[:, 1:max_children + 1]
|
280 |
+
children_labels = pred_span_labels[:, 1:max_children + 1]
|
281 |
+
children_distribution = typing_out['distribution'][:, 1:max_children + 1]
|
282 |
+
return children_boundary, children_labels, num_children, children_distribution
|
283 |
+
|
284 |
+
def post_process_pred(
|
285 |
+
self, span_boundary, span_labels, parent_indices, num_spans, label_confidence
|
286 |
+
) -> List[Span]:
|
287 |
+
pred_spans = tensor2span(
|
288 |
+
span_boundary, span_labels, parent_indices, num_spans, label_confidence,
|
289 |
+
self.vocab.get_index_to_token_vocabulary('span_label'),
|
290 |
+
label_ignore=[self._null_idx],
|
291 |
+
)
|
292 |
+
return pred_spans
|
293 |
+
|
294 |
+
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
|
295 |
+
ret = dict()
|
296 |
+
if reset:
|
297 |
+
for metric in self.metrics:
|
298 |
+
ret.update(metric.get_metric(reset))
|
299 |
+
ret.update(self._span_finder.get_metrics(reset))
|
300 |
+
ret.update(self._span_typing.get_metric(reset))
|
301 |
+
return ret
|
302 |
+
|
303 |
+
@classmethod
|
304 |
+
def from_params(
|
305 |
+
cls: Type[T],
|
306 |
+
params: Params,
|
307 |
+
constructor_to_call: Callable[..., T] = None,
|
308 |
+
constructor_to_inspect: Callable[..., T] = None,
|
309 |
+
**extras,
|
310 |
+
) -> T:
|
311 |
+
"""
|
312 |
+
Specify the dependency between modules. E.g. the input dim of a module might depend on the output dim
|
313 |
+
of another module.
|
314 |
+
"""
|
315 |
+
vocab = extras['vocab']
|
316 |
+
word_embedding = pop_and_construct_arg('SpanModel', 'word_embedding', TextFieldEmbedder, None, params, **extras)
|
317 |
+
label_dim, token_emb_dim = params.pop('label_dim'), word_embedding.get_output_dim()
|
318 |
+
span_extractor = pop_and_construct_arg(
|
319 |
+
'SpanModel', 'span_extractor', SpanExtractor, None, params, input_dim=token_emb_dim, **extras
|
320 |
+
)
|
321 |
+
label_embedding = torch.nn.Embedding(vocab.get_vocab_size('span_label'), label_dim)
|
322 |
+
extras['label_emb'] = label_embedding
|
323 |
+
|
324 |
+
if params.get('span_finder').get('type') == 'bio':
|
325 |
+
bio_encoder = Seq2SeqEncoder.from_params(
|
326 |
+
params['span_finder'].pop('bio_encoder'),
|
327 |
+
input_size=span_extractor.get_output_dim() + token_emb_dim + label_dim,
|
328 |
+
input_dim=span_extractor.get_output_dim() + token_emb_dim + label_dim,
|
329 |
+
**extras
|
330 |
+
)
|
331 |
+
extras['span_finder'] = SpanFinder.from_params(
|
332 |
+
params.pop('span_finder'), bio_encoder=bio_encoder, **extras
|
333 |
+
)
|
334 |
+
else:
|
335 |
+
extras['span_finder'] = pop_and_construct_arg(
|
336 |
+
'SpanModel', 'span_finder', SpanFinder, None, params, **extras
|
337 |
+
)
|
338 |
+
extras['span_finder'].label_emb = label_embedding
|
339 |
+
|
340 |
+
if params.get('span_typing').get('type') == 'mlp':
|
341 |
+
extras['span_typing'] = SpanTyping.from_params(
|
342 |
+
params.pop('span_typing'),
|
343 |
+
input_dim=span_extractor.get_output_dim() * 2 + label_dim,
|
344 |
+
n_category=vocab.get_vocab_size('span_label'),
|
345 |
+
label_to_ignore=[
|
346 |
+
vocab.get_token_index(lti, 'span_label')
|
347 |
+
for lti in [DEFAULT_OOV_TOKEN, DEFAULT_PADDING_TOKEN]
|
348 |
+
],
|
349 |
+
**extras
|
350 |
+
)
|
351 |
+
else:
|
352 |
+
extras['span_typing'] = pop_and_construct_arg(
|
353 |
+
'SpanModel', 'span_typing', SpanTyping, None, params, **extras
|
354 |
+
)
|
355 |
+
extras['span_typing'].label_emb = label_embedding
|
356 |
+
|
357 |
+
return super().from_params(
|
358 |
+
params,
|
359 |
+
word_embedding=word_embedding,
|
360 |
+
span_extractor=span_extractor,
|
361 |
+
**extras
|
362 |
+
)
|
sftp/modules/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .span_extractor import ComboSpanExtractor
|
2 |
+
from .span_finder import SpanFinder, BIOSpanFinder
|
3 |
+
from .span_typing import MLPSpanTyping, SpanTyping
|
4 |
+
from .smooth_crf import SmoothCRF
|
sftp/modules/smooth_crf.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from allennlp.modules.conditional_random_field import ConditionalRandomField
|
3 |
+
from allennlp.nn.util import logsumexp
|
4 |
+
from overrides import overrides
|
5 |
+
|
6 |
+
|
7 |
+
class SmoothCRF(ConditionalRandomField):
|
8 |
+
@overrides
|
9 |
+
def forward(self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.Tensor = None):
|
10 |
+
"""
|
11 |
+
|
12 |
+
:param inputs: Shape [batch, token, tag]
|
13 |
+
:param tags: Shape [batch, token] or [batch, token, tag]
|
14 |
+
:param mask: Shape [batch, token]
|
15 |
+
:return:
|
16 |
+
"""
|
17 |
+
if mask is None:
|
18 |
+
mask = tags.new_ones(tags.shape, dtype=torch.bool)
|
19 |
+
mask = mask.to(dtype=torch.bool)
|
20 |
+
if tags.dim() == 2:
|
21 |
+
return super(SmoothCRF, self).forward(inputs, tags, mask)
|
22 |
+
|
23 |
+
# smooth mode
|
24 |
+
log_denominator = self._input_likelihood(inputs, mask)
|
25 |
+
log_numerator = self._smooth_joint_likelihood(inputs, tags, mask)
|
26 |
+
|
27 |
+
return torch.sum(log_numerator - log_denominator)
|
28 |
+
|
29 |
+
def _smooth_joint_likelihood(
|
30 |
+
self, logits: torch.Tensor, soft_tags: torch.Tensor, mask: torch.Tensor
|
31 |
+
) -> torch.Tensor:
|
32 |
+
batch_size, sequence_length, num_tags = logits.size()
|
33 |
+
|
34 |
+
epsilon = 1e-30
|
35 |
+
soft_tags = soft_tags.clone()
|
36 |
+
soft_tags[soft_tags < epsilon] = epsilon
|
37 |
+
|
38 |
+
# Transpose batch size and sequence dimensions
|
39 |
+
mask = mask.transpose(0, 1).contiguous()
|
40 |
+
logits = logits.transpose(0, 1).contiguous()
|
41 |
+
soft_tags = soft_tags.transpose(0, 1).contiguous()
|
42 |
+
|
43 |
+
# Initial alpha is the (batch_size, num_tags) tensor of likelihoods combining the
|
44 |
+
# transitions to the initial states and the logits for the first timestep.
|
45 |
+
if self.include_start_end_transitions:
|
46 |
+
alpha = self.start_transitions.view(1, num_tags) + logits[0] + soft_tags[0].log()
|
47 |
+
else:
|
48 |
+
alpha = logits[0] * soft_tags[0]
|
49 |
+
|
50 |
+
# For each i we compute logits for the transitions from timestep i-1 to timestep i.
|
51 |
+
# We do so in a (batch_size, num_tags, num_tags) tensor where the axes are
|
52 |
+
# (instance, current_tag, next_tag)
|
53 |
+
for i in range(1, sequence_length):
|
54 |
+
# The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis.
|
55 |
+
emit_scores = logits[i].view(batch_size, 1, num_tags)
|
56 |
+
# Transition scores are (current_tag, next_tag) so we broadcast along the instance axis.
|
57 |
+
transition_scores = self.transitions.view(1, num_tags, num_tags)
|
58 |
+
# Alpha is for the current_tag, so we broadcast along the next_tag axis.
|
59 |
+
broadcast_alpha = alpha.view(batch_size, num_tags, 1)
|
60 |
+
|
61 |
+
# Add all the scores together and logexp over the current_tag axis.
|
62 |
+
inner = broadcast_alpha + emit_scores + transition_scores + soft_tags[i].log().unsqueeze(1)
|
63 |
+
|
64 |
+
# In valid positions (mask == True) we want to take the logsumexp over the current_tag dimension
|
65 |
+
# of `inner`. Otherwise (mask == False) we want to retain the previous alpha.
|
66 |
+
alpha = logsumexp(inner, 1) * mask[i].view(batch_size, 1) + alpha * (
|
67 |
+
~mask[i]
|
68 |
+
).view(batch_size, 1)
|
69 |
+
|
70 |
+
# Every sequence needs to end with a transition to the stop_tag.
|
71 |
+
if self.include_start_end_transitions:
|
72 |
+
stops = alpha + self.end_transitions.view(1, num_tags)
|
73 |
+
else:
|
74 |
+
stops = alpha
|
75 |
+
|
76 |
+
# Finally we log_sum_exp along the num_tags dim, result is (batch_size,)
|
77 |
+
return logsumexp(stops)
|
sftp/modules/span_extractor/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .combo import ComboSpanExtractor
|
sftp/modules/span_extractor/combo.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from allennlp.modules.span_extractors import SpanExtractor
|
5 |
+
|
6 |
+
|
7 |
+
@SpanExtractor.register('combo')
|
8 |
+
class ComboSpanExtractor(SpanExtractor):
|
9 |
+
def __init__(self, input_dim: int, sub_extractors: List[SpanExtractor]):
|
10 |
+
super().__init__()
|
11 |
+
self.sub_extractors = sub_extractors
|
12 |
+
for i, sub in enumerate(sub_extractors):
|
13 |
+
self.add_module(f'SpanExtractor-{i+1}', sub)
|
14 |
+
self.input_dim = input_dim
|
15 |
+
|
16 |
+
def get_input_dim(self) -> int:
|
17 |
+
return self.input_dim
|
18 |
+
|
19 |
+
def get_output_dim(self) -> int:
|
20 |
+
return sum([sub.get_output_dim() for sub in self.sub_extractors])
|
21 |
+
|
22 |
+
def forward(
|
23 |
+
self,
|
24 |
+
sequence_tensor: torch.FloatTensor,
|
25 |
+
span_indices: torch.LongTensor,
|
26 |
+
sequence_mask: torch.BoolTensor = None,
|
27 |
+
span_indices_mask: torch.BoolTensor = None,
|
28 |
+
):
|
29 |
+
outputs = [
|
30 |
+
sub(
|
31 |
+
sequence_tensor=sequence_tensor,
|
32 |
+
span_indices=span_indices,
|
33 |
+
span_indices_mask=span_indices_mask
|
34 |
+
) for sub in self.sub_extractors
|
35 |
+
]
|
36 |
+
return torch.cat(outputs, dim=2)
|
sftp/modules/span_finder/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .bio_span_finder import BIOSpanFinder
|
2 |
+
from .span_finder import SpanFinder
|
sftp/modules/span_finder/bio_span_finder.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from allennlp.data.dataset_readers.dataset_utils.span_utils import bio_tags_to_spans
|
5 |
+
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder
|
6 |
+
from allennlp.modules.span_extractors import SpanExtractor
|
7 |
+
from allennlp.training.metrics import FBetaMeasure
|
8 |
+
|
9 |
+
from ..smooth_crf import SmoothCRF
|
10 |
+
from .span_finder import SpanFinder
|
11 |
+
from ...utils import num2mask, mask2idx, BIO
|
12 |
+
|
13 |
+
|
14 |
+
@SpanFinder.register("bio")
|
15 |
+
class BIOSpanFinder(SpanFinder):
|
16 |
+
"""
|
17 |
+
Train BIO representations for span finding.
|
18 |
+
"""
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
bio_encoder: Seq2SeqEncoder,
|
22 |
+
label_emb: torch.nn.Embedding,
|
23 |
+
no_label: bool = True,
|
24 |
+
):
|
25 |
+
super().__init__(no_label)
|
26 |
+
self.bio_encoder = bio_encoder
|
27 |
+
self.label_emb = label_emb
|
28 |
+
|
29 |
+
self.classifier = torch.nn.Linear(bio_encoder.get_output_dim(), 3)
|
30 |
+
self.crf = SmoothCRF(3)
|
31 |
+
|
32 |
+
self.fb_measure = FBetaMeasure(1., 'micro', [BIO.index('B'), BIO.index('I')])
|
33 |
+
|
34 |
+
def forward(
|
35 |
+
self,
|
36 |
+
token_vec: torch.Tensor,
|
37 |
+
token_mask: torch.Tensor,
|
38 |
+
span_vec: torch.Tensor,
|
39 |
+
span_mask: Optional[torch.Tensor] = None, # Do not need to provide
|
40 |
+
span_labels: Optional[torch.Tensor] = None, # Do not need to provide
|
41 |
+
parent_indices: Optional[torch.Tensor] = None, # Do not need to provide
|
42 |
+
parent_mask: Optional[torch.Tensor] = None,
|
43 |
+
bio_seqs: Optional[torch.Tensor] = None,
|
44 |
+
prediction: bool = False,
|
45 |
+
**extra
|
46 |
+
) -> Dict[str, torch.Tensor]:
|
47 |
+
"""
|
48 |
+
See doc of SpanFinder.
|
49 |
+
Possible extra variables:
|
50 |
+
smoothing_factor
|
51 |
+
:return:
|
52 |
+
- loss
|
53 |
+
- prediction
|
54 |
+
"""
|
55 |
+
ret = dict()
|
56 |
+
is_soft = span_labels.dtype != torch.int64
|
57 |
+
|
58 |
+
distinct_parent_indices, num_parents = mask2idx(parent_mask)
|
59 |
+
n_batch, n_parent = distinct_parent_indices.shape
|
60 |
+
n_token = token_vec.shape[1]
|
61 |
+
# Shape [batch, parent, token_dim]
|
62 |
+
parent_span_features = span_vec.gather(
|
63 |
+
1, distinct_parent_indices.unsqueeze(2).expand(-1, -1, span_vec.shape[2])
|
64 |
+
)
|
65 |
+
label_features = span_labels @ self.label_emb.weight if is_soft else self.label_emb(span_labels)
|
66 |
+
if self._no_label:
|
67 |
+
label_features = label_features.zero_()
|
68 |
+
# Shape [batch, span, label_dim]
|
69 |
+
parent_label_features = label_features.gather(
|
70 |
+
1, distinct_parent_indices.unsqueeze(2).expand(-1, -1, label_features.shape[2])
|
71 |
+
)
|
72 |
+
# Shape [batch, parent, token, token_dim*2]
|
73 |
+
encoder_inputs = torch.cat([
|
74 |
+
parent_span_features.unsqueeze(2).expand(-1, -1, n_token, -1),
|
75 |
+
token_vec.unsqueeze(1).expand(-1, n_parent, -1, -1),
|
76 |
+
parent_label_features.unsqueeze(2).expand(-1, -1, n_token, -1),
|
77 |
+
], dim=3)
|
78 |
+
encoder_inputs = encoder_inputs.reshape(n_batch * n_parent, n_token, -1)
|
79 |
+
|
80 |
+
# Shape [batch, parent]. Considers batches may have fewer seqs.
|
81 |
+
seq_mask = num2mask(num_parents)
|
82 |
+
# Shape [batch, parent, token]. Also considers batches may have fewer tokens.
|
83 |
+
token_mask = seq_mask.unsqueeze(2).expand(-1, -1, n_token) & token_mask.unsqueeze(1).expand(-1, n_parent, -1)
|
84 |
+
|
85 |
+
class_in = self.bio_encoder(encoder_inputs, token_mask.flatten(0, 1))
|
86 |
+
class_out = self.classifier(class_in).reshape(n_batch, n_parent, n_token, 3)
|
87 |
+
|
88 |
+
if not prediction:
|
89 |
+
# For training
|
90 |
+
# We use `seq_mask` here because seq with length 0 is not acceptable.
|
91 |
+
ret['loss'] = -self.crf(class_out[seq_mask], bio_seqs[seq_mask], token_mask[seq_mask])
|
92 |
+
self.fb_measure(class_out[seq_mask], bio_seqs[seq_mask].max(2).indices, token_mask[seq_mask])
|
93 |
+
else:
|
94 |
+
# For prediction
|
95 |
+
features_for_decode = class_out.clone().detach()
|
96 |
+
decoded = self.crf.viterbi_tags(features_for_decode.flatten(0, 1), token_mask.flatten(0, 1))
|
97 |
+
pred_tag = torch.tensor(
|
98 |
+
[path + [BIO.index('O')] * (n_token - len(path)) for path, _ in decoded]
|
99 |
+
)
|
100 |
+
pred_tag = pred_tag.reshape(n_batch, n_parent, n_token)
|
101 |
+
ret['prediction'] = pred_tag
|
102 |
+
|
103 |
+
return ret
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def bio2boundary(seqs) -> Tuple[torch.Tensor, torch.Tensor]:
|
107 |
+
def recursive_construct_spans(seqs_):
|
108 |
+
"""
|
109 |
+
Helper function for bio2boundary
|
110 |
+
Recursively convert seqs of integers to boundary indices.
|
111 |
+
Return boundary indices and corresponding lens
|
112 |
+
"""
|
113 |
+
if isinstance(seqs_, torch.Tensor):
|
114 |
+
if seqs_.device.type == 'cuda':
|
115 |
+
seqs_ = seqs_.to(device='cpu')
|
116 |
+
seqs_ = seqs_.tolist()
|
117 |
+
if isinstance(seqs_[0], int):
|
118 |
+
seqs_ = [BIO[i] for i in seqs_]
|
119 |
+
span_boundary_list = bio_tags_to_spans(seqs_)
|
120 |
+
return torch.tensor([item[1] for item in span_boundary_list]), len(span_boundary_list)
|
121 |
+
span_boundary = list()
|
122 |
+
lens_ = list()
|
123 |
+
for seq in seqs_:
|
124 |
+
one_bou, one_len = recursive_construct_spans(seq)
|
125 |
+
span_boundary.append(one_bou)
|
126 |
+
lens_.append(one_len)
|
127 |
+
if isinstance(lens_[0], int):
|
128 |
+
lens_ = torch.tensor(lens_)
|
129 |
+
else:
|
130 |
+
lens_ = torch.stack(lens_)
|
131 |
+
return span_boundary, lens_
|
132 |
+
|
133 |
+
boundary_list, lens = recursive_construct_spans(seqs)
|
134 |
+
max_span = int(lens.max())
|
135 |
+
boundary = torch.zeros((*lens.shape, max_span, 2), dtype=torch.long)
|
136 |
+
|
137 |
+
def recursive_copy(list_var, tensor_var):
|
138 |
+
if len(list_var) == 0:
|
139 |
+
return
|
140 |
+
if isinstance(list_var, torch.Tensor):
|
141 |
+
tensor_var[:len(list_var)] = list_var
|
142 |
+
return
|
143 |
+
assert len(list_var) == len(tensor_var)
|
144 |
+
for list_var_, tensor_var_ in zip(list_var, tensor_var):
|
145 |
+
recursive_copy(list_var_, tensor_var_)
|
146 |
+
|
147 |
+
recursive_copy(boundary_list, boundary)
|
148 |
+
|
149 |
+
return boundary, lens
|
150 |
+
|
151 |
+
def inference_forward_handler(
|
152 |
+
self,
|
153 |
+
token_vec: torch.Tensor,
|
154 |
+
token_mask: torch.Tensor,
|
155 |
+
span_extractor: SpanExtractor,
|
156 |
+
**auxiliaries,
|
157 |
+
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], None]:
|
158 |
+
"""
|
159 |
+
Refer to the doc of the SpanFinder for definition of this function.
|
160 |
+
"""
|
161 |
+
|
162 |
+
def handler(
|
163 |
+
span_boundary: torch.Tensor,
|
164 |
+
span_labels: torch.Tensor,
|
165 |
+
parent_mask: torch.Tensor,
|
166 |
+
parent_indices: torch.Tensor,
|
167 |
+
cursor: torch.tensor,
|
168 |
+
):
|
169 |
+
"""
|
170 |
+
Refer to the doc of the SpanFinder for definition of this function.
|
171 |
+
"""
|
172 |
+
max_decoding_span = span_boundary.shape[1]
|
173 |
+
# Shape [batch, span, token_dim]
|
174 |
+
span_vec = span_extractor(token_vec, span_boundary)
|
175 |
+
# Shape [batch, parent]
|
176 |
+
parent_indices_at_span, _ = mask2idx(parent_mask)
|
177 |
+
pred_bio = self(
|
178 |
+
token_vec, token_mask, span_vec, None, span_labels, None, parent_mask, prediction=True
|
179 |
+
)['prediction']
|
180 |
+
# Shape [batch, parent, span, 2]; Shape [batch, parent]
|
181 |
+
pred_boundary, pred_num = self.bio2boundary(pred_bio)
|
182 |
+
if pred_boundary.device != span_boundary.device:
|
183 |
+
pred_boundary = pred_boundary.to(device=span_boundary.device)
|
184 |
+
pred_num = pred_num.to(device=span_boundary.device)
|
185 |
+
# Shape [batch, parent, span]
|
186 |
+
pred_mask = num2mask(pred_num)
|
187 |
+
|
188 |
+
# Parent Loop
|
189 |
+
for pred_boundary_parent, pred_mask_parent, parent_indices_parent \
|
190 |
+
in zip(pred_boundary.unbind(1), pred_mask.unbind(1), parent_indices_at_span.unbind(1)):
|
191 |
+
for pred_boundary_step, step_mask in zip(pred_boundary_parent.unbind(1), pred_mask_parent.unbind(1)):
|
192 |
+
step_mask &= cursor < max_decoding_span
|
193 |
+
parent_indices[step_mask] = parent_indices[step_mask].scatter(
|
194 |
+
1,
|
195 |
+
cursor[step_mask].unsqueeze(1),
|
196 |
+
parent_indices_parent[step_mask].unsqueeze(1)
|
197 |
+
)
|
198 |
+
span_boundary[step_mask] = span_boundary[step_mask].scatter(
|
199 |
+
1,
|
200 |
+
cursor[step_mask].reshape(-1, 1, 1).expand(-1, -1, 2),
|
201 |
+
pred_boundary_step[step_mask].unsqueeze(1)
|
202 |
+
)
|
203 |
+
cursor[step_mask] += 1
|
204 |
+
|
205 |
+
return handler
|
206 |
+
|
207 |
+
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
|
208 |
+
score = self.fb_measure.get_metric(reset)
|
209 |
+
if reset:
|
210 |
+
return {
|
211 |
+
'finder_p': score['precision'] * 100,
|
212 |
+
'finder_r': score['recall'] * 100,
|
213 |
+
'finder_f': score['fscore'] * 100,
|
214 |
+
}
|
215 |
+
else:
|
216 |
+
return {'finder_f': score['fscore'] * 100}
|
sftp/modules/span_finder/span_finder.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import *
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from allennlp.common import Registrable
|
6 |
+
from allennlp.modules.span_extractors import SpanExtractor
|
7 |
+
|
8 |
+
|
9 |
+
class SpanFinder(Registrable, ABC, torch.nn.Module):
|
10 |
+
"""
|
11 |
+
Model the probability p(child_span | parent_span [, parent_label])
|
12 |
+
It's optional to model parent_label, since in some cases we may want the parameters to be shared across
|
13 |
+
different tasks, where we may have similar span semantics but different label space.
|
14 |
+
"""
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
no_label: bool = True,
|
18 |
+
):
|
19 |
+
"""
|
20 |
+
:param no_label: If True, will not use input labels as features and use all 0 vector instead.
|
21 |
+
"""
|
22 |
+
super().__init__()
|
23 |
+
self._no_label = no_label
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def forward(
|
27 |
+
self,
|
28 |
+
token_vec: torch.Tensor,
|
29 |
+
token_mask: torch.Tensor,
|
30 |
+
span_vec: torch.Tensor,
|
31 |
+
span_mask: Optional[torch.Tensor] = None, # Do not need to provide
|
32 |
+
span_labels: Optional[torch.Tensor] = None, # Do not need to provide
|
33 |
+
parent_indices: Optional[torch.Tensor] = None, # Do not need to provide
|
34 |
+
parent_mask: Optional[torch.Tensor] = None,
|
35 |
+
bio_seqs: Optional[torch.Tensor] = None,
|
36 |
+
prediction: bool = False,
|
37 |
+
**extra
|
38 |
+
) -> Dict[str, torch.Tensor]:
|
39 |
+
"""
|
40 |
+
Return training loss and predictions.
|
41 |
+
:param token_vec: Vector representation of tokens. Shape [batch, token ,token_dim]
|
42 |
+
:param token_mask: True for non-padding tokens.
|
43 |
+
:param span_vec: Vector representation of spans. Shape [batch, span, token_dim]
|
44 |
+
:param span_mask: True for non-padding spans. Shape [batch, span]
|
45 |
+
:param span_labels: The labels of spans. Shape [batch, span]
|
46 |
+
:param parent_indices: Parent indices of spans. Shape [batch, span]
|
47 |
+
:param parent_mask: True for parent spans. Shape [batch, span]
|
48 |
+
:param prediction: If True, no loss will be return & no metrics will be updated.
|
49 |
+
:param bio_seqs: BIO sequences. Shape [batch, parent, token, 3]
|
50 |
+
:return:
|
51 |
+
loss: Training loss
|
52 |
+
prediction: Shape [batch, span]. True for positive predictions.
|
53 |
+
"""
|
54 |
+
raise NotImplementedError
|
55 |
+
|
56 |
+
@abstractmethod
|
57 |
+
def inference_forward_handler(
|
58 |
+
self,
|
59 |
+
token_vec: torch.Tensor,
|
60 |
+
token_mask: torch.Tensor,
|
61 |
+
span_extractor: SpanExtractor,
|
62 |
+
**auxiliaries,
|
63 |
+
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], None]:
|
64 |
+
"""
|
65 |
+
Pre-process some information and return a callable module for p(child_span | parent_span [,parent_label])
|
66 |
+
:param token_vec: Vector representation of tokens. Shape [batch, token ,token_dim]
|
67 |
+
:param token_mask: True for non-padding tokens.
|
68 |
+
:param span_extractor: The same module in model.
|
69 |
+
:param auxiliaries: Environment variables. You can pass extra environment variables
|
70 |
+
since the extras will be ignored.
|
71 |
+
:return:
|
72 |
+
A callable function in a closure.
|
73 |
+
The arguments for the callable object are:
|
74 |
+
- span_boundary: Shape [batch, span, 2]
|
75 |
+
- span_labels: Shape [batch, span]
|
76 |
+
- parent_mask: Shape [batch, span]
|
77 |
+
- parent_indices: Shape [batch, span]
|
78 |
+
- cursor: Shape [batch]
|
79 |
+
No return values. Everything should be done inplace.
|
80 |
+
Note the span indexing space has different meaning from training process. We don't have gold span list,
|
81 |
+
so span here refers to the predicted spans.
|
82 |
+
"""
|
83 |
+
raise NotImplementedError
|
84 |
+
|
85 |
+
@abstractmethod
|
86 |
+
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
|
87 |
+
raise NotImplementedError
|
sftp/modules/span_typing/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .mlp_span_typing import MLPSpanTyping
|
2 |
+
from .span_typing import SpanTyping
|
sftp/modules/span_typing/mlp_span_typing.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax
|
5 |
+
|
6 |
+
from .span_typing import SpanTyping
|
7 |
+
|
8 |
+
|
9 |
+
@SpanTyping.register('mlp')
|
10 |
+
class MLPSpanTyping(SpanTyping):
|
11 |
+
"""
|
12 |
+
An MLP implementation for Span Typing.
|
13 |
+
"""
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
input_dim: int,
|
17 |
+
hidden_dims: List[int],
|
18 |
+
label_emb: torch.nn.Embedding,
|
19 |
+
n_category: int,
|
20 |
+
label_to_ignore: Optional[List[int]] = None
|
21 |
+
):
|
22 |
+
"""
|
23 |
+
:param input_dim: dim(parent_span) + dim(child_span) + dim(label_dim)
|
24 |
+
:param hidden_dims: The dim of hidden layers of MLP.
|
25 |
+
:param n_category: #labels
|
26 |
+
:param label_emb: Embeds labels to vectors.
|
27 |
+
"""
|
28 |
+
super().__init__(label_emb.num_embeddings, label_to_ignore, )
|
29 |
+
self.MLPs: List[torch.nn.Linear] = list()
|
30 |
+
for i_mlp, output_dim in enumerate(hidden_dims + [n_category]):
|
31 |
+
mlp = torch.nn.Linear(input_dim, output_dim, bias=True)
|
32 |
+
self.MLPs.append(mlp)
|
33 |
+
self.add_module(f'MLP-{i_mlp}', mlp)
|
34 |
+
input_dim = output_dim
|
35 |
+
|
36 |
+
# Embeds labels as features.
|
37 |
+
self.label_emb = label_emb
|
38 |
+
|
39 |
+
def forward(
|
40 |
+
self,
|
41 |
+
span_vec: torch.Tensor,
|
42 |
+
parent_at_span: torch.Tensor,
|
43 |
+
span_labels: Optional[torch.Tensor],
|
44 |
+
prediction_only: bool = False,
|
45 |
+
) -> Dict[str, torch.Tensor]:
|
46 |
+
"""
|
47 |
+
Inputs: All features for typing a child span.
|
48 |
+
Process: Update the metric.
|
49 |
+
Output: The loss of typing and predictions.
|
50 |
+
:return:
|
51 |
+
loss: Loss for label prediction.
|
52 |
+
prediction: Predicted labels.
|
53 |
+
"""
|
54 |
+
is_soft = span_labels.dtype != torch.int64
|
55 |
+
# Shape [batch, span, label_dim]
|
56 |
+
label_vec = span_labels @ self.label_emb.weight if is_soft else self.label_emb(span_labels)
|
57 |
+
n_batch, n_span, _ = label_vec.shape
|
58 |
+
n_label, _ = self.ontology.shape
|
59 |
+
# Shape [batch, span, label_dim]
|
60 |
+
parent_label_features = label_vec.gather(1, parent_at_span.unsqueeze(2).expand_as(label_vec))
|
61 |
+
# Shape [batch, span, token_dim]
|
62 |
+
parent_span_features = span_vec.gather(1, parent_at_span.unsqueeze(2).expand_as(span_vec))
|
63 |
+
# Shape [batch, span, token_dim]
|
64 |
+
child_span_features = span_vec
|
65 |
+
|
66 |
+
features = torch.cat([parent_label_features, parent_span_features, child_span_features], dim=2)
|
67 |
+
# Shape [batch, span, label]
|
68 |
+
for mlp in self.MLPs[:-1]:
|
69 |
+
features = torch.relu(mlp(features))
|
70 |
+
logits = self.MLPs[-1](features)
|
71 |
+
|
72 |
+
logits_for_prediction = logits.clone()
|
73 |
+
|
74 |
+
if not is_soft:
|
75 |
+
# Shape [batch, span]
|
76 |
+
parent_labels = span_labels.gather(1, parent_at_span)
|
77 |
+
onto_mask = self.ontology.unsqueeze(0).expand(n_batch, -1, -1).gather(
|
78 |
+
1, parent_labels.unsqueeze(2).expand(-1, -1, n_label)
|
79 |
+
)
|
80 |
+
logits_for_prediction[~onto_mask] = float('-inf')
|
81 |
+
|
82 |
+
label_dist = torch.softmax(logits_for_prediction, 2)
|
83 |
+
label_confidence, predictions = label_dist.max(2)
|
84 |
+
ret = {'prediction': predictions, 'label_confidence': label_confidence, 'distribution': label_dist}
|
85 |
+
if prediction_only:
|
86 |
+
return ret
|
87 |
+
|
88 |
+
span_labels = span_labels.clone()
|
89 |
+
|
90 |
+
if is_soft:
|
91 |
+
self.acc_metric(logits_for_prediction, span_labels.max(2)[1], ~span_labels.sum(2).isclose(torch.tensor(0.)))
|
92 |
+
ret['loss'] = KLDivLoss(reduction='sum')(LogSoftmax(dim=2)(logits), span_labels)
|
93 |
+
else:
|
94 |
+
for label_idx in self.label_to_ignore:
|
95 |
+
span_labels[span_labels == label_idx] = -100
|
96 |
+
self.acc_metric(logits_for_prediction, span_labels, span_labels != -100)
|
97 |
+
ret['loss'] = CrossEntropyLoss(reduction='sum')(logits.flatten(0, 1), span_labels.flatten())
|
98 |
+
|
99 |
+
return ret
|
sftp/modules/span_typing/span_typing.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
from typing import *
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from allennlp.common import Registrable
|
6 |
+
from allennlp.data.vocabulary import DEFAULT_OOV_TOKEN, Vocabulary
|
7 |
+
from allennlp.training.metrics import CategoricalAccuracy
|
8 |
+
|
9 |
+
|
10 |
+
class SpanTyping(Registrable, torch.nn.Module, ABC):
|
11 |
+
"""
|
12 |
+
Models the probability p(child_label | child_span, parent_span, parent_label).
|
13 |
+
"""
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
n_label: int,
|
17 |
+
label_to_ignore: Optional[List[int]] = None,
|
18 |
+
):
|
19 |
+
"""
|
20 |
+
:param label_to_ignore: Label indexes in this list will be ignored.
|
21 |
+
Usually this should include NULL, PADDING and UNKNOWN.
|
22 |
+
"""
|
23 |
+
super().__init__()
|
24 |
+
self.label_to_ignore = label_to_ignore or list()
|
25 |
+
self.acc_metric = CategoricalAccuracy()
|
26 |
+
self.onto = torch.ones([n_label, n_label], dtype=torch.bool)
|
27 |
+
self.register_buffer('ontology', self.onto)
|
28 |
+
|
29 |
+
def load_ontology(self, path: str, vocab: Vocabulary):
|
30 |
+
unk_id = vocab.get_token_index(DEFAULT_OOV_TOKEN, 'span_label')
|
31 |
+
for line in open(path).readlines():
|
32 |
+
entities = [vocab.get_token_index(ent, 'span_label') for ent in line.replace('\n', '').split('\t')]
|
33 |
+
parent, children = entities[0], entities[1:]
|
34 |
+
if parent == unk_id:
|
35 |
+
continue
|
36 |
+
self.onto[parent, :] = False
|
37 |
+
children = list(filter(lambda x: x != unk_id, children))
|
38 |
+
self.onto[parent, children] = True
|
39 |
+
self.register_buffer('ontology', self.onto)
|
40 |
+
|
41 |
+
def forward(
|
42 |
+
self,
|
43 |
+
span_vec: torch.Tensor,
|
44 |
+
parent_at_span: torch.Tensor,
|
45 |
+
span_labels: Optional[torch.Tensor],
|
46 |
+
prediction_only: bool = False,
|
47 |
+
) -> Dict[str, torch.Tensor]:
|
48 |
+
"""
|
49 |
+
Inputs: All features for typing a child span.
|
50 |
+
Output: The loss of typing and predictions.
|
51 |
+
:param span_vec: Shape [batch, span, token_dim]
|
52 |
+
:param parent_at_span: Shape [batch, span]
|
53 |
+
:param span_labels: Shape [batch, span]
|
54 |
+
:param prediction_only: If True, no loss returned & metric will not be updated
|
55 |
+
:return:
|
56 |
+
loss: Loss for label prediction. (absent of pred_only = True)
|
57 |
+
prediction: Predicted labels.
|
58 |
+
"""
|
59 |
+
raise NotImplementedError
|
60 |
+
|
61 |
+
def get_metric(self, reset):
|
62 |
+
return{
|
63 |
+
"typing_acc": self.acc_metric.get_metric(reset) * 100
|
64 |
+
}
|
sftp/predictor/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .span_predictor import SpanPredictor
|
sftp/predictor/span_predictor.orig.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from time import time
|
3 |
+
from typing import *
|
4 |
+
import json
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from allennlp.common.util import JsonDict, sanitize
|
9 |
+
from allennlp.data import DatasetReader, Instance
|
10 |
+
from allennlp.data.data_loaders import SimpleDataLoader
|
11 |
+
from allennlp.data.samplers import MaxTokensBatchSampler
|
12 |
+
from allennlp.data.tokenizers import SpacyTokenizer
|
13 |
+
from allennlp.models import Model
|
14 |
+
from allennlp.nn import util as nn_util
|
15 |
+
from allennlp.predictors import Predictor
|
16 |
+
from concrete import (
|
17 |
+
MentionArgument, SituationMentionSet, SituationMention, TokenRefSequence,
|
18 |
+
EntityMention, EntityMentionSet, Entity, EntitySet, AnnotationMetadata, Communication
|
19 |
+
)
|
20 |
+
from concrete.util import CommunicationReader, AnalyticUUIDGeneratorFactory, CommunicationWriterZip
|
21 |
+
from concrete.validate import validate_communication
|
22 |
+
|
23 |
+
from ..data_reader import concrete_doc, concrete_doc_tokenized
|
24 |
+
from ..utils import Span, re_index_span, VIRTUAL_ROOT
|
25 |
+
|
26 |
+
|
27 |
+
class PredictionReturn(NamedTuple):
|
28 |
+
span: Union[Span, dict, Communication]
|
29 |
+
sentence: List[str]
|
30 |
+
meta: Dict[str, Any]
|
31 |
+
|
32 |
+
|
33 |
+
class ForceDecodingReturn(NamedTuple):
|
34 |
+
span: np.ndarray
|
35 |
+
label: List[str]
|
36 |
+
distribution: np.ndarray
|
37 |
+
|
38 |
+
|
39 |
+
@Predictor.register('span')
|
40 |
+
class SpanPredictor(Predictor):
|
41 |
+
@staticmethod
|
42 |
+
def format_convert(
|
43 |
+
sentence: Union[List[str], List[List[str]]],
|
44 |
+
prediction: Union[Span, List[Span]],
|
45 |
+
output_format: str
|
46 |
+
):
|
47 |
+
if output_format == 'span':
|
48 |
+
return prediction
|
49 |
+
elif output_format == 'json':
|
50 |
+
if isinstance(prediction, list):
|
51 |
+
return [SpanPredictor.format_convert(sent, pred, 'json') for sent, pred in zip(sentence, prediction)]
|
52 |
+
return prediction.to_json()
|
53 |
+
elif output_format == 'concrete':
|
54 |
+
if isinstance(prediction, Span):
|
55 |
+
sentence, prediction = [sentence], [prediction]
|
56 |
+
return concrete_doc_tokenized(sentence, prediction)
|
57 |
+
|
58 |
+
def predict_concrete(
|
59 |
+
self,
|
60 |
+
concrete_path: str,
|
61 |
+
output_path: Optional[str] = None,
|
62 |
+
max_tokens: int = 2048,
|
63 |
+
ontology_mapping: Optional[Dict[str, str]] = None,
|
64 |
+
):
|
65 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
66 |
+
writer = CommunicationWriterZip(output_path)
|
67 |
+
|
68 |
+
for comm, fn in CommunicationReader(concrete_path):
|
69 |
+
assert len(comm.sectionList) == 1
|
70 |
+
concrete_sentences = comm.sectionList[0].sentenceList
|
71 |
+
json_sentences = list()
|
72 |
+
for con_sent in concrete_sentences:
|
73 |
+
json_sentences.append(
|
74 |
+
[t.text for t in con_sent.tokenization.tokenList.tokenList]
|
75 |
+
)
|
76 |
+
predictions = self.predict_batch_sentences(json_sentences, max_tokens, ontology_mapping=ontology_mapping)
|
77 |
+
|
78 |
+
# Merge predictions into concrete
|
79 |
+
aug = AnalyticUUIDGeneratorFactory(comm).create()
|
80 |
+
situation_mention_set = SituationMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list())
|
81 |
+
comm.situationMentionSetList = [situation_mention_set]
|
82 |
+
situation_mention_set.mentionList = sm_list = list()
|
83 |
+
entity_mention_set = EntityMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list())
|
84 |
+
comm.entityMentionSetList = [entity_mention_set]
|
85 |
+
entity_mention_set.mentionList = em_list = list()
|
86 |
+
entity_set = EntitySet(
|
87 |
+
next(aug), AnnotationMetadata('Span Finder', time()), list(), None, entity_mention_set.uuid
|
88 |
+
)
|
89 |
+
comm.entitySetList = [entity_set]
|
90 |
+
|
91 |
+
em_dict = dict()
|
92 |
+
for con_sent, pred in zip(concrete_sentences, predictions):
|
93 |
+
for event in pred.span:
|
94 |
+
def raw_text_span(start_idx, end_idx, **_):
|
95 |
+
si_char = con_sent.tokenization.tokenList.tokenList[start_idx].textSpan.start
|
96 |
+
ei_char = con_sent.tokenization.tokenList.tokenList[end_idx].textSpan.ending
|
97 |
+
return comm.text[si_char:ei_char]
|
98 |
+
sm = SituationMention(
|
99 |
+
next(aug),
|
100 |
+
text=raw_text_span(event.start_idx, event.end_idx),
|
101 |
+
situationKind=event.label,
|
102 |
+
situationType='EVENT',
|
103 |
+
confidence=event.confidence,
|
104 |
+
argumentList=list(),
|
105 |
+
tokens=TokenRefSequence(
|
106 |
+
tokenIndexList=list(range(event.start_idx, event.end_idx+1)),
|
107 |
+
tokenizationId=con_sent.tokenization.uuid
|
108 |
+
)
|
109 |
+
)
|
110 |
+
|
111 |
+
for arg in event:
|
112 |
+
em = em_dict.get((arg.start_idx, arg.end_idx + 1))
|
113 |
+
if em is None:
|
114 |
+
em = EntityMention(
|
115 |
+
next(aug),
|
116 |
+
tokens=TokenRefSequence(
|
117 |
+
tokenIndexList=list(range(arg.start_idx, arg.end_idx+1)),
|
118 |
+
tokenizationId=con_sent.tokenization.uuid,
|
119 |
+
),
|
120 |
+
text=raw_text_span(arg.start_idx, arg.end_idx)
|
121 |
+
)
|
122 |
+
em_list.append(em)
|
123 |
+
entity_set.entityList.append(Entity(next(aug), id=em.text, mentionIdList=[em.uuid]))
|
124 |
+
em_dict[(arg.start_idx, arg.end_idx+1)] = em
|
125 |
+
sm.argumentList.append(MentionArgument(
|
126 |
+
role=arg.label,
|
127 |
+
entityMentionId=em.uuid,
|
128 |
+
confidence=arg.confidence
|
129 |
+
))
|
130 |
+
sm_list.append(sm)
|
131 |
+
validate_communication(comm)
|
132 |
+
writer.write(comm, fn)
|
133 |
+
writer.close()
|
134 |
+
|
135 |
+
def predict_sentence(
|
136 |
+
self,
|
137 |
+
sentence: Union[str, List[str]],
|
138 |
+
ontology_mapping: Optional[Dict[str, str]] = None,
|
139 |
+
output_format: str = 'span',
|
140 |
+
) -> PredictionReturn:
|
141 |
+
"""
|
142 |
+
Predict spans on a single sentence (no batch). If not tokenized, will tokenize it with SpacyTokenizer.
|
143 |
+
:param sentence: If tokenized, should be a list of tokens in string. If not, should be a string.
|
144 |
+
:param ontology_mapping:
|
145 |
+
:param output_format: span, json or concrete.
|
146 |
+
"""
|
147 |
+
prediction = self.predict_json(self._prepare_sentence(sentence))
|
148 |
+
prediction['prediction'] = self.format_convert(
|
149 |
+
prediction['sentence'],
|
150 |
+
Span.from_json(prediction['prediction']).map_ontology(ontology_mapping),
|
151 |
+
output_format
|
152 |
+
)
|
153 |
+
return PredictionReturn(prediction['prediction'], prediction['sentence'], prediction.get('meta', dict()))
|
154 |
+
|
155 |
+
def predict_batch_sentences(
|
156 |
+
self,
|
157 |
+
sentences: List[Union[List[str], str]],
|
158 |
+
max_tokens: int = 512,
|
159 |
+
ontology_mapping: Optional[Dict[str, str]] = None,
|
160 |
+
output_format: str = 'span',
|
161 |
+
) -> List[PredictionReturn]:
|
162 |
+
"""
|
163 |
+
Predict spans on a batch of sentences. If not tokenized, will tokenize it with SpacyTokenizer.
|
164 |
+
:param sentences: A list of sentences. Refer to `predict_sentence`.
|
165 |
+
:param max_tokens: Maximum tokens in a batch.
|
166 |
+
:param ontology_mapping: If not None, will try to map the output from one ontology to another.
|
167 |
+
If the predicted frame is not in the mapping, the prediction will be ignored.
|
168 |
+
:param output_format: span, json or concrete.
|
169 |
+
:return: A list of predictions.
|
170 |
+
"""
|
171 |
+
sentences = list(map(self._prepare_sentence, sentences))
|
172 |
+
for i_sent, sent in enumerate(sentences):
|
173 |
+
sent['meta'] = {"idx": i_sent}
|
174 |
+
instances = list(map(self._json_to_instance, sentences))
|
175 |
+
outputs = list()
|
176 |
+
for ins_indices in MaxTokensBatchSampler(max_tokens, ["tokens"], 0.0).get_batch_indices(instances):
|
177 |
+
batch_ins = list(
|
178 |
+
SimpleDataLoader([instances[ins_idx] for ins_idx in ins_indices], len(ins_indices), vocab=self.vocab)
|
179 |
+
)[0]
|
180 |
+
batch_inputs = nn_util.move_to_device(batch_ins, device=self.cuda_device)
|
181 |
+
batch_outputs = self._model(**batch_inputs)
|
182 |
+
for meta, prediction, inputs in zip(
|
183 |
+
batch_outputs['meta'], batch_outputs['prediction'], batch_outputs['inputs']
|
184 |
+
):
|
185 |
+
prediction.map_ontology(ontology_mapping)
|
186 |
+
prediction = self.format_convert(inputs['sentence'], prediction, output_format)
|
187 |
+
outputs.append(PredictionReturn(prediction, inputs['sentence'], {"input_idx": meta['idx']}))
|
188 |
+
|
189 |
+
outputs.sort(key=lambda x: x.meta['input_idx'])
|
190 |
+
return outputs
|
191 |
+
|
192 |
+
def predict_instance(self, instance: Instance) -> JsonDict:
|
193 |
+
outputs = self._model.forward_on_instance(instance)
|
194 |
+
outputs = sanitize(outputs)
|
195 |
+
return {
|
196 |
+
'prediction': outputs['prediction'],
|
197 |
+
'sentence': outputs['inputs']['sentence'],
|
198 |
+
'meta': outputs.get('meta', {})
|
199 |
+
}
|
200 |
+
|
201 |
+
def __init__(
|
202 |
+
self,
|
203 |
+
model: Model,
|
204 |
+
dataset_reader: DatasetReader,
|
205 |
+
frozen: bool = True,
|
206 |
+
):
|
207 |
+
super(SpanPredictor, self).__init__(model=model, dataset_reader=dataset_reader, frozen=frozen)
|
208 |
+
self.spacy_tokenizer = SpacyTokenizer(language='en_core_web_sm')
|
209 |
+
|
210 |
+
def economize(
|
211 |
+
self,
|
212 |
+
max_decoding_spans: Optional[int] = None,
|
213 |
+
max_recursion_depth: Optional[int] = None,
|
214 |
+
):
|
215 |
+
if max_decoding_spans:
|
216 |
+
self._model._max_decoding_spans = max_decoding_spans
|
217 |
+
if max_recursion_depth:
|
218 |
+
self._model._max_recursion_depth = max_recursion_depth
|
219 |
+
|
220 |
+
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
|
221 |
+
return self._dataset_reader.text_to_instance(**json_dict)
|
222 |
+
|
223 |
+
@staticmethod
|
224 |
+
def to_nested(prediction: List[dict]):
|
225 |
+
first_layer, idx2children = list(), dict()
|
226 |
+
for idx, pred in enumerate(prediction):
|
227 |
+
children = list()
|
228 |
+
pred['children'] = idx2children[idx+1] = children
|
229 |
+
if pred['parent'] == 0:
|
230 |
+
first_layer.append(pred)
|
231 |
+
else:
|
232 |
+
idx2children[pred['parent']].append(pred)
|
233 |
+
del pred['parent']
|
234 |
+
return first_layer
|
235 |
+
|
236 |
+
def _prepare_sentence(self, sentence: Union[str, List[str]]) -> Dict[str, List[str]]:
|
237 |
+
if isinstance(sentence, str):
|
238 |
+
while ' ' in sentence:
|
239 |
+
sentence = sentence.replace(' ', ' ')
|
240 |
+
sentence = sentence.replace(chr(65533), '')
|
241 |
+
if sentence == '':
|
242 |
+
sentence = [""]
|
243 |
+
sentence = list(map(str, self.spacy_tokenizer.tokenize(sentence)))
|
244 |
+
return {"tokens": sentence}
|
245 |
+
|
246 |
+
@staticmethod
|
247 |
+
def json_to_concrete(
|
248 |
+
predictions: List[dict],
|
249 |
+
):
|
250 |
+
sentences = list()
|
251 |
+
for pred in predictions:
|
252 |
+
tokenization, event = list(), list()
|
253 |
+
sent = {'text': ' '.join(pred['inputs']), 'tokenization': tokenization, 'event': event}
|
254 |
+
sentences.append(sent)
|
255 |
+
start_idx = 0
|
256 |
+
for token in pred['inputs']:
|
257 |
+
tokenization.append((start_idx, len(token)-1+start_idx))
|
258 |
+
start_idx += len(token) + 1
|
259 |
+
for pred_event in pred['prediction']:
|
260 |
+
arg_list = list()
|
261 |
+
one_event = {'argument': arg_list}
|
262 |
+
event.append(one_event)
|
263 |
+
for key in ['start_idx', 'end_idx', 'label']:
|
264 |
+
one_event[key] = pred_event[key]
|
265 |
+
for pred_arg in pred_event['children']:
|
266 |
+
arg_list.append({key: pred_arg[key] for key in ['start_idx', 'end_idx', 'label']})
|
267 |
+
|
268 |
+
concrete_comm = concrete_doc(sentences)
|
269 |
+
return concrete_comm
|
270 |
+
|
271 |
+
def force_decode(
|
272 |
+
self,
|
273 |
+
sentence: List[str],
|
274 |
+
parent_span: Tuple[int, int] = (-1, -1),
|
275 |
+
parent_label: str = VIRTUAL_ROOT,
|
276 |
+
child_spans: Optional[List[Tuple[int, int]]] = None,
|
277 |
+
) -> ForceDecodingReturn:
|
278 |
+
"""
|
279 |
+
Force decoding. There are 2 modes:
|
280 |
+
1. Given parent span and its label, find all it children (direct children, not including other descendents)
|
281 |
+
and type them.
|
282 |
+
2. Given parent span, parent label, and children spans, type all children.
|
283 |
+
:param sentence: Tokens.
|
284 |
+
:param parent_span: [start_idx, end_idx], both inclusive.
|
285 |
+
:param parent_label: Parent label in string.
|
286 |
+
:param child_spans: Optional. If provided, will turn to mode 2; else mode 1.
|
287 |
+
:return:
|
288 |
+
- span: children spans.
|
289 |
+
- label: most probable labels of children.
|
290 |
+
- distribution: distribution over children labels.
|
291 |
+
"""
|
292 |
+
instance = self._dataset_reader.text_to_instance(self._prepare_sentence(sentence)['tokens'])
|
293 |
+
model_input = nn_util.move_to_device(
|
294 |
+
list(SimpleDataLoader([instance], 1, vocab=self.vocab))[0], device=self.cuda_device
|
295 |
+
)
|
296 |
+
offsets = instance.fields['raw_inputs'].metadata['offsets']
|
297 |
+
|
298 |
+
with torch.no_grad():
|
299 |
+
tokens = model_input['tokens']
|
300 |
+
parent_span = re_index_span(parent_span, offsets)
|
301 |
+
if parent_span[1] >= self._dataset_reader.max_length:
|
302 |
+
return ForceDecodingReturn(
|
303 |
+
np.zeros([0, 2], dtype=np.int),
|
304 |
+
[],
|
305 |
+
np.zeros([0, self.vocab.get_vocab_size('span_label')], dtype=np.float64)
|
306 |
+
)
|
307 |
+
if child_spans is not None:
|
308 |
+
token_vec = self._model.word_embedding(tokens)
|
309 |
+
child_pieces = [re_index_span(bdr, offsets) for bdr in child_spans]
|
310 |
+
child_pieces = list(filter(lambda x: x[1] < self._dataset_reader.max_length-1, child_pieces))
|
311 |
+
span_tensor = torch.tensor(
|
312 |
+
[parent_span] + child_pieces, dtype=torch.int64, device=self.device
|
313 |
+
).unsqueeze(0)
|
314 |
+
parent_indices = span_tensor.new_zeros(span_tensor.shape[0:2])
|
315 |
+
span_labels = parent_indices.new_full(
|
316 |
+
parent_indices.shape, self._model.vocab.get_token_index(parent_label, 'span_label')
|
317 |
+
)
|
318 |
+
span_vec = self._model._span_extractor(token_vec, span_tensor)
|
319 |
+
typing_out = self._model._span_typing(span_vec, parent_indices, span_labels)
|
320 |
+
distribution = typing_out['distribution'][0, 1:].cpu().numpy()
|
321 |
+
boundary = np.array(child_spans)
|
322 |
+
else:
|
323 |
+
parent_label_tensor = torch.tensor(
|
324 |
+
[self._model.vocab.get_token_index(parent_label, 'span_label')], device=self.device
|
325 |
+
)
|
326 |
+
parent_boundary_tensor = torch.tensor([parent_span], device=self.device)
|
327 |
+
boundary, _, num_children, distribution = self._model.one_step_prediction(
|
328 |
+
tokens, parent_boundary_tensor, parent_label_tensor
|
329 |
+
)
|
330 |
+
boundary, distribution = boundary[0].cpu().tolist(), distribution[0].cpu().numpy()
|
331 |
+
boundary = np.array([re_index_span(bdr, offsets, True) for bdr in boundary])
|
332 |
+
|
333 |
+
labels = [
|
334 |
+
self.vocab.get_token_from_index(label_idx, 'span_label') for label_idx in distribution.argmax(1)
|
335 |
+
]
|
336 |
+
return ForceDecodingReturn(boundary, labels, distribution)
|
337 |
+
|
338 |
+
@property
|
339 |
+
def vocab(self):
|
340 |
+
return self._model.vocab
|
341 |
+
|
342 |
+
@property
|
343 |
+
def device(self):
|
344 |
+
return self.cuda_device if self.cuda_device > -1 else 'cpu'
|
345 |
+
|
346 |
+
@staticmethod
|
347 |
+
def read_ontology_mapping(file_path: str):
|
348 |
+
"""
|
349 |
+
Read the ontology mapping file. The file format can be read in docs.
|
350 |
+
"""
|
351 |
+
if file_path is None:
|
352 |
+
return None
|
353 |
+
if file_path.endswith('.json'):
|
354 |
+
return json.load(open(file_path))
|
355 |
+
mapping = dict()
|
356 |
+
for line in open(file_path).readlines():
|
357 |
+
parent_label, original_label, new_label = line.replace('\n', '').split('\t')
|
358 |
+
if parent_label == '*':
|
359 |
+
mapping[original_label] = new_label
|
360 |
+
else:
|
361 |
+
mapping[(parent_label, original_label)] = new_label
|
362 |
+
return mapping
|
sftp/predictor/span_predictor.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from time import time
|
3 |
+
from typing import *
|
4 |
+
import json
|
5 |
+
|
6 |
+
# # ---GFM add debugger
|
7 |
+
# import pdb
|
8 |
+
# # end---
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from allennlp.common.util import JsonDict, sanitize
|
13 |
+
from allennlp.data import DatasetReader, Instance
|
14 |
+
from allennlp.data.data_loaders import SimpleDataLoader
|
15 |
+
from allennlp.data.samplers import MaxTokensBatchSampler
|
16 |
+
from allennlp.data.tokenizers import SpacyTokenizer
|
17 |
+
from allennlp.models import Model
|
18 |
+
from allennlp.nn import util as nn_util
|
19 |
+
from allennlp.predictors import Predictor
|
20 |
+
from concrete import (
|
21 |
+
MentionArgument, SituationMentionSet, SituationMention, TokenRefSequence,
|
22 |
+
EntityMention, EntityMentionSet, Entity, EntitySet, AnnotationMetadata, Communication
|
23 |
+
)
|
24 |
+
from concrete.util import CommunicationReader, AnalyticUUIDGeneratorFactory, CommunicationWriterZip
|
25 |
+
from concrete.validate import validate_communication
|
26 |
+
|
27 |
+
from ..data_reader import concrete_doc, concrete_doc_tokenized
|
28 |
+
from ..utils import Span, re_index_span, VIRTUAL_ROOT
|
29 |
+
|
30 |
+
|
31 |
+
class PredictionReturn(NamedTuple):
|
32 |
+
span: Union[Span, dict, Communication]
|
33 |
+
sentence: List[str]
|
34 |
+
meta: Dict[str, Any]
|
35 |
+
|
36 |
+
|
37 |
+
class ForceDecodingReturn(NamedTuple):
|
38 |
+
span: np.ndarray
|
39 |
+
label: List[str]
|
40 |
+
distribution: np.ndarray
|
41 |
+
|
42 |
+
|
43 |
+
@Predictor.register('span')
|
44 |
+
class SpanPredictor(Predictor):
|
45 |
+
@staticmethod
|
46 |
+
def format_convert(
|
47 |
+
sentence: Union[List[str], List[List[str]]],
|
48 |
+
prediction: Union[Span, List[Span]],
|
49 |
+
output_format: str
|
50 |
+
):
|
51 |
+
if output_format == 'span':
|
52 |
+
return prediction
|
53 |
+
elif output_format == 'json':
|
54 |
+
if isinstance(prediction, list):
|
55 |
+
return [SpanPredictor.format_convert(sent, pred, 'json') for sent, pred in zip(sentence, prediction)]
|
56 |
+
return prediction.to_json()
|
57 |
+
elif output_format == 'concrete':
|
58 |
+
if isinstance(prediction, Span):
|
59 |
+
sentence, prediction = [sentence], [prediction]
|
60 |
+
return concrete_doc_tokenized(sentence, prediction)
|
61 |
+
|
62 |
+
def predict_concrete(
|
63 |
+
self,
|
64 |
+
concrete_path: str,
|
65 |
+
output_path: Optional[str] = None,
|
66 |
+
max_tokens: int = 2048,
|
67 |
+
ontology_mapping: Optional[Dict[str, str]] = None,
|
68 |
+
):
|
69 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
70 |
+
writer = CommunicationWriterZip(output_path)
|
71 |
+
|
72 |
+
print(concrete_path)
|
73 |
+
for comm, fn in CommunicationReader(concrete_path):
|
74 |
+
print(fn)
|
75 |
+
assert len(comm.sectionList) == 1
|
76 |
+
concrete_sentences = comm.sectionList[0].sentenceList
|
77 |
+
json_sentences = list()
|
78 |
+
for con_sent in concrete_sentences:
|
79 |
+
json_sentences.append(
|
80 |
+
[t.text for t in con_sent.tokenization.tokenList.tokenList]
|
81 |
+
)
|
82 |
+
predictions = self.predict_batch_sentences(json_sentences, max_tokens, ontology_mapping=ontology_mapping)
|
83 |
+
|
84 |
+
# Merge predictions into concrete
|
85 |
+
aug = AnalyticUUIDGeneratorFactory(comm).create()
|
86 |
+
situation_mention_set = SituationMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list())
|
87 |
+
comm.situationMentionSetList = [situation_mention_set]
|
88 |
+
situation_mention_set.mentionList = sm_list = list()
|
89 |
+
entity_mention_set = EntityMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list())
|
90 |
+
comm.entityMentionSetList = [entity_mention_set]
|
91 |
+
entity_mention_set.mentionList = em_list = list()
|
92 |
+
entity_set = EntitySet(
|
93 |
+
next(aug), AnnotationMetadata('Span Finder', time()), list(), None, entity_mention_set.uuid
|
94 |
+
)
|
95 |
+
comm.entitySetList = [entity_set]
|
96 |
+
|
97 |
+
em_dict = dict()
|
98 |
+
for con_sent, pred in zip(concrete_sentences, predictions):
|
99 |
+
for event in pred.span:
|
100 |
+
def raw_text_span(start_idx, end_idx, **_):
|
101 |
+
si_char = con_sent.tokenization.tokenList.tokenList[start_idx].textSpan.start
|
102 |
+
ei_char = con_sent.tokenization.tokenList.tokenList[end_idx].textSpan.ending
|
103 |
+
return comm.text[si_char:ei_char]
|
104 |
+
|
105 |
+
# ---GFM: added this to get around off-by-one errors (unclear why these arise)
|
106 |
+
event_start_idx = event.start_idx
|
107 |
+
event_end_idx = event.end_idx
|
108 |
+
if event_end_idx > len(con_sent.tokenization.tokenList.tokenList) - 1:
|
109 |
+
print("WARNING: invalid `event_end_idx` passed for sentence, adjusting to final token")
|
110 |
+
print("\tsentence:", con_sent.tokenization.tokenList)
|
111 |
+
print("event_end_idx:", event_end_idx)
|
112 |
+
print("length:", len(con_sent.tokenization.tokenList.tokenList))
|
113 |
+
event_end_idx = len(con_sent.tokenization.tokenList.tokenList) - 1
|
114 |
+
print("new event_end_idx:", event_end_idx)
|
115 |
+
print()
|
116 |
+
# end---
|
117 |
+
|
118 |
+
sm = SituationMention(
|
119 |
+
next(aug),
|
120 |
+
# ---GFM: added this to get around off-by-one errors (unclear why these arise)
|
121 |
+
text=raw_text_span(event_start_idx, event_end_idx),
|
122 |
+
# end---
|
123 |
+
situationKind=event.label,
|
124 |
+
situationType='EVENT',
|
125 |
+
confidence=event.confidence,
|
126 |
+
argumentList=list(),
|
127 |
+
tokens=TokenRefSequence(
|
128 |
+
# ---GFM: added this to get around off-by-one errors (unclear why these arise)
|
129 |
+
tokenIndexList=list(range(event_start_idx, event_end_idx+1)),
|
130 |
+
# end---
|
131 |
+
tokenizationId=con_sent.tokenization.uuid
|
132 |
+
)
|
133 |
+
)
|
134 |
+
|
135 |
+
for arg in event:
|
136 |
+
# ---GFM: added this to get around off-by-one errors (unclear why these arise)
|
137 |
+
arg_start_idx = arg.start_idx
|
138 |
+
arg_end_idx = arg.end_idx
|
139 |
+
if arg_end_idx > len(con_sent.tokenization.tokenList.tokenList) - 1:
|
140 |
+
print("WARNING: invalid `arg_end_idx` passed for sentence, adjusting to final token")
|
141 |
+
print("\tsentence:", con_sent.tokenization.tokenList)
|
142 |
+
print("arg_end_idx:", arg_end_idx)
|
143 |
+
print("length:", len(con_sent.tokenization.tokenList.tokenList))
|
144 |
+
arg_end_idx = len(con_sent.tokenization.tokenList.tokenList) - 1
|
145 |
+
print("new arg_end_idx:", arg_end_idx)
|
146 |
+
print()
|
147 |
+
# end---
|
148 |
+
|
149 |
+
# ---GFM: replaced all arg.*_idx to arg_*_idx
|
150 |
+
em = em_dict.get((arg_start_idx, arg_end_idx + 1))
|
151 |
+
if em is None:
|
152 |
+
em = EntityMention(
|
153 |
+
next(aug),
|
154 |
+
tokens=TokenRefSequence(
|
155 |
+
tokenIndexList=list(range(arg_start_idx, arg_end_idx+1)),
|
156 |
+
tokenizationId=con_sent.tokenization.uuid,
|
157 |
+
),
|
158 |
+
text=raw_text_span(arg_start_idx, arg_end_idx)
|
159 |
+
)
|
160 |
+
em_list.append(em)
|
161 |
+
entity_set.entityList.append(Entity(next(aug), id=em.text, mentionIdList=[em.uuid]))
|
162 |
+
em_dict[(arg_start_idx, arg_end_idx+1)] = em
|
163 |
+
sm.argumentList.append(MentionArgument(
|
164 |
+
role=arg.label,
|
165 |
+
entityMentionId=em.uuid,
|
166 |
+
confidence=arg.confidence
|
167 |
+
))
|
168 |
+
# end---
|
169 |
+
sm_list.append(sm)
|
170 |
+
validate_communication(comm)
|
171 |
+
writer.write(comm, fn)
|
172 |
+
writer.close()
|
173 |
+
|
174 |
+
def predict_sentence(
|
175 |
+
self,
|
176 |
+
sentence: Union[str, List[str]],
|
177 |
+
ontology_mapping: Optional[Dict[str, str]] = None,
|
178 |
+
output_format: str = 'span',
|
179 |
+
) -> PredictionReturn:
|
180 |
+
"""
|
181 |
+
Predict spans on a single sentence (no batch). If not tokenized, will tokenize it with SpacyTokenizer.
|
182 |
+
:param sentence: If tokenized, should be a list of tokens in string. If not, should be a string.
|
183 |
+
:param ontology_mapping:
|
184 |
+
:param output_format: span, json or concrete.
|
185 |
+
"""
|
186 |
+
prediction = self.predict_json(self._prepare_sentence(sentence))
|
187 |
+
prediction['prediction'] = self.format_convert(
|
188 |
+
prediction['sentence'],
|
189 |
+
Span.from_json(prediction['prediction']).map_ontology(ontology_mapping),
|
190 |
+
output_format
|
191 |
+
)
|
192 |
+
return PredictionReturn(prediction['prediction'], prediction['sentence'], prediction.get('meta', dict()))
|
193 |
+
|
194 |
+
def predict_batch_sentences(
|
195 |
+
self,
|
196 |
+
sentences: List[Union[List[str], str]],
|
197 |
+
max_tokens: int = 512,
|
198 |
+
ontology_mapping: Optional[Dict[str, str]] = None,
|
199 |
+
output_format: str = 'span',
|
200 |
+
) -> List[PredictionReturn]:
|
201 |
+
"""
|
202 |
+
Predict spans on a batch of sentences. If not tokenized, will tokenize it with SpacyTokenizer.
|
203 |
+
:param sentences: A list of sentences. Refer to `predict_sentence`.
|
204 |
+
:param max_tokens: Maximum tokens in a batch.
|
205 |
+
:param ontology_mapping: If not None, will try to map the output from one ontology to another.
|
206 |
+
If the predicted frame is not in the mapping, the prediction will be ignored.
|
207 |
+
:param output_format: span, json or concrete.
|
208 |
+
:return: A list of predictions.
|
209 |
+
"""
|
210 |
+
sentences = list(map(self._prepare_sentence, sentences))
|
211 |
+
for i_sent, sent in enumerate(sentences):
|
212 |
+
sent['meta'] = {"idx": i_sent}
|
213 |
+
instances = list(map(self._json_to_instance, sentences))
|
214 |
+
outputs = list()
|
215 |
+
for ins_indices in MaxTokensBatchSampler(max_tokens, ["tokens"], 0.0).get_batch_indices(instances):
|
216 |
+
batch_ins = list(
|
217 |
+
SimpleDataLoader([instances[ins_idx] for ins_idx in ins_indices], len(ins_indices), vocab=self.vocab)
|
218 |
+
)[0]
|
219 |
+
batch_inputs = nn_util.move_to_device(batch_ins, device=self.cuda_device)
|
220 |
+
batch_outputs = self._model(**batch_inputs)
|
221 |
+
for meta, prediction, inputs in zip(
|
222 |
+
batch_outputs['meta'], batch_outputs['prediction'], batch_outputs['inputs']
|
223 |
+
):
|
224 |
+
prediction.map_ontology(ontology_mapping)
|
225 |
+
prediction = self.format_convert(inputs['sentence'], prediction, output_format)
|
226 |
+
outputs.append(PredictionReturn(prediction, inputs['sentence'], {"input_idx": meta['idx']}))
|
227 |
+
|
228 |
+
outputs.sort(key=lambda x: x.meta['input_idx'])
|
229 |
+
return outputs
|
230 |
+
|
231 |
+
def predict_instance(self, instance: Instance) -> JsonDict:
|
232 |
+
outputs = self._model.forward_on_instance(instance)
|
233 |
+
outputs = sanitize(outputs)
|
234 |
+
return {
|
235 |
+
'prediction': outputs['prediction'],
|
236 |
+
'sentence': outputs['inputs']['sentence'],
|
237 |
+
'meta': outputs.get('meta', {})
|
238 |
+
}
|
239 |
+
|
240 |
+
def __init__(
|
241 |
+
self,
|
242 |
+
model: Model,
|
243 |
+
dataset_reader: DatasetReader,
|
244 |
+
frozen: bool = True,
|
245 |
+
):
|
246 |
+
super(SpanPredictor, self).__init__(model=model, dataset_reader=dataset_reader, frozen=frozen)
|
247 |
+
self.spacy_tokenizer = SpacyTokenizer(language='en_core_web_sm')
|
248 |
+
|
249 |
+
def economize(
|
250 |
+
self,
|
251 |
+
max_decoding_spans: Optional[int] = None,
|
252 |
+
max_recursion_depth: Optional[int] = None,
|
253 |
+
):
|
254 |
+
if max_decoding_spans:
|
255 |
+
self._model._max_decoding_spans = max_decoding_spans
|
256 |
+
if max_recursion_depth:
|
257 |
+
self._model._max_recursion_depth = max_recursion_depth
|
258 |
+
|
259 |
+
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
|
260 |
+
return self._dataset_reader.text_to_instance(**json_dict)
|
261 |
+
|
262 |
+
@staticmethod
|
263 |
+
def to_nested(prediction: List[dict]):
|
264 |
+
first_layer, idx2children = list(), dict()
|
265 |
+
for idx, pred in enumerate(prediction):
|
266 |
+
children = list()
|
267 |
+
pred['children'] = idx2children[idx+1] = children
|
268 |
+
if pred['parent'] == 0:
|
269 |
+
first_layer.append(pred)
|
270 |
+
else:
|
271 |
+
idx2children[pred['parent']].append(pred)
|
272 |
+
del pred['parent']
|
273 |
+
return first_layer
|
274 |
+
|
275 |
+
def _prepare_sentence(self, sentence: Union[str, List[str]]) -> Dict[str, List[str]]:
|
276 |
+
if isinstance(sentence, str):
|
277 |
+
while ' ' in sentence:
|
278 |
+
sentence = sentence.replace(' ', ' ')
|
279 |
+
sentence = sentence.replace(chr(65533), '')
|
280 |
+
if sentence == '':
|
281 |
+
sentence = [""]
|
282 |
+
sentence = list(map(str, self.spacy_tokenizer.tokenize(sentence)))
|
283 |
+
return {"tokens": sentence}
|
284 |
+
|
285 |
+
@staticmethod
|
286 |
+
def json_to_concrete(
|
287 |
+
predictions: List[dict],
|
288 |
+
):
|
289 |
+
sentences = list()
|
290 |
+
for pred in predictions:
|
291 |
+
tokenization, event = list(), list()
|
292 |
+
sent = {'text': ' '.join(pred['inputs']), 'tokenization': tokenization, 'event': event}
|
293 |
+
sentences.append(sent)
|
294 |
+
start_idx = 0
|
295 |
+
for token in pred['inputs']:
|
296 |
+
tokenization.append((start_idx, len(token)-1+start_idx))
|
297 |
+
start_idx += len(token) + 1
|
298 |
+
for pred_event in pred['prediction']:
|
299 |
+
arg_list = list()
|
300 |
+
one_event = {'argument': arg_list}
|
301 |
+
event.append(one_event)
|
302 |
+
for key in ['start_idx', 'end_idx', 'label']:
|
303 |
+
one_event[key] = pred_event[key]
|
304 |
+
for pred_arg in pred_event['children']:
|
305 |
+
arg_list.append({key: pred_arg[key] for key in ['start_idx', 'end_idx', 'label']})
|
306 |
+
|
307 |
+
concrete_comm = concrete_doc(sentences)
|
308 |
+
return concrete_comm
|
309 |
+
|
310 |
+
def force_decode(
|
311 |
+
self,
|
312 |
+
sentence: List[str],
|
313 |
+
parent_span: Tuple[int, int] = (-1, -1),
|
314 |
+
parent_label: str = VIRTUAL_ROOT,
|
315 |
+
child_spans: Optional[List[Tuple[int, int]]] = None,
|
316 |
+
) -> ForceDecodingReturn:
|
317 |
+
"""
|
318 |
+
Force decoding. There are 2 modes:
|
319 |
+
1. Given parent span and its label, find all it children (direct children, not including other descendents)
|
320 |
+
and type them.
|
321 |
+
2. Given parent span, parent label, and children spans, type all children.
|
322 |
+
:param sentence: Tokens.
|
323 |
+
:param parent_span: [start_idx, end_idx], both inclusive.
|
324 |
+
:param parent_label: Parent label in string.
|
325 |
+
:param child_spans: Optional. If provided, will turn to mode 2; else mode 1.
|
326 |
+
:return:
|
327 |
+
- span: children spans.
|
328 |
+
- label: most probable labels of children.
|
329 |
+
- distribution: distribution over children labels.
|
330 |
+
"""
|
331 |
+
instance = self._dataset_reader.text_to_instance(self._prepare_sentence(sentence)['tokens'])
|
332 |
+
model_input = nn_util.move_to_device(
|
333 |
+
list(SimpleDataLoader([instance], 1, vocab=self.vocab))[0], device=self.cuda_device
|
334 |
+
)
|
335 |
+
offsets = instance.fields['raw_inputs'].metadata['offsets']
|
336 |
+
|
337 |
+
with torch.no_grad():
|
338 |
+
tokens = model_input['tokens']
|
339 |
+
parent_span = re_index_span(parent_span, offsets)
|
340 |
+
if parent_span[1] >= self._dataset_reader.max_length:
|
341 |
+
return ForceDecodingReturn(
|
342 |
+
np.zeros([0, 2], dtype=np.int),
|
343 |
+
[],
|
344 |
+
np.zeros([0, self.vocab.get_vocab_size('span_label')], dtype=np.float64)
|
345 |
+
)
|
346 |
+
if child_spans is not None:
|
347 |
+
token_vec = self._model.word_embedding(tokens)
|
348 |
+
child_pieces = [re_index_span(bdr, offsets) for bdr in child_spans]
|
349 |
+
child_pieces = list(filter(lambda x: x[1] < self._dataset_reader.max_length-1, child_pieces))
|
350 |
+
span_tensor = torch.tensor(
|
351 |
+
[parent_span] + child_pieces, dtype=torch.int64, device=self.device
|
352 |
+
).unsqueeze(0)
|
353 |
+
parent_indices = span_tensor.new_zeros(span_tensor.shape[0:2])
|
354 |
+
span_labels = parent_indices.new_full(
|
355 |
+
parent_indices.shape, self._model.vocab.get_token_index(parent_label, 'span_label')
|
356 |
+
)
|
357 |
+
span_vec = self._model._span_extractor(token_vec, span_tensor)
|
358 |
+
typing_out = self._model._span_typing(span_vec, parent_indices, span_labels)
|
359 |
+
distribution = typing_out['distribution'][0, 1:].cpu().numpy()
|
360 |
+
boundary = np.array(child_spans)
|
361 |
+
else:
|
362 |
+
parent_label_tensor = torch.tensor(
|
363 |
+
[self._model.vocab.get_token_index(parent_label, 'span_label')], device=self.device
|
364 |
+
)
|
365 |
+
parent_boundary_tensor = torch.tensor([parent_span], device=self.device)
|
366 |
+
boundary, _, num_children, distribution = self._model.one_step_prediction(
|
367 |
+
tokens, parent_boundary_tensor, parent_label_tensor
|
368 |
+
)
|
369 |
+
boundary, distribution = boundary[0].cpu().tolist(), distribution[0].cpu().numpy()
|
370 |
+
boundary = np.array([re_index_span(bdr, offsets, True) for bdr in boundary])
|
371 |
+
|
372 |
+
labels = [
|
373 |
+
self.vocab.get_token_from_index(label_idx, 'span_label') for label_idx in distribution.argmax(1)
|
374 |
+
]
|
375 |
+
return ForceDecodingReturn(boundary, labels, distribution)
|
376 |
+
|
377 |
+
@property
|
378 |
+
def vocab(self):
|
379 |
+
return self._model.vocab
|
380 |
+
|
381 |
+
@property
|
382 |
+
def device(self):
|
383 |
+
return self.cuda_device if self.cuda_device > -1 else 'cpu'
|
384 |
+
|
385 |
+
@staticmethod
|
386 |
+
def read_ontology_mapping(file_path: str):
|
387 |
+
"""
|
388 |
+
Read the ontology mapping file. The file format can be read in docs.
|
389 |
+
"""
|
390 |
+
if file_path is None:
|
391 |
+
return None
|
392 |
+
if file_path.endswith('.json'):
|
393 |
+
return json.load(open(file_path))
|
394 |
+
mapping = dict()
|
395 |
+
for line in open(file_path).readlines():
|
396 |
+
parent_label, original_label, new_label = line.replace('\n', '').split('\t')
|
397 |
+
if parent_label == '*':
|
398 |
+
mapping[original_label] = new_label
|
399 |
+
else:
|
400 |
+
mapping[(parent_label, original_label)] = new_label
|
401 |
+
return mapping
|
sftp/training/__init__.py
ADDED
File without changes
|
sftp/training/transformer_optimizer.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import re
|
3 |
+
from typing import *
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from allennlp.common.from_params import Params, T
|
7 |
+
from allennlp.training.optimizers import Optimizer
|
8 |
+
|
9 |
+
logger = logging.getLogger('optim')
|
10 |
+
|
11 |
+
|
12 |
+
@Optimizer.register('transformer')
|
13 |
+
class TransformerOptimizer:
|
14 |
+
"""
|
15 |
+
Wrapper for AllenNLP optimizer.
|
16 |
+
This is used to fine-tune the pretrained transformer with some layers fixed and different learning rate.
|
17 |
+
When some layers are fixed, the wrapper will set the `require_grad` flag as False, which could save
|
18 |
+
training time and optimize memory usage.
|
19 |
+
Plz contact Guanghui Qin for bugs.
|
20 |
+
Params:
|
21 |
+
base: base optimizer.
|
22 |
+
embeddings_lr: learning rate for embedding layer. Set as 0.0 to fix it.
|
23 |
+
encoder_lr: learning rate for encoder layer. Set as 0.0 to fix it.
|
24 |
+
pooler_lr: learning rate for pooler layer. Set as 0.0 to fix it.
|
25 |
+
layer_fix: the number of encoder layers that should be fixed.
|
26 |
+
|
27 |
+
Example json config:
|
28 |
+
|
29 |
+
1. No-op. Do nothing (why do you use me?)
|
30 |
+
optimizer: {
|
31 |
+
type: "transformer",
|
32 |
+
base: {
|
33 |
+
type: "adam",
|
34 |
+
lr: 0.001
|
35 |
+
}
|
36 |
+
}
|
37 |
+
|
38 |
+
2. Fix everything in the transformer.
|
39 |
+
optimizer: {
|
40 |
+
type: "transformer",
|
41 |
+
base: {
|
42 |
+
type: "adam",
|
43 |
+
lr: 0.001
|
44 |
+
},
|
45 |
+
embeddings_lr: 0.0,
|
46 |
+
encoder_lr: 0.0,
|
47 |
+
pooler_lr: 0.0
|
48 |
+
}
|
49 |
+
|
50 |
+
Or equivalently (suppose we have 24 layers)
|
51 |
+
|
52 |
+
optimizer: {
|
53 |
+
type: "transformer",
|
54 |
+
base: {
|
55 |
+
type: "adam",
|
56 |
+
lr: 0.001
|
57 |
+
},
|
58 |
+
embeddings_lr: 0.0,
|
59 |
+
layer_fix: 24,
|
60 |
+
pooler_lr: 0.0
|
61 |
+
}
|
62 |
+
|
63 |
+
3. Fix embeddings and the lower 12 encoder layers, set a small learning rate
|
64 |
+
for the other parts of the transformer
|
65 |
+
|
66 |
+
optimizer: {
|
67 |
+
type: "transformer",
|
68 |
+
base: {
|
69 |
+
type: "adam",
|
70 |
+
lr: 0.001
|
71 |
+
},
|
72 |
+
embeddings_lr: 0.0,
|
73 |
+
layer_fix: 12,
|
74 |
+
encoder_lr: 1e-5,
|
75 |
+
pooler_lr: 1e-5
|
76 |
+
}
|
77 |
+
"""
|
78 |
+
@classmethod
|
79 |
+
def from_params(
|
80 |
+
cls: Type[T],
|
81 |
+
params: Params,
|
82 |
+
model_parameters: List[Tuple[str, torch.nn.Parameter]],
|
83 |
+
**_
|
84 |
+
):
|
85 |
+
param_groups = list()
|
86 |
+
|
87 |
+
def remove_param(keyword_):
|
88 |
+
nonlocal model_parameters
|
89 |
+
logger.info(f'Fix param with name matching {keyword_}.')
|
90 |
+
for name, param in model_parameters:
|
91 |
+
if keyword_ in name:
|
92 |
+
logger.debug(f'Fix param {name}.')
|
93 |
+
param.requires_grad_(False)
|
94 |
+
model_parameters = list(filter(lambda x: keyword_ not in x[0], model_parameters))
|
95 |
+
|
96 |
+
for i_layer in range(params.pop('layer_fix')):
|
97 |
+
remove_param('transformer_model.encoder.layer.{}.'.format(i_layer))
|
98 |
+
|
99 |
+
for specific_lr, keyword in (
|
100 |
+
(params.pop('embeddings_lr', None), 'transformer_model.embeddings'),
|
101 |
+
(params.pop('encoder_lr', None), 'transformer_model.encoder.layer'),
|
102 |
+
(params.pop('pooler_lr', None), 'transformer_model.pooler'),
|
103 |
+
):
|
104 |
+
if specific_lr is not None:
|
105 |
+
if specific_lr > 0.:
|
106 |
+
pattern = '.*' + keyword.replace('.', r'\.') + '.*'
|
107 |
+
if len([name for name, _ in model_parameters if re.match(pattern, name)]) > 0:
|
108 |
+
param_groups.append([[pattern], {'lr': specific_lr}])
|
109 |
+
else:
|
110 |
+
logger.warning(f'{pattern} is set to use lr {specific_lr} but no param matches.')
|
111 |
+
else:
|
112 |
+
remove_param(keyword)
|
113 |
+
|
114 |
+
if 'parameter_groups' in params:
|
115 |
+
for pg in params.pop('parameter_groups'):
|
116 |
+
param_groups.append([pg[0], pg[1].as_dict()])
|
117 |
+
|
118 |
+
return Optimizer.by_name(params.get('base').pop('type'))(
|
119 |
+
model_parameters=model_parameters, parameter_groups=param_groups,
|
120 |
+
**params.pop('base').as_flat_dict()
|
121 |
+
)
|
sftp/utils/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sftp.utils.label_smoothing
|
2 |
+
from sftp.utils.common import VIRTUAL_ROOT, DEFAULT_SPAN, BIO
|
3 |
+
from sftp.utils.db_storage import Cache
|
4 |
+
from sftp.utils.functions import num2mask, mask2idx, numpy2torch, one_hot, max_match
|
5 |
+
from sftp.utils.span import Span, re_index_span
|
6 |
+
from sftp.utils.span_utils import tensor2span
|
7 |
+
from sftp.utils.bio_smoothing import BIOSmoothing, apply_bio_smoothing
|
sftp/utils/bio_smoothing.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from .common import BIO
|
5 |
+
|
6 |
+
|
7 |
+
class BIOSmoothing:
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
b_smooth: float = 0.0,
|
11 |
+
i_smooth: float = 0.0,
|
12 |
+
o_smooth: float = 0.0,
|
13 |
+
weight: float = 1.0
|
14 |
+
):
|
15 |
+
self.smooth = [b_smooth, i_smooth, o_smooth]
|
16 |
+
self.weight = weight
|
17 |
+
|
18 |
+
def apply_sequence(self, sequence: List[str]):
|
19 |
+
bio_tags = np.zeros([len(sequence), 3], np.float32)
|
20 |
+
for i, tag in enumerate(sequence):
|
21 |
+
bio_tags[i] = self.apply_tag(tag)
|
22 |
+
return bio_tags
|
23 |
+
|
24 |
+
def apply_tag(self, tag: str):
|
25 |
+
j = BIO.index(tag)
|
26 |
+
ret = np.zeros([3], np.float32)
|
27 |
+
if self.smooth[j] >= 0.0:
|
28 |
+
# Smooth
|
29 |
+
ret[j] = 1.0 - self.smooth[j]
|
30 |
+
for j_ in set(range(3)) - {j}:
|
31 |
+
ret[j_] = self.smooth[j] / 2
|
32 |
+
else:
|
33 |
+
# Marginalize
|
34 |
+
ret[:] = 1.0
|
35 |
+
|
36 |
+
return ret * self.weight
|
37 |
+
|
38 |
+
def __repr__(self):
|
39 |
+
ret = f'<W={self.weight:.2f}'
|
40 |
+
for j, tag in enumerate(BIO):
|
41 |
+
if self.smooth[j] != 0.0:
|
42 |
+
if self.smooth[j] < 0:
|
43 |
+
ret += f' [marginalize {tag}]'
|
44 |
+
else:
|
45 |
+
ret += f' [smooth {tag} by {self.smooth[j]:.2f}]'
|
46 |
+
return ret + '>'
|
47 |
+
|
48 |
+
def clone(self):
|
49 |
+
return BIOSmoothing(*self.smooth, self.weight)
|
50 |
+
|
51 |
+
|
52 |
+
def apply_bio_smoothing(
|
53 |
+
config: Optional[Union[BIOSmoothing, List[BIOSmoothing]]],
|
54 |
+
bio_seq: List[str]
|
55 |
+
) -> np.ndarray:
|
56 |
+
if config is None:
|
57 |
+
config = BIOSmoothing()
|
58 |
+
if isinstance(config, BIOSmoothing):
|
59 |
+
return config.apply_sequence(bio_seq)
|
60 |
+
else:
|
61 |
+
assert len(bio_seq) == len(config)
|
62 |
+
return np.stack([cfg.apply_tag(tag) for cfg, tag in zip(config, bio_seq)])
|
sftp/utils/common.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
DEFAULT_SPAN = '@@SPAN@@'
|
2 |
+
VIRTUAL_ROOT = '@@VIRTUAL_ROOT@@'
|
3 |
+
BIO = 'BIO'
|
sftp/utils/db_storage.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import h5py
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
class Cache:
|
9 |
+
def __init__(self, file: str, mode: str = 'a', overwrite=False):
|
10 |
+
self.db_file = h5py.File(file, mode=mode)
|
11 |
+
self.overwrite = overwrite
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
def _key(key):
|
15 |
+
if isinstance(key, str):
|
16 |
+
return key
|
17 |
+
elif isinstance(key, list):
|
18 |
+
ret = []
|
19 |
+
for k in key:
|
20 |
+
ret.append(Cache._key(k))
|
21 |
+
return ' '.join(ret)
|
22 |
+
else:
|
23 |
+
return str(key)
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def _value(value: np.ndarray):
|
27 |
+
if isinstance(value, h5py.Dataset):
|
28 |
+
value: np.ndarray = value[()]
|
29 |
+
if value.dtype.name.startswith('bytes'):
|
30 |
+
value = pickle.loads(value)
|
31 |
+
return value
|
32 |
+
|
33 |
+
def __getitem__(self, key):
|
34 |
+
key = self._key(key)
|
35 |
+
if key not in self:
|
36 |
+
raise KeyError
|
37 |
+
return self._value(self.db_file[key])
|
38 |
+
|
39 |
+
def __setitem__(self, key, value) -> None:
|
40 |
+
key = self._key(key)
|
41 |
+
if key in self:
|
42 |
+
del self.db_file[key]
|
43 |
+
if not isinstance(value, np.ndarray):
|
44 |
+
value = np.array(pickle.dumps(value))
|
45 |
+
self.db_file[key] = value
|
46 |
+
|
47 |
+
def __delitem__(self, key) -> None:
|
48 |
+
key = self._key(key)
|
49 |
+
if key in self:
|
50 |
+
del self.db_file[key]
|
51 |
+
|
52 |
+
def __len__(self) -> int:
|
53 |
+
return len(self.db_file)
|
54 |
+
|
55 |
+
def close(self) -> None:
|
56 |
+
self.db_file.close()
|
57 |
+
|
58 |
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
59 |
+
self.close()
|
60 |
+
|
61 |
+
def __contains__(self, item):
|
62 |
+
item = self._key(item)
|
63 |
+
return item in self.db_file
|
64 |
+
|
65 |
+
def __enter__(self):
|
66 |
+
return self
|
67 |
+
|
68 |
+
def __call__(self, function):
|
69 |
+
"""
|
70 |
+
The object of the class could also be used as a decorator. Provide an additional
|
71 |
+
argument `cache_id' when calling the function, and the results will be cached.
|
72 |
+
"""
|
73 |
+
|
74 |
+
def wrapper(*args, **kwargs):
|
75 |
+
if 'cache_id' in kwargs:
|
76 |
+
cache_id = kwargs['cache_id']
|
77 |
+
del kwargs['cache_id']
|
78 |
+
if cache_id in self and not self.overwrite:
|
79 |
+
return self[cache_id]
|
80 |
+
rst = function(*args, **kwargs)
|
81 |
+
self[cache_id] = rst
|
82 |
+
return rst
|
83 |
+
else:
|
84 |
+
warnings.warn("`cache_id' argument not found. Cache is disabled.")
|
85 |
+
return function(*args, **kwargs)
|
86 |
+
|
87 |
+
return wrapper
|
sftp/utils/functions.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from scipy.optimize import linear_sum_assignment
|
6 |
+
from torch.nn.utils.rnn import pad_sequence
|
7 |
+
|
8 |
+
|
9 |
+
def num2mask(
|
10 |
+
nums: torch.Tensor,
|
11 |
+
max_length: Optional[int] = None
|
12 |
+
) -> torch.Tensor:
|
13 |
+
"""
|
14 |
+
E.g. input a tensor [2, 3, 4], return [[T T F F], [T T T F], [T T T T]]
|
15 |
+
:param nums: Shape [batch]
|
16 |
+
:param max_length: maximum length. if not provided, will choose the largest number from nums.
|
17 |
+
:return: 2D binary mask.
|
18 |
+
"""
|
19 |
+
shape_backup = nums.shape
|
20 |
+
nums = nums.flatten()
|
21 |
+
max_length = max_length or int(nums.max())
|
22 |
+
batch_size = len(nums)
|
23 |
+
range_nums = torch.arange(0, max_length, device=nums.device).unsqueeze(0).expand([batch_size, max_length])
|
24 |
+
ret = (range_nums.T < nums).T
|
25 |
+
return ret.reshape(*shape_backup, max_length)
|
26 |
+
|
27 |
+
|
28 |
+
def mask2idx(
|
29 |
+
mask: torch.Tensor,
|
30 |
+
max_length: Optional[int] = None,
|
31 |
+
padding_value: int = 0,
|
32 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
33 |
+
"""
|
34 |
+
E.g. input a tensor [[T T F F], [T T T F], [F F F T]] with padding value -1,
|
35 |
+
return [[0, 1, -1], [0, 1, 2], [3, -1, -1]]
|
36 |
+
:param mask: Mask tensor. Boolean. Not necessarily to be 2D.
|
37 |
+
:param max_length: If provided, will truncate.
|
38 |
+
:param padding_value: Padding value. Default to 0.
|
39 |
+
:return: Index tensor.
|
40 |
+
"""
|
41 |
+
shape_prefix, mask_length = mask.shape[:-1], mask.shape[-1]
|
42 |
+
flat_mask = mask.flatten(0, -2)
|
43 |
+
index_list = [torch.arange(mask_length, device=mask.device)[one_mask] for one_mask in flat_mask.unbind(0)]
|
44 |
+
index_tensor = pad_sequence(index_list, batch_first=True, padding_value=padding_value)
|
45 |
+
if max_length is not None:
|
46 |
+
index_tensor = index_tensor[:, :max_length]
|
47 |
+
index_tensor = index_tensor.reshape(*shape_prefix, -1)
|
48 |
+
return index_tensor, mask.sum(-1)
|
49 |
+
|
50 |
+
|
51 |
+
def one_hot(tags: torch.Tensor, num_tags: Optional[int] = None) -> torch.Tensor:
|
52 |
+
num_tags = num_tags or int(tags.max())
|
53 |
+
ret = tags.new_zeros(size=[*tags.shape, num_tags], dtype=torch.bool)
|
54 |
+
ret.scatter_(2, tags.unsqueeze(2), tags.new_ones([*tags.shape, 1], dtype=torch.bool))
|
55 |
+
return ret
|
56 |
+
|
57 |
+
|
58 |
+
def numpy2torch(
|
59 |
+
dict_obj: dict
|
60 |
+
) -> dict:
|
61 |
+
"""
|
62 |
+
Convert list/np.ndarray data to torch.Tensor and add add a batch dim.
|
63 |
+
"""
|
64 |
+
ret = dict()
|
65 |
+
for k, v in dict_obj.items():
|
66 |
+
if isinstance(v, list) or isinstance(v, np.ndarray):
|
67 |
+
ret[k] = torch.tensor(v).unsqueeze(0)
|
68 |
+
else:
|
69 |
+
ret[k] = v
|
70 |
+
return ret
|
71 |
+
|
72 |
+
|
73 |
+
def max_match(mat: np.ndarray):
|
74 |
+
row_idx, col_idx = linear_sum_assignment(mat, True)
|
75 |
+
return mat[row_idx, col_idx].sum()
|
sftp/utils/label_smoothing.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import KLDivLoss
|
4 |
+
from torch.nn import LogSoftmax
|
5 |
+
|
6 |
+
|
7 |
+
class LabelSmoothingLoss(nn.Module):
|
8 |
+
def __init__(self, label_smoothing=0.0, unreliable_label=None, ignore_index=-100):
|
9 |
+
"""
|
10 |
+
If label_smoothing == 0.0, it is equivalent to xentropy
|
11 |
+
"""
|
12 |
+
assert 0.0 <= label_smoothing <= 1.0
|
13 |
+
super(LabelSmoothingLoss, self).__init__()
|
14 |
+
|
15 |
+
self.ignore_index = ignore_index
|
16 |
+
self.label_smoothing = label_smoothing
|
17 |
+
|
18 |
+
self.loss_fn = KLDivLoss(reduction='batchmean')
|
19 |
+
self.unreliable_label = unreliable_label
|
20 |
+
self.max_gap = 100.
|
21 |
+
self.log_softmax = LogSoftmax(1)
|
22 |
+
|
23 |
+
def forward(self, output, target):
|
24 |
+
"""
|
25 |
+
output: logits
|
26 |
+
target: labels
|
27 |
+
"""
|
28 |
+
vocab_size = output.shape[1]
|
29 |
+
mask = (target != self.ignore_index)
|
30 |
+
output, target = output[mask], target[mask]
|
31 |
+
output = self.log_softmax(output)
|
32 |
+
|
33 |
+
def get_smooth_prob(ls):
|
34 |
+
smoothing_value = ls / (vocab_size - 1)
|
35 |
+
prob = output.new_full((target.size(0), vocab_size), smoothing_value)
|
36 |
+
prob.scatter_(1, target.unsqueeze(1), 1 - ls)
|
37 |
+
return prob
|
38 |
+
|
39 |
+
if self.unreliable_label is not None:
|
40 |
+
smoothed_prob = get_smooth_prob(self.label_smoothing)
|
41 |
+
hard_prob = get_smooth_prob(0.0)
|
42 |
+
unreliable_mask = (target == self.unreliable_label).to(torch.float)
|
43 |
+
model_prob = ((smoothed_prob.T * unreliable_mask) + (hard_prob.T * (1 - unreliable_mask))).T
|
44 |
+
else:
|
45 |
+
model_prob = get_smooth_prob(self.label_smoothing)
|
46 |
+
|
47 |
+
loss = self.loss_fn(output, model_prob)
|
48 |
+
return loss
|
sftp/utils/span.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from .common import VIRTUAL_ROOT, DEFAULT_SPAN
|
6 |
+
from .bio_smoothing import BIOSmoothing
|
7 |
+
from .functions import max_match
|
8 |
+
|
9 |
+
|
10 |
+
class Span:
|
11 |
+
"""
|
12 |
+
Span is a simple data structure for a span (not necessarily associated with text), along with its label,
|
13 |
+
children and possibly its parent and a confidence score.
|
14 |
+
|
15 |
+
Basic usages (suppose span is a Span object):
|
16 |
+
1. len(span) -- #children.
|
17 |
+
2. span[i] -- i-th child.
|
18 |
+
3. for s in span: ... -- iterate its children.
|
19 |
+
4. for s in span.bfs: ... -- iterate its descendents.
|
20 |
+
5. print(span) -- show its description.
|
21 |
+
6. span.tree() -- print the whole tree.
|
22 |
+
|
23 |
+
It provides some utilities:
|
24 |
+
1. Re-indexing. BPE will change token indices, and the `re_index` method can convert normal tokens
|
25 |
+
BPE word piece indices, or vice versa.
|
26 |
+
2. Span object and span dict (JSON format) are mutually convertible (by `to_json` and `from_json` methods).
|
27 |
+
3. Recursively truncate spans up to a given length. (see `truncate` method)
|
28 |
+
4. Recursively replace all labels with the default label. (see `ignore_labels` method)
|
29 |
+
5. Recursively solve the span overlapping problem by removing children overlapped with others.
|
30 |
+
(see `remove_overlapping` method)
|
31 |
+
"""
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
start_idx: int,
|
35 |
+
end_idx: int,
|
36 |
+
label: Union[str, int, list] = DEFAULT_SPAN,
|
37 |
+
is_parent: bool = False,
|
38 |
+
parent: Optional["Span"] = None,
|
39 |
+
confidence: Optional[float] = None,
|
40 |
+
):
|
41 |
+
"""
|
42 |
+
Init function. Children should be added using the `add_children` method.
|
43 |
+
:param start_idx: Start index in a seq of tokens, inclusive.
|
44 |
+
:param end_idx: End index in a seq of tokens, inclusive.
|
45 |
+
:param label: Label. If not provided, will assign a default label.
|
46 |
+
Can be of various types: String, integer, or list of something.
|
47 |
+
:param is_parent: If True, will be treated as parent. This is important because in the training process of BIO
|
48 |
+
tagger, when a span has no children, we need to know if it's a parent with no children (so we should have
|
49 |
+
an training example with all O tags) or not (then the above example doesn't exist).
|
50 |
+
We follow a convention where if a span is not parent, then the key `children` shouldn't appear in its
|
51 |
+
JSON dict; if a span is parent but has no children, the key `children` in its JSON dict should appear
|
52 |
+
and be an empty list.
|
53 |
+
:param parent: A pointer to its parent.
|
54 |
+
:param confidence: Confidence value.
|
55 |
+
"""
|
56 |
+
self.start_idx, self.end_idx = start_idx, end_idx
|
57 |
+
self.label: Union[int, str, list] = label
|
58 |
+
self.is_parent = is_parent
|
59 |
+
self.parent = parent
|
60 |
+
self._children: List[Span] = list()
|
61 |
+
self.confidence = confidence
|
62 |
+
|
63 |
+
# Following are for label smoothing. Leave default is you don't need smoothing.
|
64 |
+
# Logic:
|
65 |
+
# The label smoothing factors of (i.e. b_smooth, i_smooth, o_smooth) depend on the `child_span` of its parent.
|
66 |
+
# The re-weighting factor of a span also depends on the `child_span` of its parent, but can be overridden
|
67 |
+
# by its own `smoothing_weight` field if it's not None.
|
68 |
+
self.child_smooth: BIOSmoothing = BIOSmoothing()
|
69 |
+
self.smooth_weight: Optional[float] = None
|
70 |
+
|
71 |
+
def add_child(self, span: "Span") -> "Span":
|
72 |
+
"""
|
73 |
+
Add a span to children list. Will link current span to child's parent pointer.
|
74 |
+
:param span: Child span.
|
75 |
+
"""
|
76 |
+
assert self.is_parent
|
77 |
+
span.parent = self
|
78 |
+
self._children.append(span)
|
79 |
+
return self
|
80 |
+
|
81 |
+
def re_index(
|
82 |
+
self,
|
83 |
+
offsets: List[Optional[Tuple[int, int]]],
|
84 |
+
reverse: bool = False,
|
85 |
+
recursive: bool = True,
|
86 |
+
inplace: bool = False,
|
87 |
+
) -> "Span":
|
88 |
+
"""
|
89 |
+
BPE will change token indices, and the `re_index` method can convert normal tokens BPE word piece indices,
|
90 |
+
or vice versa.
|
91 |
+
We assume Virtual Root has a boundary [-1, -1] before being mapped to the BPE space, and a boundary [0, 0]
|
92 |
+
after the re-indexing. We use [0, 0] because it's always the BOS token in BPE.
|
93 |
+
Mapping to BPE space is straight forward. The reverse mapping has special cases where the span might
|
94 |
+
contain BOS or EOS. Usually this is a parsing bug. We will map the BOS index to 0, and EOS index to -1.
|
95 |
+
:param offsets: Offsets. Defined by BPE tokenizer and resides in the SpanFinder outputs.
|
96 |
+
:param reverse: If True, map from the BPE space to original token space.
|
97 |
+
:param recursive: If True, will apply the re-indexing to its children.
|
98 |
+
:param inplace: Inplace?
|
99 |
+
:return: Re-indexed span.
|
100 |
+
"""
|
101 |
+
span = self if inplace else self.clone()
|
102 |
+
|
103 |
+
span.start_idx, span.end_idx = re_index_span(span.boundary, offsets, reverse)
|
104 |
+
if recursive:
|
105 |
+
new_children = list()
|
106 |
+
for child in span._children:
|
107 |
+
new_children.append(child.re_index(offsets, reverse, recursive, True))
|
108 |
+
span._children = new_children
|
109 |
+
return span
|
110 |
+
|
111 |
+
def truncate(self, max_length: int) -> bool:
|
112 |
+
"""
|
113 |
+
Discard spans whose end_idx exceeds the max_length (inclusive).
|
114 |
+
This is done recursively.
|
115 |
+
This is useful for some encoder like XLMR that has a limit on input length. (512 for XLMR large)
|
116 |
+
:param max_length: Max length.
|
117 |
+
:return: You don't need to care return value.
|
118 |
+
"""
|
119 |
+
if self.end_idx >= max_length:
|
120 |
+
return False
|
121 |
+
else:
|
122 |
+
self._children = list(filter(lambda x: x.truncate(max_length), self._children))
|
123 |
+
return True
|
124 |
+
|
125 |
+
@classmethod
|
126 |
+
def virtual_root(cls: "Span", spans: Optional[List["Span"]] = None) -> "Span":
|
127 |
+
"""
|
128 |
+
An official method to create a tree: Generate the first layer of spans by yourself, and pass them into this
|
129 |
+
method.
|
130 |
+
E.g., for SRL style task, generate a list of events, assign arguments to them as children. Then pass the
|
131 |
+
events to this method to have a virtual root which serves as a parent of events.
|
132 |
+
:param spans: 1st layer spans.
|
133 |
+
:return: Virtual root.
|
134 |
+
"""
|
135 |
+
vr = Span(-1, -1, VIRTUAL_ROOT, True)
|
136 |
+
if spans is not None:
|
137 |
+
vr._children = spans
|
138 |
+
for child in vr._children:
|
139 |
+
child.parent = vr
|
140 |
+
return vr
|
141 |
+
|
142 |
+
def ignore_labels(self) -> None:
|
143 |
+
"""
|
144 |
+
Remove all labels. Make them placeholders. Inplace.
|
145 |
+
"""
|
146 |
+
self.label = DEFAULT_SPAN
|
147 |
+
for child in self._children:
|
148 |
+
child.ignore_labels()
|
149 |
+
|
150 |
+
def clone(self) -> "Span":
|
151 |
+
"""
|
152 |
+
Clone a tree.
|
153 |
+
:return: Cloned tree.
|
154 |
+
"""
|
155 |
+
span = Span(self.start_idx, self.end_idx, self.label, self.is_parent, self.parent, self.confidence)
|
156 |
+
span.child_smooth, span.smooth_weight = self.child_smooth, self.smooth_weight
|
157 |
+
for child in self._children:
|
158 |
+
span.add_child(child.clone())
|
159 |
+
return span
|
160 |
+
|
161 |
+
def bfs(self) -> Iterable["Span"]:
|
162 |
+
"""
|
163 |
+
Iterate over all descendents with BFS, including self.
|
164 |
+
:return: Spans.
|
165 |
+
"""
|
166 |
+
yield self
|
167 |
+
yield from self._bfs()
|
168 |
+
|
169 |
+
def _bfs(self) -> List["Span"]:
|
170 |
+
"""
|
171 |
+
Helper function.
|
172 |
+
"""
|
173 |
+
for child in self._children:
|
174 |
+
yield child
|
175 |
+
for child in self._children:
|
176 |
+
yield from child._bfs()
|
177 |
+
|
178 |
+
def remove_overlapping(self, recursive=True) -> int:
|
179 |
+
"""
|
180 |
+
Remove overlapped spans. If spans overlap, will pick the first one and discard the others, judged by start_idx.
|
181 |
+
:param recursive: Apply to all of the descendents?
|
182 |
+
:return: The number of spans that are removed.
|
183 |
+
"""
|
184 |
+
indices = set()
|
185 |
+
new_children = list()
|
186 |
+
removing = 0
|
187 |
+
for child in self._children:
|
188 |
+
if len(set(range(child.start_idx, child.end_idx + 1)) & indices) > 0:
|
189 |
+
removing += 1
|
190 |
+
continue
|
191 |
+
indices.update(set(range(child.start_idx, child.end_idx + 1)))
|
192 |
+
new_children.append(child)
|
193 |
+
if recursive:
|
194 |
+
removing += child.remove_overlapping(True)
|
195 |
+
self._children = new_children
|
196 |
+
return removing
|
197 |
+
|
198 |
+
def describe(self, sentence: Optional[List[str]] = None) -> str:
|
199 |
+
"""
|
200 |
+
:param sentence: If provided, will replace the indices with real tokens for presentation.
|
201 |
+
:return: The description in a single line.
|
202 |
+
"""
|
203 |
+
if self.start_idx >= 0:
|
204 |
+
if sentence is None:
|
205 |
+
span = f'({self.start_idx}, {self.end_idx})'
|
206 |
+
else:
|
207 |
+
span = '(' + ' '.join(sentence[self.start_idx: self.end_idx + 1]) + ')'
|
208 |
+
if self.is_parent:
|
209 |
+
return f'<Span: {span}, {self.label}, {len(self._children)} children>'
|
210 |
+
else:
|
211 |
+
return f'[Span: {span}, {self.label}]'
|
212 |
+
else:
|
213 |
+
return f'<Span Annotation: {self.n_nodes - 1} descendents>'
|
214 |
+
|
215 |
+
def __repr__(self) -> str:
|
216 |
+
return self.describe()
|
217 |
+
|
218 |
+
@property
|
219 |
+
def n_nodes(self) -> int:
|
220 |
+
"""
|
221 |
+
:return: Number of descendents + self.
|
222 |
+
"""
|
223 |
+
return sum([child.n_nodes for child in self._children], 1)
|
224 |
+
|
225 |
+
@property
|
226 |
+
def boundary(self):
|
227 |
+
"""
|
228 |
+
:return: (start_idx, end_idx), both inclusive.
|
229 |
+
"""
|
230 |
+
return self.start_idx, self.end_idx
|
231 |
+
|
232 |
+
def __iter__(self) -> Iterable["Span"]:
|
233 |
+
"""
|
234 |
+
Iterate over children.
|
235 |
+
"""
|
236 |
+
yield from self._children
|
237 |
+
|
238 |
+
def __len__(self):
|
239 |
+
"""
|
240 |
+
:return: #children.
|
241 |
+
"""
|
242 |
+
return len(self._children)
|
243 |
+
|
244 |
+
def __getitem__(self, idx: int):
|
245 |
+
"""
|
246 |
+
:return: The indexed child.
|
247 |
+
"""
|
248 |
+
return self._children[idx]
|
249 |
+
|
250 |
+
def tree(self, sentence: Optional[List[str]] = None, printing: bool = True) -> str:
|
251 |
+
"""
|
252 |
+
A tree description of all descendents. Human readable.
|
253 |
+
:param sentence: If provided, will replace the indices with real tokens for presentation.
|
254 |
+
:param printing: If True, will print out.
|
255 |
+
:return: The description.
|
256 |
+
"""
|
257 |
+
ret = list()
|
258 |
+
ret.append(self.describe(sentence))
|
259 |
+
for child in self._children:
|
260 |
+
child_lines = child.tree(sentence, False).split('\n')
|
261 |
+
for line in child_lines:
|
262 |
+
ret.append(' ' + line)
|
263 |
+
desc = '\n'.join(ret)
|
264 |
+
if printing: print(desc)
|
265 |
+
else: return desc
|
266 |
+
|
267 |
+
def match(
|
268 |
+
self,
|
269 |
+
other: "Span",
|
270 |
+
match_label: bool = True,
|
271 |
+
depth: int = -1,
|
272 |
+
ignore_parent_boundary: bool = False,
|
273 |
+
) -> int:
|
274 |
+
"""
|
275 |
+
Used for evaluation. Count how many spans two trees share. Two spans are considered to be identical
|
276 |
+
if their boundary, label, and parent match.
|
277 |
+
:param other: The other tree to compare.
|
278 |
+
:param match_label: If False, will ignore label.
|
279 |
+
:param depth: If specified as non-negative, will only search thru certain depth.
|
280 |
+
:param ignore_parent_boundary: If True, two children can be matched ignoring parent boundaries.
|
281 |
+
:return: #spans two tree share.
|
282 |
+
"""
|
283 |
+
if depth == 0:
|
284 |
+
return 0
|
285 |
+
if self.label != other.label and match_label:
|
286 |
+
return 0
|
287 |
+
if self.boundary == other.boundary:
|
288 |
+
n_match = 1
|
289 |
+
elif ignore_parent_boundary:
|
290 |
+
# Parents fail, Children might match!
|
291 |
+
n_match = 0
|
292 |
+
else:
|
293 |
+
return 0
|
294 |
+
|
295 |
+
sub_matches = np.zeros([len(self), len(other)], dtype=np.int)
|
296 |
+
for self_idx, my_child in enumerate(self):
|
297 |
+
for other_idx, other_child in enumerate(other):
|
298 |
+
sub_matches[self_idx, other_idx] = my_child.match(
|
299 |
+
other_child, match_label, depth-1, ignore_parent_boundary
|
300 |
+
)
|
301 |
+
if not ignore_parent_boundary:
|
302 |
+
for m in [sub_matches, sub_matches.T]:
|
303 |
+
for line in m:
|
304 |
+
assert (line > 0).sum() <= 1
|
305 |
+
n_match += max_match(sub_matches)
|
306 |
+
return n_match
|
307 |
+
|
308 |
+
def to_json(self) -> dict:
|
309 |
+
"""
|
310 |
+
To JSON dict format. See init.
|
311 |
+
"""
|
312 |
+
ret = {
|
313 |
+
"label": self.label,
|
314 |
+
"span": list(self.boundary),
|
315 |
+
}
|
316 |
+
if self.confidence is not None:
|
317 |
+
ret['confidence'] = self.confidence
|
318 |
+
if self.is_parent:
|
319 |
+
children = list()
|
320 |
+
for child in self._children:
|
321 |
+
children.append(child.to_json())
|
322 |
+
ret['children'] = children
|
323 |
+
return ret
|
324 |
+
|
325 |
+
@classmethod
|
326 |
+
def from_json(cls, span_json: Union[list, dict]) -> "Span":
|
327 |
+
"""
|
328 |
+
Load from JSON. See init.
|
329 |
+
"""
|
330 |
+
if isinstance(span_json, dict):
|
331 |
+
span = Span(
|
332 |
+
span_json['span'][0], span_json['span'][1], span_json.get('label', None), 'children' in span_json,
|
333 |
+
confidence=span_json.get('confidence', None)
|
334 |
+
)
|
335 |
+
for child_dict in span_json.get('children', []):
|
336 |
+
span.add_child(Span.from_json(child_dict))
|
337 |
+
else:
|
338 |
+
spans = [Span.from_json(child) for child in span_json]
|
339 |
+
span = Span.virtual_root(spans)
|
340 |
+
return span
|
341 |
+
|
342 |
+
def map_ontology(
|
343 |
+
self,
|
344 |
+
ontology_mapping: Optional[dict] = None,
|
345 |
+
inplace: bool = True,
|
346 |
+
recursive: bool = True,
|
347 |
+
) -> Optional["Span"]:
|
348 |
+
"""
|
349 |
+
Map labels to other things, like another ontology of soft labels.
|
350 |
+
:param ontology_mapping: Mapping dict. The key should be labels, and values can be anything.
|
351 |
+
Labels not in the dict will not be deleted. So be careful.
|
352 |
+
:param inplace: Inplace?
|
353 |
+
:param recursive: Apply to all descendents if True.
|
354 |
+
:return: The mapped tree.
|
355 |
+
"""
|
356 |
+
span = self if inplace else self.clone()
|
357 |
+
if ontology_mapping is None:
|
358 |
+
# Do nothing if mapping not provided.
|
359 |
+
return span
|
360 |
+
|
361 |
+
if recursive:
|
362 |
+
new_children = list()
|
363 |
+
for child in span:
|
364 |
+
new_child = child.map_ontology(ontology_mapping, False, True)
|
365 |
+
if new_child is not None:
|
366 |
+
new_children.append(new_child)
|
367 |
+
span._children = new_children
|
368 |
+
|
369 |
+
if span.label != VIRTUAL_ROOT:
|
370 |
+
if span.parent is not None and (span.parent.label, span.label) in ontology_mapping:
|
371 |
+
span.label = ontology_mapping[(span.parent.label, span.label)]
|
372 |
+
elif span.label in ontology_mapping:
|
373 |
+
span.label = ontology_mapping[span.label]
|
374 |
+
else:
|
375 |
+
return
|
376 |
+
|
377 |
+
return span
|
378 |
+
|
379 |
+
def isolate(self) -> "Span":
|
380 |
+
"""
|
381 |
+
Generate a span that is identical to self but has no children or parent.
|
382 |
+
"""
|
383 |
+
return Span(self.start_idx, self.end_idx, self.label, self.is_parent, None, self.confidence)
|
384 |
+
|
385 |
+
def remove_child(self, span: Optional["Span"] = None):
|
386 |
+
"""
|
387 |
+
Remove a child. If pass None, will reset the children list.
|
388 |
+
"""
|
389 |
+
if span is None:
|
390 |
+
self._children = list()
|
391 |
+
else:
|
392 |
+
del self._children[self._children.index(span)]
|
393 |
+
|
394 |
+
|
395 |
+
def re_index_span(
|
396 |
+
boundary: Tuple[int, int], offsets: List[Tuple[int, int]], reverse: bool = False
|
397 |
+
) -> Tuple[int, int]:
|
398 |
+
"""
|
399 |
+
Helper function.
|
400 |
+
"""
|
401 |
+
if not reverse:
|
402 |
+
if boundary[0] == boundary[1] == -1:
|
403 |
+
# Virtual Root
|
404 |
+
start_idx = end_idx = 0
|
405 |
+
else:
|
406 |
+
start_idx = offsets[boundary[0]][0]
|
407 |
+
end_idx = offsets[boundary[1]][1]
|
408 |
+
else:
|
409 |
+
if boundary[0] == boundary[1] == 0:
|
410 |
+
# Virtual Root
|
411 |
+
start_idx = end_idx = -1
|
412 |
+
else:
|
413 |
+
start_within = [bo[0] <= boundary[0] <= bo[1] if bo is not None else False for bo in offsets]
|
414 |
+
end_within = [bo[0] <= boundary[1] <= bo[1] if bo is not None else False for bo in offsets]
|
415 |
+
assert sum(start_within) <= 1 and sum(end_within) <= 1
|
416 |
+
start_idx = start_within.index(True) if sum(start_within) == 1 else 0
|
417 |
+
end_idx = end_within.index(True) if sum(end_within) == 1 else len(offsets)
|
418 |
+
if start_idx > end_idx:
|
419 |
+
raise IndexError
|
420 |
+
return start_idx, end_idx
|
sftp/utils/span_utils.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .span import Span
|
6 |
+
|
7 |
+
|
8 |
+
def _tensor2span_batch(
|
9 |
+
span_boundary: torch.Tensor,
|
10 |
+
span_labels: torch.Tensor,
|
11 |
+
parent_indices: torch.Tensor,
|
12 |
+
num_spans: torch.Tensor,
|
13 |
+
label_confidence: torch.Tensor,
|
14 |
+
idx2label: Dict[int, str],
|
15 |
+
label_ignore: List[int],
|
16 |
+
) -> Span:
|
17 |
+
spans = list()
|
18 |
+
for (start_idx, end_idx), parent_idx, label, label_conf in \
|
19 |
+
list(zip(span_boundary, parent_indices, span_labels, label_confidence))[:int(num_spans)]:
|
20 |
+
if label not in label_ignore:
|
21 |
+
span = Span(int(start_idx), int(end_idx), idx2label[int(label)], True, confidence=float(label_conf))
|
22 |
+
if int(parent_idx) < len(spans):
|
23 |
+
spans[int(parent_idx)].add_child(span)
|
24 |
+
spans.append(span)
|
25 |
+
return spans[0]
|
26 |
+
|
27 |
+
|
28 |
+
def tensor2span(
|
29 |
+
span_boundary: torch.Tensor,
|
30 |
+
span_labels: torch.Tensor,
|
31 |
+
parent_indices: torch.Tensor,
|
32 |
+
num_spans: torch.Tensor,
|
33 |
+
label_confidence: torch.Tensor,
|
34 |
+
idx2label: Dict[int, str],
|
35 |
+
label_ignore: Optional[List[int]] = None,
|
36 |
+
) -> List[Span]:
|
37 |
+
"""
|
38 |
+
Generate spans in dict from vectors. Refer to the model part for the meaning of these variables.
|
39 |
+
If idx_ignore is provided, some labels will be ignored.
|
40 |
+
:return:
|
41 |
+
"""
|
42 |
+
label_ignore = label_ignore or []
|
43 |
+
if span_boundary.device.type != 'cpu':
|
44 |
+
span_boundary = span_boundary.to(device='cpu')
|
45 |
+
parent_indices = parent_indices.to(device='cpu')
|
46 |
+
span_labels = span_labels.to(device='cpu')
|
47 |
+
num_spans = num_spans.to(device='cpu')
|
48 |
+
label_confidence = label_confidence.to(device='cpu')
|
49 |
+
|
50 |
+
ret = list()
|
51 |
+
for args in zip(
|
52 |
+
span_boundary.unbind(0), span_labels.unbind(0), parent_indices.unbind(0), num_spans.unbind(0),
|
53 |
+
label_confidence.unbind(0),
|
54 |
+
):
|
55 |
+
ret.append(_tensor2span_batch(*args, label_ignore=label_ignore, idx2label=idx2label))
|
56 |
+
|
57 |
+
return ret
|
sociolome/combine_models.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional
|
2 |
+
import dataclasses
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import json
|
7 |
+
|
8 |
+
import spacy
|
9 |
+
from spacy.language import Language
|
10 |
+
|
11 |
+
from sftp import SpanPredictor
|
12 |
+
|
13 |
+
|
14 |
+
@dataclasses.dataclass
|
15 |
+
class FrameAnnotation:
|
16 |
+
tokens: List[str] = dataclasses.field(default_factory=list)
|
17 |
+
pos: List[str] = dataclasses.field(default_factory=list)
|
18 |
+
|
19 |
+
|
20 |
+
@dataclasses.dataclass
|
21 |
+
class MultiLabelAnnotation(FrameAnnotation):
|
22 |
+
frame_list: List[List[str]] = dataclasses.field(default_factory=list)
|
23 |
+
lu_list: List[Optional[str]] = dataclasses.field(default_factory=list)
|
24 |
+
|
25 |
+
def to_txt(self):
|
26 |
+
for i, tok in enumerate(self.tokens):
|
27 |
+
yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}"
|
28 |
+
|
29 |
+
|
30 |
+
def convert_to_seq_labels(sentence: List[str], structures: Dict[int, Dict[str, Any]]) -> List[List[str]]:
|
31 |
+
labels = [[] for _ in sentence]
|
32 |
+
|
33 |
+
for struct_id, struct in structures.items():
|
34 |
+
tgt_span = struct["target"]
|
35 |
+
frame = struct["frame"]
|
36 |
+
|
37 |
+
for i in range(tgt_span[0], tgt_span[1] + 1):
|
38 |
+
labels[i].append(f"T:{frame}@{struct_id:02}")
|
39 |
+
for role in struct["roles"]:
|
40 |
+
role_span = role["boundary"]
|
41 |
+
role_label = role["label"]
|
42 |
+
for i in range(role_span[0], role_span[1] + 1):
|
43 |
+
prefix = "B" if i == role_span[0] else "I"
|
44 |
+
labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}")
|
45 |
+
return labels
|
46 |
+
|
47 |
+
|
48 |
+
def predict_combined(
|
49 |
+
spacy_model: Language,
|
50 |
+
sentences: List[str],
|
51 |
+
tgt_predictor: SpanPredictor,
|
52 |
+
frm_predictor: SpanPredictor,
|
53 |
+
bnd_predictor: SpanPredictor,
|
54 |
+
arg_predictor: SpanPredictor,
|
55 |
+
) -> List[MultiLabelAnnotation]:
|
56 |
+
|
57 |
+
annotations_out = []
|
58 |
+
|
59 |
+
for sent_idx, sent in enumerate(sentences):
|
60 |
+
|
61 |
+
sent = sent.strip()
|
62 |
+
|
63 |
+
print(f"Processing sent with idx={sent_idx}: {sent}")
|
64 |
+
|
65 |
+
doc = spacy_model(sent)
|
66 |
+
sent_tokens = [t.text for t in doc]
|
67 |
+
|
68 |
+
tgt_spans, _, _ = tgt_predictor.force_decode(sent_tokens)
|
69 |
+
|
70 |
+
frame_structures = {}
|
71 |
+
|
72 |
+
for i, span in enumerate(tgt_spans):
|
73 |
+
span = tuple(span)
|
74 |
+
_, fr_labels, _ = frm_predictor.force_decode(sent_tokens, child_spans=[span])
|
75 |
+
frame = fr_labels[0]
|
76 |
+
if frame == "@@VIRTUAL_ROOT@@@":
|
77 |
+
continue
|
78 |
+
|
79 |
+
boundaries, _, _ = bnd_predictor.force_decode(sent_tokens, parent_span=span, parent_label=frame)
|
80 |
+
_, arg_labels, _ = arg_predictor.force_decode(sent_tokens, parent_span=span, parent_label=frame, child_spans=boundaries)
|
81 |
+
|
82 |
+
frame_structures[i] = {
|
83 |
+
"target": span,
|
84 |
+
"frame": frame,
|
85 |
+
"roles": [
|
86 |
+
{"boundary": bnd, "label": label}
|
87 |
+
for bnd, label in zip(boundaries, arg_labels)
|
88 |
+
if label != "Target"
|
89 |
+
]
|
90 |
+
}
|
91 |
+
annotations_out.append(MultiLabelAnnotation(
|
92 |
+
tokens=sent_tokens,
|
93 |
+
pos=[t.pos_ for t in doc],
|
94 |
+
frame_list=convert_to_seq_labels(sent_tokens, frame_structures),
|
95 |
+
lu_list=[None for _ in sent_tokens]
|
96 |
+
))
|
97 |
+
return annotations_out
|
98 |
+
|
99 |
+
|
100 |
+
def main(input_folder):
|
101 |
+
|
102 |
+
print("Loading spaCy model ...")
|
103 |
+
nlp = spacy.load("it_core_news_md")
|
104 |
+
|
105 |
+
print("Loading predictors ...")
|
106 |
+
zs_predictor = SpanPredictor.from_path("/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz", cuda_device=0)
|
107 |
+
ev_predictor = SpanPredictor.from_path("/scratch/p289731/lome-training-files/train-evalita-plus-fn-vanilla/model.tar.gz", cuda_device=0)
|
108 |
+
|
109 |
+
|
110 |
+
print("Reading input files ...")
|
111 |
+
for file in glob.glob(os.path.join(input_folder, "*.txt")):
|
112 |
+
print(file)
|
113 |
+
with open(file, encoding="utf-8") as f:
|
114 |
+
sentences = list(f)
|
115 |
+
|
116 |
+
annotations = predict_combined(nlp, sentences, zs_predictor, ev_predictor, ev_predictor, ev_predictor)
|
117 |
+
|
118 |
+
out_name = os.path.splitext(os.path.basename(file))[0]
|
119 |
+
with open(f"../../data-out/{out_name}.combined_zs_ev.tc_bilstm.txt", "w", encoding="utf-8") as f_out:
|
120 |
+
for ann in annotations:
|
121 |
+
for line in ann.to_txt():
|
122 |
+
f_out.write(line + os.linesep)
|
123 |
+
f_out.write(os.linesep)
|
124 |
+
|
125 |
+
with open(f"../../data-out/{out_name}.combined_zs_ev.tc_bilstm.json", "w", encoding="utf-8") as f_out:
|
126 |
+
json.dump([dataclasses.asdict(ann) for ann in annotations], f_out)
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
main(sys.argv[1])
|
sociolome/evalita_eval.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
from sftp import SpanPredictor
|
7 |
+
|
8 |
+
|
9 |
+
def main():
|
10 |
+
# data_file = "/home/p289731/cloned/lome/preproc/evalita_jsonl/evalita_dev.jsonl"
|
11 |
+
# data_file = "/home/p289731/cloned/lome/preproc/svm_challenge.jsonl"
|
12 |
+
data_file = "/home/p289731/cloned/lome/preproc/evalita_jsonl/evalita_test.jsonl"
|
13 |
+
models = [
|
14 |
+
(
|
15 |
+
"lome-en",
|
16 |
+
"/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz",
|
17 |
+
),
|
18 |
+
(
|
19 |
+
"lome-it-best",
|
20 |
+
"/scratch/p289731/lome-training-files/train-evalita-plus-fn-vanilla/model.tar.gz",
|
21 |
+
),
|
22 |
+
# (
|
23 |
+
# "lome-it-freeze",
|
24 |
+
# "/data/p289731/cloned/lome/train-evalita-plus-fn-freeze/model.tar.gz",
|
25 |
+
# ),
|
26 |
+
# (
|
27 |
+
# "lome-it-mono",
|
28 |
+
# "/data/p289731/cloned/lome/train-evalita-it_mono/model.tar.gz",
|
29 |
+
# ),
|
30 |
+
]
|
31 |
+
|
32 |
+
for (model_name, model_path) in models:
|
33 |
+
print("testing model: ", model_name)
|
34 |
+
predictor = SpanPredictor.from_path(model_path)
|
35 |
+
|
36 |
+
print("=== FD (run 1) ===")
|
37 |
+
eval_frame_detection(data_file, predictor, model_name=model_name)
|
38 |
+
|
39 |
+
for run in [1, 2]:
|
40 |
+
print(f"=== BD (run {run}) ===")
|
41 |
+
eval_boundary_detection(data_file, predictor, run=run)
|
42 |
+
|
43 |
+
for run in [1, 2, 3]:
|
44 |
+
print(f"=== AC (run {run}) ===")
|
45 |
+
eval_argument_classification(data_file, predictor, run=run)
|
46 |
+
|
47 |
+
|
48 |
+
def predict_frame(
|
49 |
+
predictor: SpanPredictor, tokens: List[str], predicate_span: Tuple[int, int]
|
50 |
+
):
|
51 |
+
_, labels, _ = predictor.force_decode(tokens, child_spans=[predicate_span])
|
52 |
+
return labels[0]
|
53 |
+
|
54 |
+
|
55 |
+
def eval_frame_detection(data_file, predictor, verbose=False, model_name="_"):
|
56 |
+
|
57 |
+
true_pos = 0
|
58 |
+
false_pos = 0
|
59 |
+
|
60 |
+
out = []
|
61 |
+
|
62 |
+
with open(data_file, encoding="utf-8") as f:
|
63 |
+
for sent_id, sent in enumerate(f):
|
64 |
+
sent_data = json.loads(sent)
|
65 |
+
|
66 |
+
tokens = sent_data["tokens"]
|
67 |
+
annotation = sent_data["annotations"][0]
|
68 |
+
|
69 |
+
predicate_span = tuple(annotation["span"])
|
70 |
+
predicate = tokens[predicate_span[0] : predicate_span[1] + 1]
|
71 |
+
|
72 |
+
frame_gold = annotation["label"]
|
73 |
+
frame_pred = predict_frame(predictor, tokens, predicate_span)
|
74 |
+
|
75 |
+
if frame_pred == frame_gold:
|
76 |
+
true_pos += 1
|
77 |
+
else:
|
78 |
+
false_pos += 1
|
79 |
+
|
80 |
+
out.append({
|
81 |
+
"sentence": " ".join(tokens),
|
82 |
+
"predicate": predicate,
|
83 |
+
"frame_gold": frame_gold,
|
84 |
+
"frame_pred": frame_pred
|
85 |
+
})
|
86 |
+
|
87 |
+
if verbose:
|
88 |
+
print(f"Sentence #{sent_id:03}: {' '.join(tokens)}")
|
89 |
+
print(f"\tpredicate: {predicate}")
|
90 |
+
print(f"\t gold: {frame_gold}")
|
91 |
+
print(f"\tpredicted: {frame_pred}")
|
92 |
+
print()
|
93 |
+
|
94 |
+
acc_score = true_pos / (true_pos + false_pos)
|
95 |
+
print("ACC =", acc_score)
|
96 |
+
|
97 |
+
data_sect = "rai" if "svm_challenge" in data_file else "dev" if "dev" in data_file else "test"
|
98 |
+
|
99 |
+
df_out = pd.DataFrame(out)
|
100 |
+
df_out.to_csv(f"frame_prediction_output_{model_name}_{data_sect}.csv")
|
101 |
+
|
102 |
+
|
103 |
+
def predict_boundaries(predictor: SpanPredictor, tokens, predicate_span, frame):
|
104 |
+
boundaries, labels, _ = predictor.force_decode(
|
105 |
+
tokens, parent_span=predicate_span, parent_label=frame
|
106 |
+
)
|
107 |
+
out = []
|
108 |
+
for bnd, lab in zip(boundaries, labels):
|
109 |
+
bnd = tuple(bnd)
|
110 |
+
if bnd == predicate_span and lab == "Target":
|
111 |
+
continue
|
112 |
+
out.append(bnd)
|
113 |
+
return out
|
114 |
+
|
115 |
+
|
116 |
+
def get_gold_boundaries(annotation, predicate_span):
|
117 |
+
return {
|
118 |
+
tuple(c["span"])
|
119 |
+
for c in annotation["children"]
|
120 |
+
if not (tuple(c["span"]) == predicate_span and c["label"] == "Target")
|
121 |
+
}
|
122 |
+
|
123 |
+
|
124 |
+
def eval_boundary_detection(data_file, predictor, run=1, verbose=False):
|
125 |
+
|
126 |
+
assert run in [1, 2]
|
127 |
+
|
128 |
+
true_pos = 0
|
129 |
+
false_pos = 0
|
130 |
+
false_neg = 0
|
131 |
+
|
132 |
+
true_pos_tok = 0
|
133 |
+
false_pos_tok = 0
|
134 |
+
false_neg_tok = 0
|
135 |
+
|
136 |
+
with open(data_file, encoding="utf-8") as f:
|
137 |
+
for sent_id, sent in enumerate(f):
|
138 |
+
sent_data = json.loads(sent)
|
139 |
+
|
140 |
+
tokens = sent_data["tokens"]
|
141 |
+
annotation = sent_data["annotations"][0]
|
142 |
+
|
143 |
+
predicate_span = tuple(annotation["span"])
|
144 |
+
predicate = tokens[predicate_span[0] : predicate_span[1] + 1]
|
145 |
+
|
146 |
+
if run == 1:
|
147 |
+
frame = predict_frame(predictor, tokens, predicate_span)
|
148 |
+
else:
|
149 |
+
frame = annotation["label"]
|
150 |
+
|
151 |
+
boundaries_gold = get_gold_boundaries(annotation, predicate_span)
|
152 |
+
boundaries_pred = set(
|
153 |
+
predict_boundaries(predictor, tokens, predicate_span, frame)
|
154 |
+
)
|
155 |
+
|
156 |
+
sent_true_pos = len(boundaries_gold & boundaries_pred)
|
157 |
+
sent_false_pos = len(boundaries_pred - boundaries_gold)
|
158 |
+
sent_false_neg = len(boundaries_gold - boundaries_pred)
|
159 |
+
true_pos += sent_true_pos
|
160 |
+
false_pos += sent_false_pos
|
161 |
+
false_neg += sent_false_neg
|
162 |
+
|
163 |
+
boundary_toks_gold = {
|
164 |
+
tok_idx
|
165 |
+
for (start, stop) in boundaries_gold
|
166 |
+
for tok_idx in range(start, stop + 1)
|
167 |
+
}
|
168 |
+
boundary_toks_pred = {
|
169 |
+
tok_idx
|
170 |
+
for (start, stop) in boundaries_pred
|
171 |
+
for tok_idx in range(start, stop + 1)
|
172 |
+
}
|
173 |
+
sent_tok_true_pos = len(boundary_toks_gold & boundary_toks_pred)
|
174 |
+
sent_tok_false_pos = len(boundary_toks_pred - boundary_toks_gold)
|
175 |
+
sent_tok_false_neg = len(boundary_toks_gold - boundary_toks_pred)
|
176 |
+
true_pos_tok += sent_tok_true_pos
|
177 |
+
false_pos_tok += sent_tok_false_pos
|
178 |
+
false_neg_tok += sent_tok_false_neg
|
179 |
+
|
180 |
+
if verbose:
|
181 |
+
print(f"Sentence #{sent_id:03}: {' '.join(tokens)}")
|
182 |
+
print(f"\tpredicate: {predicate}")
|
183 |
+
print(f"\t frame: {frame}")
|
184 |
+
print(f"\t gold: {boundaries_gold}")
|
185 |
+
print(f"\tpredicted: {boundaries_pred}")
|
186 |
+
print(f"\ttp={sent_true_pos}\tfp={sent_false_pos}\tfn={sent_false_neg}")
|
187 |
+
print(
|
188 |
+
f"\ttp_t={sent_tok_true_pos}\tfp_t={sent_tok_false_pos}\tfn_t={sent_tok_false_neg}"
|
189 |
+
)
|
190 |
+
print()
|
191 |
+
|
192 |
+
prec = true_pos / (true_pos + false_pos)
|
193 |
+
rec = true_pos / (true_pos + false_neg)
|
194 |
+
f1_score = 2 * ((prec * rec) / (prec + rec))
|
195 |
+
|
196 |
+
print(f"P/R/F=\n{prec}\t{rec}\t{f1_score}")
|
197 |
+
|
198 |
+
tok_prec = true_pos_tok / (true_pos_tok + false_pos_tok)
|
199 |
+
tok_rec = true_pos_tok / (true_pos_tok + false_neg_tok)
|
200 |
+
tok_f1 = 2 * ((tok_prec * tok_rec) / (tok_prec + tok_rec))
|
201 |
+
|
202 |
+
print(f"Pt/Rt/Ft=\n{tok_prec}\t{tok_rec}\t{tok_f1}")
|
203 |
+
|
204 |
+
|
205 |
+
def predict_arguments(
|
206 |
+
predictor: SpanPredictor, tokens, predicate_span, frame, boundaries
|
207 |
+
):
|
208 |
+
boundaries = list(sorted(boundaries, key=lambda t: t[0]))
|
209 |
+
_, labels, _ = predictor.force_decode(
|
210 |
+
tokens, parent_span=predicate_span, parent_label=frame, child_spans=boundaries
|
211 |
+
)
|
212 |
+
out = []
|
213 |
+
for bnd, lab in zip(boundaries, labels):
|
214 |
+
if bnd == predicate_span and lab == "Target":
|
215 |
+
continue
|
216 |
+
out.append((bnd, lab))
|
217 |
+
return out
|
218 |
+
|
219 |
+
|
220 |
+
def eval_argument_classification(data_file, predictor, run=1, verbose=False):
|
221 |
+
assert run in [1, 2, 3]
|
222 |
+
|
223 |
+
true_pos = 0
|
224 |
+
false_pos = 0
|
225 |
+
false_neg = 0
|
226 |
+
|
227 |
+
true_pos_tok = 0
|
228 |
+
false_pos_tok = 0
|
229 |
+
false_neg_tok = 0
|
230 |
+
|
231 |
+
with open(data_file, encoding="utf-8") as f:
|
232 |
+
for sent_id, sent in enumerate(f):
|
233 |
+
sent_data = json.loads(sent)
|
234 |
+
|
235 |
+
tokens = sent_data["tokens"]
|
236 |
+
annotation = sent_data["annotations"][0]
|
237 |
+
|
238 |
+
predicate_span = tuple(annotation["span"])
|
239 |
+
predicate = tokens[predicate_span[0] : predicate_span[1] + 1]
|
240 |
+
|
241 |
+
# gold or predicted frames?
|
242 |
+
if run == 1:
|
243 |
+
frame = predict_frame(predictor, tokens, predicate_span)
|
244 |
+
else:
|
245 |
+
frame = annotation["label"]
|
246 |
+
|
247 |
+
# gold or predicted argument boundaries?
|
248 |
+
if run in [1, 2]:
|
249 |
+
boundaries = set(
|
250 |
+
predict_boundaries(predictor, tokens, predicate_span, frame)
|
251 |
+
)
|
252 |
+
else:
|
253 |
+
boundaries = get_gold_boundaries(annotation, predicate_span)
|
254 |
+
|
255 |
+
pred_arguments = predict_arguments(
|
256 |
+
predictor, tokens, predicate_span, frame, boundaries
|
257 |
+
)
|
258 |
+
gold_arguments = {
|
259 |
+
(tuple(c["span"]), c["label"])
|
260 |
+
for c in annotation["children"]
|
261 |
+
if not (tuple(c["span"]) == predicate_span and c["label"] == "Target")
|
262 |
+
}
|
263 |
+
|
264 |
+
if verbose:
|
265 |
+
print(f"Sentence #{sent_id:03}: {' '.join(tokens)}")
|
266 |
+
print(f"\tpredicate: {predicate}")
|
267 |
+
print(f"\t frame: {frame}")
|
268 |
+
print(f"\t gold: {gold_arguments}")
|
269 |
+
print(f"\tpredicted: {pred_arguments}")
|
270 |
+
print()
|
271 |
+
|
272 |
+
# -- full spans version
|
273 |
+
for g_bnd, g_label in gold_arguments:
|
274 |
+
# true positive: found the span and labeled it correctly
|
275 |
+
if (g_bnd, g_label) in pred_arguments:
|
276 |
+
true_pos += 1
|
277 |
+
# false negative: missed this argument
|
278 |
+
else:
|
279 |
+
false_neg += 1
|
280 |
+
for p_bnd, p_label in pred_arguments:
|
281 |
+
# all predictions that are not true positives are false positives
|
282 |
+
if (p_bnd, p_label) not in gold_arguments:
|
283 |
+
false_pos += 1
|
284 |
+
|
285 |
+
# -- token based
|
286 |
+
tok_gold_labels = {
|
287 |
+
(token, label)
|
288 |
+
for ((bnd_start, bnd_end), label) in gold_arguments
|
289 |
+
for token in range(bnd_start, bnd_end + 1)
|
290 |
+
}
|
291 |
+
tok_pred_labels = {
|
292 |
+
(token, label)
|
293 |
+
for ((bnd_start, bnd_end), label) in pred_arguments
|
294 |
+
for token in range(bnd_start, bnd_end + 1)
|
295 |
+
}
|
296 |
+
for g_tok, g_tok_label in tok_gold_labels:
|
297 |
+
if (g_tok, g_tok_label) in tok_pred_labels:
|
298 |
+
true_pos_tok += 1
|
299 |
+
else:
|
300 |
+
false_neg_tok += 1
|
301 |
+
for p_tok, p_tok_label in tok_pred_labels:
|
302 |
+
if (p_tok, p_tok_label) not in tok_gold_labels:
|
303 |
+
false_pos_tok += 1
|
304 |
+
|
305 |
+
prec = true_pos / (true_pos + false_pos)
|
306 |
+
rec = true_pos / (true_pos + false_neg)
|
307 |
+
f1_score = 2 * ((prec * rec) / (prec + rec))
|
308 |
+
|
309 |
+
print(f"P/R/F=\n{prec}\t{rec}\t{f1_score}")
|
310 |
+
|
311 |
+
tok_prec = true_pos_tok / (true_pos_tok + false_pos_tok)
|
312 |
+
tok_rec = true_pos_tok / (true_pos_tok + false_neg_tok)
|
313 |
+
tok_f1 = 2 * ((tok_prec * tok_rec) / (tok_prec + tok_rec))
|
314 |
+
|
315 |
+
print(f"Pt/Rt/Ft=\n{tok_prec}\t{tok_rec}\t{tok_f1}")
|
316 |
+
|
317 |
+
|
318 |
+
if __name__ == "__main__":
|
319 |
+
main()
|
sociolome/lome_wrapper.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sftp import SpanPredictor
|
2 |
+
import spacy
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import dataclasses
|
6 |
+
from typing import List, Optional, Dict, Any
|
7 |
+
|
8 |
+
|
9 |
+
predictor = SpanPredictor.from_path("model.mod.tar.gz")
|
10 |
+
nlp = spacy.load("xx_sent_ud_sm")
|
11 |
+
|
12 |
+
|
13 |
+
@dataclasses.dataclass
|
14 |
+
class FrameAnnotation:
|
15 |
+
tokens: List[str] = dataclasses.field(default_factory=list)
|
16 |
+
pos: List[str] = dataclasses.field(default_factory=list)
|
17 |
+
|
18 |
+
|
19 |
+
@dataclasses.dataclass
|
20 |
+
class MultiLabelAnnotation(FrameAnnotation):
|
21 |
+
frame_list: List[List[str]] = dataclasses.field(default_factory=list)
|
22 |
+
lu_list: List[Optional[str]] = dataclasses.field(default_factory=list)
|
23 |
+
|
24 |
+
def to_txt(self):
|
25 |
+
for i, tok in enumerate(self.tokens):
|
26 |
+
yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}"
|
27 |
+
|
28 |
+
|
29 |
+
# reused from "combine_predictions.py" (cloned/lome/src/spanfinder/sociolome)
|
30 |
+
def convert_to_seq_labels(sentence: List[str], structures: Dict[int, Dict[str, Any]]) -> List[List[str]]:
|
31 |
+
labels = [[] for _ in sentence]
|
32 |
+
|
33 |
+
for struct_id, struct in structures.items():
|
34 |
+
tgt_span = struct["target"]
|
35 |
+
frame = struct["frame"]
|
36 |
+
|
37 |
+
for i in range(tgt_span[0], tgt_span[1] + 1):
|
38 |
+
labels[i].append(f"T:{frame}@{struct_id:02}")
|
39 |
+
for role in struct["roles"]:
|
40 |
+
role_span = role["boundary"]
|
41 |
+
role_label = role["label"]
|
42 |
+
for i in range(role_span[0], role_span[1] + 1):
|
43 |
+
prefix = "B" if i == role_span[0] else "I"
|
44 |
+
labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}")
|
45 |
+
return labels
|
46 |
+
|
47 |
+
def make_prediction(sentence, spacy_model, predictor):
|
48 |
+
spacy_doc = spacy_model(sentence)
|
49 |
+
tokens = [t.text for t in spacy_doc]
|
50 |
+
tgt_spans, fr_labels, _ = predictor.force_decode(tokens)
|
51 |
+
|
52 |
+
frame_structures = {}
|
53 |
+
|
54 |
+
for i, (tgt, frm) in enumerate(sorted(zip(tgt_spans, fr_labels), key=lambda t: t[0][0])):
|
55 |
+
arg_spans, arg_labels, _ = predictor.force_decode(tokens, parent_span=tgt, parent_label=frm)
|
56 |
+
|
57 |
+
frame_structures[i] = {
|
58 |
+
"target": tgt,
|
59 |
+
"frame": frm,
|
60 |
+
"roles": [
|
61 |
+
{"boundary": bnd, "label": label}
|
62 |
+
for bnd, label in zip(arg_spans, arg_labels)
|
63 |
+
if label != "Target"
|
64 |
+
]
|
65 |
+
}
|
66 |
+
|
67 |
+
return MultiLabelAnnotation(
|
68 |
+
tokens=tokens,
|
69 |
+
pos=[t.pos_ for t in spacy_doc],
|
70 |
+
frame_list=convert_to_seq_labels(tokens, frame_structures),
|
71 |
+
lu_list=[None for _ in tokens]
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
def analyze(text):
|
76 |
+
analyses = []
|
77 |
+
for sentence in text.split("\n"):
|
78 |
+
analyses.append(make_prediction(sentence, nlp, predictor))
|
79 |
+
|
80 |
+
return {
|
81 |
+
"result": "OK",
|
82 |
+
"analyses": [dataclasses.asdict(an) for an in analyses]
|
83 |
+
}
|