|
import argparse |
|
|
|
import torch |
|
from reader.data.relik_reader_sample import load_relik_reader_samples |
|
|
|
from relik.reader.pytorch_modules.hf.modeling_relik import ( |
|
RelikReaderConfig, |
|
RelikReaderREModel, |
|
) |
|
from relik.reader.relik_reader_re import RelikReaderForTripletExtraction |
|
from relik.reader.utils.relation_matching_eval import StrongMatching |
|
|
|
dict_nyt = { |
|
"/people/person/nationality": "nationality", |
|
"/sports/sports_team/location": "sports team location", |
|
"/location/country/administrative_divisions": "administrative divisions", |
|
"/business/company/major_shareholders": "shareholders", |
|
"/people/ethnicity/people": "ethnicity", |
|
"/people/ethnicity/geographic_distribution": "geographic distributi6on", |
|
"/business/company_shareholder/major_shareholder_of": "major shareholder", |
|
"/location/location/contains": "location", |
|
"/business/company/founders": "founders", |
|
"/business/person/company": "company", |
|
"/business/company/advisors": "advisor", |
|
"/people/deceased_person/place_of_death": "place of death", |
|
"/business/company/industry": "industry", |
|
"/people/person/ethnicity": "ethnic background", |
|
"/people/person/place_of_birth": "place of birth", |
|
"/location/administrative_division/country": "country of an administration division", |
|
"/people/person/place_lived": "place lived", |
|
"/sports/sports_team_location/teams": "sports team", |
|
"/people/person/children": "child", |
|
"/people/person/religion": "religion", |
|
"/location/neighborhood/neighborhood_of": "neighborhood", |
|
"/location/country/capital": "capital", |
|
"/business/company/place_founded": "company founded location", |
|
"/people/person/profession": "occupation", |
|
} |
|
|
|
|
|
def eval(model_path, data_path, is_eval, output_path=None): |
|
if model_path.endswith(".ckpt"): |
|
|
|
model_dict = torch.load(model_path) |
|
|
|
additional_special_symbols = model_dict["hyper_parameters"][ |
|
"additional_special_symbols" |
|
] |
|
from transformers import AutoTokenizer |
|
|
|
from relik.reader.utils.special_symbols import get_special_symbols_re |
|
|
|
special_symbols = get_special_symbols_re(additional_special_symbols - 1) |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_dict["hyper_parameters"]["transformer_model"], |
|
additional_special_tokens=special_symbols, |
|
add_prefix_space=True, |
|
) |
|
config_model = RelikReaderConfig( |
|
model_dict["hyper_parameters"]["transformer_model"], |
|
len(special_symbols), |
|
training=False, |
|
) |
|
model = RelikReaderREModel(config_model) |
|
model_dict["state_dict"] = { |
|
k.replace("relik_reader_re_model.", ""): v |
|
for k, v in model_dict["state_dict"].items() |
|
} |
|
model.load_state_dict(model_dict["state_dict"], strict=False) |
|
reader = RelikReaderForTripletExtraction( |
|
model, training=False, device="cuda", tokenizer=tokenizer |
|
) |
|
else: |
|
|
|
model = RelikReaderREModel.from_pretrained(model_path) |
|
reader = RelikReaderForTripletExtraction(model, training=False, device="cuda") |
|
|
|
samples = list(load_relik_reader_samples(data_path)) |
|
|
|
for sample in samples: |
|
sample.candidates = [dict_nyt[cand] for cand in sample.candidates] |
|
sample.triplets = [ |
|
{ |
|
"subject": triplet["subject"], |
|
"relation": { |
|
"name": dict_nyt[triplet["relation"]["name"]], |
|
"type": triplet["relation"]["type"], |
|
}, |
|
"object": triplet["object"], |
|
} |
|
for triplet in sample.triplets |
|
] |
|
|
|
predicted_samples = reader.read(samples=samples, progress_bar=True) |
|
if is_eval: |
|
strong_matching_metric = StrongMatching() |
|
predicted_samples = list(predicted_samples) |
|
for k, v in strong_matching_metric(predicted_samples).items(): |
|
print(f"test_{k}", v) |
|
if output_path is not None: |
|
with open(output_path, "w") as f: |
|
for sample in predicted_samples: |
|
f.write(sample.to_jsons() + "\n") |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--model_path", |
|
type=str, |
|
default="/home/huguetcabot/alby-re/relik/relik/reader/models/relik_re_reader_base", |
|
) |
|
parser.add_argument( |
|
"--data_path", |
|
type=str, |
|
default="/home/huguetcabot/alby-re/relik/relik/reader/data/testa.jsonl", |
|
) |
|
parser.add_argument("--is-eval", action="store_true") |
|
parser.add_argument("--output_path", type=str, default=None) |
|
args = parser.parse_args() |
|
eval(args.model_path, args.data_path, args.is_eval, args.output_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|