File size: 887 Bytes
45311fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# -*- coding: utf-8 -*-
# @Time    : 2021/12/2 5:41 p.m.
# @Author  : JianingWang
# @File    : sampler.py

import numpy as np
from typing import Optional

"""
random sampling for each label
"""
def random_sampling(raw_datasets, num_examples_per_label: Optional[int]=16):
    label_list = raw_datasets["label"] # [0, 1, 0, 0, ...]
    label_dict = dict()
    # denote index of each label
    for ei, label in enumerate(label_list):
        if label not in label_dict.keys():
            label_dict[label] = list()
        label_dict[label].append(ei)
    # random sample k examples of each class
    few_example_ids = list()
    for label, eid_list in label_dict.items():
        idxs = np.random.choice(len(eid_list), size=num_examples_per_label, replace=False)
        selected_eids = [eid_list[i] for i in idxs]
        few_example_ids.extend(selected_eids)
    return few_example_ids