File size: 19,756 Bytes
12bee07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
# Config utils
import yaml
import pathlib
from os.path import join
import os
import numpy as np
import torch
from multiprocessing import cpu_count

class BaseConfig:
    """Base class for managing and validating configurations."""

    numpy_dtype_mapping = {1: np.int8,
                           2: np.int16,
                           8: np.int64,
                           4: np.int32}

    def __init__(self):
        super().__init__()

    def cast_to_expected_type(self, parameter_class: str, parameter_name: str, value: any) -> any:
        """
        Cast the given value to the expected type.

        :param parameter_class: The class/category of the parameter.
        :type parameter_class: str
        :param parameter_name: The name of the parameter.
        :type parameter_name: str
        :param value: The value to be casted.
        :type value: any
        :return: Value casted to the expected type.
        :rtype: any
        :raises ValueError: If casting fails.
        """
        expected_type = self.parameters[parameter_class][parameter_name]['type']

        if expected_type in ["integer", "int"]:
            try:
                return int(value)
            except ValueError:
                raise ValueError(f"Failed to cast value '{value}' to integer for parameter '{parameter_name}' in class '{parameter_class}'.")
        elif expected_type == "float":
            try:
                return float(value)
            except ValueError:
                raise ValueError(f"Failed to cast value '{value}' to float for parameter '{parameter_name}' in class '{parameter_class}'.")
        elif expected_type in ["string", "str"]:
            return str(value)
        elif expected_type in ["boolean", "bool"]:
            if isinstance(value, bool):
                return value
            elif str(value).lower() == "true":
                return True
            elif str(value).lower() == "false":
                return False
            else:
                raise ValueError(f"Failed to cast value '{value}' to boolean for parameter '{parameter_name}' in class '{parameter_class}'.")
        elif expected_type == "type":
            # For this type, we will simply return the value without casting. 
            # It assumes the configuration provides valid Python types.
            return value
        elif expected_type == "list":
            if isinstance(value, list):
                return value
            else:
                raise ValueError(f"Failed to validate value '{value}' as a list for parameter '{parameter_name}' in class '{parameter_class}'.")
        elif expected_type == "tuple":
            if isinstance(value, tuple):
                return value
            else:
                raise ValueError(f"Failed to validate value '{value}' as a tuple for parameter '{parameter_name}' in class '{parameter_class}'.")
        elif expected_type == "set":
            if isinstance(value, set):
                return value
            else:
                raise ValueError(f"Failed to validate value '{value}' as a set for parameter '{parameter_name}' in class '{parameter_class}'.")
        elif expected_type == "dict":
            if isinstance(value, dict):
                return value
            else:
                raise ValueError(f"Failed to validate value '{value}' as a dict for parameter '{parameter_name}' in class '{parameter_class}'.")
        else:
            raise ValueError(f"Unknown expected type '{expected_type}' for parameter '{parameter_name}' in class '{parameter_class}'.")



    def get_parameter(self, parameter_class: str, parameter_name: str) -> any:
        """
        Retrieve the default value of a specified parameter.

        :param parameter_class: The class/category of the parameter (e.g., 'segmentation').
        :type parameter_class: str
        :param parameter_name: The name of the parameter.
        :type parameter_name: str
        :return: Default value of the parameter, casted to the expected type.
        :rtype: any
        """
        default_value = self.parameters[parameter_class][parameter_name]['default']
        return self.cast_to_expected_type(parameter_class, parameter_name, default_value)
    

    
    def validate_type(self, parameter_class: str, parameter_name: str, value: any) -> bool:
        """
        Validate the type of a given value against the expected type.

        :param parameter_class: The class/category of the parameter.
        :type parameter_class: str
        :param parameter_name: The name of the parameter.
        :type parameter_name: str
        :param value: The value to be validated.
        :type value: any
        :return: True if the value is of the expected type, otherwise False.
        :rtype: bool
        """
        expected_type = self.parameters[parameter_class][parameter_name]['type']

        if expected_type == "integer" and not isinstance(value, int):
            return False
        elif expected_type == "float" and not isinstance(value, float):
            return False
        elif expected_type == "string" and not isinstance(value, str):
            return False
        else:
            return True
    
    def validate_value(self, parameter_class: str, parameter_name: str, value: any) -> bool:
        """
        Validate the value of a parameter against its constraints.

        :param parameter_class: The class/category of the parameter.
        :type parameter_class: str
        :param parameter_name: The name of the parameter.
        :type parameter_name: str
        :param value: The value to be validated.
        :type value: any
        :return: True if the value meets the constraints, otherwise False.
        :rtype: bool
        """
        constraints = self.parameters[parameter_class][parameter_name].get('constraints', {})
        
        if 'options' in constraints and value not in constraints['options']:
            return False
        if 'min' in constraints and value < constraints['min']:
            return False
        if 'max' in constraints and value > constraints['max']:
            return False
        return True
    

    def validate(self, parameter_class: str, parameter_name: str, value: any):
        """
        Validate both the type and value of a parameter.

        :param parameter_class: The class/category of the parameter.
        :type parameter_class: str
        :param parameter_name: The name of the parameter.
        :type parameter_name: str
        :param value: The value to be validated.
        :type value: any
        :raises TypeError: If the value is not of the expected type.
        :raises ValueError: If the value does not meet the parameter's constraints.
        """
        if not self.validate_type(parameter_class, parameter_name, value):
            raise TypeError(f"Invalid type for {parameter_name} for parameter class '{parameter_class}'. Expected {self.parameters[parameter_class][parameter_name]['type']}.")
        
        if not self.validate_value(parameter_class, parameter_name, value):
            raise ValueError(f"Invalid value for {parameter_name}  for parameter class '{parameter_class}'. Constraints: {self.parameters[parameter_class][parameter_name].get('constraints', {})}.")

    def describe(self, parameter_class: str, parameter_name: str) -> str:
        """
        Retrieve the description of a parameter.

        :param parameter_class: The class/category of the parameter.
        :type parameter_class: str
        :param parameter_name: The name of the parameter.
        :type parameter_name: str
        :return: Description of the parameter.
        :rtype: str
        """
        return self.parameters[parameter_class][parameter_name]['description']



