File size: 3,985 Bytes
a8f44f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from typing import List, Tuple, Callable
import spacy
from spacy.tokens import Doc, Span
from thinc.types import Floats2d, Ints1d, Ragged, cast
from thinc.api import Model, Linear, chain, Logistic
@spacy.registry.architectures("rel_model.v1")
def create_relation_model(
create_instance_tensor: Model[List[Doc], Floats2d],
classification_layer: Model[Floats2d, Floats2d],
) -> Model[List[Doc], Floats2d]:
with Model.define_operators({">>": chain}):
model = create_instance_tensor >> classification_layer
model.attrs["get_instances"] = create_instance_tensor.attrs["get_instances"]
return model
@spacy.registry.architectures("rel_classification_layer.v1")
def create_classification_layer(
nO: int = None, nI: int = None
) -> Model[Floats2d, Floats2d]:
with Model.define_operators({">>": chain}):
return Linear(nO=nO, nI=nI) >> Logistic()
@spacy.registry.misc("rel_instance_generator.v1")
def create_instances(max_length: int) -> Callable[[Doc], List[Tuple[Span, Span]]]:
def get_instances(doc: Doc) -> List[Tuple[Span, Span]]:
instances = []
for ent1 in doc.ents:
for ent2 in doc.ents:
if ent1 != ent2:
if max_length and abs(ent2.start - ent1.start) <= max_length:
instances.append((ent1, ent2))
return instances
return get_instances
@spacy.registry.architectures("rel_instance_tensor.v1")
def create_tensors(
tok2vec: Model[List[Doc], List[Floats2d]],
pooling: Model[Ragged, Floats2d],
get_instances: Callable[[Doc], List[Tuple[Span, Span]]],
) -> Model[List[Doc], Floats2d]:
return Model(
"instance_tensors",
instance_forward,
layers=[tok2vec, pooling],
refs={"tok2vec": tok2vec, "pooling": pooling},
attrs={"get_instances": get_instances},
init=instance_init,
)
def instance_forward(model: Model[List[Doc], Floats2d], docs: List[Doc], is_train: bool) -> Tuple[Floats2d, Callable]:
pooling = model.get_ref("pooling")
tok2vec = model.get_ref("tok2vec")
get_instances = model.attrs["get_instances"]
all_instances = [get_instances(doc) for doc in docs]
tokvecs, bp_tokvecs = tok2vec(docs, is_train)
ents = []
lengths = []
for doc_nr, (instances, tokvec) in enumerate(zip(all_instances, tokvecs)):
token_indices = []
for instance in instances:
for ent in instance:
token_indices.extend([i for i in range(ent.start, ent.end)])
lengths.append(ent.end - ent.start)
ents.append(tokvec[token_indices])
lengths = cast(Ints1d, model.ops.asarray(lengths, dtype="int32"))
entities = Ragged(model.ops.flatten(ents), lengths)
pooled, bp_pooled = pooling(entities, is_train)
# Reshape so that pairs of rows are concatenated
relations = model.ops.reshape2f(pooled, -1, pooled.shape[1] * 2)
def backprop(d_relations: Floats2d) -> List[Doc]:
d_pooled = model.ops.reshape2f(d_relations, d_relations.shape[0] * 2, -1)
d_ents = bp_pooled(d_pooled).data
d_tokvecs = []
ent_index = 0
for doc_nr, instances in enumerate(all_instances):
shape = tokvecs[doc_nr].shape
d_tokvec = model.ops.alloc2f(*shape)
count_occ = model.ops.alloc2f(*shape)
for instance in instances:
for ent in instance:
d_tokvec[ent.start : ent.end] += d_ents[ent_index]
count_occ[ent.start : ent.end] += 1
ent_index += ent.end - ent.start
d_tokvec /= count_occ + 0.00000000001
d_tokvecs.append(d_tokvec)
d_docs = bp_tokvecs(d_tokvecs)
return d_docs
return relations, backprop
def instance_init(model: Model, X: List[Doc] = None, Y: Floats2d = None) -> Model:
tok2vec = model.get_ref("tok2vec")
if X is not None:
tok2vec.initialize(X)
return model
|