|
|
|
|
|
|
|
|
|
""" |
|
This script computes smatch score between two AMRs. |
|
For detailed description of smatch, see http://www.isi.edu/natural-language/amr/smatch-13.pdf |
|
|
|
""" |
|
|
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
try: |
|
import smatch.amr |
|
except: |
|
import amr |
|
import os |
|
import random |
|
import sys |
|
|
|
|
|
iteration_num = 5 |
|
|
|
|
|
|
|
verbose = False |
|
veryVerbose = False |
|
|
|
|
|
|
|
single_score = True |
|
|
|
|
|
|
|
pr_flag = False |
|
|
|
|
|
ERROR_LOG = sys.stderr |
|
|
|
|
|
DEBUG_LOG = sys.stderr |
|
|
|
|
|
|
|
|
|
match_triple_dict = {} |
|
|
|
|
|
def build_arg_parser(): |
|
""" |
|
Build an argument parser using argparse. Use it when python version is 2.7 or later. |
|
|
|
""" |
|
parser = argparse.ArgumentParser(description="Smatch calculator -- arguments") |
|
parser.add_argument('-f', nargs=2, required=True, type=argparse.FileType('r', encoding="utf-8"), |
|
help='Two files containing AMR pairs. AMRs in each file are separated by a single blank line') |
|
parser.add_argument('-r', type=int, default=4, help='Restart number (Default:4)') |
|
parser.add_argument('--significant', type=int, default=2, help='significant digits to output (default: 2)') |
|
parser.add_argument('-v', action='store_true', help='Verbose output (Default:false)') |
|
parser.add_argument('--vv', action='store_true', help='Very Verbose output (Default:false)') |
|
parser.add_argument('--ms', action='store_true', default=False, |
|
help='Output multiple scores (one AMR pair a score)' |
|
'instead of a single document-level smatch score (Default: false)') |
|
parser.add_argument('--pr', action='store_true', default=False, |
|
help="Output precision and recall as well as the f-score. Default: false") |
|
parser.add_argument('--justinstance', action='store_true', default=False, |
|
help="just pay attention to matching instances") |
|
parser.add_argument('--justattribute', action='store_true', default=False, |
|
help="just pay attention to matching attributes") |
|
parser.add_argument('--justrelation', action='store_true', default=False, |
|
help="just pay attention to matching relations") |
|
|
|
return parser |
|
|
|
|
|
def build_arg_parser2(): |
|
""" |
|
Build an argument parser using optparse. Use it when python version is 2.5 or 2.6. |
|
|
|
""" |
|
usage_str = "Smatch calculator -- arguments" |
|
parser = optparse.OptionParser(usage=usage_str) |
|
parser.add_option("-f", "--files", nargs=2, dest="f", type="string", |
|
help='Two files containing AMR pairs. AMRs in each file are ' \ |
|
'separated by a single blank line. This option is required.') |
|
parser.add_option("-r", "--restart", dest="r", type="int", help='Restart number (Default: 4)') |
|
parser.add_option('--significant', dest="significant", type="int", default=2, |
|
help='significant digits to output (default: 2)') |
|
parser.add_option("-v", "--verbose", action='store_true', dest="v", help='Verbose output (Default:False)') |
|
parser.add_option("--vv", "--veryverbose", action='store_true', dest="vv", |
|
help='Very Verbose output (Default:False)') |
|
parser.add_option("--ms", "--multiple_score", action='store_true', dest="ms", |
|
help='Output multiple scores (one AMR pair a score) instead of ' \ |
|
'a single document-level smatch score (Default: False)') |
|
parser.add_option('--pr', "--precision_recall", action='store_true', dest="pr", |
|
help="Output precision and recall as well as the f-score. Default: false") |
|
parser.add_option('--justinstance', action='store_true', default=False, |
|
help="just pay attention to matching instances") |
|
parser.add_option('--justattribute', action='store_true', default=False, |
|
help="just pay attention to matching attributes") |
|
parser.add_option('--justrelation', action='store_true', default=False, |
|
help="just pay attention to matching relations") |
|
parser.set_defaults(r=4, v=False, ms=False, pr=False) |
|
return parser |
|
|
|
|
|
def get_best_match(instance1, attribute1, relation1, |
|
instance2, attribute2, relation2, |
|
prefix1, prefix2, doinstance=True, doattribute=True, dorelation=True): |
|
""" |
|
Get the highest triple match number between two sets of triples via hill-climbing. |
|
Arguments: |
|
instance1: instance triples of AMR 1 ("instance", node name, node value) |
|
attribute1: attribute triples of AMR 1 (attribute name, node name, attribute value) |
|
relation1: relation triples of AMR 1 (relation name, node 1 name, node 2 name) |
|
instance2: instance triples of AMR 2 ("instance", node name, node value) |
|
attribute2: attribute triples of AMR 2 (attribute name, node name, attribute value) |
|
relation2: relation triples of AMR 2 (relation name, node 1 name, node 2 name) |
|
prefix1: prefix label for AMR 1 |
|
prefix2: prefix label for AMR 2 |
|
Returns: |
|
best_match: the node mapping that results in the highest triple matching number |
|
best_match_num: the highest triple matching number |
|
|
|
""" |
|
|
|
|
|
|
|
(candidate_mappings, weight_dict) = compute_pool(instance1, attribute1, relation1, |
|
instance2, attribute2, relation2, |
|
prefix1, prefix2, doinstance=doinstance, doattribute=doattribute, |
|
dorelation=dorelation) |
|
if veryVerbose: |
|
print("Candidate mappings:", file=DEBUG_LOG) |
|
print(candidate_mappings, file=DEBUG_LOG) |
|
print("Weight dictionary", file=DEBUG_LOG) |
|
print(weight_dict, file=DEBUG_LOG) |
|
|
|
best_match_num = 0 |
|
|
|
|
|
best_mapping = [-1] * len(instance1) |
|
for i in range(iteration_num): |
|
if veryVerbose: |
|
print("Iteration", i, file=DEBUG_LOG) |
|
if i == 0: |
|
|
|
cur_mapping = smart_init_mapping(candidate_mappings, instance1, instance2) |
|
else: |
|
|
|
cur_mapping = random_init_mapping(candidate_mappings) |
|
|
|
match_num = compute_match(cur_mapping, weight_dict) |
|
if veryVerbose: |
|
print("Node mapping at start", cur_mapping, file=DEBUG_LOG) |
|
print("Triple match number at start:", match_num, file=DEBUG_LOG) |
|
while True: |
|
|
|
(gain, new_mapping) = get_best_gain(cur_mapping, candidate_mappings, weight_dict, |
|
len(instance2), match_num) |
|
if veryVerbose: |
|
print("Gain after the hill-climbing", gain, file=DEBUG_LOG) |
|
|
|
if gain <= 0: |
|
break |
|
|
|
match_num += gain |
|
cur_mapping = new_mapping[:] |
|
if veryVerbose: |
|
print("Update triple match number to:", match_num, file=DEBUG_LOG) |
|
print("Current mapping:", cur_mapping, file=DEBUG_LOG) |
|
if match_num > best_match_num: |
|
best_mapping = cur_mapping[:] |
|
best_match_num = match_num |
|
return best_mapping, best_match_num |
|
|
|
|
|
def normalize(item): |
|
""" |
|
lowercase and remove quote signifiers from items that are about to be compared |
|
""" |
|
item = item.rstrip("¦") |
|
return item.lower().rstrip('_') |
|
|
|
|
|
def compute_pool(instance1, attribute1, relation1, |
|
instance2, attribute2, relation2, |
|
prefix1, prefix2, doinstance=True, doattribute=True, dorelation=True): |
|
""" |
|
compute all possible node mapping candidates and their weights (the triple matching number gain resulting from |
|
mapping one node in AMR 1 to another node in AMR2) |
|
|
|
Arguments: |
|
instance1: instance triples of AMR 1 |
|
attribute1: attribute triples of AMR 1 (attribute name, node name, attribute value) |
|
relation1: relation triples of AMR 1 (relation name, node 1 name, node 2 name) |
|
instance2: instance triples of AMR 2 |
|
attribute2: attribute triples of AMR 2 (attribute name, node name, attribute value) |
|
relation2: relation triples of AMR 2 (relation name, node 1 name, node 2 name |
|
prefix1: prefix label for AMR 1 |
|
prefix2: prefix label for AMR 2 |
|
Returns: |
|
candidate_mapping: a list of candidate nodes. |
|
The ith element contains the node indices (in AMR 2) the ith node (in AMR 1) can map to. |
|
(resulting in non-zero triple match) |
|
weight_dict: a dictionary which contains the matching triple number for every pair of node mapping. The key |
|
is a node pair. The value is another dictionary. key {-1} is triple match resulting from this node |
|
pair alone (instance triples and attribute triples), and other keys are node pairs that can result |
|
in relation triple match together with the first node pair. |
|
|
|
|
|
""" |
|
candidate_mapping = [] |
|
weight_dict = {} |
|
for instance1_item in instance1: |
|
|
|
candidate_mapping.append(set()) |
|
if doinstance: |
|
for instance2_item in instance2: |
|
|
|
if normalize(instance1_item[0]) == normalize(instance2_item[0]) and \ |
|
normalize(instance1_item[2]) == normalize(instance2_item[2]): |
|
|
|
node1_index = int(instance1_item[1][len(prefix1):]) |
|
node2_index = int(instance2_item[1][len(prefix2):]) |
|
candidate_mapping[node1_index].add(node2_index) |
|
node_pair = (node1_index, node2_index) |
|
|
|
if node_pair in weight_dict: |
|
weight_dict[node_pair][-1] += 1 |
|
else: |
|
weight_dict[node_pair] = {} |
|
weight_dict[node_pair][-1] = 1 |
|
if doattribute: |
|
for attribute1_item in attribute1: |
|
for attribute2_item in attribute2: |
|
|
|
if normalize(attribute1_item[0]) == normalize(attribute2_item[0]) \ |
|
and normalize(attribute1_item[2]) == normalize(attribute2_item[2]): |
|
node1_index = int(attribute1_item[1][len(prefix1):]) |
|
node2_index = int(attribute2_item[1][len(prefix2):]) |
|
candidate_mapping[node1_index].add(node2_index) |
|
node_pair = (node1_index, node2_index) |
|
|
|
if node_pair in weight_dict: |
|
weight_dict[node_pair][-1] += 1 |
|
else: |
|
weight_dict[node_pair] = {} |
|
weight_dict[node_pair][-1] = 1 |
|
if dorelation: |
|
for relation1_item in relation1: |
|
for relation2_item in relation2: |
|
|
|
if normalize(relation1_item[0]) == normalize(relation2_item[0]): |
|
node1_index_amr1 = int(relation1_item[1][len(prefix1):]) |
|
node1_index_amr2 = int(relation2_item[1][len(prefix2):]) |
|
node2_index_amr1 = int(relation1_item[2][len(prefix1):]) |
|
node2_index_amr2 = int(relation2_item[2][len(prefix2):]) |
|
|
|
candidate_mapping[node1_index_amr1].add(node1_index_amr2) |
|
candidate_mapping[node2_index_amr1].add(node2_index_amr2) |
|
node_pair1 = (node1_index_amr1, node1_index_amr2) |
|
node_pair2 = (node2_index_amr1, node2_index_amr2) |
|
if node_pair2 != node_pair1: |
|
|
|
|
|
|
|
if node1_index_amr1 > node2_index_amr1: |
|
|
|
node_pair1 = (node2_index_amr1, node2_index_amr2) |
|
node_pair2 = (node1_index_amr1, node1_index_amr2) |
|
if node_pair1 in weight_dict: |
|
if node_pair2 in weight_dict[node_pair1]: |
|
weight_dict[node_pair1][node_pair2] += 1 |
|
else: |
|
weight_dict[node_pair1][node_pair2] = 1 |
|
else: |
|
weight_dict[node_pair1] = {-1: 0, node_pair2: 1} |
|
if node_pair2 in weight_dict: |
|
if node_pair1 in weight_dict[node_pair2]: |
|
weight_dict[node_pair2][node_pair1] += 1 |
|
else: |
|
weight_dict[node_pair2][node_pair1] = 1 |
|
else: |
|
weight_dict[node_pair2] = {-1: 0, node_pair1: 1} |
|
else: |
|
|
|
|
|
if node_pair1 in weight_dict: |
|
weight_dict[node_pair1][-1] += 1 |
|
else: |
|
weight_dict[node_pair1] = {-1: 1} |
|
return candidate_mapping, weight_dict |
|
|
|
|
|
def smart_init_mapping(candidate_mapping, instance1, instance2): |
|
""" |
|
Initialize mapping based on the concept mapping (smart initialization) |
|
Arguments: |
|
candidate_mapping: candidate node match list |
|
instance1: instance triples of AMR 1 |
|
instance2: instance triples of AMR 2 |
|
Returns: |
|
initialized node mapping between two AMRs |
|
|
|
""" |
|
random.seed() |
|
matched_dict = {} |
|
result = [] |
|
|
|
no_word_match = [] |
|
for i, candidates in enumerate(candidate_mapping): |
|
if not candidates: |
|
|
|
result.append(-1) |
|
continue |
|
|
|
value1 = instance1[i][2] |
|
for node_index in candidates: |
|
value2 = instance2[node_index][2] |
|
|
|
|
|
if value1 == value2: |
|
if node_index not in matched_dict: |
|
result.append(node_index) |
|
matched_dict[node_index] = 1 |
|
break |
|
if len(result) == i: |
|
no_word_match.append(i) |
|
result.append(-1) |
|
|
|
for i in no_word_match: |
|
candidates = list(candidate_mapping[i]) |
|
while candidates: |
|
|
|
rid = random.randint(0, len(candidates) - 1) |
|
candidate = candidates[rid] |
|
if candidate in matched_dict: |
|
candidates.pop(rid) |
|
else: |
|
matched_dict[candidate] = 1 |
|
result[i] = candidate |
|
break |
|
return result |
|
|
|
|
|
def random_init_mapping(candidate_mapping): |
|
""" |
|
Generate a random node mapping. |
|
Args: |
|
candidate_mapping: candidate_mapping: candidate node match list |
|
Returns: |
|
randomly-generated node mapping between two AMRs |
|
|
|
""" |
|
|
|
random.seed() |
|
matched_dict = {} |
|
result = [] |
|
for c in candidate_mapping: |
|
candidates = list(c) |
|
if not candidates: |
|
|
|
result.append(-1) |
|
continue |
|
found = False |
|
while candidates: |
|
|
|
rid = random.randint(0, len(candidates) - 1) |
|
candidate = candidates[rid] |
|
|
|
if candidate in matched_dict: |
|
candidates.pop(rid) |
|
else: |
|
matched_dict[candidate] = 1 |
|
result.append(candidate) |
|
found = True |
|
break |
|
if not found: |
|
result.append(-1) |
|
return result |
|
|
|
|
|
def compute_match(mapping, weight_dict): |
|
""" |
|
Given a node mapping, compute match number based on weight_dict. |
|
Args: |
|
mappings: a list of node index in AMR 2. The ith element (value j) means node i in AMR 1 maps to node j in AMR 2. |
|
Returns: |
|
matching triple number |
|
Complexity: O(m*n) , m is the node number of AMR 1, n is the node number of AMR 2 |
|
|
|
""" |
|
|
|
if veryVerbose: |
|
print("Computing match for mapping", file=DEBUG_LOG) |
|
print(mapping, file=DEBUG_LOG) |
|
if tuple(mapping) in match_triple_dict: |
|
if veryVerbose: |
|
print("saved value", match_triple_dict[tuple(mapping)], file=DEBUG_LOG) |
|
return match_triple_dict[tuple(mapping)] |
|
match_num = 0 |
|
|
|
for i, m in enumerate(mapping): |
|
if m == -1: |
|
|
|
continue |
|
|
|
current_node_pair = (i, m) |
|
if current_node_pair not in weight_dict: |
|
continue |
|
if veryVerbose: |
|
print("node_pair", current_node_pair, file=DEBUG_LOG) |
|
for key in weight_dict[current_node_pair]: |
|
if key == -1: |
|
|
|
match_num += weight_dict[current_node_pair][key] |
|
if veryVerbose: |
|
print("instance/attribute match", weight_dict[current_node_pair][key], file=DEBUG_LOG) |
|
|
|
|
|
|
|
elif key[0] < i: |
|
continue |
|
elif mapping[key[0]] == key[1]: |
|
match_num += weight_dict[current_node_pair][key] |
|
if veryVerbose: |
|
print("relation match with", key, weight_dict[current_node_pair][key], file=DEBUG_LOG) |
|
if veryVerbose: |
|
print("match computing complete, result:", match_num, file=DEBUG_LOG) |
|
|
|
match_triple_dict[tuple(mapping)] = match_num |
|
return match_num |
|
|
|
|
|
def move_gain(mapping, node_id, old_id, new_id, weight_dict, match_num): |
|
""" |
|
Compute the triple match number gain from the move operation |
|
Arguments: |
|
mapping: current node mapping |
|
node_id: remapped node in AMR 1 |
|
old_id: original node id in AMR 2 to which node_id is mapped |
|
new_id: new node in to which node_id is mapped |
|
weight_dict: weight dictionary |
|
match_num: the original triple matching number |
|
Returns: |
|
the triple match gain number (might be negative) |
|
|
|
""" |
|
|
|
new_mapping = (node_id, new_id) |
|
|
|
old_mapping = (node_id, old_id) |
|
|
|
new_mapping_list = mapping[:] |
|
new_mapping_list[node_id] = new_id |
|
|
|
if tuple(new_mapping_list) in match_triple_dict: |
|
return match_triple_dict[tuple(new_mapping_list)] - match_num |
|
gain = 0 |
|
|
|
if new_mapping in weight_dict: |
|
for key in weight_dict[new_mapping]: |
|
if key == -1: |
|
|
|
gain += weight_dict[new_mapping][-1] |
|
elif new_mapping_list[key[0]] == key[1]: |
|
|
|
gain += weight_dict[new_mapping][key] |
|
|
|
if old_mapping in weight_dict: |
|
for k in weight_dict[old_mapping]: |
|
if k == -1: |
|
gain -= weight_dict[old_mapping][-1] |
|
elif mapping[k[0]] == k[1]: |
|
gain -= weight_dict[old_mapping][k] |
|
|
|
match_triple_dict[tuple(new_mapping_list)] = match_num + gain |
|
return gain |
|
|
|
|
|
def swap_gain(mapping, node_id1, mapping_id1, node_id2, mapping_id2, weight_dict, match_num): |
|
""" |
|
Compute the triple match number gain from the swapping |
|
Arguments: |
|
mapping: current node mapping list |
|
node_id1: node 1 index in AMR 1 |
|
mapping_id1: the node index in AMR 2 node 1 maps to (in the current mapping) |
|
node_id2: node 2 index in AMR 1 |
|
mapping_id2: the node index in AMR 2 node 2 maps to (in the current mapping) |
|
weight_dict: weight dictionary |
|
match_num: the original matching triple number |
|
Returns: |
|
the gain number (might be negative) |
|
|
|
""" |
|
new_mapping_list = mapping[:] |
|
|
|
|
|
new_mapping_list[node_id1] = mapping_id2 |
|
new_mapping_list[node_id2] = mapping_id1 |
|
if tuple(new_mapping_list) in match_triple_dict: |
|
return match_triple_dict[tuple(new_mapping_list)] - match_num |
|
gain = 0 |
|
new_mapping1 = (node_id1, mapping_id2) |
|
new_mapping2 = (node_id2, mapping_id1) |
|
old_mapping1 = (node_id1, mapping_id1) |
|
old_mapping2 = (node_id2, mapping_id2) |
|
if node_id1 > node_id2: |
|
new_mapping2 = (node_id1, mapping_id2) |
|
new_mapping1 = (node_id2, mapping_id1) |
|
old_mapping1 = (node_id2, mapping_id2) |
|
old_mapping2 = (node_id1, mapping_id1) |
|
if new_mapping1 in weight_dict: |
|
for key in weight_dict[new_mapping1]: |
|
if key == -1: |
|
gain += weight_dict[new_mapping1][-1] |
|
elif new_mapping_list[key[0]] == key[1]: |
|
gain += weight_dict[new_mapping1][key] |
|
if new_mapping2 in weight_dict: |
|
for key in weight_dict[new_mapping2]: |
|
if key == -1: |
|
gain += weight_dict[new_mapping2][-1] |
|
|
|
elif key[0] == node_id1: |
|
continue |
|
elif new_mapping_list[key[0]] == key[1]: |
|
gain += weight_dict[new_mapping2][key] |
|
if old_mapping1 in weight_dict: |
|
for key in weight_dict[old_mapping1]: |
|
if key == -1: |
|
gain -= weight_dict[old_mapping1][-1] |
|
elif mapping[key[0]] == key[1]: |
|
gain -= weight_dict[old_mapping1][key] |
|
if old_mapping2 in weight_dict: |
|
for key in weight_dict[old_mapping2]: |
|
if key == -1: |
|
gain -= weight_dict[old_mapping2][-1] |
|
|
|
elif key[0] == node_id1: |
|
continue |
|
elif mapping[key[0]] == key[1]: |
|
gain -= weight_dict[old_mapping2][key] |
|
match_triple_dict[tuple(new_mapping_list)] = match_num + gain |
|
return gain |
|
|
|
|
|
def get_best_gain(mapping, candidate_mappings, weight_dict, instance_len, cur_match_num): |
|
""" |
|
Hill-climbing method to return the best gain swap/move can get |
|
Arguments: |
|
mapping: current node mapping |
|
candidate_mappings: the candidates mapping list |
|
weight_dict: the weight dictionary |
|
instance_len: the number of the nodes in AMR 2 |
|
cur_match_num: current triple match number |
|
Returns: |
|
the best gain we can get via swap/move operation |
|
|
|
""" |
|
largest_gain = 0 |
|
|
|
use_swap = True |
|
|
|
node1 = None |
|
|
|
|
|
node2 = None |
|
|
|
unmatched = set(range(instance_len)) |
|
|
|
|
|
for nid in mapping: |
|
if nid in unmatched: |
|
unmatched.remove(nid) |
|
for i, nid in enumerate(mapping): |
|
|
|
for nm in unmatched: |
|
if nm in candidate_mappings[i]: |
|
|
|
|
|
if veryVerbose: |
|
print("Remap node", i, "from ", nid, "to", nm, file=DEBUG_LOG) |
|
mv_gain = move_gain(mapping, i, nid, nm, weight_dict, cur_match_num) |
|
if veryVerbose: |
|
print("Move gain:", mv_gain, file=DEBUG_LOG) |
|
new_mapping = mapping[:] |
|
new_mapping[i] = nm |
|
new_match_num = compute_match(new_mapping, weight_dict) |
|
if new_match_num != cur_match_num + mv_gain: |
|
print(mapping, new_mapping, file=ERROR_LOG) |
|
print("Inconsistency in computing: move gain", cur_match_num, mv_gain, new_match_num, |
|
file=ERROR_LOG) |
|
if mv_gain > largest_gain: |
|
largest_gain = mv_gain |
|
node1 = i |
|
node2 = nm |
|
use_swap = False |
|
|
|
for i, m in enumerate(mapping): |
|
for j in range(i + 1, len(mapping)): |
|
m2 = mapping[j] |
|
|
|
|
|
if veryVerbose: |
|
print("Swap node", i, "and", j, file=DEBUG_LOG) |
|
print("Before swapping:", i, "-", m, ",", j, "-", m2, file=DEBUG_LOG) |
|
print(mapping, file=DEBUG_LOG) |
|
print("After swapping:", i, "-", m2, ",", j, "-", m, file=DEBUG_LOG) |
|
sw_gain = swap_gain(mapping, i, m, j, m2, weight_dict, cur_match_num) |
|
if veryVerbose: |
|
print("Swap gain:", sw_gain, file=DEBUG_LOG) |
|
new_mapping = mapping[:] |
|
new_mapping[i] = m2 |
|
new_mapping[j] = m |
|
print(new_mapping, file=DEBUG_LOG) |
|
new_match_num = compute_match(new_mapping, weight_dict) |
|
if new_match_num != cur_match_num + sw_gain: |
|
print(mapping, new_mapping, file=ERROR_LOG) |
|
print("Inconsistency in computing: swap gain", cur_match_num, sw_gain, new_match_num, |
|
file=ERROR_LOG) |
|
if sw_gain > largest_gain: |
|
largest_gain = sw_gain |
|
node1 = i |
|
node2 = j |
|
use_swap = True |
|
|
|
cur_mapping = mapping[:] |
|
if node1 is not None: |
|
if use_swap: |
|
if veryVerbose: |
|
print("Use swap gain", file=DEBUG_LOG) |
|
temp = cur_mapping[node1] |
|
cur_mapping[node1] = cur_mapping[node2] |
|
cur_mapping[node2] = temp |
|
else: |
|
if veryVerbose: |
|
print("Use move gain", file=DEBUG_LOG) |
|
cur_mapping[node1] = node2 |
|
else: |
|
if veryVerbose: |
|
print("no move/swap gain found", file=DEBUG_LOG) |
|
if veryVerbose: |
|
print("Original mapping", mapping, file=DEBUG_LOG) |
|
print("Current mapping", cur_mapping, file=DEBUG_LOG) |
|
return largest_gain, cur_mapping |
|
|
|
|
|
def print_alignment(mapping, instance1, instance2): |
|
""" |
|
print the alignment based on a node mapping |
|
Args: |
|
mapping: current node mapping list |
|
instance1: nodes of AMR 1 |
|
instance2: nodes of AMR 2 |
|
|
|
""" |
|
result = [] |
|
for instance1_item, m in zip(instance1, mapping): |
|
r = instance1_item[1] + "(" + instance1_item[2] + ")" |
|
if m == -1: |
|
r += "-Null" |
|
else: |
|
instance2_item = instance2[m] |
|
r += "-" + instance2_item[1] + "(" + instance2_item[2] + ")" |
|
result.append(r) |
|
return " ".join(result) |
|
|
|
|
|
def compute_f(match_num, test_num, gold_num): |
|
""" |
|
Compute the f-score based on the matching triple number, |
|
triple number of AMR set 1, |
|
triple number of AMR set 2 |
|
Args: |
|
match_num: matching triple number |
|
test_num: triple number of AMR 1 (test file) |
|
gold_num: triple number of AMR 2 (gold file) |
|
Returns: |
|
precision: match_num/test_num |
|
recall: match_num/gold_num |
|
f_score: 2*precision*recall/(precision+recall) |
|
""" |
|
if test_num == 0 or gold_num == 0: |
|
return 0.00, 0.00, 0.00 |
|
precision = float(match_num) / float(test_num) |
|
recall = float(match_num) / float(gold_num) |
|
if (precision + recall) != 0: |
|
f_score = 2 * precision * recall / (precision + recall) |
|
if veryVerbose: |
|
print("F-score:", f_score, file=DEBUG_LOG) |
|
return precision, recall, f_score |
|
else: |
|
if veryVerbose: |
|
print("F-score:", "0.0", file=DEBUG_LOG) |
|
return precision, recall, 0.00 |
|
|
|
|
|
def generate_amr_lines(f1, f2): |
|
""" |
|
Read one AMR line at a time from each file handle |
|
:param f1: file handle (or any iterable of strings) to read AMR 1 lines from |
|
:param f2: file handle (or any iterable of strings) to read AMR 2 lines from |
|
:return: generator of cur_amr1, cur_amr2 pairs: one-line AMR strings |
|
""" |
|
while True: |
|
cur_amr1 = amr.AMR.get_amr_line(f1) |
|
cur_amr2 = amr.AMR.get_amr_line(f2) |
|
if not cur_amr1 and not cur_amr2: |
|
pass |
|
elif not cur_amr1: |
|
print("Error: File 1 has less AMRs than file 2", file=ERROR_LOG) |
|
print("Ignoring remaining AMRs", file=ERROR_LOG) |
|
elif not cur_amr2: |
|
print("Error: File 2 has less AMRs than file 1", file=ERROR_LOG) |
|
print("Ignoring remaining AMRs", file=ERROR_LOG) |
|
else: |
|
yield cur_amr1, cur_amr2 |
|
continue |
|
break |
|
|
|
|
|
def get_amr_match(cur_amr1, cur_amr2, sent_num=1, justinstance=False, justattribute=False, justrelation=False, |
|
limit = None, |
|
instance1 = None, attributes1 = None, relation1 = None, prefix1 = None, |
|
instance2 = None, attributes2 = None, relation2 = None, prefix2 = None): |
|
global iteration_num |
|
if limit is not None: iteration_num = limit |
|
if cur_amr1 and cur_amr2: |
|
amr_pair = [] |
|
for i, cur_amr in (1, cur_amr1), (2, cur_amr2): |
|
try: |
|
amr_pair.append(amr.AMR.parse_AMR_line(cur_amr)) |
|
except Exception as e: |
|
print("Error in parsing amr %d: %s" % (i, cur_amr), file=ERROR_LOG) |
|
print("Please check if the AMR is ill-formatted. Ignoring remaining AMRs", file=ERROR_LOG) |
|
print("Error message: %s" % e, file=ERROR_LOG) |
|
amr1, amr2 = amr_pair |
|
prefix1 = "a" |
|
prefix2 = "b" |
|
|
|
amr1.rename_node(prefix1) |
|
|
|
amr2.rename_node(prefix2) |
|
(instance1, attributes1, relation1) = amr1.get_triples() |
|
(instance2, attributes2, relation2) = amr2.get_triples() |
|
if verbose: |
|
print("AMR pair", sent_num, file=DEBUG_LOG) |
|
print("============================================", file=DEBUG_LOG) |
|
print("AMR 1 (one-line):", cur_amr1, file=DEBUG_LOG) |
|
print("AMR 2 (one-line):", cur_amr2, file=DEBUG_LOG) |
|
print("Instance triples of AMR 1:", len(instance1), file=DEBUG_LOG) |
|
print(instance1, file=DEBUG_LOG) |
|
print("Attribute triples of AMR 1:", len(attributes1), file=DEBUG_LOG) |
|
print(attributes1, file=DEBUG_LOG) |
|
print("Relation triples of AMR 1:", len(relation1), file=DEBUG_LOG) |
|
print(relation1, file=DEBUG_LOG) |
|
print("Instance triples of AMR 2:", len(instance2), file=DEBUG_LOG) |
|
print(instance2, file=DEBUG_LOG) |
|
print("Attribute triples of AMR 2:", len(attributes2), file=DEBUG_LOG) |
|
print(attributes2, file=DEBUG_LOG) |
|
print("Relation triples of AMR 2:", len(relation2), file=DEBUG_LOG) |
|
print(relation2, file=DEBUG_LOG) |
|
|
|
doinstance = doattribute = dorelation = True |
|
if justinstance: |
|
doattribute = dorelation = False |
|
if justattribute: |
|
doinstance = dorelation = False |
|
if justrelation: |
|
doinstance = doattribute = False |
|
(best_mapping, best_match_num) = get_best_match(instance1, attributes1, relation1, |
|
instance2, attributes2, relation2, |
|
prefix1, prefix2, doinstance=doinstance, |
|
doattribute=doattribute, dorelation=dorelation) |
|
if verbose: |
|
print("best match number", best_match_num, file=DEBUG_LOG) |
|
print("best node mapping", best_mapping, file=DEBUG_LOG) |
|
print("Best node mapping alignment:", print_alignment(best_mapping, instance1, instance2), file=DEBUG_LOG) |
|
if justinstance: |
|
test_triple_num = len(instance1) |
|
gold_triple_num = len(instance2) |
|
elif justattribute: |
|
test_triple_num = len(attributes1) |
|
gold_triple_num = len(attributes2) |
|
elif justrelation: |
|
test_triple_num = len(relation1) |
|
gold_triple_num = len(relation2) |
|
else: |
|
test_triple_num = len(instance1) + len(attributes1) + len(relation1) |
|
gold_triple_num = len(instance2) + len(attributes2) + len(relation2) |
|
match_triple_dict.clear() |
|
if cur_amr1 and cur_amr2: |
|
return best_match_num, test_triple_num, gold_triple_num |
|
else: |
|
return best_match_num, test_triple_num, gold_triple_num, best_mapping |
|
|
|
|
|
def score_amr_pairs(f1, f2, justinstance=False, justattribute=False, justrelation=False): |
|
""" |
|
Score one pair of AMR lines at a time from each file handle |
|
:param f1: file handle (or any iterable of strings) to read AMR 1 lines from |
|
:param f2: file handle (or any iterable of strings) to read AMR 2 lines from |
|
:param justinstance: just pay attention to matching instances |
|
:param justattribute: just pay attention to matching attributes |
|
:param justrelation: just pay attention to matching relations |
|
:return: generator of cur_amr1, cur_amr2 pairs: one-line AMR strings |
|
""" |
|
|
|
total_match_num = total_test_num = total_gold_num = 0 |
|
|
|
for sent_num, (cur_amr1, cur_amr2) in enumerate(generate_amr_lines(f1, f2), start=1): |
|
best_match_num, test_triple_num, gold_triple_num = get_amr_match(cur_amr1, cur_amr2, |
|
sent_num=sent_num, |
|
justinstance=justinstance, |
|
justattribute=justattribute, |
|
justrelation=justrelation) |
|
total_match_num += best_match_num |
|
total_test_num += test_triple_num |
|
total_gold_num += gold_triple_num |
|
|
|
match_triple_dict.clear() |
|
if not single_score: |
|
yield compute_f(best_match_num, test_triple_num, gold_triple_num) |
|
if verbose: |
|
print("Total match number, total triple number in AMR 1, and total triple number in AMR 2:", file=DEBUG_LOG) |
|
print(total_match_num, total_test_num, total_gold_num, file=DEBUG_LOG) |
|
print("---------------------------------------------------------------------------------", file=DEBUG_LOG) |
|
if single_score: |
|
yield compute_f(total_match_num, total_test_num, total_gold_num) |
|
|
|
|
|
def main(arguments): |
|
""" |
|
Main function of smatch score calculation |
|
""" |
|
global verbose |
|
global veryVerbose |
|
global iteration_num |
|
global single_score |
|
global pr_flag |
|
global match_triple_dict |
|
|
|
|
|
iteration_num = arguments.r + 1 |
|
if arguments.ms: |
|
single_score = False |
|
if arguments.v: |
|
verbose = True |
|
if arguments.vv: |
|
veryVerbose = True |
|
if arguments.pr: |
|
pr_flag = True |
|
|
|
floatdisplay = "%%.%df" % arguments.significant |
|
for (precision, recall, best_f_score) in score_amr_pairs(args.f[0], args.f[1], |
|
justinstance=arguments.justinstance, |
|
justattribute=arguments.justattribute, |
|
justrelation=arguments.justrelation): |
|
|
|
if pr_flag: |
|
print("Precision: " + floatdisplay % precision) |
|
print("Recall: " + floatdisplay % recall) |
|
print("F-score: " + floatdisplay % best_f_score) |
|
args.f[0].close() |
|
args.f[1].close() |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = None |
|
args = None |
|
|
|
if sys.version_info[0] == 2 and sys.version_info[1] < 7: |
|
import optparse |
|
|
|
if len(sys.argv) == 1: |
|
print("No argument given. Please run smatch.py -h to see the argument description.", file=ERROR_LOG) |
|
exit(1) |
|
parser = build_arg_parser2() |
|
(args, opts) = parser.parse_args() |
|
file_handle = [] |
|
if args.f is None: |
|
print("smatch.py requires -f option to indicate two files \ |
|
containing AMR as input. Please run smatch.py -h to \ |
|
see the argument description.", file=ERROR_LOG) |
|
exit(1) |
|
|
|
assert (len(args.f) == 2) |
|
for file_path in args.f: |
|
if not os.path.exists(file_path): |
|
print("Given file", args.f[0], "does not exist", file=ERROR_LOG) |
|
exit(1) |
|
file_handle.append(open(file_path)) |
|
|
|
args.f = tuple(file_handle) |
|
|
|
else: |
|
import argparse |
|
|
|
parser = build_arg_parser() |
|
args = parser.parse_args() |
|
main(args) |
|
|