File size: 2,209 Bytes
2f044c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from relik.retriever import GoldenRetriever

from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
from relik.retriever.indexers.document import DocumentStore
from relik.retriever import GoldenRetriever
from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
from relik.reader.utils.strong_matching_eval import StrongMatching
from relik.reader.data.relik_reader_data import RelikDataset

from relik.inference.annotator import Relik
from relik.inference.data.objects import (
    AnnotationType,
    RelikOutput,
    Span,
    TaskType,
    Triples,
)

def load_model():

    retriever = GoldenRetriever(
        question_encoder="/home/carlos/amr-parsing-master/sentence-similarity/retriever/wandb/wandb/latest-run/files/retriever/question_encoder",
        document_index=InMemoryDocumentIndex(
                documents=DocumentStore.from_file(
                    "/home/carlos/amr-parsing-master/sentence-similarity/retriever/wandb/wandb/latest-run/files/retriever/document_index/documents.jsonl"
                ),
                metadata_fields=["definition"],
                separator=' <def> ',
                device="cuda"
            ),
            devide="cuda"

    )
    retriever.index()

    reader = RelikReaderForSpanExtraction("/home/carlos/amr-parsing-master/sentence-similarity/relik-main/experiments/relik-reader-deberta-small-io/2024-04-26/12-56-49/wandb/run-20240426_125654-vfznbu4r/files/hf_model/hf_model",
                                        dataset_kwargs={"use_nme": True})

    relik = Relik(reader=reader, retriever=retriever, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu")

    relik()

    val_dataset: RelikDataset = hydra.utils.instantiate(
        cfg.data.val_dataset,
        dataset_path=to_absolute_path(cfg.data.val_dataset_path),
    )
    
    predicted_samples = relik.predict(
        dataset_path, token_batch_size=token_batch_size
    )

    eval_dict = StrongMatching()(predicted_samples)
    pprint(eval_dict)

    if output_path is not None:
        with open(output_path, "w") as f:
            for sample in predicted_samples:
                f.write(sample.to_jsons() + "\n")