|
import multiprocessing as mp |
|
import sys |
|
from operator import itemgetter |
|
|
|
import numpy as np |
|
|
|
import score.core |
|
from score.smatch import smatch |
|
from score.ucca import identify |
|
|
|
counter = 0 |
|
|
|
def reindex(i): |
|
return -2 - i |
|
|
|
def get_or_update(index, key): |
|
return index.setdefault(key, len(index)) |
|
|
|
class InternalGraph(): |
|
|
|
def __init__(self, graph, index): |
|
self.node2id = dict() |
|
self.id2node = dict() |
|
self.nodes = [] |
|
self.edges = [] |
|
for i, node in enumerate(graph.nodes): |
|
self.node2id[node] = i |
|
self.id2node[i] = node |
|
self.nodes.append(i) |
|
for edge in graph.edges: |
|
src = graph.find_node(edge.src) |
|
src = self.node2id[src] |
|
tgt = graph.find_node(edge.tgt) |
|
tgt = self.node2id[tgt] |
|
self.edges.append((src, tgt, edge.lab)) |
|
if edge.attributes: |
|
for prop, val in zip(edge.attributes, edge.values): |
|
self.edges.append((src, tgt, ("E", prop, val))) |
|
|
|
|
|
|
|
|
|
if index is None: |
|
index = dict() |
|
for i, node in enumerate(graph.nodes): |
|
|
|
j = get_or_update(index, ("L", node.label)) |
|
self.edges.append((i, reindex(j), None)) |
|
|
|
if node.is_top: |
|
j = get_or_update(index, ("T")) |
|
self.edges.append((i, reindex(j), None)) |
|
|
|
if node.anchors is not None: |
|
anchor = score.core.anchor(node); |
|
if graph.input: |
|
anchor = score.core.explode(graph.input, anchor); |
|
else: |
|
anchor = tuple(anchor); |
|
j = get_or_update(index, ("A", anchor)) |
|
self.edges.append((i, reindex(j), None)) |
|
|
|
if node.properties: |
|
for prop, val in zip(node.properties, node.values): |
|
j = get_or_update(index, ("P", prop, val)) |
|
self.edges.append((i, reindex(j), None)) |
|
|
|
def initial_node_correspondences(graph1, graph2, |
|
identities1, identities2, |
|
bilexical): |
|
|
|
|
|
|
|
|
|
shape = (len(graph1.nodes), len(graph2.nodes) + 1) |
|
rewards = np.zeros(shape, dtype=np.int); |
|
edges = np.zeros(shape, dtype=np.int); |
|
anchors = np.zeros(shape, dtype=np.int); |
|
|
|
|
|
|
|
|
|
|
|
if bilexical: |
|
queue = None; |
|
else: |
|
queue = []; |
|
|
|
for i, node1 in enumerate(graph1.nodes): |
|
for j, node2 in enumerate(graph2.nodes + [None]): |
|
rewards[i, j], _, _, _ = node1.compare(node2); |
|
if node2 is not None: |
|
|
|
|
|
|
|
|
|
src_edges_x = [ len([ 1 for e1 in graph1.edges if e1.src == node1.id and e1.lab == e2.lab ]) |
|
for e2 in graph2.edges if e2.src == node2.id ] |
|
tgt_edges_x = [ len([ 1 for e1 in graph1.edges if e1.tgt == node1.id and e1.lab == e2.lab ]) |
|
for e2 in graph2.edges if e2.tgt == node2.id ] |
|
edges[i, j] += sum(src_edges_x) + sum(tgt_edges_x) |
|
|
|
|
|
|
|
|
|
if identities1 and identities2: |
|
anchors[i, j] += len(identities1[node1.id] & |
|
identities2[node2.id]) |
|
if queue is not None: |
|
queue.append((rewards[i, j], edges[i, j], anchors[i, j], |
|
i, j if node2 is not None else None)); |
|
|
|
|
|
|
|
|
|
|
|
|
|
rewards *= 1000; |
|
anchors *= 10; |
|
rewards += edges + anchors; |
|
|
|
if queue is None: |
|
pairs = levenshtein(graph1, graph2); |
|
else: |
|
pairs = []; |
|
sources = set(); |
|
targets = set(); |
|
for _, _, _, i, j in sorted(queue, key = itemgetter(0, 2, 1), |
|
reverse = True): |
|
if i not in sources and j not in targets: |
|
pairs.append((i, j)); |
|
sources.add(i); |
|
if j is not None: targets.add(j); |
|
|
|
return pairs, rewards; |
|
|
|
def levenshtein(graph1, graph2): |
|
m = len(graph1.nodes) |
|
n = len(graph2.nodes) |
|
d = {(i,j): float('-inf') for i in range(m+1) for j in range(n+1)} |
|
p = {(i,j): None for i in range(m+1) for j in range(n+1)} |
|
d[(0,0)] = 0 |
|
for i in range(1, m+1): |
|
d[(i,0)] = 0 |
|
p[(i,0)] = ((i-1,0), None) |
|
for j in range(1, n+1): |
|
d[(0,j)] = 0 |
|
p[(0,j)] = ((0,j-1), None) |
|
for j, node2 in enumerate(graph2.nodes, 1): |
|
for i, node1 in enumerate(graph1.nodes, 1): |
|
best_d = float('-inf') |
|
|
|
cand_d = d[(i-1,j-0)] |
|
if cand_d > best_d: |
|
best_d = cand_d |
|
best_p = ((i-1,j-0), None) |
|
|
|
cand_d = d[(i-0,j-1)] |
|
if cand_d > best_d: |
|
best_d = cand_d |
|
best_p = ((i-0,j-1), None) |
|
|
|
cand_d = d[(i-1,j-1)] + node1.compare(node2)[2] |
|
if cand_d > best_d: |
|
best_d = cand_d |
|
best_p = ((i-1,j-1), (i-1, j-1)) |
|
d[(i,j)] = best_d |
|
p[(i,j)] = best_p |
|
|
|
pairs = {i: None for i in range(len(graph1.nodes))} |
|
def backtrace(idx): |
|
ptr = p[idx] |
|
if ptr is None: |
|
pass |
|
else: |
|
next_idx, pair = ptr |
|
if pair is not None: |
|
i, j = pair |
|
pairs[i] = j |
|
backtrace(next_idx) |
|
backtrace((m, n)) |
|
return sorted(pairs.items()) |
|
|
|
|
|
|
|
|
|
def make_edge_candidates(graph1, graph2): |
|
candidates = dict() |
|
for raw_edge1 in graph1.edges: |
|
src1, tgt1, lab1 = raw_edge1 |
|
if raw_edge1 not in candidates: |
|
edge1_candidates = set() |
|
else: |
|
edge1_candidates = candidates[raw_edge1] |
|
for raw_edge2 in graph2.edges: |
|
src2, tgt2, lab2 = raw_edge2 |
|
edge2 = (src2, tgt2) |
|
if tgt1 < 0: |
|
|
|
|
|
if tgt2 == tgt1 and lab1 == lab2: |
|
edge1_candidates.add(edge2) |
|
elif tgt2 >= 0 and lab1 == lab2: |
|
|
|
|
|
edge1_candidates.add(edge2) |
|
if edge1_candidates: |
|
candidates[raw_edge1] = edge1_candidates |
|
return candidates |
|
|
|
|
|
|
|
|
|
def update_edge_candidates(edge_candidates, i, j): |
|
new_candidates = edge_candidates.copy() |
|
for edge1, edge1_candidates in edge_candidates.items(): |
|
if i == edge1[0] or i == edge1[1]: |
|
|
|
|
|
|
|
|
|
|
|
src1, tgt1, _ = edge1 |
|
edge1_candidates = {(src2, tgt2) for src2, tgt2 in edge1_candidates |
|
if src1 == i and src2 == j or tgt1 == i and tgt2 == j} |
|
if edge1_candidates: |
|
new_candidates[edge1] = edge1_candidates |
|
else: |
|
new_candidates.pop(edge1) |
|
return new_candidates, len(new_candidates) |
|
|
|
def splits(xs): |
|
|
|
for i, x in enumerate(xs): |
|
yield x, xs[:i] + xs[i+1:] |
|
|
|
yield -1, xs |
|
|
|
def sorted_splits(i, xs, rewards, pairs, bilexical): |
|
for _i, _j in pairs: |
|
if i == _i: j = _j if _j is not None else -1 |
|
if bilexical: |
|
sorted_xs = sorted(xs, key=lambda x: (-abs(x-i), rewards.item((i, x)), -x), reverse=True) |
|
else: |
|
sorted_xs = sorted(xs, key=lambda x: (rewards.item((i, x)), -x), reverse=True) |
|
if j in sorted_xs or j < 0: |
|
if j >= 0: sorted_xs.remove(j) |
|
sorted_xs = [j] + sorted_xs |
|
yield from splits(sorted_xs) |
|
|
|
|
|
|
|
|
|
def identities(g, s): |
|
|
|
|
|
|
|
if g.framework == "ucca" and g.input \ |
|
and s.framework == "ucca" and s.input: |
|
g_identities = dict() |
|
s_identities = dict() |
|
g_dominated = dict() |
|
s_dominated = dict() |
|
for node in g.nodes: |
|
g_identities, g_dominated = \ |
|
identify(g, node.id, g_identities, g_dominated) |
|
g_identities = {key: score.core.explode(g.input, value) |
|
for key, value in g_identities.items()} |
|
for node in s.nodes: |
|
s_identities, s_dominated = \ |
|
identify(s, node.id, s_identities, s_dominated) |
|
s_identities = {key: score.core.explode(s.input, value) |
|
for key, value in s_identities.items()} |
|
else: |
|
g_identities = s_identities = g_dominated = s_dominated = None |
|
return g_identities, s_identities, g_dominated, s_dominated |
|
|
|
def domination_conflict(graph1, graph2, cv, i, j, dominated1, dominated2): |
|
if not dominated1 or not dominated2 or i < 0 or j < 0: |
|
return False |
|
dominated_i = dominated1[graph1.id2node[i].id] |
|
dominated_j = dominated2[graph2.id2node[j].id] |
|
|
|
if bool(dominated_i) != bool(dominated_j): |
|
return True |
|
for _i, _j in cv.items(): |
|
if _i >= 0 and _j >= 0 and \ |
|
graph1.id2node[_i].id in dominated_i and \ |
|
graph2.id2node[_j].id not in dominated_j: |
|
return True |
|
return False |
|
|
|
|
|
|
|
|
|
def correspondences(graph1, graph2, pairs, rewards, limit=None, trace=0, |
|
dominated1=None, dominated2=None, bilexical = False): |
|
global counter |
|
index = dict() |
|
graph1 = InternalGraph(graph1, index) |
|
graph2 = InternalGraph(graph2, index) |
|
cv = dict() |
|
ce = make_edge_candidates(graph1, graph2) |
|
|
|
source_todo = [pair[0] for pair in pairs] |
|
todo = [(cv, ce, source_todo, sorted_splits( |
|
source_todo[0], graph2.nodes, rewards, pairs, bilexical))] |
|
n_matched = 0 |
|
while todo and (limit is None or counter <= limit): |
|
cv, ce, source_todo, untried = todo[-1] |
|
i = source_todo[0] |
|
try: |
|
j, new_untried = next(untried) |
|
if cv: |
|
if bilexical: |
|
max_j = max((_j for _i, _j in cv.items() if _i < i), default=-1) |
|
if 0 <= j < max_j + 1: |
|
continue |
|
elif domination_conflict(graph1, graph2, cv, i, j, dominated1, dominated2): |
|
continue |
|
counter += 1 |
|
if trace > 2: print("({}:{}) ".format(i, j), end="", file = sys.stderr) |
|
new_cv = dict(cv) |
|
new_cv[i] = j |
|
new_ce, new_potential = update_edge_candidates(ce, i, j) |
|
if new_potential > n_matched: |
|
new_source_todo = source_todo[1:] |
|
if new_source_todo: |
|
if trace > 2: print("> ", end="", file = sys.stderr) |
|
todo.append((new_cv, new_ce, new_source_todo, |
|
sorted_splits(new_source_todo[0], |
|
new_untried, rewards, |
|
pairs, bilexical))) |
|
else: |
|
if trace > 2: print(file = sys.stderr) |
|
yield new_cv, new_ce |
|
n_matched = new_potential |
|
except StopIteration: |
|
if trace > 2: print("< ", file = sys.stderr) |
|
todo.pop() |
|
|
|
def is_valid(correspondence): |
|
return all(len(x) <= 1 for x in correspondence.values()) |
|
|
|
def is_injective(correspondence): |
|
seen = set() |
|
for xs in correspondence.values(): |
|
for x in xs: |
|
if x in seen: |
|
return False |
|
else: |
|
seen.add(x) |
|
return True |
|
|
|
def schedule(g, s, rrhc_limit, mces_limit, trace, errors): |
|
global counter; |
|
try: |
|
counter = 0; |
|
g_identities, s_identities, g_dominated, s_dominated \ |
|
= identities(g, s); |
|
bilexical = g.flavor == 0 or g.framework in {"dm", "psd", "pas", "ccd"}; |
|
pairs, rewards \ |
|
= initial_node_correspondences(g, s, |
|
g_identities, s_identities, |
|
bilexical); |
|
if errors is not None and g.framework not in errors: errors[g.framework] = dict(); |
|
if trace > 1: |
|
print("\n\ngraph #{} ({}; {}; {})" |
|
"".format(g.id, g.language(), g.flavor, g.framework), |
|
file = sys.stderr); |
|
print("number of gold nodes: {}".format(len(g.nodes)), |
|
file = sys.stderr); |
|
print("number of system nodes: {}".format(len(s.nodes)), |
|
file = sys.stderr); |
|
print("number of edges: {}".format(len(g.edges)), |
|
file = sys.stderr); |
|
if trace > 2: |
|
print("rewards and pairs:\n{}\n{}\n" |
|
"".format(rewards, sorted(pairs)), |
|
file = sys.stderr); |
|
smatches = 0; |
|
if g.framework in {"eds", "amr"} and rrhc_limit > 0: |
|
smatches, _, _, mapping \ |
|
= smatch(g, s, rrhc_limit, |
|
{"tops", "labels", "properties", "anchors", |
|
"edges", "attributes"}, |
|
0, False); |
|
mapping = [(i, j if j >= 0 else None) |
|
for i, j in enumerate(mapping)]; |
|
tops, labels, properties, anchors, edges, attributes \ |
|
= g.score(s, mapping); |
|
all = tops["c"] + labels["c"] + properties["c"] \ |
|
+ anchors["c"] + edges["c"] + attributes["c"]; |
|
status = "{}".format(smatches); |
|
if smatches > all: |
|
status = "{} vs. {}".format(smatches, all); |
|
smatches = all; |
|
if trace > 1: |
|
print("pairs {} smatch [{}]: {}" |
|
"".format("from" if set(pairs) != set(mapping) else "by", |
|
status, sorted(mapping)), |
|
file = sys.stderr); |
|
if set(pairs) != set(mapping): pairs = mapping; |
|
matches, best_cv, best_ce = 0, {}, {}; |
|
if g.nodes and mces_limit > 0: |
|
for i, (cv, ce) in \ |
|
enumerate(correspondences(g, s, pairs, rewards, |
|
mces_limit, trace, |
|
dominated1 = g_dominated, |
|
dominated2 = s_dominated, |
|
bilexical = bilexical)): |
|
|
|
|
|
n = sum(map(len, ce.values())); |
|
if n > matches: |
|
if trace > 1: |
|
print("\n[{}] solution #{}; matches: {}" |
|
"".format(counter, i, n), file = sys.stderr); |
|
matches, best_cv, best_ce = n, cv, ce; |
|
tops, labels, properties, anchors, edges, attributes \ |
|
= g.score(s, best_cv or pairs, errors); |
|
|
|
if trace > 1: |
|
if smatches and matches != smatches: |
|
print("delta to smatch: {}" |
|
"".format(matches - smatches), file = sys.stderr); |
|
print("[{}] edges in correspondence: {}" |
|
"".format(counter, matches), file = sys.stderr) |
|
print("tops: {}\nlabels: {}\nproperties: {}\nanchors: {}" |
|
"\nedges: {}\nattributes: {}" |
|
"".format(tops, labels, properties, anchors, |
|
edges, attributes), file = sys.stderr); |
|
if trace > 2: |
|
print(best_cv, file = sys.stderr) |
|
print(best_ce, file = sys.stderr) |
|
return g.id, g, s, tops, labels, properties, anchors, \ |
|
edges, attributes, matches, counter, None; |
|
|
|
except Exception as e: |
|
|
|
|
|
|
|
raise e; |
|
return g.id, g, s, None, None, None, None, None, None, None, None, e; |
|
|
|
def evaluate(gold, system, format = "json", |
|
limits = None, |
|
cores = 0, trace = 0, errors = None, quiet = False): |
|
def update(total, counts): |
|
for key in ("g", "s", "c"): |
|
total[key] += counts[key]; |
|
|
|
def finalize(counts): |
|
p, r, f = score.core.fscore(counts["g"], counts["s"], counts["c"]); |
|
counts.update({"p": p, "r": r, "f": f}); |
|
|
|
if limits is None: |
|
limits = {"rrhc": 20, "mces": 500000} |
|
rrhc_limit = mces_limit = None; |
|
if isinstance(limits, dict): |
|
if "rrhc" in limits: rrhc_limit = limits["rrhc"]; |
|
if "mces" in limits: mces_limit = limits["mces"]; |
|
if rrhc_limit is None or rrhc_limit < 0: rrhc_limit = 20; |
|
if mces_limit is None or mces_limit < 0: mces_limit = 500000; |
|
if trace > 1: |
|
print("RRHC limit: {}; MCES limit: {}".format(rrhc_limit, mces_limit), |
|
file = sys.stderr); |
|
total_matches = total_steps = 0; |
|
total_pairs = 0; |
|
total_empty = 0; |
|
total_inexact = 0; |
|
total_tops = {"g": 0, "s": 0, "c": 0} |
|
total_labels = {"g": 0, "s": 0, "c": 0} |
|
total_properties = {"g": 0, "s": 0, "c": 0} |
|
total_anchors = {"g": 0, "s": 0, "c": 0} |
|
total_edges = {"g": 0, "s": 0, "c": 0} |
|
total_attributes = {"g": 0, "s": 0, "c": 0} |
|
scores = dict() if trace else None; |
|
if cores > 1: |
|
if trace > 1: |
|
print("mces.evaluate(): using {} cores".format(cores), |
|
file = sys.stderr); |
|
with mp.Pool(cores) as pool: |
|
results = pool.starmap(schedule, |
|
((g, s, rrhc_limit, mces_limit, |
|
trace, errors) |
|
for g, s |
|
in score.core.intersect(gold, |
|
system, |
|
quiet = quiet))); |
|
else: |
|
results = (schedule(g, s, rrhc_limit, mces_limit, trace, errors) |
|
for g, s in score.core.intersect(gold, system)); |
|
|
|
for id, g, s, tops, labels, properties, anchors, \ |
|
edges, attributes, matches, steps, error \ |
|
in results: |
|
framework = g.framework if g.framework else "none"; |
|
if scores is not None and framework not in scores: scores[framework] = dict(); |
|
if s.nodes is None or len(s.nodes) == 0: |
|
total_empty += 1; |
|
if error is None: |
|
total_matches += matches; |
|
total_steps += steps; |
|
update(total_tops, tops); |
|
update(total_labels, labels); |
|
update(total_properties, properties); |
|
update(total_anchors, anchors); |
|
update(total_edges, edges); |
|
update(total_attributes, attributes); |
|
total_pairs += 1; |
|
if mces_limit == 0 or steps > mces_limit: total_inexact += 1; |
|
|
|
if trace and s.nodes is not None and len(s.nodes) != 0: |
|
if id in scores[framework]: |
|
print("mces.evaluate(): duplicate {} graph identifier: {}" |
|
"".format(framework, id), file = sys.stderr); |
|
scores[framework][id] \ |
|
= {"tops": tops, "labels": labels, |
|
"properties": properties, "anchors": anchors, |
|
"edges": edges, "attributes": attributes, |
|
"exact": not (mces_limit == 0 or steps > mces_limit), |
|
"steps": steps}; |
|
else: |
|
print("mces.evaluate(): exception in {} graph #{}:\n{}" |
|
"".format(framework, id, error)); |
|
if trace: |
|
scores[framework][id] = {"error": repr(error)}; |
|
|
|
total_all = {"g": 0, "s": 0, "c": 0}; |
|
for counts in [total_tops, total_labels, total_properties, total_anchors, |
|
total_edges, total_attributes]: |
|
update(total_all, counts); |
|
finalize(counts); |
|
finalize(total_all); |
|
result = {"n": total_pairs, "null": total_empty, |
|
"exact": total_pairs - total_inexact, |
|
"tops": total_tops, "labels": total_labels, |
|
"properties": total_properties, "anchors": total_anchors, |
|
"edges": total_edges, "attributes": total_attributes, |
|
"all": total_all}; |
|
if trace: result["scores"] = scores; |
|
return result; |
|
|