File size: 4,252 Bytes
4da4ff3
 
d5062c8
4da4ff3
 
 
 
 
d5062c8
 
 
 
 
 
 
 
4da4ff3
 
 
 
 
 
 
d5062c8
 
ac97109
4da4ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5062c8
 
 
 
 
 
 
 
 
ac97109
d5062c8
 
 
 
 
ac97109
d5062c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d883519
 
 
ac97109
 
 
d5062c8
 
 
 
 
 
4da4ff3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import logging
import os
import shutil
import subprocess
from pathlib import Path
from tempfile import TemporaryDirectory

from tqdm.contrib.concurrent import process_map


def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx : min(ndx + n, l)]


def main():
    logging.basicConfig(level=logging.INFO)
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str, default="gnorm2", help="mode to run in (gnorm2, gnormplus)")
    parser.add_argument("input_dir", type=str, help="directory containing files to process")
    parser.add_argument("output_dir", type=str, help="directory to write processed files to")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--max_workers", type=int, default=os.cpu_count() - 4)
    parser.add_argument("--ignore_errors", action="store_true", help="ignore errors")
    args = parser.parse_args()

    input_dir = Path(args.input_dir)
    input_files = input_dir.rglob("*")
    input_files = set(file.name for file in input_files)
    output_dir = Path(args.output_dir)
    output_dir.mkdir(exist_ok=True)
    output_files = output_dir.rglob("*")
    output_files = set(file.name for file in output_files)

    logging.info(f"Found {len(input_files)} input files")
    logging.info(f"Found {len(output_files)} output files")

    input_files = input_files - output_files

    logging.info(f"Processing {len(input_files)} files")

    input_files = sorted(input_files, key=lambda file: (input_dir / file).stat().st_size)

    input_files_batches = list(batch(list(input_files), args.batch_size))
    process_map(
        run_batch,
        input_files_batches,
        [input_dir] * len(input_files_batches),
        [output_dir] * len(input_files_batches),
        [args.mode] * len(input_files_batches),
        [args.ignore_errors] * len(input_files_batches),
        max_workers=args.max_workers,
        chunksize=1,
    )


def run_batch(input_files_batch, input_dir, output_dir, mode, ignore_errors):
    with TemporaryDirectory() as temp_dir_SR, TemporaryDirectory() as temp_dir_GNR, TemporaryDirectory() as temp_dir_SA, TemporaryDirectory() as input_temp_dir, TemporaryDirectory() as output_temp_dir:
        input_temp_dir = Path(input_temp_dir)
        output_temp_dir = Path(output_temp_dir)
        for file in input_files_batch:
            logging.info(f"cp {input_dir / file} {input_temp_dir}")
            shutil.copy(input_dir / file, input_temp_dir)

        if mode == "gnorm2":
            command_SR = (
                f"java -Xmx32G -Xms16G -jar GNormPlus.jar {str(input_temp_dir)} {str(temp_dir_SR)} setup.SR.txt"
            )
            command_GNR_SA = f"python GeneNER_SpeAss_run.py -i {str(temp_dir_SR)} -r {str(temp_dir_GNR)} -a {str(temp_dir_SA)} -n gnorm_trained_models/geneNER/GeneNER-Bioformer.h5 -s gnorm_trained_models/SpeAss/SpeAss-Bioformer.h5"
            command_GN = (
                f"java -Xmx32G -Xms16G -jar GNormPlus.jar {str(temp_dir_SA)} {str(output_temp_dir)} setup.GN.txt"
            )
            commands = [command_SR, command_GNR_SA, command_GN]
        elif mode == "gnormplus":
            commands = [
                f"java -Xmx32G -Xms16G -jar GNormPlus.jar {str(input_temp_dir)} {str(output_temp_dir)} setup.txt"
            ]
        else:
            raise ValueError(f"Invalid mode: {mode}")

        for command in commands:
            try:
                logging.info(command)
                subprocess.run([command], check=True, shell=True)
            except subprocess.CalledProcessError as e:
                logging.exception(f"Error running command: {command}")
                if "returned non-zero exit status 137" in str(e):
                    logging.error("Process killed due to memory limit")
                    return
                elif ignore_errors:
                    logging.error("Ignoring error")
                    return
                raise e

        output_paths = list(output_temp_dir.rglob("*"))
        for output_path in output_paths:
            logging.info(f"cp {output_path} {output_dir}")
            shutil.copy(output_path, output_dir)


if __name__ == "__main__":
    main()