File size: 1,041 Bytes
05922fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from sftp import SpanPredictor


def print_children(sentence, boundary, labels, _):
    print('Sentence:', ' '.join(sentence))
    for (start_idx, end_idx), lbl in zip(boundary, labels):
        print(' '.join(sentence[start_idx:end_idx+1]), ':', lbl)
    print('='*20)


def example():
    print("Loading predictor...")
    predictor = SpanPredictor.from_path(
        #'/home/gqin2/public/release/sftp/0.0.2/framenet',
        "/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz",
        cuda_device=-1
    )

    print("Predicting for sentence..")
    sentence = ['Tom', 'eats', 'an', 'apple', 'and', 'he', 'wakes', 'up', '.']
    p1 = predictor.force_decode(sentence)
    print_children(sentence, *p1)
    p2 = predictor.force_decode(sentence, parent_span=(1, 1), parent_label='Ingestion')
    print_children(sentence, *p2)
    p3 = predictor.force_decode(sentence, child_spans=[(0, 0), (2, 3)], parent_span=(1, 1), parent_label='Ingestion')
    print_children(sentence, *p3)


if __name__ == '__main__':
    example()