import argparse import logging import os import shutil import subprocess from pathlib import Path from tempfile import TemporaryDirectory from tqdm.contrib.concurrent import process_map def batch(iterable, n=1): l = len(iterable) for ndx in range(0, l, n): yield iterable[ndx : min(ndx + n, l)] def main(): logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser() parser.add_argument("--mode", type=str, default="gnorm2", help="mode to run in (gnorm2, gnormplus)") parser.add_argument("input_dir", type=str, help="directory containing files to process") parser.add_argument("output_dir", type=str, help="directory to write processed files to") parser.add_argument("--batch_size", type=int, default=8) parser.add_argument("--max_workers", type=int, default=os.cpu_count() - 4) parser.add_argument("--ignore_errors", action="store_true", help="ignore errors") args = parser.parse_args() input_dir = Path(args.input_dir) input_files = input_dir.rglob("*") input_files = set(file.name for file in input_files) output_dir = Path(args.output_dir) output_dir.mkdir(exist_ok=True) output_files = output_dir.rglob("*") output_files = set(file.name for file in output_files) logging.info(f"Found {len(input_files)} input files") logging.info(f"Found {len(output_files)} output files") input_files = input_files - output_files logging.info(f"Processing {len(input_files)} files") input_files = sorted(input_files, key=lambda file: (input_dir / file).stat().st_size) input_files_batches = list(batch(list(input_files), args.batch_size)) process_map( run_batch, input_files_batches, [input_dir] * len(input_files_batches), [output_dir] * len(input_files_batches), [args.mode] * len(input_files_batches), [args.ignore_errors] * len(input_files_batches), max_workers=args.max_workers, chunksize=1, ) def run_batch(input_files_batch, input_dir, output_dir, mode, ignore_errors): with TemporaryDirectory() as temp_dir_SR, TemporaryDirectory() as temp_dir_GNR, TemporaryDirectory() as temp_dir_SA, TemporaryDirectory() as input_temp_dir, TemporaryDirectory() as output_temp_dir: input_temp_dir = Path(input_temp_dir) output_temp_dir = Path(output_temp_dir) for file in input_files_batch: logging.info(f"cp {input_dir / file} {input_temp_dir}") shutil.copy(input_dir / file, input_temp_dir) if mode == "gnorm2": command_SR = ( f"java -Xmx32G -Xms16G -jar GNormPlus.jar {str(input_temp_dir)} {str(temp_dir_SR)} setup.SR.txt" ) command_GNR_SA = f"python GeneNER_SpeAss_run.py -i {str(temp_dir_SR)} -r {str(temp_dir_GNR)} -a {str(temp_dir_SA)} -n gnorm_trained_models/geneNER/GeneNER-Bioformer.h5 -s gnorm_trained_models/SpeAss/SpeAss-Bioformer.h5" command_GN = ( f"java -Xmx32G -Xms16G -jar GNormPlus.jar {str(temp_dir_SA)} {str(output_temp_dir)} setup.GN.txt" ) commands = [command_SR, command_GNR_SA, command_GN] elif mode == "gnormplus": commands = [ f"java -Xmx32G -Xms16G -jar GNormPlus.jar {str(input_temp_dir)} {str(output_temp_dir)} setup.txt" ] else: raise ValueError(f"Invalid mode: {mode}") for command in commands: try: logging.info(command) subprocess.run([command], check=True, shell=True) except subprocess.CalledProcessError as e: logging.exception(f"Error running command: {command}") if "returned non-zero exit status 137" in str(e): logging.error("Process killed due to memory limit") return elif ignore_errors: logging.error("Ignoring error") return raise e output_paths = list(output_temp_dir.rglob("*")) for output_path in output_paths: logging.info(f"cp {output_path} {output_dir}") shutil.copy(output_path, output_dir) if __name__ == "__main__": main()