File size: 770 Bytes
14a3421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np


def weighted_random_sample(items: np.array, weights: np.array, n: int) -> np.array:
    """
    Does np.random.choice but ensuring we don't have duplicates in the final result

    Args:
        items (np.array): _description_
        weights (np.array): _description_
        n (int): _description_

    Returns:
        np.array: _description_
    """
    indices = np.arange(len(items))
    out_indices = []

    for _ in range(n):
        chosen_index = np.random.choice(indices, p=weights)
        out_indices.append(chosen_index)

        mask = indices != chosen_index
        indices = indices[mask]
        weights = weights[mask]

        if weights.sum() != 0:
            weights = weights / weights.sum()

    return items[out_indices]