Spaces:
Build error
Build error
import numpy as np | |
from data_utils.image import draw_images | |
from ml_utils import il | |
import raven_utils as rv | |
from raven_utils.uitls import get_matrix | |
from tensorflow.keras.models import load_model | |
from raven_utils.draw import render_from_model | |
import models | |
import ast | |
def load_example(index=0): | |
index = ast.literal_eval(str(index)) | |
if il(index): | |
example = rv.draw.render_panels(np.array(index)) | |
desc = "Custom matrix" | |
else: | |
if not index: | |
index = 0 | |
index = int(index) | |
desc = rv.draw.extract_rules(models.properties[index]) | |
desc = "<br /><br />".join(["<br />".join(d) for d in desc]) | |
example = get_matrix(models.data[index:index + 1], models.indexes[index:index + 1, None] + 8) | |
result = np.tile(draw_images(example[:9], row=3), reps=(1, 1, 3)) | |
return result, desc | |
def load_model_(name): | |
if name == "Transformer": | |
path = "/home/jkwiatkowski/all/best/rav/full_trans/6e8e6bad403e4171ad10daa1a518ba09" | |
else: | |
path = name | |
models.model = load_model(path) | |
return f"Success loading: {name}" | |
def run_nn(index=0): | |
index = ast.literal_eval(str(index)) | |
if il(index): | |
data = rv.draw.render_panels(np.array(index)) | |
data = np.concatenate([data, data[:7]])[None] | |
else: | |
if not index: | |
index = models.START_IMAGE | |
index = int(index) | |
data = models.data[index:index + 1] | |
# model = load_model("/home/jkwiatkowski/all/best/rav/full_trans/6e8e6bad403e4171ad10daa1a518ba09") | |
data = { | |
'inputs': data, | |
'index': np.zeros(shape=(1, 1), dtype="uint8"), | |
'labels': np.zeros(shape=(1, 16, 113), dtype="int8"), | |
'target': np.zeros(shape=(1, 16, 113), dtype="int8"), | |
# 'features': np.zeros(shape=(1, 16, 64), dtype="float32") | |
} | |
res = np.tile(render_from_model(data, models.model)[0, ..., None], reps=(1, 1, 3)) | |
# res = model({'inputs': data[0:1]}) | |
return res | |
def next_(index=0): | |
index = ast.literal_eval(str(index)) | |
if not isinstance(index, int): | |
index = models.START_IMAGE | |
index = int(index) + 1 | |
return (index,) + load_example(index) | |
def prev_(index=0): | |
index = ast.literal_eval(str(index)) | |
if not isinstance(index, int): | |
index = models.START_IMAGE | |
index = int(index) - 1 | |
return (index,) + load_example(index) | |
if __name__ == '__main__': | |
image, _ = load_example(5) | |
run_nn(image) | |