DeepLearning101's picture
Upload 21 files
45311fe
raw
history blame
887 Bytes
# -*- coding: utf-8 -*-
# @Time : 2021/12/2 5:41 p.m.
# @Author : JianingWang
# @File : sampler.py
import numpy as np
from typing import Optional
"""
random sampling for each label
"""
def random_sampling(raw_datasets, num_examples_per_label: Optional[int]=16):
label_list = raw_datasets["label"] # [0, 1, 0, 0, ...]
label_dict = dict()
# denote index of each label
for ei, label in enumerate(label_list):
if label not in label_dict.keys():
label_dict[label] = list()
label_dict[label].append(ei)
# random sample k examples of each class
few_example_ids = list()
for label, eid_list in label_dict.items():
idxs = np.random.choice(len(eid_list), size=num_examples_per_label, replace=False)
selected_eids = [eid_list[i] for i in idxs]
few_example_ids.extend(selected_eids)
return few_example_ids