File size: 2,209 Bytes
2f044c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from relik.retriever import GoldenRetriever
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
from relik.retriever.indexers.document import DocumentStore
from relik.retriever import GoldenRetriever
from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
from relik.reader.utils.strong_matching_eval import StrongMatching
from relik.reader.data.relik_reader_data import RelikDataset
from relik.inference.annotator import Relik
from relik.inference.data.objects import (
AnnotationType,
RelikOutput,
Span,
TaskType,
Triples,
)
def load_model():
retriever = GoldenRetriever(
question_encoder="/home/carlos/amr-parsing-master/sentence-similarity/retriever/wandb/wandb/latest-run/files/retriever/question_encoder",
document_index=InMemoryDocumentIndex(
documents=DocumentStore.from_file(
"/home/carlos/amr-parsing-master/sentence-similarity/retriever/wandb/wandb/latest-run/files/retriever/document_index/documents.jsonl"
),
metadata_fields=["definition"],
separator=' <def> ',
device="cuda"
),
devide="cuda"
)
retriever.index()
reader = RelikReaderForSpanExtraction("/home/carlos/amr-parsing-master/sentence-similarity/relik-main/experiments/relik-reader-deberta-small-io/2024-04-26/12-56-49/wandb/run-20240426_125654-vfznbu4r/files/hf_model/hf_model",
dataset_kwargs={"use_nme": True})
relik = Relik(reader=reader, retriever=retriever, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu")
relik()
val_dataset: RelikDataset = hydra.utils.instantiate(
cfg.data.val_dataset,
dataset_path=to_absolute_path(cfg.data.val_dataset_path),
)
predicted_samples = relik.predict(
dataset_path, token_batch_size=token_batch_size
)
eval_dict = StrongMatching()(predicted_samples)
pprint(eval_dict)
if output_path is not None:
with open(output_path, "w") as f:
for sample in predicted_samples:
f.write(sample.to_jsons() + "\n")
|