|
from typing import List, Tuple, Callable |
|
|
|
import spacy |
|
from spacy.tokens import Doc, Span |
|
from thinc.types import Floats2d, Ints1d, Ragged, cast |
|
from thinc.api import Model, Linear, chain, Logistic |
|
|
|
|
|
@spacy.registry.architectures("rel_model.v1") |
|
def create_relation_model( |
|
create_instance_tensor: Model[List[Doc], Floats2d], |
|
classification_layer: Model[Floats2d, Floats2d], |
|
) -> Model[List[Doc], Floats2d]: |
|
with Model.define_operators({">>": chain}): |
|
model = create_instance_tensor >> classification_layer |
|
model.attrs["get_instances"] = create_instance_tensor.attrs["get_instances"] |
|
return model |
|
|
|
|
|
@spacy.registry.architectures("rel_classification_layer.v1") |
|
def create_classification_layer( |
|
nO: int = None, nI: int = None |
|
) -> Model[Floats2d, Floats2d]: |
|
with Model.define_operators({">>": chain}): |
|
return Linear(nO=nO, nI=nI) >> Logistic() |
|
|
|
|
|
@spacy.registry.misc("rel_instance_generator.v1") |
|
def create_instances(max_length: int) -> Callable[[Doc], List[Tuple[Span, Span]]]: |
|
def get_instances(doc: Doc) -> List[Tuple[Span, Span]]: |
|
instances = [] |
|
for ent1 in doc.ents: |
|
for ent2 in doc.ents: |
|
if ent1 != ent2: |
|
if max_length and abs(ent2.start - ent1.start) <= max_length: |
|
instances.append((ent1, ent2)) |
|
return instances |
|
|
|
return get_instances |
|
|
|
|
|
@spacy.registry.architectures("rel_instance_tensor.v1") |
|
def create_tensors( |
|
tok2vec: Model[List[Doc], List[Floats2d]], |
|
pooling: Model[Ragged, Floats2d], |
|
get_instances: Callable[[Doc], List[Tuple[Span, Span]]], |
|
) -> Model[List[Doc], Floats2d]: |
|
|
|
return Model( |
|
"instance_tensors", |
|
instance_forward, |
|
layers=[tok2vec, pooling], |
|
refs={"tok2vec": tok2vec, "pooling": pooling}, |
|
attrs={"get_instances": get_instances}, |
|
init=instance_init, |
|
) |
|
|
|
|
|
def instance_forward(model: Model[List[Doc], Floats2d], docs: List[Doc], is_train: bool) -> Tuple[Floats2d, Callable]: |
|
pooling = model.get_ref("pooling") |
|
tok2vec = model.get_ref("tok2vec") |
|
get_instances = model.attrs["get_instances"] |
|
all_instances = [get_instances(doc) for doc in docs] |
|
tokvecs, bp_tokvecs = tok2vec(docs, is_train) |
|
|
|
ents = [] |
|
lengths = [] |
|
|
|
for doc_nr, (instances, tokvec) in enumerate(zip(all_instances, tokvecs)): |
|
token_indices = [] |
|
for instance in instances: |
|
for ent in instance: |
|
token_indices.extend([i for i in range(ent.start, ent.end)]) |
|
lengths.append(ent.end - ent.start) |
|
ents.append(tokvec[token_indices]) |
|
lengths = cast(Ints1d, model.ops.asarray(lengths, dtype="int32")) |
|
entities = Ragged(model.ops.flatten(ents), lengths) |
|
pooled, bp_pooled = pooling(entities, is_train) |
|
|
|
|
|
relations = model.ops.reshape2f(pooled, -1, pooled.shape[1] * 2) |
|
|
|
def backprop(d_relations: Floats2d) -> List[Doc]: |
|
d_pooled = model.ops.reshape2f(d_relations, d_relations.shape[0] * 2, -1) |
|
d_ents = bp_pooled(d_pooled).data |
|
d_tokvecs = [] |
|
ent_index = 0 |
|
for doc_nr, instances in enumerate(all_instances): |
|
shape = tokvecs[doc_nr].shape |
|
d_tokvec = model.ops.alloc2f(*shape) |
|
count_occ = model.ops.alloc2f(*shape) |
|
for instance in instances: |
|
for ent in instance: |
|
d_tokvec[ent.start : ent.end] += d_ents[ent_index] |
|
count_occ[ent.start : ent.end] += 1 |
|
ent_index += ent.end - ent.start |
|
d_tokvec /= count_occ + 0.00000000001 |
|
d_tokvecs.append(d_tokvec) |
|
|
|
d_docs = bp_tokvecs(d_tokvecs) |
|
return d_docs |
|
|
|
return relations, backprop |
|
|
|
|
|
def instance_init(model: Model, X: List[Doc] = None, Y: Floats2d = None) -> Model: |
|
tok2vec = model.get_ref("tok2vec") |
|
if X is not None: |
|
tok2vec.initialize(X) |
|
return model |
|
|