Spaces:
Build error
Build error
from functools import partial | |
import numpy as np | |
def gather(a, index): | |
return a[np.arange(np.shape(a)[0]), index] | |
import raven_utils.group as group | |
def def_init_image(shape=(10, 64, 64, 3), mode="uniform", min=0, max=1): | |
if mode == "normal" or mode == "n": | |
return np.random.normal(min, max, shape) | |
elif mode == "zero" or mode == 0: | |
return np.zeros(shape) | |
elif mode == "one" or mode == 1: | |
return np.ones(shape) | |
elif mode == "int" or isinstance(mode, int): | |
return np.random.randint(min, max, shape) | |
return np.random.uniform(min, max, shape) | |
init_image = partial(def_init_image, shape=(16, 8, 80, 80, 1)) | |
def get_val_index(no=group.NO, base=3, add_end=False): | |
indexes = np.arange(no) * 2000 + base | |
if add_end: | |
indexes = np.concatenate([indexes, no * 2000]) | |
return indexes | |
def get_matrix(inputs, index): | |
return np.concatenate([inputs[:, :8], gather(inputs, index[:, 0])[:, None]], axis=1) | |
def get_matrix_from_data(x): | |
inputs = x["inputs"] | |
index = x["index"] | |
return get_matrix(inputs, index) | |
def compare_from_result(result, data): | |
data = data.data.data | |
answer = D.gather(data['target'].data, data['index'].data[:, 0]) | |
import raven_utils as rv | |
predict = result['predict'] | |
predict_mask = result['predict_mask'] | |
return np.all(rv.decode.compare(answer[:len(predict)], predict, predict_mask), axis=-1) | |