class SeqConfig(BaseConfig):
    """Class to manage and validate sequence processing configurations."""

    def __init__(self):
        super().__init__()
        self.default_seq_config_file = self._get_default_sequence_processing_config_file()
        with open(self.default_seq_config_file, 'r') as file:
            self.parameters = yaml.safe_load(file)

        # Some postprocessing steps
        self.parameters['tokenization']['shift']['constraints']['max'] = self.parameters['tokenization']['kmer']['default']-1
        # Ha valaki update-li a k-mer paramter-t, akkor triggerelni kellene, hogy mi legyen. 

        self.get_and_set_segmentation_parameters()
        self.get_and_set_tokenization_parameters()
        self.get_and_set_computational_parameters()

    def _get_default_sequence_processing_config_file(self) -> str:
        """
        Retrieve the default sequence processing configuration file.

        :return: Path to the configuration file.
        :rtype: str
        """
        current_path = pathlib.Path(__file__).parent
        prokbert_seq_config_file = join(current_path, 'configs', 'sequence_processing.yaml')
        self.current_path = current_path

        try:
            # Attempt to read the environment variable
            prokbert_seq_config_file = os.environ['SEQ_CONFIG_FILE']
        except KeyError:
            # Handle the case when the environment variable is not found
            print("SEQ_CONFIG_FILE environment variable has not been set. Using default value: {0}".format(prokbert_seq_config_file))
        return prokbert_seq_config_file

    
    def get_and_set_segmentation_parameters(self, parameters: dict = {}) -> dict:
        """
        Retrieve and validate the provided parameters for segmentation.

        :param parameters: A dictionary of parameters to be validated.
        :type parameters: dict
        :return: A dictionary of validated segmentation parameters.
        :rtype: dict
        :raises ValueError: If an invalid segmentation parameter is provided.
        """
        segmentation_params = {k: self.get_parameter('segmentation', k) for k in self.parameters['segmentation']}

        for param, param_value in parameters.items():
            if param not in segmentation_params:
                raise ValueError(f"The provided {param} is an INVALID segmentation parameter! The valid parameters are: {list(segmentation_params.keys())}")
            self.validate('segmentation', param, param_value)
            segmentation_params[param] = param_value
        self.segmentation_params = segmentation_params


        return segmentation_params


    def get_and_set_tokenization_parameters(self, parameters: dict = {}) -> dict:
        # Updating the other parameters if necesseary, i.e. if k-mer has-been changed, then the shift is updated and we run a parameter check at the end

        tokenization_params = {k: self.get_parameter('tokenization', k) for k in self.parameters['tokenization']}
        for param, param_value in parameters.items():
            if param not in tokenization_params:
                raise ValueError(f"The provided {param} is an INVALID tokenization parameter! The valid parameters are: {list(tokenization_params.keys())}")
            self.validate('tokenization', param, param_value)
            tokenization_params[param] = param_value

        # Loading and check the vocab file. It is assumed that its ordered dictionary
        vocabfile=tokenization_params['vocabfile']
        act_kmer = tokenization_params['kmer']
        if vocabfile=='auto':
            print(self.current_path)
            vocabfile_path = join(self.current_path, 'data/prokbert_vocabs/', f'prokbert-base-dna{act_kmer}', 'vocab.txt')
            tokenization_params['vocabfile'] = vocabfile_path
        else:
            vocabfile_path = vocabfile
        with open(vocabfile_path) as vocabfile_in:
            vocabmap = {line.strip(): i for i, line in enumerate(vocabfile_in)}
        tokenization_params['vocabmap'] = vocabmap

        # Loading the vocab
        self.tokenization_params = tokenization_params
        return tokenization_params    

    def get_and_set_computational_parameters(self, parameters: dict = {}) -> dict:
        """ Reading and validating the computational paramters
        """

        computational_params = {k: self.get_parameter('computation', k) for k in self.parameters['computation']}
        core_count = cpu_count()

        if computational_params['cpu_cores_for_segmentation'] == -1:
            computational_params['cpu_cores_for_segmentation'] = core_count

        if computational_params['cpu_cores_for_tokenization'] == -1:
            computational_params['cpu_cores_for_tokenization'] = core_count

        

        for param, param_value in parameters.items():
            if param not in computational_params:
                raise ValueError(f"The provided {param} is an INVALID computation parameter! The valid parameters are: {list(computational_params.keys())}")
            self.validate('computation', param, param_value)
            computational_params[param] = param_value

        np_tokentype= SeqConfig.numpy_dtype_mapping[computational_params['numpy_token_integer_prec_byte']]
        computational_params['np_tokentype'] = np_tokentype
        self.computational_params = computational_params
        return computational_params


    def get_maximum_segment_length_from_token_count_from_params(self):
        """Calculating the maximum length of the segment from the token count """
        max_token_counts = self.tokenization_params['token_limit']
        shift = self.tokenization_params['shift']
        kmer = self.tokenization_params['kmer']
        return self.get_maximum_segment_length_from_token_count(max_token_counts, shift, kmer)

    def get_maximum_token_count_from_max_length_from_params(self):
        """Calculating the maximum length of the segment from the token count """


        max_segment_length = self.tokenization_params['max_segment_length']
        shift = self.tokenization_params['shift']
        kmer = self.tokenization_params['kmer']          
        max_token_count = self.get_maximum_token_count_from_max_length(max_segment_length, shift, kmer)

        return max_token_count

    @staticmethod
    def get_maximum_segment_length_from_token_count(max_token_counts, shift, kmer):
        """Calcuates how long sequence can be covered
        """

        max_segment_length = (max_token_counts-3)*shift + kmer
        return max_segment_length

    @staticmethod
    def get_maximum_token_count_from_max_length(max_segment_length, shift, kmer):
        """Calcuates how long sequence can be covered
        """
        max_token_count = int(np.ceil((max_segment_length - kmer)/shift+3))
        return max_token_count

