from functools import partial import numpy as np def gather(a, index): return a[np.arange(np.shape(a)[0]), index] import raven_utils.group as group def def_init_image(shape=(10, 64, 64, 3), mode="uniform", min=0, max=1): if mode == "normal" or mode == "n": return np.random.normal(min, max, shape) elif mode == "zero" or mode == 0: return np.zeros(shape) elif mode == "one" or mode == 1: return np.ones(shape) elif mode == "int" or isinstance(mode, int): return np.random.randint(min, max, shape) return np.random.uniform(min, max, shape) init_image = partial(def_init_image, shape=(16, 8, 80, 80, 1)) def get_val_index(no=group.NO, base=3, add_end=False): indexes = np.arange(no) * 2000 + base if add_end: indexes = np.concatenate([indexes, no * 2000]) return indexes def get_matrix(inputs, index): return np.concatenate([inputs[:, :8], gather(inputs, index[:, 0])[:, None]], axis=1) def get_matrix_from_data(x): inputs = x["inputs"] index = x["index"] return get_matrix(inputs, index) def compare_from_result(result, data): data = data.data.data answer = D.gather(data['target'].data, data['index'].data[:, 0]) import raven_utils as rv predict = result['predict'] predict_mask = result['predict_mask'] return np.all(rv.decode.compare(answer[:len(predict)], predict, predict_mask), axis=-1)