|
from relik.retriever import GoldenRetriever |
|
|
|
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex |
|
from relik.retriever.indexers.document import DocumentStore |
|
from relik.retriever import GoldenRetriever |
|
from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction |
|
from relik.reader.utils.strong_matching_eval import StrongMatching |
|
from relik.reader.data.relik_reader_data import RelikDataset |
|
|
|
from relik.inference.annotator import Relik |
|
from relik.inference.data.objects import ( |
|
AnnotationType, |
|
RelikOutput, |
|
Span, |
|
TaskType, |
|
Triples, |
|
) |
|
|
|
def load_model(): |
|
|
|
retriever = GoldenRetriever( |
|
question_encoder="/home/carlos/amr-parsing-master/sentence-similarity/retriever/wandb/wandb/latest-run/files/retriever/question_encoder", |
|
document_index=InMemoryDocumentIndex( |
|
documents=DocumentStore.from_file( |
|
"/home/carlos/amr-parsing-master/sentence-similarity/retriever/wandb/wandb/latest-run/files/retriever/document_index/documents.jsonl" |
|
), |
|
metadata_fields=["definition"], |
|
separator=' <def> ', |
|
device="cuda" |
|
), |
|
devide="cuda" |
|
|
|
) |
|
retriever.index() |
|
|
|
reader = RelikReaderForSpanExtraction("/home/carlos/amr-parsing-master/sentence-similarity/relik-main/experiments/relik-reader-deberta-small-io/2024-04-26/12-56-49/wandb/run-20240426_125654-vfznbu4r/files/hf_model/hf_model", |
|
dataset_kwargs={"use_nme": True}) |
|
|
|
relik = Relik(reader=reader, retriever=retriever, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu") |
|
|
|
relik() |
|
|
|
val_dataset: RelikDataset = hydra.utils.instantiate( |
|
cfg.data.val_dataset, |
|
dataset_path=to_absolute_path(cfg.data.val_dataset_path), |
|
) |
|
|
|
predicted_samples = relik.predict( |
|
dataset_path, token_batch_size=token_batch_size |
|
) |
|
|
|
eval_dict = StrongMatching()(predicted_samples) |
|
pprint(eval_dict) |
|
|
|
if output_path is not None: |
|
with open(output_path, "w") as f: |
|
for sample in predicted_samples: |
|
f.write(sample.to_jsons() + "\n") |
|
|