|
|
|
import yaml |
|
import pathlib |
|
from os.path import join |
|
import os |
|
import numpy as np |
|
import torch |
|
from multiprocessing import cpu_count |
|
|
|
class BaseConfig: |
|
"""Base class for managing and validating configurations.""" |
|
|
|
numpy_dtype_mapping = {1: np.int8, |
|
2: np.int16, |
|
8: np.int64, |
|
4: np.int32} |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def cast_to_expected_type(self, parameter_class: str, parameter_name: str, value: any) -> any: |
|
""" |
|
Cast the given value to the expected type. |
|
|
|
:param parameter_class: The class/category of the parameter. |
|
:type parameter_class: str |
|
:param parameter_name: The name of the parameter. |
|
:type parameter_name: str |
|
:param value: The value to be casted. |
|
:type value: any |
|
:return: Value casted to the expected type. |
|
:rtype: any |
|
:raises ValueError: If casting fails. |
|
""" |
|
expected_type = self.parameters[parameter_class][parameter_name]['type'] |
|
|
|
if expected_type in ["integer", "int"]: |
|
try: |
|
return int(value) |
|
except ValueError: |
|
raise ValueError(f"Failed to cast value '{value}' to integer for parameter '{parameter_name}' in class '{parameter_class}'.") |
|
elif expected_type == "float": |
|
try: |
|
return float(value) |
|
except ValueError: |
|
raise ValueError(f"Failed to cast value '{value}' to float for parameter '{parameter_name}' in class '{parameter_class}'.") |
|
elif expected_type in ["string", "str"]: |
|
return str(value) |
|
elif expected_type in ["boolean", "bool"]: |
|
if isinstance(value, bool): |
|
return value |
|
elif str(value).lower() == "true": |
|
return True |
|
elif str(value).lower() == "false": |
|
return False |
|
else: |
|
raise ValueError(f"Failed to cast value '{value}' to boolean for parameter '{parameter_name}' in class '{parameter_class}'.") |
|
elif expected_type == "type": |
|
|
|
|
|
return value |
|
elif expected_type == "list": |
|
if isinstance(value, list): |
|
return value |
|
else: |
|
raise ValueError(f"Failed to validate value '{value}' as a list for parameter '{parameter_name}' in class '{parameter_class}'.") |
|
elif expected_type == "tuple": |
|
if isinstance(value, tuple): |
|
return value |
|
else: |
|
raise ValueError(f"Failed to validate value '{value}' as a tuple for parameter '{parameter_name}' in class '{parameter_class}'.") |
|
elif expected_type == "set": |
|
if isinstance(value, set): |
|
return value |
|
else: |
|
raise ValueError(f"Failed to validate value '{value}' as a set for parameter '{parameter_name}' in class '{parameter_class}'.") |
|
elif expected_type == "dict": |
|
if isinstance(value, dict): |
|
return value |
|
else: |
|
raise ValueError(f"Failed to validate value '{value}' as a dict for parameter '{parameter_name}' in class '{parameter_class}'.") |
|
else: |
|
raise ValueError(f"Unknown expected type '{expected_type}' for parameter '{parameter_name}' in class '{parameter_class}'.") |
|
|
|
|
|
|
|
def get_parameter(self, parameter_class: str, parameter_name: str) -> any: |
|
""" |
|
Retrieve the default value of a specified parameter. |
|
|
|
:param parameter_class: The class/category of the parameter (e.g., 'segmentation'). |
|
:type parameter_class: str |
|
:param parameter_name: The name of the parameter. |
|
:type parameter_name: str |
|
:return: Default value of the parameter, casted to the expected type. |
|
:rtype: any |
|
""" |
|
default_value = self.parameters[parameter_class][parameter_name]['default'] |
|
return self.cast_to_expected_type(parameter_class, parameter_name, default_value) |
|
|
|
|
|
|
|
def validate_type(self, parameter_class: str, parameter_name: str, value: any) -> bool: |
|
""" |
|
Validate the type of a given value against the expected type. |
|
|
|
:param parameter_class: The class/category of the parameter. |
|
:type parameter_class: str |
|
:param parameter_name: The name of the parameter. |
|
:type parameter_name: str |
|
:param value: The value to be validated. |
|
:type value: any |
|
:return: True if the value is of the expected type, otherwise False. |
|
:rtype: bool |
|
""" |
|
expected_type = self.parameters[parameter_class][parameter_name]['type'] |
|
|
|
if expected_type == "integer" and not isinstance(value, int): |
|
return False |
|
elif expected_type == "float" and not isinstance(value, float): |
|
return False |
|
elif expected_type == "string" and not isinstance(value, str): |
|
return False |
|
else: |
|
return True |
|
|
|
def validate_value(self, parameter_class: str, parameter_name: str, value: any) -> bool: |
|
""" |
|
Validate the value of a parameter against its constraints. |
|
|
|
:param parameter_class: The class/category of the parameter. |
|
:type parameter_class: str |
|
:param parameter_name: The name of the parameter. |
|
:type parameter_name: str |
|
:param value: The value to be validated. |
|
:type value: any |
|
:return: True if the value meets the constraints, otherwise False. |
|
:rtype: bool |
|
""" |
|
constraints = self.parameters[parameter_class][parameter_name].get('constraints', {}) |
|
|
|
if 'options' in constraints and value not in constraints['options']: |
|
return False |
|
if 'min' in constraints and value < constraints['min']: |
|
return False |
|
if 'max' in constraints and value > constraints['max']: |
|
return False |
|
return True |
|
|
|
|
|
def validate(self, parameter_class: str, parameter_name: str, value: any): |
|
""" |
|
Validate both the type and value of a parameter. |
|
|
|
:param parameter_class: The class/category of the parameter. |
|
:type parameter_class: str |
|
:param parameter_name: The name of the parameter. |
|
:type parameter_name: str |
|
:param value: The value to be validated. |
|
:type value: any |
|
:raises TypeError: If the value is not of the expected type. |
|
:raises ValueError: If the value does not meet the parameter's constraints. |
|
""" |
|
if not self.validate_type(parameter_class, parameter_name, value): |
|
raise TypeError(f"Invalid type for {parameter_name} for parameter class '{parameter_class}'. Expected {self.parameters[parameter_class][parameter_name]['type']}.") |
|
|
|
if not self.validate_value(parameter_class, parameter_name, value): |
|
raise ValueError(f"Invalid value for {parameter_name} for parameter class '{parameter_class}'. Constraints: {self.parameters[parameter_class][parameter_name].get('constraints', {})}.") |
|
|
|
def describe(self, parameter_class: str, parameter_name: str) -> str: |
|
""" |
|
Retrieve the description of a parameter. |
|
|
|
:param parameter_class: The class/category of the parameter. |
|
:type parameter_class: str |
|
:param parameter_name: The name of the parameter. |
|
:type parameter_name: str |
|
:return: Description of the parameter. |
|
:rtype: str |
|
""" |
|
return self.parameters[parameter_class][parameter_name]['description'] |
|
|
|
|
|
|
|
class SeqConfig(BaseConfig): |
|
"""Class to manage and validate sequence processing configurations.""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.default_seq_config_file = self._get_default_sequence_processing_config_file() |
|
with open(self.default_seq_config_file, 'r') as file: |
|
self.parameters = yaml.safe_load(file) |
|
|
|
|
|
self.parameters['tokenization']['shift']['constraints']['max'] = self.parameters['tokenization']['kmer']['default']-1 |
|
|
|
|
|
self.get_and_set_segmentation_parameters() |
|
self.get_and_set_tokenization_parameters() |
|
self.get_and_set_computational_parameters() |
|
|
|
def _get_default_sequence_processing_config_file(self) -> str: |
|
""" |
|
Retrieve the default sequence processing configuration file. |
|
|
|
:return: Path to the configuration file. |
|
:rtype: str |
|
""" |
|
current_path = pathlib.Path(__file__).parent |
|
prokbert_seq_config_file = join(current_path, 'configs', 'sequence_processing.yaml') |
|
self.current_path = current_path |
|
|
|
try: |
|
|
|
prokbert_seq_config_file = os.environ['SEQ_CONFIG_FILE'] |
|
except KeyError: |
|
|
|
print("SEQ_CONFIG_FILE environment variable has not been set. Using default value: {0}".format(prokbert_seq_config_file)) |
|
return prokbert_seq_config_file |
|
|
|
|
|
def get_and_set_segmentation_parameters(self, parameters: dict = {}) -> dict: |
|
""" |
|
Retrieve and validate the provided parameters for segmentation. |
|
|
|
:param parameters: A dictionary of parameters to be validated. |
|
:type parameters: dict |
|
:return: A dictionary of validated segmentation parameters. |
|
:rtype: dict |
|
:raises ValueError: If an invalid segmentation parameter is provided. |
|
""" |
|
segmentation_params = {k: self.get_parameter('segmentation', k) for k in self.parameters['segmentation']} |
|
|
|
for param, param_value in parameters.items(): |
|
if param not in segmentation_params: |
|
raise ValueError(f"The provided {param} is an INVALID segmentation parameter! The valid parameters are: {list(segmentation_params.keys())}") |
|
self.validate('segmentation', param, param_value) |
|
segmentation_params[param] = param_value |
|
self.segmentation_params = segmentation_params |
|
|
|
|
|
return segmentation_params |
|
|
|
|
|
def get_and_set_tokenization_parameters(self, parameters: dict = {}) -> dict: |
|
|
|
|
|
tokenization_params = {k: self.get_parameter('tokenization', k) for k in self.parameters['tokenization']} |
|
for param, param_value in parameters.items(): |
|
if param not in tokenization_params: |
|
raise ValueError(f"The provided {param} is an INVALID tokenization parameter! The valid parameters are: {list(tokenization_params.keys())}") |
|
self.validate('tokenization', param, param_value) |
|
tokenization_params[param] = param_value |
|
|
|
|
|
vocabfile=tokenization_params['vocabfile'] |
|
act_kmer = tokenization_params['kmer'] |
|
if vocabfile=='auto': |
|
print(self.current_path) |
|
vocabfile_path = join(self.current_path, 'data/prokbert_vocabs/', f'prokbert-base-dna{act_kmer}', 'vocab.txt') |
|
tokenization_params['vocabfile'] = vocabfile_path |
|
else: |
|
vocabfile_path = vocabfile |
|
with open(vocabfile_path) as vocabfile_in: |
|
vocabmap = {line.strip(): i for i, line in enumerate(vocabfile_in)} |
|
tokenization_params['vocabmap'] = vocabmap |
|
|
|
|
|
self.tokenization_params = tokenization_params |
|
return tokenization_params |
|
|
|
def get_and_set_computational_parameters(self, parameters: dict = {}) -> dict: |
|
""" Reading and validating the computational paramters |
|
""" |
|
|
|
computational_params = {k: self.get_parameter('computation', k) for k in self.parameters['computation']} |
|
core_count = cpu_count() |
|
|
|
if computational_params['cpu_cores_for_segmentation'] == -1: |
|
computational_params['cpu_cores_for_segmentation'] = core_count |
|
|
|
if computational_params['cpu_cores_for_tokenization'] == -1: |
|
computational_params['cpu_cores_for_tokenization'] = core_count |
|
|
|
|
|
|
|
for param, param_value in parameters.items(): |
|
if param not in computational_params: |
|
raise ValueError(f"The provided {param} is an INVALID computation parameter! The valid parameters are: {list(computational_params.keys())}") |
|
self.validate('computation', param, param_value) |
|
computational_params[param] = param_value |
|
|
|
np_tokentype= SeqConfig.numpy_dtype_mapping[computational_params['numpy_token_integer_prec_byte']] |
|
computational_params['np_tokentype'] = np_tokentype |
|
self.computational_params = computational_params |
|
return computational_params |
|
|
|
|
|
def get_maximum_segment_length_from_token_count_from_params(self): |
|
"""Calculating the maximum length of the segment from the token count """ |
|
max_token_counts = self.tokenization_params['token_limit'] |
|
shift = self.tokenization_params['shift'] |
|
kmer = self.tokenization_params['kmer'] |
|
return self.get_maximum_segment_length_from_token_count(max_token_counts, shift, kmer) |
|
|
|
def get_maximum_token_count_from_max_length_from_params(self): |
|
"""Calculating the maximum length of the segment from the token count """ |
|
|
|
|
|
max_segment_length = self.tokenization_params['max_segment_length'] |
|
shift = self.tokenization_params['shift'] |
|
kmer = self.tokenization_params['kmer'] |
|
max_token_count = self.get_maximum_token_count_from_max_length(max_segment_length, shift, kmer) |
|
|
|
return max_token_count |
|
|
|
@staticmethod |
|
def get_maximum_segment_length_from_token_count(max_token_counts, shift, kmer): |
|
"""Calcuates how long sequence can be covered |
|
""" |
|
|
|
max_segment_length = (max_token_counts-3)*shift + kmer |
|
return max_segment_length |
|
|
|
@staticmethod |
|
def get_maximum_token_count_from_max_length(max_segment_length, shift, kmer): |
|
"""Calcuates how long sequence can be covered |
|
""" |
|
max_token_count = int(np.ceil((max_segment_length - kmer)/shift+3)) |
|
return max_token_count |
|
|
|
class ProkBERTConfig(BaseConfig): |
|
"""Class to manage and validate pretraining configurations.""" |
|
|
|
torch_dtype_mapping = {1: torch.uint8, |
|
2: torch.int16, |
|
8: torch.int64, |
|
4: torch.int32} |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.default_pretrain_config_file = self._get_default_pretrain_config_file() |
|
with open(self.default_pretrain_config_file, 'r') as file: |
|
self.parameters = yaml.safe_load(file) |
|
|
|
|
|
self.data_collator_params = self.get_set_parameters('data_collator') |
|
self.model_params = self.get_set_parameters('model') |
|
self.dataset_params = self.get_set_parameters('dataset') |
|
self.pretraining_params = self.get_set_parameters('pretraining') |
|
|
|
|
|
self.def_seq_config = SeqConfig() |
|
self.segmentation_params = self.def_seq_config.get_and_set_segmentation_parameters(self.parameters['segmentation']) |
|
self.tokenization_params = self.def_seq_config.get_and_set_tokenization_parameters(self.parameters['tokenization']) |
|
self.computation_params = self.def_seq_config.get_and_set_computational_parameters(self.parameters['computation']) |
|
|
|
self.default_torchtype = ProkBERTConfig.torch_dtype_mapping[self.computation_params['numpy_token_integer_prec_byte']] |
|
|
|
def _get_default_pretrain_config_file(self) -> str: |
|
""" |
|
Retrieve the default pretraining configuration file. |
|
|
|
:return: Path to the configuration file. |
|
:rtype: str |
|
""" |
|
current_path = pathlib.Path(__file__).parent |
|
pretrain_config_file = join(current_path, 'configs', 'pretraining.yaml') |
|
|
|
try: |
|
|
|
pretrain_config_file = os.environ['PRETRAIN_CONFIG_FILE'] |
|
except KeyError: |
|
|
|
print(f"PRETRAIN_CONFIG_FILE environment variable has not been set. Using default value: {pretrain_config_file}") |
|
return pretrain_config_file |
|
|
|
def get_set_parameters(self, parameter_class: str, parameters: dict = {}) -> dict: |
|
""" |
|
Retrieve and validate the provided parameters for a given parameter class. |
|
|
|
:param parameter_class: The class/category of the parameter (e.g., 'data_collator'). |
|
:type parameter_class: str |
|
:param parameters: A dictionary of parameters to be validated. |
|
:type parameters: dict |
|
:return: A dictionary of validated parameters. |
|
:rtype: dict |
|
:raises ValueError: If an invalid parameter is provided. |
|
""" |
|
class_params = {k: self.get_parameter(parameter_class, k) for k in self.parameters[parameter_class]} |
|
|
|
|
|
for param, param_value in class_params.items(): |
|
|
|
self.validate(parameter_class, param, param_value) |
|
|
|
|
|
for param, param_value in parameters.items(): |
|
if param not in class_params: |
|
raise ValueError(f"The provided {param} is an INVALID {parameter_class} parameter! The valid parameters are: {list(class_params.keys())}") |
|
self.validate(parameter_class, param, param_value) |
|
class_params[param] = param_value |
|
|
|
return class_params |
|
|
|
def get_and_set_model_parameters(self, parameters: dict = {}) -> dict: |
|
""" Setting the model parameters """ |
|
|
|
self.model_params = self.get_set_parameters('model', parameters) |
|
|
|
return self.model_params |
|
|
|
def get_and_set_dataset_parameters(self, parameters: dict = {}) -> dict: |
|
""" Setting the dataset parameters """ |
|
|
|
self.dataset_params = self.get_set_parameters('dataset', parameters) |
|
|
|
return self.dataset_params |
|
|
|
def get_and_set_pretraining_parameters(self, parameters: dict = {}) -> dict: |
|
""" Setting the model parameters """ |
|
self.pretraining_params = self.get_set_parameters('pretraining', parameters) |
|
|
|
return self.pretraining_params |
|
|
|
|
|
def get_and_set_datacollator_parameters(self, parameters: dict = {}) -> dict: |
|
""" Setting the model parameters """ |
|
self.data_collator_params = self.get_set_parameters('data_collator', parameters) |
|
return self.data_collator_params |
|
|
|
def get_and_set_segmentation_parameters(self, parameters: dict = {}) -> dict: |
|
self.segmentation_params = self.def_seq_config.get_and_set_segmentation_parameters(parameters) |
|
|
|
return self.segmentation_params |
|
def get_and_set_tokenization_parameters(self, parameters: dict = {}) -> dict: |
|
self.tokenization_params = self.def_seq_config.get_and_set_tokenization_parameters(parameters) |
|
|
|
return self.tokenization_params |
|
def get_and_set_computation_params(self, parameters: dict = {}) -> dict: |
|
self.computation_params = self.def_seq_config.get_and_set_computational_parameters(parameters) |
|
return self.computation_params |
|
|