class ProkBERTConfig(BaseConfig):
    """Class to manage and validate pretraining configurations."""

    torch_dtype_mapping = {1: torch.uint8,
                           2: torch.int16,
                           8: torch.int64,
                           4: torch.int32}

    def __init__(self):
        super().__init__()

        self.default_pretrain_config_file = self._get_default_pretrain_config_file()
        with open(self.default_pretrain_config_file, 'r') as file:
            self.parameters = yaml.safe_load(file)
            
        # Load and validate each parameter set
        self.data_collator_params = self.get_set_parameters('data_collator')
        self.model_params = self.get_set_parameters('model')
        self.dataset_params = self.get_set_parameters('dataset')
        self.pretraining_params = self.get_set_parameters('pretraining')
        # Getting the sequtils params as well

        self.def_seq_config = SeqConfig()
        self.segmentation_params = self.def_seq_config.get_and_set_segmentation_parameters(self.parameters['segmentation'])
        self.tokenization_params = self.def_seq_config.get_and_set_tokenization_parameters(self.parameters['tokenization'])
        self.computation_params = self.def_seq_config.get_and_set_computational_parameters(self.parameters['computation'])

        self.default_torchtype = ProkBERTConfig.torch_dtype_mapping[self.computation_params['numpy_token_integer_prec_byte']]

    def _get_default_pretrain_config_file(self) -> str:
        """
        Retrieve the default pretraining configuration file.

        :return: Path to the configuration file.
        :rtype: str
        """
        current_path = pathlib.Path(__file__).parent
        pretrain_config_file = join(current_path, 'configs', 'pretraining.yaml')

        try:
            # Attempt to read the environment variable
            pretrain_config_file = os.environ['PRETRAIN_CONFIG_FILE']
        except KeyError:
            # Handle the case when the environment variable is not found
            print(f"PRETRAIN_CONFIG_FILE environment variable has not been set. Using default value: {pretrain_config_file}")
        return pretrain_config_file
    
    def get_set_parameters(self, parameter_class: str, parameters: dict = {}) -> dict:
        """
        Retrieve and validate the provided parameters for a given parameter class.

        :param parameter_class: The class/category of the parameter (e.g., 'data_collator').
        :type parameter_class: str
        :param parameters: A dictionary of parameters to be validated.
        :type parameters: dict
        :return: A dictionary of validated parameters.
        :rtype: dict
        :raises ValueError: If an invalid parameter is provided.
        """
        class_params = {k: self.get_parameter(parameter_class, k) for k in self.parameters[parameter_class]}

        # First validatiading the class parameters as well
        for param, param_value in class_params.items():

            self.validate(parameter_class, param, param_value)


        for param, param_value in parameters.items():
            if param not in class_params:
                raise ValueError(f"The provided {param} is an INVALID {parameter_class} parameter! The valid parameters are: {list(class_params.keys())}")
            self.validate(parameter_class, param, param_value)
            class_params[param] = param_value

        return class_params
    
    def get_and_set_model_parameters(self, parameters: dict = {}) -> dict:
        """ Setting the model parameters """

        self.model_params = self.get_set_parameters('model', parameters)

        return self.model_params

    def get_and_set_dataset_parameters(self, parameters: dict = {}) -> dict:
        """ Setting the dataset parameters """

        self.dataset_params = self.get_set_parameters('dataset', parameters)

        return self.dataset_params

    def get_and_set_pretraining_parameters(self, parameters: dict = {}) -> dict:
        """ Setting the model parameters """
        self.pretraining_params = self.get_set_parameters('pretraining', parameters)

        return self.pretraining_params       
    
    
    def get_and_set_datacollator_parameters(self, parameters: dict = {}) -> dict:
        """ Setting the model parameters """
        self.data_collator_params = self.get_set_parameters('data_collator', parameters)
        return self.data_collator_params
    
    def get_and_set_segmentation_parameters(self, parameters: dict = {}) -> dict:
        self.segmentation_params = self.def_seq_config.get_and_set_segmentation_parameters(parameters)

        return self.segmentation_params 
    def get_and_set_tokenization_parameters(self, parameters: dict = {}) -> dict:
        self.tokenization_params = self.def_seq_config.get_and_set_tokenization_parameters(parameters)
        
        return self.tokenization_params 
    def get_and_set_computation_params(self, parameters: dict = {}) -> dict:
        self.computation_params = self.def_seq_config.get_and_set_computational_parameters(parameters)
        return self.computation_params