import argparse from pprint import pprint from typing import Optional from relik.reader.relik_reader_predictor import RelikReaderPredictor from relik.reader.utils.strong_matching_eval import StrongMatching from relik.reader.relik_reader_core import RelikReaderCoreModel from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction import hydra from omegaconf import DictConfig from relik.reader.data.relik_reader_sample import load_relik_reader_samples import json # @hydra.main(config_path="config.yaml", config_name="") # Specify your config path and name here def predict( model_path: str, dataset_path: str, token_batch_size: int, is_eval: bool, output_path: Optional[str], ) -> None: relik_reader = RelikReaderForSpanExtraction(model_path,training=False, device="cuda") samples = list(load_relik_reader_samples(dataset_path)) predicted_samples = relik_reader.read( samples=samples, progress_bar=True ) if True: eval_dict = StrongMatching()(predicted_samples) pprint(eval_dict) if output_path is not None: with open(output_path, "w") as f: gold_text = "" for sample in predicted_samples: text = sample.to_jsons() # json.dump(text, f) # f.write("\n") gold_text += str(text["window_labels"]) + "\t" + str(text["predicted_window_labels"]) + "\n" f.write(gold_text) def parse_arg() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument( "--model-path", required=True, ) parser.add_argument("--dataset-path", "-i", required=True) parser.add_argument("--is-eval", action="store_true") parser.add_argument( "--output-path", "-o", ) parser.add_argument("--token-batch-size", default=4096) return parser.parse_args() def main(): args = parse_arg() predict( args.model_path, args.dataset_path, token_batch_size=args.token_batch_size, is_eval=args.is_eval, output_path=args.output_path, ) if __name__ == "__main__": main()