sociolome / scripts /predict_force.py
Gosse Minnema
Initial commit
05922fb
from sftp import SpanPredictor
def print_children(sentence, boundary, labels, _):
print('Sentence:', ' '.join(sentence))
for (start_idx, end_idx), lbl in zip(boundary, labels):
print(' '.join(sentence[start_idx:end_idx+1]), ':', lbl)
print('='*20)
def example():
print("Loading predictor...")
predictor = SpanPredictor.from_path(
#'/home/gqin2/public/release/sftp/0.0.2/framenet',
"/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz",
cuda_device=-1
)
print("Predicting for sentence..")
sentence = ['Tom', 'eats', 'an', 'apple', 'and', 'he', 'wakes', 'up', '.']
p1 = predictor.force_decode(sentence)
print_children(sentence, *p1)
p2 = predictor.force_decode(sentence, parent_span=(1, 1), parent_label='Ingestion')
print_children(sentence, *p2)
p3 = predictor.force_decode(sentence, child_spans=[(0, 0), (2, 3)], parent_span=(1, 1), parent_label='Ingestion')
print_children(sentence, *p3)
if __name__ == '__main__':
example()