import statistics import sys from dataclasses import dataclass from typing import List, Union import torch from numpy.typing import NDArray from type_aliases import DEVICE_TYPE, ENCODER_DEVICE_TYPE, NumSentencesType, EmbeddingSlicesType def get_gpu(gpu: DEVICE_TYPE) -> ENCODER_DEVICE_TYPE: """ Determine the correct GPU device based on the provided input. In the following, output 0 means CUDA device 0. Args: gpu (Union[bool, str, int, List[Union[str, int]]]): Input specifying the GPU device(s): - bool: If True, returns 0 if CUDA is available, otherwise returns "cpu". - str: Can be "cpu", "gpu", or "cuda" (case-insensitive). Returns 0 if CUDA is available and the input is not "cpu", otherwise returns "cpu". - int: Should be a valid GPU index. Returns the index if CUDA is available and valid, otherwise returns "cpu". - List[Union[str, int]]: List containing combinations of the str/int. Processes each element and returns a list of corresponding results. Returns: Union[str, int, List[Union[str, int]]]: Depending on the input type: - str: Returns "cpu" if no GPU is available or the input is "cpu". - int: Returns the GPU index if valid and CUDA is available. - List[Union[str, int]]: Returns a list of strings and/or integers based on the input list. Raises: ValueError: If the input gpu type is not recognized or invalid. ValueError: If a string input is not one of ["cpu", "gpu", "cuda"]. ValueError: If an integer input is outside the valid range of GPU indices. Notes: - This function checks CUDA availability using torch.cuda.is_available() and counts available GPUs using torch.cuda.device_count(). - Case insensitivity is maintained for string inputs ("cpu", "gpu", "cuda"). - The function ensures robust error handling for invalid input types or out-of-range indices. """ # Ensure gpu index is within the range of total available gpus gpu_available = torch.cuda.is_available() gpu_count = torch.cuda.device_count() correct_strs = ["cpu", "gpu", "cuda"] def _get_single_device(gpu_item): if isinstance(gpu_item, bool): return 0 if gpu_item and gpu_available else "cpu" elif isinstance(gpu_item, str): if gpu_item.lower() not in correct_strs: raise ValueError(f"Wrong gpu type: {gpu_item}. Should be one of {correct_strs}") return 0 if (gpu_item.lower() != "cpu") and gpu_available else "cpu" elif isinstance(gpu_item, int): if gpu_item >= gpu_count: raise ValueError( f"There are {gpu_count} GPUs available. Provide a valid GPU index. You provided: {gpu_item}" ) return gpu_item if gpu_available else "cpu" else: raise ValueError(f"Invalid gpu type: {type(gpu_item)}. Must be bool, str, or int.") if isinstance(gpu, list): seen_indices = set() result = [] for item in gpu: device = _get_single_device(item) if isinstance(device, int): if device not in seen_indices: seen_indices.add(device) result.append(device) else: result.append(device) return result else: return _get_single_device(gpu) def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType: def _slice_embeddings(s_idx: int, n_sentences: List[int]): _result = [] for count in n_sentences: _result.append(embeddings[s_idx:s_idx + count]) s_idx += count return _result, s_idx if isinstance(num_sentences, list) and all(isinstance(item, int) for item in num_sentences): result, _ = _slice_embeddings(0, num_sentences) return result elif isinstance(num_sentences, list) and all( isinstance(sublist, list) and all( isinstance(item, int) for item in sublist ) for sublist in num_sentences ): nested_result = [] start_idx = 0 for nested_num_sentences in num_sentences: embedding_slice, start_idx = _slice_embeddings(start_idx, nested_num_sentences) nested_result.append(embedding_slice) return nested_result else: raise TypeError(f"Incorrect Type for {num_sentences=}") def is_nested_list_of_type(lst_obj, element_type, depth: int) -> bool: if depth == 0: return isinstance(lst_obj, element_type) elif depth > 0: return isinstance(lst_obj, list) and all(is_nested_list_of_type(item, element_type, depth - 1) for item in lst_obj) else: raise ValueError("Depth can't be negative") def flatten_list(nested_list: list) -> list: """ Recursively flattens a nested list of any depth. Parameters: nested_list (list): The nested list to flatten. Returns: list: A flat list containing all the elements of the nested list. """ flat_list = [] for item in nested_list: if isinstance(item, list): flat_list.extend(flatten_list(item)) else: flat_list.append(item) return flat_list def compute_f1(p: float, r: float, eps=sys.float_info.epsilon) -> float: """ Computes F1 value :param p: Precision Value :param r: Recall Value :param eps: Epsilon Value :return: """ f1 = 2 * p * r / (p + r + eps) return f1 @dataclass class Scores: precision: float recall: List[float] def __post_init__(self): self.f1: float = compute_f1(self.precision, statistics.fmean(self.recall))