nbansal commited on
Commit
bcb2272
1 Parent(s): a249916

Add doc strings

Browse files
Files changed (1) hide show
  1. utils.py +74 -0
utils.py CHANGED
@@ -79,7 +79,50 @@ def get_gpu(gpu: DEVICE_TYPE) -> ENCODER_DEVICE_TYPE:
79
 
80
 
81
  def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def _slice_embeddings(s_idx: int, n_sentences: List[int]):
 
 
 
 
 
 
 
 
 
 
83
  _result = []
84
  for count in n_sentences:
85
  _result.append(embeddings[s_idx:s_idx + count])
@@ -107,6 +150,37 @@ def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> Em
107
 
108
 
109
  def is_nested_list_of_type(lst_obj, element_type, depth: int) -> bool:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  if depth == 0:
111
  return isinstance(lst_obj, element_type)
112
  elif depth > 0:
 
79
 
80
 
81
  def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType:
82
+ """
83
+ Slice embeddings into segments based on the provided number of sentences per segment.
84
+
85
+ Args:
86
+ - embeddings (np.ndarray): The array of embeddings to be sliced.
87
+ - num_sentences (Union[List[int], List[List[int]]]):
88
+ - If a list of integers: Specifies the number of embeddings to take in each slice.
89
+ - If a list of lists of integers: Specifies multiple nested levels of slicing.
90
+
91
+ Returns:
92
+ - List[np.ndarray]: A list of numpy arrays where each array represents a slice of embeddings.
93
+
94
+ Raises:
95
+ - TypeError: If `num_sentences` is not of type List[int] or List[List[int]].
96
+
97
+ Example Usage:
98
+
99
+ ```python
100
+ embeddings = np.random.rand(10, 5)
101
+ num_sentences = [3, 2, 5]
102
+ result = slice_embeddings(embeddings, num_sentences)
103
+ # `result` will be a list of numpy arrays:
104
+ # [embeddings[:3], embeddings[3:5], embeddings[5:]]
105
+
106
+ num_sentences_nested = [[2, 1], [3, 4]]
107
+ result_nested = slice_embeddings(embeddings, num_sentences_nested)
108
+ # `result_nested` will be a nested list of numpy arrays:
109
+ # [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]]
110
+
111
+ slice_embeddings(embeddings, "invalid") # Raises a TypeError
112
+ ```
113
+ """
114
+
115
  def _slice_embeddings(s_idx: int, n_sentences: List[int]):
116
+ """
117
+ Helper function to slice embeddings starting from index `s_idx`.
118
+
119
+ Args:
120
+ - s_idx (int): Starting index for slicing.
121
+ - n_sentences (List[int]): List specifying number of sentences in each slice.
122
+
123
+ Returns:
124
+ - Tuple[List[np.ndarray], int]: A tuple containing a list of sliced embeddings and the next starting index.
125
+ """
126
  _result = []
127
  for count in n_sentences:
128
  _result.append(embeddings[s_idx:s_idx + count])
 
150
 
151
 
152
  def is_nested_list_of_type(lst_obj, element_type, depth: int) -> bool:
153
+ """
154
+ Check if the given object is a nested list of a specific type up to a specified depth.
155
+
156
+ Args:
157
+ - lst_obj: The object to check, expected to be a list or a single element.
158
+ - element_type: The type that each element in the nested list should match.
159
+ - depth (int): The depth of nesting to check. Must be non-negative.
160
+
161
+ Returns:
162
+ - bool: True if lst_obj is a nested list of the specified type up to the given depth, False otherwise.
163
+
164
+ Raises:
165
+ - ValueError: If depth is negative.
166
+
167
+ Example:
168
+ ```python
169
+ # Test cases
170
+ is_nested_list_of_type("test", str, 0) # Returns True
171
+ is_nested_list_of_type([1, 2, 3], str, 0) # Returns False
172
+ is_nested_list_of_type(["apple", "banana"], str, 1) # Returns True
173
+ is_nested_list_of_type([[1, 2], [3, 4]], int, 2) # Returns True
174
+ is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2) # Returns False
175
+ is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3) # Returns True
176
+ ```
177
+
178
+ Explanation:
179
+ - The function checks if `lst_obj` is a nested list of elements of type `element_type` up to `depth` levels deep.
180
+ - If `depth` is 0, it checks if `lst_obj` itself is of type `element_type`.
181
+ - If `depth` is greater than 0, it recursively checks each level of nesting to ensure all elements match `element_type`.
182
+ - Raises a `ValueError` if `depth` is negative, as depth must be a non-negative integer.
183
+ """
184
  if depth == 0:
185
  return isinstance(lst_obj, element_type)
186
  elif depth > 0: