raven / raven_utils /uitls.py
Jakub Kwiatkowski
Refactor.
38f87b5
raw
history blame
1.43 kB
from functools import partial
import numpy as np
def gather(a, index):
return a[np.arange(np.shape(a)[0]), index]
import raven_utils.group as group
def def_init_image(shape=(10, 64, 64, 3), mode="uniform", min=0, max=1):
if mode == "normal" or mode == "n":
return np.random.normal(min, max, shape)
elif mode == "zero" or mode == 0:
return np.zeros(shape)
elif mode == "one" or mode == 1:
return np.ones(shape)
elif mode == "int" or isinstance(mode, int):
return np.random.randint(min, max, shape)
return np.random.uniform(min, max, shape)
init_image = partial(def_init_image, shape=(16, 8, 80, 80, 1))
def get_val_index(no=group.NO, base=3, add_end=False):
indexes = np.arange(no) * 2000 + base
if add_end:
indexes = np.concatenate([indexes, no * 2000])
return indexes
def get_matrix(inputs, index):
return np.concatenate([inputs[:, :8], gather(inputs, index[:, 0])[:, None]], axis=1)
def get_matrix_from_data(x):
inputs = x["inputs"]
index = x["index"]
return get_matrix(inputs, index)
def compare_from_result(result, data):
data = data.data.data
answer = D.gather(data['target'].data, data['index'].data[:, 0])
import raven_utils as rv
predict = result['predict']
predict_mask = result['predict_mask']
return np.all(rv.decode.compare(answer[:len(predict)], predict, predict_mask), axis=-1)