raven / models.py
Jakub Kwiatkowski
Change properties.
758ad23
raw
history blame
735 Bytes
import tensorflow as tf
from config_utils import tf_gpu
from data_utils import DataSetFromFolder
tf_gpu()
tf.experimental.numpy.experimental_enable_numpy_behavior(prefer_float32=True)
from huggingface_hub import from_pretrained_keras
from datasets import load_dataset
repo = "jkwiatkowski/raven"
data = load_dataset(repo, split="val")
model = from_pretrained_keras(repo)
properties = load_dataset(repo + "_properties", split="val")
START_IMAGE = 12000
# def convert(data):
# return {
# 'inputs': tf.cast(data['inputs'], dtype="uint8"),
# 'index': tf.cast(data['index'], dtype="uint8")[..., None],
# 'target': tf.cast(data['target'], dtype="int8"),
# }
#
# model(convert(data[0:1]))
print("xD")