|
import statistics |
|
import sys |
|
from dataclasses import dataclass |
|
from typing import List, Union |
|
|
|
import torch |
|
from numpy.typing import NDArray |
|
|
|
from .type_aliases import DEVICE_TYPE, ENCODER_DEVICE_TYPE, NumSentencesType, EmbeddingSlicesType |
|
|
|
|
|
def get_gpu(gpu: DEVICE_TYPE) -> ENCODER_DEVICE_TYPE: |
|
""" |
|
Determine the correct GPU device based on the provided input. In the following, output 0 means CUDA device 0. |
|
|
|
Args: |
|
gpu (Union[bool, str, int, List[Union[str, int]]]): Input specifying the GPU device(s): |
|
- bool: If True, returns 0 if CUDA is available, otherwise returns "cpu". |
|
- str: Can be "cpu", "gpu", or "cuda" (case-insensitive). Returns 0 if CUDA is available |
|
and the input is not "cpu", otherwise returns "cpu". |
|
- int: Should be a valid GPU index. Returns the index if CUDA is available and valid, |
|
otherwise returns "cpu". |
|
- List[Union[str, int]]: List containing combinations of the str/int. Processes each |
|
element and returns a list of corresponding results. |
|
|
|
Returns: |
|
Union[str, int, List[Union[str, int]]]: Depending on the input type: |
|
- str: Returns "cpu" if no GPU is available or the input is "cpu". |
|
- int: Returns the GPU index if valid and CUDA is available. |
|
- List[Union[str, int]]: Returns a list of strings and/or integers based on the input list. |
|
|
|
Raises: |
|
ValueError: If the input gpu type is not recognized or invalid. |
|
ValueError: If a string input is not one of ["cpu", "gpu", "cuda"]. |
|
ValueError: If an integer input is outside the valid range of GPU indices. |
|
|
|
Notes: |
|
- This function checks CUDA availability using torch.cuda.is_available() and counts |
|
available GPUs using torch.cuda.device_count(). |
|
- Case insensitivity is maintained for string inputs ("cpu", "gpu", "cuda"). |
|
- The function ensures robust error handling for invalid input types or out-of-range indices. |
|
""" |
|
|
|
|
|
gpu_available = torch.cuda.is_available() |
|
gpu_count = torch.cuda.device_count() |
|
correct_strs = ["cpu", "gpu", "cuda"] |
|
|
|
def _get_single_device(gpu_item): |
|
if isinstance(gpu_item, bool): |
|
return 0 if gpu_item and gpu_available else "cpu" |
|
elif isinstance(gpu_item, str): |
|
if gpu_item.lower() not in correct_strs: |
|
raise ValueError(f"Wrong gpu type: {gpu_item}. Should be one of {correct_strs}") |
|
return 0 if (gpu_item.lower() != "cpu") and gpu_available else "cpu" |
|
elif isinstance(gpu_item, int): |
|
if gpu_item >= gpu_count: |
|
raise ValueError( |
|
f"There are {gpu_count} GPUs available. Provide a valid GPU index. You provided: {gpu_item}" |
|
) |
|
return gpu_item if gpu_available else "cpu" |
|
else: |
|
raise ValueError(f"Invalid gpu type: {type(gpu_item)}. Must be bool, str, or int.") |
|
|
|
if isinstance(gpu, list): |
|
seen_indices = set() |
|
result = [] |
|
for item in gpu: |
|
device = _get_single_device(item) |
|
if isinstance(device, int): |
|
if device not in seen_indices: |
|
seen_indices.add(device) |
|
result.append(device) |
|
else: |
|
result.append(device) |
|
return result[0] if len(result) == 1 else result |
|
else: |
|
return _get_single_device(gpu) |
|
|
|
|
|
def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType: |
|
""" |
|
Slice embeddings into segments based on the provided number of sentences per segment. |
|
|
|
Args: |
|
- embeddings (np.ndarray): The array of embeddings to be sliced. |
|
- num_sentences (Union[List[int], List[List[int]]]): |
|
- If a list of integers: Specifies the number of embeddings to take in each slice. |
|
- If a list of lists of integers: Specifies multiple nested levels of slicing. |
|
|
|
Returns: |
|
- List[np.ndarray]: A list of numpy arrays where each array represents a slice of embeddings. |
|
|
|
Raises: |
|
- TypeError: If `num_sentences` is not of type List[int] or List[List[int]]. |
|
|
|
Example Usage: |
|
|
|
```python |
|
embeddings = np.random.rand(10, 5) |
|
num_sentences = [3, 2, 5] |
|
result = slice_embeddings(embeddings, num_sentences) |
|
# `result` will be a list of numpy arrays: |
|
# [embeddings[:3], embeddings[3:5], embeddings[5:]] |
|
|
|
num_sentences_nested = [[2, 1], [3, 4]] |
|
result_nested = slice_embeddings(embeddings, num_sentences_nested) |
|
# `result_nested` will be a nested list of numpy arrays: |
|
# [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]] |
|
|
|
slice_embeddings(embeddings, "invalid") # Raises a TypeError |
|
``` |
|
""" |
|
|
|
def _slice_embeddings(s_idx: int, n_sentences: List[int]): |
|
""" |
|
Helper function to slice embeddings starting from index `s_idx`. |
|
|
|
Args: |
|
- s_idx (int): Starting index for slicing. |
|
- n_sentences (List[int]): List specifying number of sentences in each slice. |
|
|
|
Returns: |
|
- Tuple[List[np.ndarray], int]: A tuple containing a list of sliced embeddings and the next starting index. |
|
""" |
|
_result = [] |
|
for count in n_sentences: |
|
_result.append(embeddings[s_idx:s_idx + count]) |
|
s_idx += count |
|
return _result, s_idx |
|
|
|
if isinstance(num_sentences, list) and all(isinstance(item, int) for item in num_sentences): |
|
result, _ = _slice_embeddings(0, num_sentences) |
|
return result |
|
elif isinstance(num_sentences, list) and all( |
|
isinstance(sublist, list) and all( |
|
isinstance(item, int) for item in sublist |
|
) |
|
for sublist in num_sentences |
|
): |
|
nested_result = [] |
|
start_idx = 0 |
|
for nested_num_sentences in num_sentences: |
|
embedding_slice, start_idx = _slice_embeddings(start_idx, nested_num_sentences) |
|
nested_result.append(embedding_slice) |
|
|
|
return nested_result |
|
else: |
|
raise TypeError(f"Incorrect Type for {num_sentences=}") |
|
|
|
|
|
def is_nested_list_of_type(lst_obj, element_type, depth: int) -> bool: |
|
""" |
|
Check if the given object is a nested list of a specific type up to a specified depth. |
|
|
|
Args: |
|
- lst_obj: The object to check, expected to be a list or a single element. |
|
- element_type: The type that each element in the nested list should match. |
|
- depth (int): The depth of nesting to check. Must be non-negative. |
|
|
|
Returns: |
|
- bool: True if lst_obj is a nested list of the specified type up to the given depth, False otherwise. |
|
|
|
Raises: |
|
- ValueError: If depth is negative. |
|
|
|
Example: |
|
```python |
|
# Test cases |
|
is_nested_list_of_type("test", str, 0) # Returns True |
|
is_nested_list_of_type([1, 2, 3], str, 0) # Returns False |
|
is_nested_list_of_type(["apple", "banana"], str, 1) # Returns True |
|
is_nested_list_of_type([[1, 2], [3, 4]], int, 2) # Returns True |
|
is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2) # Returns False |
|
is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3) # Returns True |
|
``` |
|
|
|
Explanation: |
|
- The function checks if `lst_obj` is a nested list of elements of type `element_type` up to `depth` levels deep. |
|
- If `depth` is 0, it checks if `lst_obj` itself is of type `element_type`. |
|
- If `depth` is greater than 0, it recursively checks each level of nesting to ensure all elements match `element_type`. |
|
- Raises a `ValueError` if `depth` is negative, as depth must be a non-negative integer. |
|
""" |
|
if depth == 0: |
|
return isinstance(lst_obj, element_type) |
|
elif depth > 0: |
|
return isinstance(lst_obj, list) and all(is_nested_list_of_type(item, element_type, depth - 1) for item in lst_obj) |
|
else: |
|
raise ValueError("Depth can't be negative") |
|
|
|
|
|
def flatten_list(nested_list: list) -> list: |
|
""" |
|
Recursively flattens a nested list of any depth. |
|
|
|
Parameters: |
|
nested_list (list): The nested list to flatten. |
|
|
|
Returns: |
|
list: A flat list containing all the elements of the nested list. |
|
""" |
|
flat_list = [] |
|
for item in nested_list: |
|
if isinstance(item, list): |
|
flat_list.extend(flatten_list(item)) |
|
else: |
|
flat_list.append(item) |
|
return flat_list |
|
|
|
|
|
def compute_f1(p: float, r: float, eps=sys.float_info.epsilon) -> float: |
|
""" |
|
Computes F1 value |
|
:param p: Precision Value |
|
:param r: Recall Value |
|
:param eps: Epsilon Value |
|
:return: |
|
""" |
|
f1 = 2 * p * r / (p + r + eps) |
|
return f1 |
|
|
|
|
|
@dataclass |
|
class Scores: |
|
""" |
|
Data class representing evaluation scores including precision, recall, and computed F1 score. |
|
|
|
Attributes: |
|
- precision (float): The precision score for the evaluation. |
|
- recall (List[float]): List of recall scores for each reference |
|
- f1 (float): Computed F1 score based on the precision and mean recall values. |
|
""" |
|
precision: float |
|
recall: List[float] |
|
|
|
def __post_init__(self): |
|
self.f1: float = compute_f1(self.precision, statistics.fmean(self.recall)) |
|
|