File size: 145,450 Bytes
e8c4ed3
1
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.7.10","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"## THINGS TO DO??\n\n### Existing NeRF\n1. Rewrite the model into FLAX (otherwise cant work with Flax-CLIP?) (Haiku seems like functional, but Flax is OOP?)\n2. Rewrite training loop into FLAX (does FLAX provide abstraction for writing training loop?? current training loop is pretty low-level)\n3. Refactor the notebook code --> module code\n    - render scene, visualizing, animation ... etc.\n4. Dataloading for our concerned dataset\n5. consolidate all controllabe params in a class (e.g. `Config`)\n\n\n### NeRF --> DietNeRF\n1. Change sampling to 8 samples only\n2. Add CLIP into the training loop for a new loss function\n3. Check if DietNeRF can get comparable result to NeRF\n\n### Optional?\n1. Understand what the hell each operations are doing? (e.g. `get_rays` ... etc)\n2. Add W&B for visualization?\n3. Super large scale NeRF(ssssss) training --> get huge samples of scene for POC\n4. Optimize bottleneck operations by `jax.vmap`, `jax.pmap`","metadata":{}},{"cell_type":"code","source":"# enable showing live \"loss plot\" inside notebook\n!pip install livelossplot","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:25:20.105372Z","iopub.execute_input":"2021-07-04T06:25:20.105956Z","iopub.status.idle":"2021-07-04T06:25:27.213206Z","shell.execute_reply.started":"2021-07-04T06:25:20.105901Z","shell.execute_reply":"2021-07-04T06:25:27.211937Z"},"trusted":true},"execution_count":63,"outputs":[{"name":"stdout","text":"Requirement already satisfied: livelossplot in /opt/conda/lib/python3.7/site-packages (0.5.4)\nRequirement already satisfied: matplotlib in /opt/conda/lib/python3.7/site-packages (from livelossplot) (3.4.1)\nRequirement already satisfied: ipython in /opt/conda/lib/python3.7/site-packages (from livelossplot) (7.22.0)\nRequirement already satisfied: bokeh in /opt/conda/lib/python3.7/site-packages (from livelossplot) (2.3.1)\nRequirement already satisfied: typing-extensions>=3.7.4 in /opt/conda/lib/python3.7/site-packages (from bokeh->livelossplot) (3.7.4.3)\nRequirement already satisfied: packaging>=16.8 in /opt/conda/lib/python3.7/site-packages (from bokeh->livelossplot) (20.9)\nRequirement already satisfied: PyYAML>=3.10 in /opt/conda/lib/python3.7/site-packages (from bokeh->livelossplot) (5.3.1)\nRequirement already satisfied: pillow>=7.1.0 in /opt/conda/lib/python3.7/site-packages (from bokeh->livelossplot) (7.2.0)\nRequirement already satisfied: tornado>=5.1 in /opt/conda/lib/python3.7/site-packages (from bokeh->livelossplot) (6.1)\nRequirement already satisfied: numpy>=1.11.3 in /opt/conda/lib/python3.7/site-packages (from bokeh->livelossplot) (1.19.5)\nRequirement already satisfied: python-dateutil>=2.1 in /opt/conda/lib/python3.7/site-packages (from bokeh->livelossplot) (2.8.1)\nRequirement already satisfied: Jinja2>=2.7 in /opt/conda/lib/python3.7/site-packages (from bokeh->livelossplot) (2.11.3)\nRequirement already satisfied: MarkupSafe>=0.23 in /opt/conda/lib/python3.7/site-packages (from Jinja2>=2.7->bokeh->livelossplot) (1.1.1)\nRequirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging>=16.8->bokeh->livelossplot) (2.4.7)\nRequirement already satisfied: six>=1.5 in /opt/conda/lib/python3.7/site-packages (from python-dateutil>=2.1->bokeh->livelossplot) (1.15.0)\nRequirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (3.0.18)\nRequirement already satisfied: decorator in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (4.4.2)\nRequirement already satisfied: pygments in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (2.8.1)\nRequirement already satisfied: pickleshare in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (0.7.5)\nRequirement already satisfied: jedi>=0.16 in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (0.18.0)\nRequirement already satisfied: pexpect>4.3 in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (4.8.0)\nRequirement already satisfied: traitlets>=4.2 in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (5.0.5)\nRequirement already satisfied: setuptools>=18.5 in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (49.6.0.post20210108)\nRequirement already satisfied: backcall in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (0.2.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /opt/conda/lib/python3.7/site-packages (from jedi>=0.16->ipython->livelossplot) (0.8.1)\nRequirement already satisfied: ptyprocess>=0.5 in /opt/conda/lib/python3.7/site-packages (from pexpect>4.3->ipython->livelossplot) (0.7.0)\nRequirement already satisfied: wcwidth in /opt/conda/lib/python3.7/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython->livelossplot) (0.2.5)\nRequirement already satisfied: ipython-genutils in /opt/conda/lib/python3.7/site-packages (from traitlets>=4.2->ipython->livelossplot) (0.2.0)\nRequirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/lib/python3.7/site-packages (from matplotlib->livelossplot) (1.3.1)\nRequirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.7/site-packages (from matplotlib->livelossplot) (0.10.0)\n","output_type":"stream"}]},{"cell_type":"code","source":"%%capture\n!conda install -y -c conda-forge jax jaxlib flax optax datasets transformers\n!conda install -y importlib-metadata\n!pip install -U dm-haiku","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","execution":{"iopub.status.busy":"2021-07-04T06:25:27.215245Z","iopub.execute_input":"2021-07-04T06:25:27.215562Z","iopub.status.idle":"2021-07-04T06:27:11.574413Z","shell.execute_reply.started":"2021-07-04T06:25:27.215530Z","shell.execute_reply":"2021-07-04T06:27:11.572485Z"},"trusted":true},"execution_count":64,"outputs":[]},{"cell_type":"markdown","source":"","metadata":{}},{"cell_type":"code","source":"# TPU setup\nimport os\nif 'TPU_NAME' in os.environ:\n    import requests\n    if 'TPU_DRIVER_MODE' not in globals():\n        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'\n        resp = requests.post(url)\n        TPU_DRIVER_MODE = 1\n\n    from jax.config import config\n    config.FLAGS.jax_xla_backend = \"tpu_driver\"\n    config.FLAGS.jax_backend_target = os.environ['TPU_NAME']\n    print('Registered TPU:', config.FLAGS.jax_backend_target)\nelse:\n    print('No TPU detected. Can be changed under \"Runtime/Change runtime type\".')\n\n# Module check\nimport jax\nimport flax\nimport haiku as hk\n\nfor _m in (jax, flax, hk):\n    print(f'{_m.__name__}: {_m.__version__}')\njax.local_devices()","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.578625Z","iopub.execute_input":"2021-07-04T06:27:11.579247Z","iopub.status.idle":"2021-07-04T06:27:11.600702Z","shell.execute_reply.started":"2021-07-04T06:27:11.579197Z","shell.execute_reply":"2021-07-04T06:27:11.598533Z"},"trusted":true},"execution_count":65,"outputs":[{"name":"stdout","text":"No TPU detected. Can be changed under \"Runtime/Change runtime type\".\njax: 0.2.16\nflax: 0.3.4\nhaiku: 0.0.4\n","output_type":"stream"},{"execution_count":65,"output_type":"execute_result","data":{"text/plain":"[CpuDevice(id=0)]"},"metadata":{}}]},{"cell_type":"code","source":"from functools import partial\n\nimport jax\nfrom jax import random, grad, jit, vmap, flatten_util, nn\nfrom jax.experimental import optimizers  # change due to version difference\nfrom jax.config import config\nimport jax.numpy as np\n\nimport haiku as hk\nfrom haiku._src import utils\n\nfrom livelossplot import PlotLosses\nimport matplotlib.pyplot as plt\nfrom tqdm.notebook import tqdm as tqdm\nimport cv2\nimport imageio\nimport glob\nfrom IPython.display import clear_output\nimport pickle\nfrom skimage.metrics import structural_similarity as ssim_fn\n\nrng = jax.random.PRNGKey(42)","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.603322Z","iopub.execute_input":"2021-07-04T06:27:11.603678Z","iopub.status.idle":"2021-07-04T06:27:11.619997Z","shell.execute_reply.started":"2021-07-04T06:27:11.603646Z","shell.execute_reply":"2021-07-04T06:27:11.618654Z"},"trusted":true},"execution_count":66,"outputs":[]},{"cell_type":"code","source":"#ls ../input/pull-phototourism-images/sacre_coeur/dense/images\nDATASET = 'sacre'\nposedir = f'../input/phototourism/phototourism/sacre' # Directory condtains [bds.npy, c2w_mats.npy, kinv_mats.npy, res_mats.npy]\nimgdir = f'../input/pull-phototourism-images/sacre_coeur/dense/images' # Directory of images","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.621386Z","iopub.execute_input":"2021-07-04T06:27:11.621825Z","iopub.status.idle":"2021-07-04T06:27:11.630402Z","shell.execute_reply.started":"2021-07-04T06:27:11.621792Z","shell.execute_reply":"2021-07-04T06:27:11.629270Z"},"trusted":true},"execution_count":67,"outputs":[]},{"cell_type":"markdown","source":"### 1. Helper Functions for Loading Data","metadata":{}},{"cell_type":"code","source":"posedata = {}\nfor f in os.listdir(posedir):\n    if '.npy' not in f:\n        continue\n    z = np.load(os.path.join(posedir, f))\n    posedata[f.split('.')[0]] = z\nprint('Pose data loaded - ', posedata.keys())\n\nimgfiles = sorted(glob.glob(imgdir + '/*.jpg'))\nprint(f'{len(imgfiles)} images')","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.633311Z","iopub.execute_input":"2021-07-04T06:27:11.633646Z","iopub.status.idle":"2021-07-04T06:27:11.676300Z","shell.execute_reply.started":"2021-07-04T06:27:11.633614Z","shell.execute_reply":"2021-07-04T06:27:11.674797Z"},"trusted":true},"execution_count":68,"outputs":[{"name":"stdout","text":"Pose data loaded -  dict_keys(['kinv_mats', 'res_mats', 'c2w_mats', 'bds'])\n1179 images\n","output_type":"stream"}]},{"cell_type":"code","source":"@jit\ndef get_rays(c2w, kinv, i, j):\n#     i, j = np.meshgrid(np.arange(W), np.arange(H), indexing='xy')\n    pixco = np.stack([i, j, np.ones_like(i)], -1)\n    dirs = pixco @ kinv.T\n#     dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1)\n    rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)\n    rays_o = np.broadcast_to(c2w[:3,-1], rays_d.shape)\n    return np.stack([rays_o, rays_d], 0)\n\n\ndef normalize(x):\n    return x / np.linalg.norm(x)\n\n\ndef viewmatrix(z, up, pos):\n    vec2 = normalize(z)\n    vec1_avg = up\n    vec0 = normalize(np.cross(vec1_avg, vec2))\n    vec1 = normalize(np.cross(vec2, vec0))\n    m = np.stack([vec0, vec1, vec2, pos], 1)\n    return m\n\n\ndef ptstocam(pts, c2w):\n    tt = np.matmul(c2w[:3,:3].T, (pts-c2w[:3,3])[...,np.newaxis])[...,0]\n    return tt\n\n\ndef poses_avg(poses):\n    center = poses[:, :3, 3].mean(0)\n    vec2 = normalize(poses[:, :3, 2].sum(0))\n    up = poses[:, :3, 1].sum(0)\n    return viewmatrix(vec2, up, center)\n\n\ndef render_path_spiral(c2w, up, rads, focal, zrate, rots, N):\n    \"\"\"\n    enumerate list of poses around a spiral\n    used for test set visualization\n    \"\"\"\n    render_poses = []\n    rads = np.array(list(rads) + [1.])\n    for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:\n        c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) \n        z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.])))\n        render_poses.append(viewmatrix(z, up, c))\n    return render_poses","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.678295Z","iopub.execute_input":"2021-07-04T06:27:11.678750Z","iopub.status.idle":"2021-07-04T06:27:11.699145Z","shell.execute_reply.started":"2021-07-04T06:27:11.678692Z","shell.execute_reply":"2021-07-04T06:27:11.698223Z"},"trusted":true},"execution_count":69,"outputs":[]},{"cell_type":"code","source":"def get_example(img_idx, split='train', downsample=4):\n    sc = .05\n    \n    # first 20 are test, next 5 are validation, the rest are training:\n    # https://github.com/tancik/learnit/issues/3\n    if 'train' in split:\n        img_idx = img_idx + 25\n    if 'val' in split:\n        img_idx = img_idx + 20\n        \n    # uint8 --> float\n    img = imageio.imread(imgfiles[img_idx])[...,:3]/255.\n    \n    # WHAT DO THESE MATRICES MEAN???\n    # (4, 4)\n    c2w = posedata['c2w_mats'][img_idx]\n    # (3, 3)\n    kinv = posedata['kinv_mats'][img_idx]\n    c2w = np.concatenate([c2w[:3,:3], c2w[:3,3:4]*sc], -1)\n    # (2, )\n    bds = posedata['bds'][img_idx] * np.array([.9, 1.2]) * sc\n    H, W = img.shape[:2]\n    \n    # (0, 4, 8, ..., H)\n    # WHAT ARE THE PURPOSES OF THIS MATRIX???\n    i, j = np.meshgrid(np.arange(0,W,downsample), np.arange(0,H,downsample), indexing='xy')\n    \n    test_images = img[j, i]\n    test_rays = get_rays(c2w, kinv, i, j)\n    return test_images, test_rays, bds","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.702928Z","iopub.execute_input":"2021-07-04T06:27:11.703761Z","iopub.status.idle":"2021-07-04T06:27:11.717429Z","shell.execute_reply.started":"2021-07-04T06:27:11.703650Z","shell.execute_reply":"2021-07-04T06:27:11.715895Z"},"trusted":true},"execution_count":70,"outputs":[]},{"cell_type":"markdown","source":"### 2. NeRF Renderer","metadata":{}},{"cell_type":"code","source":"def render_rays(\n        rnd_input, model, params, \n        bvals, rays, near, far, \n        N_samples, rand=False, allret=False\n    ):\n    rays_o, rays_d = rays\n\n    # Compute 3D query points\n    z_vals = np.linspace(near, far, N_samples) \n    if rand:\n        z_vals += random.uniform(rnd_input, shape=list(rays_o.shape[:-1]) + [N_samples]) * (far-near)/N_samples\n    # r(t) = o + t*d\n    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]\n    \n    # Run network\n    pts_flat = np.reshape(pts, [-1,3])\n    if bvals is not None:\n        pts_flat = np.concatenate([np.sin(pts_flat @ bvals.T), \n                                np.cos(pts_flat @ bvals.T)], axis=-1)\n        \n    raw = model.apply(params, pts_flat)\n    raw = np.reshape(raw, list(pts.shape[:-1]) + [4])\n    \n    # Compute opacities and colors\n    rgb, sigma_a = raw[...,:3], raw[...,3]\n    sigma_a = jax.nn.relu(sigma_a)\n    rgb = jax.nn.sigmoid(rgb) \n    \n    # Do volume rendering\n    dists = np.concatenate([z_vals[..., 1:] - z_vals[..., :-1], np.broadcast_to([1e10], z_vals[...,:1].shape)], -1) \n    alpha = 1. - np.exp(-sigma_a * dists)\n    trans = np.minimum(1., 1. - alpha + 1e-10)\n    trans = np.concatenate([np.ones_like(trans[...,:1]), trans[...,:-1]], -1)  \n    weights = alpha * np.cumprod(trans, -1)\n    \n    rgb_map = np.sum(weights[...,None] * rgb, -2) \n    acc_map = np.sum(weights, -1)\n    \n    if False:\n        rgb_map = rgb_map + (1.-acc_map[..., None])\n    \n    if not allret:\n        return rgb_map\n    \n    depth_map = np.sum(weights * z_vals, -1) \n\n    return rgb_map, depth_map, acc_map\n\n\ndef render_fn_inner(rnd_input, model, params, bvals, rays, near, far, rand, allret, N_samples):\n    return render_rays(rnd_input, model, params, bvals, rays, near, far, \n                       N_samples=N_samples, rand=rand, allret=allret)\n\n# optimize render_fn_inner by JIT (func in, func out)\nrender_fn_inner = jit(render_fn_inner, static_argnums=(1, 7, 8, 9))\n\n\ndef render_fn(rnd_input, model, params, bvals, rays, near, far, N_samples, rand):\n    chunk = 5\n    for i in range(0, rays.shape[1], chunk):\n        out = render_fn_inner(rnd_input, model, params, bvals, rays[:,i:i+chunk], near, far, rand, True, N_samples)\n        if i==0:\n            rets = out\n        else:\n            rets = [np.concatenate([a, b], 0) for a, b in zip(rets, out)]\n    return rets","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.719345Z","iopub.execute_input":"2021-07-04T06:27:11.719675Z","iopub.status.idle":"2021-07-04T06:27:11.741107Z","shell.execute_reply.started":"2021-07-04T06:27:11.719638Z","shell.execute_reply":"2021-07-04T06:27:11.740228Z"},"trusted":true},"execution_count":71,"outputs":[]},{"cell_type":"markdown","source":"### 3. NeRF Model Architecture","metadata":{}},{"cell_type":"code","source":"class Model(hk.Module):\n    def __init__(self):\n        super().__init__()\n        self.width = 256\n        self.depth = 6\n        self.use_viewdirs = False\n                \n    def __call__(self, coords, view_dirs=None):\n        sh = coords.shape\n        if self.use_viewdirs:\n            viewdirs = None\n            viewdirs = np.repeat(viewdirs[...,None,:], coords.shape[-2], axis=-2)\n            viewdirs /= np.linalg.norm(viewdirs, axis=-1, keepdims=True)\n            viewdirs = np.reshape(viewdirs, (-1,3))\n            viewdirs = hk.Linear(output_size=self.width//2)(viewdirs)\n            viewdirs = jax.nn.relu(viewdirs)\n        coords = np.reshape(coords, [-1,3])\n        \n        # positional encoding\n        x = np.concatenate([np.concatenate([np.sin(coords*(2**i)), np.cos(coords*(2**i))], axis=-1) for i in np.linspace(0,8,20)], axis=-1)\n\n        for _ in range(self.depth-1):\n            x = hk.Linear(output_size=self.width)(x)\n            x = jax.nn.relu(x)\n            \n        if self.use_viewdirs:\n            density = hk.Linear(output_size=1)(x)\n            x = np.concatenate([x,viewdirs], axis=-1)\n            x = hk.Linear(output_size=self.width)(x)\n            x = jax.nn.relu(x)\n            rgb = hk.Linear(output_size=3)(x)\n            out = np.concatenate([density, rgb], axis=-1)\n        else:\n            out = hk.Linear(output_size=4)(x)\n        out = np.reshape(out, list(sh[:-1]) + [4])\n        return out","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.742202Z","iopub.execute_input":"2021-07-04T06:27:11.742739Z","iopub.status.idle":"2021-07-04T06:27:11.761243Z","shell.execute_reply.started":"2021-07-04T06:27:11.742684Z","shell.execute_reply":"2021-07-04T06:27:11.760296Z"},"trusted":true},"execution_count":72,"outputs":[]},{"cell_type":"markdown","source":"### 4. Training Loop","metadata":{}},{"cell_type":"code","source":"batch_size = 64\nN_samples = 128\ninner_step_size = 1\n\nmodel = hk.without_apply_rng(hk.transform(lambda x, y=None: Model()(x, y)))\n\nmse_fn = jit(lambda x, y: np.mean((x - y)**2))\npsnr_fn = jit(lambda x, y: -10 * np.log10(mse_fn(x, y)))\n\n@jit\ndef single_step(rng, image, rays, params, bds):\n    def sgd(param, update):\n        return param - inner_step_size * update\n    \n    rng, rng_inputs = jax.random.split(rng)\n    def loss_model(params):\n        g = render_rays(rng_inputs, model, params, None, rays, bds[0], bds[1], N_samples, rand=True)\n        return mse_fn(g, image)\n    \n    model_loss, grad = jax.value_and_grad(loss_model)(params)\n    new_params = jax.tree_multimap(sgd, params, grad)\n    return rng, new_params, model_loss\n\ndef update_network_weights(rng, images, rays, params, inner_steps, bds):\n    for _ in range(inner_steps):\n        rng, rng_input = random.split(rng)\n        idx = random.randint(rng_input, shape=(batch_size,), minval=0, maxval=images.shape[0])\n        image_sub = images[idx,:]\n        rays_sub = rays[:,idx,:]\n        \n        rng, params, loss = single_step(rng, image_sub, rays_sub, params, bds)\n    return rng, params, loss","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.762522Z","iopub.execute_input":"2021-07-04T06:27:11.763042Z","iopub.status.idle":"2021-07-04T06:27:11.781650Z","shell.execute_reply.started":"2021-07-04T06:27:11.763004Z","shell.execute_reply":"2021-07-04T06:27:11.780574Z"},"trusted":true},"execution_count":73,"outputs":[]},{"cell_type":"code","source":"plt_groups = {'Train PSNR':[], 'Test PSNR':[]}\nplotlosses_model = PlotLosses(groups=plt_groups)","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.782906Z","iopub.execute_input":"2021-07-04T06:27:11.783199Z","iopub.status.idle":"2021-07-04T06:27:11.797320Z","shell.execute_reply.started":"2021-07-04T06:27:11.783172Z","shell.execute_reply":"2021-07-04T06:27:11.796476Z"},"trusted":true},"execution_count":74,"outputs":[]},{"cell_type":"code","source":"max_iters = 150000\n\ninner_update_steps = 64\nlr = 5e-4\n\nexp_name = f'{DATASET}_ius_{inner_update_steps}_ilr_{inner_step_size}_olr_{lr}_bs_{batch_size}'\nexp_dir = f'checkpoint/phototourism_checkpoints/{exp_name}/'\n\nif not os.path.exists(exp_dir):\n    os.makedirs(exp_dir)\n\nparams = model.init(rng, np.ones((1,3)))\n\nopt = optimizers.adam(lr)\nopt_state = opt.init_fun(params)\n\ntest_inner_steps = 64\n\n\ndef update_model(rng, params, opt_state, image, rays, bds):\n    rng, new_params, model_loss = update_network_weights(rng, image, rays, params, inner_update_steps, bds)\n    \n    def calc_grad(params, new_params):\n        return params - new_params\n    model_grad = jax.tree_multimap(calc_grad, params, new_params)\n    \n    updates, opt_state = opt.update(model_grad, opt_state)\n    params = optimizers.apply_updates(params, updates)\n    return rng, params, opt_state, model_loss\n\n@jit\ndef update_model_single(rng, params, opt_state, image, rays, bds):\n    rng, new_params, model_loss = single_step(rng, image, rays, params, bds)\n    \n    def calc_grad(params, new_params):\n        return params - new_params\n    model_grad = jax.tree_multimap(calc_grad, params, new_params)\n    \n    updates, opt_state = opt.update(model_grad, opt_state)\n    params = optimizers.apply_updates(params, updates)\n    return rng, params, opt_state, model_loss\n\n\n\nplt_groups['Train PSNR'].append(exp_name+f'_train')\nplt_groups['Test PSNR'].append(exp_name+f'_test')\nstep = 0\n\ntrain_psnrs = []\nrng = jax.random.PRNGKey(0)\n\ntrain_steps = []\ntrain_psnrs_all = []\ntest_steps = []\ntest_psnrs_all = []\nfor step in tqdm(range(max_iters)):\n    try:\n        rng, rng_input = jax.random.split(rng)\n        img_idx = random.randint(rng_input, shape=(), minval=0, maxval=len(imgfiles)-25)        \n        images, rays, bds = get_example(img_idx, downsample=1)\n    except:\n        print('data loading error')\n        raise\n        continue\n        \n\n    images = np.reshape(images, (-1,3))\n    rays = np.reshape(rays, (2,-1,3))\n\n    if inner_update_steps == 1:\n        rng, rng_input = random.split(rng)\n        idx = random.randint(rng_input, shape=(batch_size,), minval=0, maxval=images.shape[0])\n        rng, params, opt_state, loss = update_model_single(rng, params, opt_state, \n                                                           images[idx,:], rays[:,idx,:], bds)\n    else:\n        rng, params, opt_state, loss = update_model(rng, params, opt_state, \n                                                    images, rays, bds)\n    train_psnrs.append(-10 * np.log10(loss))\n    \n    # track model loss\n    if step % 250 == 0:\n        plotlosses_model.update({exp_name+'_train':np.mean(np.array(train_psnrs))}, current_step=step)\n        train_steps.append(step)\n        train_psnrs_all.append(np.mean(np.array(train_psnrs)))\n        train_psnrs = []\n        \n    # run validation\n    if step % 500 == 0 and step != 0:\n        test_psnr = []\n        for ti in range(5):\n            test_images, test_rays, bds = get_example(ti, split='val', downsample=2)\n\n            test_images, test_holdout_images = np.split(test_images, [test_images.shape[1]//2], axis=1)\n            test_rays, test_holdout_rays = np.split(test_rays, [test_rays.shape[2]//2], axis=2)\n\n            test_images_flat = np.reshape(test_images, (-1,3))\n            test_rays = np.reshape(test_rays, (2,-1,3))\n\n            rng, test_params, test_inner_loss = update_network_weights(rng, test_images_flat, test_rays, params, test_inner_steps, bds)\n\n            test_result = np.clip(render_fn(rng, model, test_params, None, test_holdout_rays, bds[0], bds[1], N_samples, rand=False)[0], 0, 1)\n            test_psnr.append(psnr_fn(test_holdout_images, test_result))\n        test_psnr = np.mean(np.array(test_psnr))\n\n        test_steps.append(step)\n        test_psnrs_all.append(test_psnr)\n        \n        plotlosses_model.update({exp_name+'_test':test_psnr}, current_step=step)\n        plotlosses_model.send()\n\n        plt.figure(figsize=(15,5))   \n        plt.subplot(1,3, 1)\n        plt.imshow(test_images)\n        plt.subplot(1,3, 2)\n        plt.imshow(test_holdout_images)\n        plt.subplot(1,3, 3)\n        plt.imshow(test_result)\n        plt.show()\n        \n    # save model checkpoint + render sample view on test set for model check\n    if step % 10000 == 0 and step != 0:\n        test_images, test_rays, bds = get_example(0, split='test')\n        test_images_flat = np.reshape(test_images, (-1,3))\n        test_rays = np.reshape(test_rays, (2,-1,3))\n        rng, test_params_1, test_inner_loss = update_network_weights(rng, test_images_flat, test_rays, params, test_inner_steps, bds)\n\n        test_images, test_rays, bds = get_example(1, split='test')\n        test_images_flat = np.reshape(test_images, (-1,3))\n        test_rays = np.reshape(test_rays, (2,-1,3))\n        rng, test_params_2, test_inner_loss = update_network_weights(rng, test_images_flat, test_rays, params, test_inner_steps, bds)\n        \n        poses = posedata['c2w_mats']\n        c2w = poses_avg(poses)\n        focal = .8\n        render_poses = render_path_spiral(c2w, c2w[:3,1], [.1, .1, .05], focal, zrate=.5, rots=2, N=120)\n        \n        bds = np.array([5., 25.]) * .05\n        H = 128\n        W = H*3//2\n        f = H * 1.\n        kinv = np.array(\n            [1./f, 0, -W*.5/f,\n             0, -1./f, H*.5/f,\n             0, 0, -1.]\n        ).reshape([3,3])\n        i, j = np.meshgrid(np.arange(0,W), np.arange(0,H), indexing='xy')\n        renders = []\n        for p, c2w in enumerate(tqdm(render_poses)):\n            rays = get_rays(c2w, kinv, i, j)\n            interp = p / len(render_poses)\n            interp_params = jax.tree_multimap(lambda x, y: y*p/len(render_poses) + x*(1-p/len(render_poses)), test_params_1, test_params_2)\n            result = render_fn(rng, model, interp_params, None, rays, bds[0], bds[1], N_samples, rand=False)[0]\n            renders.append(result)\n        \n        renders = (np.clip(np.array(renders), 0, 1)*255).astype(np.uint8)\n        imageio.mimwrite(f'{exp_dir}render_sprial_{step}.mp4', renders, fps=30, quality=8)\n        \n        plt.plot(train_steps, train_psnrs_all)\n        plt.savefig(f'{exp_dir}train_curve_{step}.png')\n        \n        plt.plot(test_steps, test_psnrs_all)\n        plt.savefig(f'{exp_dir}test_curve_{step}.png')\n        \n        with open(f'{exp_dir}checkpount_{step}.pkl', 'wb') as file:\n            pickle.dump(params, file)","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.798560Z","iopub.execute_input":"2021-07-04T06:27:11.799038Z","iopub.status.idle":"2021-07-04T06:27:11.942383Z","shell.execute_reply.started":"2021-07-04T06:27:11.798994Z","shell.execute_reply":"2021-07-04T06:27:11.940947Z"},"trusted":true},"execution_count":75,"outputs":[{"traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)","\u001b[0;32m<ipython-input-75-e3130bb95dc5>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     14\u001b[0m \u001b[0mopt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptimizers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mopt_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[0mtest_inner_steps\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m64\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mAttributeError\u001b[0m: 'Optimizer' object has no attribute 'init_fun'"],"ename":"AttributeError","evalue":"'Optimizer' object has no attribute 'init_fun'","output_type":"error"}]},{"cell_type":"markdown","source":"### 1. PLAYGROUND\n- optimizers API: https://jax.readthedocs.io/en/latest/_modules/jax/experimental/optimizers.html#adam","metadata":{}},{"cell_type":"code","source":"from flax import linen as nn","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:17.457842Z","iopub.execute_input":"2021-07-04T06:27:17.458453Z","iopub.status.idle":"2021-07-04T06:27:17.462662Z","shell.execute_reply.started":"2021-07-04T06:27:17.458415Z","shell.execute_reply":"2021-07-04T06:27:17.461847Z"},"trusted":true},"execution_count":76,"outputs":[]},{"cell_type":"code","source":"nn.Dense?","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:28:32.651658Z","iopub.execute_input":"2021-07-04T06:28:32.652070Z","iopub.status.idle":"2021-07-04T06:28:32.728399Z","shell.execute_reply.started":"2021-07-04T06:28:32.652032Z","shell.execute_reply":"2021-07-04T06:28:32.726998Z"},"trusted":true},"execution_count":79,"outputs":[{"output_type":"display_data","data":{"text/plain":"\u001b[0;31mInit signature:\u001b[0m\n\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDense\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m    \u001b[0mfeatures\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m    \u001b[0muse_bias\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m    \u001b[0mdtype\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m<\u001b[0m\u001b[0;32mclass\u001b[0m \u001b[0;34m'jax._src.numpy.lax_numpy.float32'\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m    \u001b[0mprecision\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m    \u001b[0mkernel_init\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mCallable\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIterable\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m<\u001b[0m\u001b[0mfunction\u001b[0m \u001b[0mvariance_scaling\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m<\u001b[0m\u001b[0mlocals\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m \u001b[0mat\u001b[0m \u001b[0;36m0x7f66d9aac050\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m    \u001b[0mbias_init\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mCallable\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIterable\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m<\u001b[0m\u001b[0mfunction\u001b[0m \u001b[0mzeros\u001b[0m \u001b[0mat\u001b[0m \u001b[0;36m0x7f66e16b1050\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m    \u001b[0mparent\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mType\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mForwardRef\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Module'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mType\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mForwardRef\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Scope'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mType\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mForwardRef\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'_Sentinel'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m<\u001b[0m\u001b[0mflax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_Sentinel\u001b[0m \u001b[0mobject\u001b[0m \u001b[0mat\u001b[0m \u001b[0;36m0x7f66d9abf8d0\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m    \u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;31mDocstring:\u001b[0m     \nA linear transformation applied over the last dimension of the input.\n\nAttributes:\n  features: the number of output features.\n  use_bias: whether to add a bias to the output (default: True).\n  dtype: the dtype of the computation (default: float32).\n  precision: numerical precision of the computation see `jax.lax.Precision`\n    for details.\n  kernel_init: initializer function for the weight matrix.\n  bias_init: initializer function for the bias.\n\u001b[0;31mFile:\u001b[0m           /opt/conda/lib/python3.7/site-packages/flax/linen/linear.py\n\u001b[0;31mType:\u001b[0m           type\n\u001b[0;31mSubclasses:\u001b[0m     \n"},"metadata":{}}]},{"cell_type":"code","source":"class ModelFlax(nn.Module):\n    width = 256\n    depth = 6\n    \n    @nn.compact\n    def __call__(self, coords):\n        sh = coords.shape\n        coords = np.reshape(coords, [-1,3])\n        \n        # positional encoding\n        x = np.concatenate([np.concatenate([np.sin(coords*(2**i)), np.cos(coords*(2**i))], axis=-1) for i in np.linspace(0,8,20)], axis=-1)\n\n        for idx in range(self.depth-1):\n            #x = hk.Linear(output_size=self.width)(x)\n            x = nn.Dense(self.depth, name=f'fc{idx}')(x)\n            x = nn.relu(x)\n\n        #out = hk.Linear(output_size=4)(x)\n        out = nn.Dense(4, name='fc_last')(x)\n        out = np.reshape(out, list(sh[:-1]) + [4])\n        return out","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:37:54.690404Z","iopub.execute_input":"2021-07-04T06:37:54.690877Z","iopub.status.idle":"2021-07-04T06:37:54.702764Z","shell.execute_reply.started":"2021-07-04T06:37:54.690841Z","shell.execute_reply":"2021-07-04T06:37:54.701461Z"},"trusted":true},"execution_count":95,"outputs":[]},{"cell_type":"code","source":"model = ModelFlax()\nkey1, key2 = random.split(jax.random.PRNGKey(0))\ndummy_x = random.normal(key1, (1, 3))\nparams = model.init(key2, dummy_x)","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:37:55.558848Z","iopub.execute_input":"2021-07-04T06:37:55.559294Z","iopub.status.idle":"2021-07-04T06:37:56.881656Z","shell.execute_reply.started":"2021-07-04T06:37:55.559253Z","shell.execute_reply":"2021-07-04T06:37:56.878424Z"},"trusted":true},"execution_count":96,"outputs":[]},{"cell_type":"code","source":"dummy_x = \nModelFlax.init","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:33:01.022645Z","iopub.execute_input":"2021-07-04T06:33:01.023024Z","iopub.status.idle":"2021-07-04T06:33:01.032804Z","shell.execute_reply.started":"2021-07-04T06:33:01.022993Z","shell.execute_reply":"2021-07-04T06:33:01.031934Z"},"trusted":true},"execution_count":85,"outputs":[{"execution_count":85,"output_type":"execute_result","data":{"text/plain":"<function flax.linen.module.Module.init(self, rngs: Union[Any, Dict[str, Any]], *args, method: Union[Callable[..., Any], NoneType] = None, mutable: Union[bool, str, Container[str], ForwardRef('DenyList')] = DenyList(deny='intermediates'), **kwargs) -> flax.core.frozen_dict.FrozenDict[str, typing.Mapping[str, typing.Any]]>"},"metadata":{}}]},{"cell_type":"code","source":"model = hk.without_apply_rng(hk.transform(lambda x, y=None: ModelHaiku()(x, y)))\nparams = model.init(rng, np.ones((1,3)))\n","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"class ModelHaiku(hk.Module):\n    def __init__(self):\n        super().__init__()\n        self.width = 256\n        self.depth = 6\n                \n    def __call__(self, coords):\n        sh = coords.shape\n        coords = np.reshape(coords, [-1,3])\n        \n        # positional encoding\n        x = np.concatenate([np.concatenate([np.sin(coords*(2**i)), np.cos(coords*(2**i))], axis=-1) for i in np.linspace(0,8,20)], axis=-1)\n\n        for _ in range(self.depth-1):\n            x = hk.Linear(output_size=self.width)(x)\n            x = jax.nn.relu(x)\n\n        out = hk.Linear(output_size=4)(x)\n        out = np.reshape(out, list(sh[:-1]) + [4])\n        return out","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:20.392396Z","iopub.execute_input":"2021-07-04T06:27:20.393211Z","iopub.status.idle":"2021-07-04T06:27:20.403900Z","shell.execute_reply.started":"2021-07-04T06:27:20.393166Z","shell.execute_reply":"2021-07-04T06:27:20.402592Z"},"trusted":true},"execution_count":77,"outputs":[]},{"cell_type":"code","source":"## OLD JAX + HAIKU\ndef single_step(rng, image, rays, params, bds):\n    def sgd(param, update):\n        return param - inner_step_size * update\n    \n    rng, rng_inputs = jax.random.split(rng)\n    def loss_model(params):\n        g = render_rays(rng_inputs, model, params, None, rays, bds[0], bds[1], N_samples, rand=True)\n        return mse_fn(g, image)\n    \n    model_loss, grad = jax.value_and_grad(loss_model)(params)\n    new_params = jax.tree_multimap(sgd, params, grad)\n    return rng, new_params, model_loss\n\n\nmodel = hk.without_apply_rng(hk.transform(lambda x, y=None: ModelHaiku()(x, y)))\nparams = model.init(rng, np.ones((1,3)))\nopt = optimizers.adam(lr)\nopt_state = opt.init_fun(params)\nupdates, opt_state = opt.update(model_grad, opt_state)\nparams = optimizers.apply_updates(params, updates)","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:21.016922Z","iopub.execute_input":"2021-07-04T06:27:21.017295Z","iopub.status.idle":"2021-07-04T06:27:21.065052Z","shell.execute_reply.started":"2021-07-04T06:27:21.017263Z","shell.execute_reply":"2021-07-04T06:27:21.062768Z"},"trusted":true},"execution_count":78,"outputs":[{"traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)","\u001b[0;32m<ipython-input-78-72b1606120c8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwithout_apply_rng\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mModelHaiku\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     18\u001b[0m \u001b[0mopt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptimizers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m \u001b[0mopt_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/opt/conda/lib/python3.7/site-packages/haiku/_src/transform.py\u001b[0m in \u001b[0;36minit_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    109\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    110\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0minit_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 111\u001b[0;31m     \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    112\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    113\u001b[0m       raise ValueError(\"If your transformed function uses `hk.{get,set}_state` \"\n","\u001b[0;32m/opt/conda/lib/python3.7/site-packages/haiku/_src/transform.py\u001b[0m in \u001b[0;36minit_fn\u001b[0;34m(rng, *args, **kwargs)\u001b[0m\n\u001b[1;32m    295\u001b[0m     \u001b[0mrng\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mto_prng_sequence\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merr_msg\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mINIT_RNG_ERROR\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    296\u001b[0m     \u001b[0;32mwith\u001b[0m \u001b[0mbase\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 297\u001b[0;31m       \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    298\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollect_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollect_initial_state\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    299\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m<ipython-input-78-72b1606120c8>\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(x, y)\u001b[0m\n\u001b[1;32m     14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwithout_apply_rng\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mModelHaiku\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     17\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     18\u001b[0m \u001b[0mopt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptimizers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/opt/conda/lib/python3.7/site-packages/haiku/_src/module.py\u001b[0m in \u001b[0;36mwrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    426\u001b[0m         \u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstateful\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnamed_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlocal_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    427\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 428\u001b[0;31m       \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    429\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    430\u001b[0m       \u001b[0;31m# Module names are set in the constructor. If `f` is the constructor then\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/opt/conda/lib/python3.7/site-packages/haiku/_src/module.py\u001b[0m in \u001b[0;36mrun_interceptors\u001b[0;34m(bound_method, method_name, self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    277\u001b[0m   \u001b[0;34m\"\"\"Runs any method interceptors or the original method.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    278\u001b[0m   \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0minterceptor_stack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 279\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mbound_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    280\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    281\u001b[0m   ctx = MethodContext(module=self,\n","\u001b[0;31mTypeError\u001b[0m: __call__() takes 2 positional arguments but 3 were given"],"ename":"TypeError","evalue":"__call__() takes 2 positional arguments but 3 were given","output_type":"error"}]},{"cell_type":"code","source":"## NEW JAX + HAIKU\nlr = 1e-3\nnum_steps = 3\n\nmodel = hk.without_apply_rng(hk.transform(lambda x: ModelHaiku()(x)))\nparams = model.init(rng, np.ones((1,3)))\nopt_init, opt_update, get_params = optimizers.adam(lr)\nopt_state = opt_init(params)\n\n\ndef single_step_v2(step, rng, image, rays, bds, opt_state):\n    def loss_model(params):\n        g = render_rays(rng_inputs, model, params,\n                        None, rays, bds[0], bds[1], \n                        N_samples, rand=True)\n        return mse_fn(g, image)\n    rng, rng_inputs = jax.random.split(rng)\n    value, grads = jax.value_and_grad(loss_model)(get_params(opt_state))\n    opt_state = opt_update(step, grads, opt_state)\n    return value, opt_state","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.949773Z","iopub.status.idle":"2021-07-04T06:27:11.950196Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"rng = jax.random.PRNGKey(0)\n\nfor istep in range(num_steps):\n    rng, rng_input = jax.random.split(rng)\n    img_idx = random.randint(rng_input, shape=(), minval=0, maxval=len(imgfiles)-25)\n    images, rays, bds = get_example(img_idx, downsample=1)\n    images = np.reshape(images, (-1,3))\n    rays = np.reshape(rays, (2,-1,3))\n    rng, rng_input = random.split(rng)\n    idx = random.randint(rng_input, shape=(batch_size,), minval=0, maxval=images.shape[0])\n    loss, opt_state = single_step_v2(istep, rng, images[idx,:], rays[:,idx,:], bds, opt_state)","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.951115Z","iopub.status.idle":"2021-07-04T06:27:11.951524Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"## NEW JAX + FLAX!\nlr = 1e-3\nnum_steps = 3\n\n# model = hk.without_apply_rng(hk.transform(lambda x: ModelHaiku()(x)))\n# params = model.init(rng, np.ones((1,3)))\n\nmodel = ModelFlax()\nkey1, key2 = random.split(jax.random.PRNGKey(0))\ndummy_x = random.normal(key1, (1, 3))\nparams = model.init(key2, dummy_x)\n\nopt_init, opt_update, get_params = optimizers.adam(lr)\nopt_state = opt_init(params)\n\n\ndef single_step_v2(step, rng, image, rays, bds, opt_state):\n    def loss_model(params):\n        g = render_rays(rng_inputs, model, params,\n                        None, rays, bds[0], bds[1], \n                        N_samples, rand=True)\n        return mse_fn(g, image)\n    rng, rng_inputs = jax.random.split(rng)\n    value, grads = jax.value_and_grad(loss_model)(get_params(opt_state))\n    opt_state = opt_update(step, grads, opt_state)\n    return value, opt_state","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:39:14.752065Z","iopub.execute_input":"2021-07-04T06:39:14.752535Z","iopub.status.idle":"2021-07-04T06:39:14.838259Z","shell.execute_reply.started":"2021-07-04T06:39:14.752490Z","shell.execute_reply":"2021-07-04T06:39:14.837384Z"},"trusted":true},"execution_count":100,"outputs":[]},{"cell_type":"code","source":"opt_state","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:39:18.950560Z","iopub.execute_input":"2021-07-04T06:39:18.951230Z","iopub.status.idle":"2021-07-04T06:39:18.998534Z","shell.execute_reply.started":"2021-07-04T06:39:18.951187Z","shell.execute_reply":"2021-07-04T06:39:18.997174Z"},"collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"execution_count":101,"outputs":[{"execution_count":101,"output_type":"execute_result","data":{"text/plain":"OptimizerState(packed_state=([DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32)], [DeviceArray([[ 0.12091248,  0.03934956,  0.09956338, -0.03747796,\n              -0.04041335,  0.02077009],\n             [ 0.05760278, -0.0443754 ,  0.05104939, -0.14502399,\n               0.0776381 ,  0.1672416 ],\n             [ 0.11950634, -0.01816947,  0.01361674,  0.03426957,\n              -0.04762232, -0.01502447],\n             [-0.07393871,  0.04445767,  0.02080443,  0.00191198,\n               0.06556749,  0.04923652],\n             [ 0.0131189 ,  0.02381099, -0.01838387,  0.12599164,\n              -0.06083291, -0.05262698],\n             [-0.1097906 ,  0.09892927, -0.13584226,  0.03626684,\n               0.02823985, -0.10594288],\n             [ 0.0871003 , -0.09209569, -0.0871902 ,  0.0108161 ,\n              -0.05293119, -0.06841837],\n             [ 0.04326692, -0.13521153,  0.20045498, -0.10092928,\n               0.16232044, -0.05342549],\n             [ 0.0112839 ,  0.11713059,  0.04803858,  0.06951734,\n               0.0195748 ,  0.06725274],\n             [ 0.03751812, -0.11262533, -0.01647667, -0.03335583,\n              -0.13196915,  0.01376943],\n             [-0.04743343,  0.06179082,  0.04837865,  0.07638828,\n              -0.09461062, -0.19005549],\n             [ 0.08255602, -0.13530254,  0.06071713, -0.04774788,\n              -0.03129638,  0.0693047 ],\n             [ 0.04571046, -0.03062774,  0.0852575 , -0.10863639,\n               0.04058903,  0.04844888],\n             [-0.04248823, -0.05062192, -0.01092482, -0.02707409,\n              -0.13995013, -0.10880554],\n             [-0.11670938,  0.10568915, -0.16793151, -0.01668308,\n              -0.06549685,  0.04782696],\n             [ 0.06555521,  0.15952255, -0.02444034, -0.14221518,\n               0.00795548,  0.00058907],\n             [ 0.06769255, -0.10001128,  0.01939742, -0.16975056,\n               0.08136977,  0.05386211],\n             [-0.08528083,  0.13880436, -0.12550238,  0.07701986,\n              -0.04409956, -0.13809827],\n             [-0.04213446, -0.1537718 ,  0.07943241, -0.17456312,\n               0.15739301, -0.02497932],\n             [ 0.05776352, -0.06588574,  0.07887202,  0.11499051,\n              -0.03748027, -0.01763058],\n             [ 0.00971414, -0.02444758, -0.04719648, -0.04523701,\n               0.11060363,  0.07518162],\n             [ 0.17120014, -0.02850895,  0.03745553,  0.01270946,\n              -0.03092966,  0.04966141],\n             [-0.0771074 ,  0.00305039, -0.1426532 , -0.00744342,\n              -0.08475614,  0.07027218],\n             [-0.1148666 ,  0.0939114 , -0.20275104,  0.05386702,\n              -0.03729506, -0.09220184],\n             [-0.08134414, -0.11703359, -0.0546924 ,  0.07569566,\n              -0.03793426,  0.14144461],\n             [ 0.00797677,  0.09448453,  0.06531649,  0.00926861,\n               0.07197642,  0.047566  ],\n             [ 0.01645362, -0.01201729,  0.11911698, -0.05343648,\n               0.0041722 , -0.04309576],\n             [ 0.07480554, -0.15580009,  0.06402453,  0.14318122,\n               0.05626502, -0.05839388],\n             [-0.07621136,  0.08332268,  0.10026689, -0.09124082,\n               0.0228574 , -0.02363835],\n             [ 0.0955759 , -0.18413888,  0.0389773 ,  0.04142505,\n               0.06638285, -0.0562977 ],\n             [ 0.06386134,  0.02720288,  0.00071549,  0.11486595,\n               0.11699079,  0.0286552 ],\n             [-0.08813303, -0.10803533, -0.05704414,  0.09067526,\n              -0.05155005, -0.07570118],\n             [-0.03661836, -0.01282759, -0.11591093, -0.15659498,\n               0.19731197, -0.08506151],\n             [ 0.04415998, -0.16286144, -0.00652541,  0.1599725 ,\n               0.0077349 ,  0.12481923],\n             [ 0.15215434, -0.00402213,  0.01797182,  0.13577195,\n               0.07287586,  0.07491174],\n             [ 0.14486   , -0.03073188,  0.0474685 , -0.08027478,\n               0.08148502, -0.15101376],\n             [ 0.08231065, -0.04066302,  0.00606986, -0.05671879,\n              -0.03834728,  0.03751288],\n             [ 0.05405528, -0.12826225,  0.13363709,  0.06226205,\n               0.1394112 , -0.00298018],\n             [-0.03758067,  0.17012277,  0.11220542,  0.09133592,\n              -0.07800542, -0.06042796],\n             [-0.05766695,  0.06635129,  0.01898533,  0.02079507,\n              -0.04781189,  0.01584896],\n             [ 0.12000843,  0.16353324, -0.09740052,  0.10700001,\n               0.10240921,  0.1396231 ],\n             [ 0.13632071, -0.09265424, -0.03013684,  0.01697874,\n              -0.0997249 ,  0.04033921],\n             [-0.06027675, -0.08162045,  0.0226552 ,  0.03845311,\n               0.15729518,  0.0500581 ],\n             [ 0.00548866, -0.10702777,  0.1363892 ,  0.0008856 ,\n               0.04104868,  0.00985958],\n             [-0.0602491 , -0.0130259 , -0.06303586, -0.15844673,\n               0.06022697, -0.16976918],\n             [-0.03691809, -0.02184887, -0.00489147, -0.07071362,\n               0.01866116,  0.01260312],\n             [ 0.09207677, -0.0602571 ,  0.06722448,  0.03493843,\n               0.08584649,  0.04340586],\n             [ 0.08255793, -0.07338887,  0.04268095, -0.04148448,\n               0.02685298,  0.05295678],\n             [-0.05688781,  0.0053965 ,  0.05855345,  0.05604774,\n               0.11289045,  0.0337536 ],\n             [-0.03694104, -0.04049438,  0.11850937,  0.17993085,\n               0.03200574,  0.10255481],\n             [-0.0245451 , -0.1777949 ,  0.02939972, -0.09398863,\n               0.18782113, -0.1360264 ],\n             [-0.15326034, -0.03229709, -0.0839939 ,  0.09428229,\n               0.0673226 , -0.1048438 ],\n             [ 0.03231962, -0.19225913,  0.05750696, -0.12309711,\n              -0.0318327 ,  0.05704029],\n             [ 0.07545964,  0.17980802,  0.14941423,  0.03874794,\n               0.03335564,  0.13724373],\n             [ 0.13273065,  0.06468265,  0.01514143, -0.03719363,\n               0.11870434,  0.02415203],\n             [-0.10081078, -0.01774029,  0.04868399,  0.02396935,\n               0.0357847 , -0.07437636],\n             [ 0.06230159, -0.08518022, -0.13052619, -0.0298177 ,\n              -0.06633851, -0.15825692],\n             [ 0.07715177, -0.05978196,  0.05070818, -0.11113419,\n               0.10514048, -0.15949102],\n             [-0.06095733, -0.03677897, -0.00935357, -0.12114109,\n               0.05016354,  0.07169061],\n             [-0.10700291,  0.00947172, -0.12596504, -0.11015648,\n              -0.10353563, -0.03361167],\n             [ 0.0635042 , -0.04348218, -0.10771246,  0.18733755,\n               0.17300211, -0.05960783],\n             [ 0.09768728,  0.08765952, -0.03345449, -0.08652668,\n               0.06814633, -0.04087353],\n             [-0.11300852, -0.02791403, -0.03409228, -0.12401181,\n               0.08971363,  0.0175717 ],\n             [ 0.0301298 , -0.06322609, -0.04609533,  0.04999686,\n              -0.06116783, -0.06802417],\n             [-0.11089841,  0.07835439,  0.06409204, -0.11715482,\n              -0.03738579,  0.20215245],\n             [-0.14400601, -0.09568423,  0.00715016,  0.10031349,\n              -0.11718752, -0.0624514 ],\n             [ 0.02698691, -0.09550516,  0.0097266 ,  0.04360985,\n               0.13089257,  0.07150275],\n             [ 0.00952834,  0.17542061,  0.1686825 , -0.04919006,\n              -0.2015965 , -0.13389407],\n             [ 0.01955793, -0.04790407,  0.08299995, -0.04268588,\n               0.01052913,  0.02939646],\n             [-0.0022939 , -0.17001595,  0.01649912,  0.00899479,\n              -0.04769804, -0.03968188],\n             [ 0.0478185 ,  0.06605493,  0.02228497, -0.1544845 ,\n               0.1604688 ,  0.0804143 ],\n             [ 0.01103472,  0.04315194, -0.01691535,  0.07182854,\n              -0.05721959, -0.00783077],\n             [ 0.00145147, -0.05753785,  0.14071816,  0.01597903,\n               0.01853715,  0.18673742],\n             [ 0.10080386, -0.03879032, -0.00257544,  0.00165138,\n               0.01999713,  0.07541166],\n             [ 0.17055124, -0.08377179, -0.07121725, -0.1030369 ,\n               0.03812077,  0.08947916],\n             [ 0.14384921, -0.00645927, -0.09719041, -0.17143993,\n               0.0341155 , -0.15642157],\n             [-0.14672083, -0.18627107,  0.09514766,  0.1655882 ,\n               0.07048287, -0.04141303],\n             [-0.11038305,  0.01266968, -0.13012522,  0.00432192,\n               0.00661376, -0.12586735],\n             [ 0.04380185, -0.04043896,  0.04071361, -0.1095993 ,\n               0.02999796, -0.09467401],\n             [-0.02747682,  0.01572298,  0.17412968, -0.01638444,\n              -0.07010888,  0.09199664],\n             [ 0.14056218,  0.03970063,  0.02137369, -0.00897393,\n              -0.00536025, -0.00209912],\n             [ 0.10748938, -0.05072176, -0.10428328,  0.01942103,\n               0.15795839,  0.14172362],\n             [-0.19731991, -0.02555625, -0.15654613,  0.10174139,\n              -0.10126485,  0.07406058],\n             [ 0.0051182 , -0.04685609, -0.19060855, -0.00625047,\n              -0.15289049,  0.18272123],\n             [-0.19823857, -0.02669578,  0.0014769 , -0.01741426,\n               0.10098754, -0.10494607],\n             [ 0.1671394 ,  0.11931118, -0.0711092 ,  0.06434302,\n               0.05375719, -0.11997687],\n             [ 0.08855402,  0.10628866, -0.14936283,  0.118646  ,\n               0.07264227, -0.08480131],\n             [ 0.02289091, -0.17247637,  0.03616032, -0.08103012,\n              -0.11150467,  0.0383321 ],\n             [ 0.01550049, -0.03067526, -0.0012579 , -0.09759074,\n               0.034106  ,  0.06797335],\n             [ 0.12155779, -0.14713879,  0.05044763, -0.15492183,\n               0.10530532, -0.10844996],\n             [-0.0395255 , -0.00755156, -0.03988497,  0.0533135 ,\n              -0.19197764, -0.14226356],\n             [-0.01343509,  0.00661296,  0.16345409, -0.06798036,\n               0.05056882, -0.05704824],\n             [ 0.11013287, -0.12514994, -0.02816967,  0.07051858,\n               0.0587835 ,  0.09308016],\n             [ 0.0209928 ,  0.17353372,  0.16734958,  0.04956442,\n               0.00093992,  0.00340455],\n             [-0.09287037, -0.12340804,  0.03708031, -0.04078699,\n              -0.18511374,  0.13498192],\n             [ 0.15583818,  0.20263459, -0.07443506, -0.11255585,\n              -0.07247806,  0.0619547 ],\n             [ 0.10758242, -0.00078523, -0.02034423,  0.09368771,\n              -0.1336972 , -0.02547406],\n             [-0.11606696, -0.0244782 ,  0.09383436, -0.09345954,\n               0.00139141,  0.05733661],\n             [-0.18299046, -0.01726742,  0.07100719, -0.04187343,\n               0.17501818,  0.08150984],\n             [-0.06030202, -0.15978408,  0.00370273,  0.03474844,\n               0.02906018, -0.05044353],\n             [-0.10054277,  0.01139852,  0.03567186, -0.08161601,\n               0.00288773, -0.20045884],\n             [ 0.04335637,  0.00288314, -0.09950867,  0.01127051,\n               0.02589588,  0.05039444],\n             [-0.20394403,  0.12326696, -0.1643659 , -0.0723019 ,\n              -0.09912831,  0.04431969],\n             [ 0.16497536, -0.16872397, -0.04751958, -0.11541699,\n               0.18788281,  0.08790423],\n             [ 0.13208571, -0.08465189, -0.02066516, -0.10026854,\n              -0.04670858,  0.03192634],\n             [-0.00228765, -0.00040981, -0.02937566, -0.0937202 ,\n              -0.194368  ,  0.06886023],\n             [-0.00191272,  0.03282534,  0.12399007,  0.05509276,\n               0.18744221, -0.00469013],\n             [ 0.08384778, -0.04454415,  0.02690851,  0.03728205,\n              -0.08604282,  0.02411885],\n             [-0.15516514,  0.09445769,  0.00470266, -0.12434138,\n              -0.04008782, -0.19726186],\n             [-0.17874427,  0.0048454 , -0.00387846, -0.08311484,\n               0.05458146,  0.16586116],\n             [ 0.08855603, -0.07168159,  0.02372633, -0.19512348,\n               0.02156699,  0.17569917],\n             [-0.03421487,  0.01564328, -0.09506329, -0.06229905,\n               0.01248883, -0.11405812],\n             [ 0.08093395,  0.10542554, -0.07732578, -0.04079597,\n              -0.02625835, -0.07219328],\n             [-0.12371734, -0.05892467,  0.1537514 ,  0.02370239,\n              -0.09075558,  0.01606907],\n             [-0.00653294, -0.07476543, -0.01906146, -0.00472945,\n              -0.09184022,  0.15842594],\n             [-0.08626378, -0.1070536 ,  0.03994581,  0.06633023,\n               0.05852774, -0.04346308],\n             [-0.03864894, -0.157292  ,  0.0475421 ,  0.1494269 ,\n              -0.18904321,  0.09033928],\n             [-0.01302009,  0.01030955, -0.15401638,  0.07121224,\n              -0.15460083, -0.06733674],\n             [-0.12611501,  0.00686998,  0.05387065, -0.00674056,\n              -0.00749963, -0.12718946],\n             [ 0.03201933,  0.19906111,  0.20728587,  0.10627076,\n              -0.14703809, -0.0194511 ]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.]], dtype=float32)], [DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32)], [DeviceArray([[ 0.19187102, -0.08696733, -0.2890502 ,  0.625623  ,\n               0.5421405 ,  0.4058794 ],\n             [-0.32072243,  0.21094166,  0.12682122, -0.25806317,\n               0.7562236 ,  0.26565063],\n             [-0.31112608,  0.64321524, -0.6304671 , -0.2540195 ,\n              -0.34511346,  0.8914479 ],\n             [ 0.0170626 ,  0.08604698, -0.3607038 , -0.46026543,\n               0.05988517,  0.16992189],\n             [-0.02520008, -0.5196175 ,  0.40215695, -0.7760595 ,\n              -0.3977024 ,  0.32397616],\n             [-0.47431922,  0.28736246, -0.82317984, -0.62767595,\n              -0.6443609 , -0.45300403]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.]], dtype=float32)], [DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32)], [DeviceArray([[ 7.09373713e-01,  5.58225453e-01,  1.05266884e-01,\n               8.47791284e-02,  4.02058363e-01,  1.03561655e-01],\n             [ 1.63567051e-01,  3.24267261e-02, -6.43938184e-01,\n               8.08098242e-02,  6.05995297e-01, -6.22118413e-01],\n             [-1.64262623e-01, -2.33843014e-01, -1.84899673e-01,\n              -1.95914432e-01,  2.11128980e-01,  2.52425253e-01],\n             [ 6.72683001e-01, -2.49966606e-01, -1.87426805e-01,\n               2.40193129e-01, -5.11264145e-01, -4.63818789e-01],\n             [-2.92124599e-03,  4.82702762e-01, -6.04833551e-02,\n              -8.24802756e-01, -3.40679497e-01,  1.23547213e-02],\n             [ 1.46633625e-01, -7.92004820e-03, -2.78792083e-01,\n              -3.65616666e-04,  8.00321922e-02,  6.98535085e-01]],            dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.]], dtype=float32)], [DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32)], [DeviceArray([[-0.30406743,  0.42582163, -0.0433044 ,  0.4696062 ,\n               0.162392  , -0.45115194],\n             [-0.02224123,  0.26482254, -0.3821817 ,  0.8167653 ,\n               0.0565933 , -0.27328548],\n             [-0.3441138 ,  0.3751544 ,  0.05917086,  0.4884692 ,\n               0.04338139,  0.2134114 ],\n             [-0.3870027 , -0.3817356 ,  0.31533888,  0.10569409,\n              -0.00611069, -0.19789281],\n             [ 0.12731338,  0.3734065 , -0.10944682,  0.5042247 ,\n              -0.26706368, -0.17124975],\n             [ 0.35682514,  0.1869185 ,  0.23269117, -0.8522264 ,\n              -0.20571367, -0.33262372]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.]], dtype=float32)], [DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32)], [DeviceArray([[ 0.4786232 ,  0.07740223,  0.23075195,  0.24710925,\n              -0.00430446,  0.09901427],\n             [-0.7429878 ,  0.7190723 ,  0.18975034, -0.26684535,\n               0.57330495,  0.10749137],\n             [-0.5879746 ,  0.22385383,  0.5853615 ,  0.15528546,\n               0.563753  , -0.3713637 ],\n             [-0.13496533,  0.30754456, -0.9095284 , -0.00883   ,\n              -0.14046153,  0.64349794],\n             [-0.47050482, -0.0958715 , -0.3475464 ,  0.02285524,\n              -0.23184529, -0.55695194],\n             [ 0.46003628, -0.26375806,  0.1943689 ,  0.04007643,\n              -0.29916492,  0.16186029]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.],\n             [0., 0., 0., 0., 0., 0.]], dtype=float32)], [DeviceArray([0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0.], dtype=float32)], [DeviceArray([[-0.02883793, -0.06974681,  0.24037988, -0.45696107],\n             [-0.66506857, -0.35821897, -0.34971425,  0.35783553],\n             [ 0.66095954, -0.03526216,  0.7372601 , -0.4655306 ],\n             [-0.11639789, -0.82979643, -0.01575499, -0.20149845],\n             [ 0.23644614, -0.3937142 , -0.61596936, -0.27844813],\n             [ 0.9190065 ,  0.37393552,  0.05566693, -0.7466602 ]],            dtype=float32), DeviceArray([[0., 0., 0., 0.],\n             [0., 0., 0., 0.],\n             [0., 0., 0., 0.],\n             [0., 0., 0., 0.],\n             [0., 0., 0., 0.],\n             [0., 0., 0., 0.]], dtype=float32), DeviceArray([[0., 0., 0., 0.],\n             [0., 0., 0., 0.],\n             [0., 0., 0., 0.],\n             [0., 0., 0., 0.],\n             [0., 0., 0., 0.],\n             [0., 0., 0., 0.]], dtype=float32)]), tree_def=PyTreeDef(CustomNode(<class 'flax.core.frozen_dict.FrozenDict'>[()], [{'params': {'fc0': {'bias': *, 'kernel': *}, 'fc1': {'bias': *, 'kernel': *}, 'fc2': {'bias': *, 'kernel': *}, 'fc3': {'bias': *, 'kernel': *}, 'fc4': {'bias': *, 'kernel': *}, 'fc_last': {'bias': *, 'kernel': *}}}])), subtree_defs=(PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *))))"},"metadata":{}}]},{"cell_type":"code","source":"rng = jax.random.PRNGKey(0)\n\nfor istep in range(num_steps):\n    rng, rng_input = jax.random.split(rng)\n    img_idx = random.randint(rng_input, shape=(), minval=0, maxval=len(imgfiles)-25)\n    images, rays, bds = get_example(img_idx, downsample=1)\n    images = np.reshape(images, (-1,3))\n    rays = np.reshape(rays, (2,-1,3))\n    rng, rng_input = random.split(rng)\n    idx = random.randint(rng_input, shape=(batch_size,), minval=0, maxval=images.shape[0])\n    loss, opt_state = single_step_v2(istep, rng, images[idx,:], rays[:,idx,:], bds, opt_state)","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:39:32.502213Z","iopub.execute_input":"2021-07-04T06:39:32.502642Z","iopub.status.idle":"2021-07-04T06:39:33.647639Z","shell.execute_reply.started":"2021-07-04T06:39:32.502606Z","shell.execute_reply":"2021-07-04T06:39:33.646620Z"},"trusted":true},"execution_count":102,"outputs":[]},{"cell_type":"code","source":"opt_state","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:39:35.648041Z","iopub.execute_input":"2021-07-04T06:39:35.648460Z","iopub.status.idle":"2021-07-04T06:39:35.708276Z","shell.execute_reply.started":"2021-07-04T06:39:35.648417Z","shell.execute_reply":"2021-07-04T06:39:35.706959Z"},"collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"execution_count":103,"outputs":[{"execution_count":103,"output_type":"execute_result","data":{"text/plain":"OptimizerState(packed_state=([DeviceArray([-0.00045145, -0.00143954, -0.00263059, -0.00230609,\n              0.00044354,  0.0007933 ], dtype=float32), DeviceArray([ 7.4164245e-07,  9.0824096e-06,  1.9418590e-06,\n              5.1692371e-07,  3.2076082e-06, -8.0520622e-06],            dtype=float32), DeviceArray([3.7262568e-14, 6.9343763e-12, 2.2007938e-13, 4.5521316e-14,\n             4.6548199e-12, 9.6004038e-12], dtype=float32)], [DeviceArray([[ 1.19515203e-01,  3.79956812e-02,  1.00867495e-01,\n              -3.65395024e-02, -3.95988673e-02,  2.29738709e-02],\n             [ 5.77125698e-02, -4.45387252e-02,  4.90781777e-02,\n              -1.46169439e-01,  7.75758103e-02,  1.68276310e-01],\n             [ 1.20900065e-01, -1.95997935e-02,  1.10316165e-02,\n               3.20435129e-02, -4.71294001e-02, -1.43090524e-02],\n             [-7.43934736e-02,  4.30180728e-02,  1.81723312e-02,\n              -3.99712473e-04,  6.60072789e-02,  5.00307828e-02],\n             [ 1.26672154e-02,  2.23715417e-02, -2.10146792e-02,\n               1.23676464e-01, -6.03894331e-02, -5.18337041e-02],\n             [-1.10233277e-01,  9.74887088e-02, -1.38467059e-01,\n               3.39527391e-02,  2.86765136e-02, -1.05141982e-01],\n             [ 8.55626091e-02, -9.34591144e-02, -8.58496949e-02,\n               1.18011115e-02, -5.21217994e-02, -6.61588833e-02],\n             [ 4.34344709e-02, -1.35381013e-01,  1.98374435e-01,\n              -1.02123044e-01,  1.62237614e-01, -5.23203351e-02],\n             [ 1.26959765e-02,  1.15699135e-01,  4.54383790e-02,\n               6.72767386e-02,  2.00665258e-02,  6.79813847e-02],\n             [ 3.70608568e-02, -1.14064977e-01, -1.91099793e-02,\n              -3.56719680e-02, -1.31532311e-01,  1.45644369e-02],\n             [-4.78853099e-02,  6.03514388e-02,  4.57476750e-02,\n               7.40659088e-02, -9.41671953e-02, -1.89262241e-01],\n             [ 8.21234509e-02, -1.36744007e-01,  5.80971800e-02,\n              -5.00687174e-02, -3.08657531e-02,  7.01127574e-02],\n             [ 4.40107286e-02, -3.19981165e-02,  8.66279900e-02,\n              -1.07607514e-01,  4.13880087e-02,  5.07526249e-02],\n             [-4.22691666e-02, -5.07959910e-02, -1.31016057e-02,\n              -2.83072051e-02, -1.40049130e-01, -1.07643224e-01],\n             [-1.15292378e-01,  1.04256600e-01, -1.70544341e-01,\n              -1.89380869e-02, -6.50074780e-02,  4.85670343e-02],\n             [ 6.50937259e-02,  1.58082798e-01, -2.70758662e-02,\n              -1.44539312e-01,  8.38710368e-03,  1.38538575e-03],\n             [ 6.72403350e-02, -1.01450533e-01,  1.67661272e-02,\n              -1.72085881e-01,  8.18130970e-02,  5.46553172e-02],\n             [-8.56901556e-02,  1.37361005e-01, -1.28114462e-01,\n               7.46859163e-02, -4.36814055e-02, -1.37273654e-01],\n             [-4.40381952e-02, -1.55146882e-01,  8.08292851e-02,\n              -1.73489004e-01,  1.58172131e-01, -2.26417482e-02],\n             [ 5.80274314e-02, -6.60630912e-02,  7.66124576e-02,\n               1.13725699e-01, -3.75916548e-02, -1.64222773e-02],\n             [ 1.09880436e-02, -2.58814096e-02, -4.98207957e-02,\n              -4.75093797e-02,  1.11088477e-01,  7.59330243e-02],\n             [ 1.70731843e-01, -2.99488902e-02,  3.48158926e-02,\n               1.03710145e-02, -3.05071846e-02,  5.04600182e-02],\n             [-7.75601938e-02,  1.61137013e-03, -1.45285055e-01,\n              -9.80186835e-03, -8.43129978e-02,  7.10653290e-02],\n             [-1.15224957e-01,  9.24633294e-02, -2.05374062e-01,\n               5.15054911e-02, -3.69068421e-02, -9.13216770e-02],\n             [-8.35028067e-02, -1.18411213e-01, -5.32682054e-02,\n               7.68247694e-02, -3.71921435e-02,  1.43807575e-01],\n             [ 8.27940181e-03,  9.43050608e-02,  6.29868209e-02,\n               7.97823165e-03,  7.18559921e-02,  4.88115028e-02],\n             [ 1.66943446e-02, -1.34530757e-02,  1.16480984e-01,\n              -5.57347909e-02,  4.64809174e-03, -4.23300602e-02],\n             [ 7.43269697e-02, -1.57240435e-01,  6.13770597e-02,\n               1.40817121e-01,  5.66717610e-02, -5.75914197e-02],\n             [-7.66651481e-02,  8.18840712e-02,  9.76340249e-02,\n              -9.36404243e-02,  2.33002231e-02, -2.28453223e-02],\n             [ 9.53125134e-02, -1.85607061e-01,  3.80081609e-02,\n               3.89923938e-02,  6.66604489e-02, -5.37296869e-02],\n             [ 6.14767075e-02,  2.58252937e-02,  2.17701704e-03,\n               1.16079114e-01,  1.17664330e-01,  3.10384557e-02],\n             [-8.77963677e-02, -1.08215772e-01, -5.94331436e-02,\n               8.93637240e-02, -5.16765937e-02, -7.44252503e-02],\n             [-3.69373262e-02, -1.42672677e-02, -1.18558399e-01,\n              -1.58940330e-01,  1.97768405e-01, -8.42699409e-02],\n             [ 4.36676294e-02, -1.64302766e-01, -9.18847229e-03,\n               1.57562688e-01,  8.11561663e-03,  1.25627786e-01],\n             [ 1.51698932e-01, -5.45999967e-03,  1.53371440e-02,\n               1.33300990e-01,  7.33180791e-02,  7.57046118e-02],\n             [ 1.44726440e-01, -3.08332909e-02,  4.71221283e-02,\n              -8.30720440e-02,  8.04841295e-02, -1.49666801e-01],\n             [ 7.98062310e-02, -4.20358926e-02,  7.60007044e-03,\n              -5.53373694e-02, -3.77982035e-02,  3.99156027e-02],\n             [ 5.44234477e-02, -1.28442332e-01,  1.31196484e-01,\n               6.09315857e-02,  1.39281034e-01, -1.68151711e-03],\n             [-3.80312428e-02,  1.68671548e-01,  1.09560236e-01,\n               8.88865143e-02, -7.76072890e-02, -5.95312268e-02],\n             [-5.81748076e-02,  6.49072379e-02,  1.62887760e-02,\n               1.83060467e-02, -4.74709012e-02,  1.66655071e-02],\n             [ 1.19550623e-01,  1.62096694e-01, -1.00038491e-01,\n               1.04416534e-01,  1.02850296e-01,  1.40415817e-01],\n             [ 1.34725481e-01, -9.12372246e-02, -2.81765349e-02,\n               1.86341237e-02, -1.00317590e-01,  4.11429219e-02],\n             [-6.28308132e-02, -8.29493925e-02,  2.43374854e-02,\n               4.02909331e-02,  1.57642350e-01,  5.24831302e-02],\n             [ 5.88842249e-03, -1.07205682e-01,  1.33899868e-01,\n              -4.64881101e-04,  4.09163609e-02,  1.11551182e-02],\n             [-6.04856983e-02, -1.40968747e-02, -6.44786209e-02,\n              -1.61199972e-01,  5.91667481e-02, -1.68288842e-01],\n             [-3.74374464e-02, -2.33025309e-02, -7.66711216e-03,\n              -7.33268559e-02,  1.89521369e-02,  1.34219714e-02],\n             [ 9.16161388e-02, -6.16912395e-02,  6.45804331e-02,\n               3.22281457e-02,  8.62854049e-02,  4.41986509e-02],\n             [ 8.09125751e-02, -7.19564557e-02,  4.51078340e-02,\n              -3.91592197e-02,  2.63338201e-02,  5.22068962e-02],\n             [-5.94698898e-02,  5.49785653e-03,  6.05785921e-02,\n               5.87637201e-02,  1.12998255e-01,  3.62044908e-02],\n             [-3.65077220e-02, -4.06673774e-02,  1.15966111e-01,\n               1.78552181e-01,  3.18685621e-02,  1.03732467e-01],\n             [-2.45138742e-02, -1.76385358e-01,  2.93305293e-02,\n              -9.44927782e-02,  1.87303945e-01, -1.34810388e-01],\n             [-1.53767079e-01, -3.38131189e-02, -8.66767094e-02,\n               9.15280879e-02,  6.75857961e-02, -1.04090497e-01],\n             [ 3.18573974e-02, -1.93688884e-01,  5.48515618e-02,\n              -1.25846893e-01, -3.13979946e-02,  5.78340515e-02],\n             [ 7.54026771e-02,  1.81285068e-01,  1.52018696e-01,\n               4.14612181e-02,  3.37518901e-02,  1.34640053e-01],\n             [ 1.30115807e-01,  6.48052245e-02,  1.76690277e-02,\n              -3.51731479e-02,  1.17213652e-01,  2.66280770e-02],\n             [-1.00345813e-01, -1.79041419e-02,  4.60862778e-02,\n               2.25334708e-02,  3.56246531e-02, -7.34259188e-02],\n             [ 6.08137362e-02, -8.37168917e-02, -1.28189042e-01,\n              -2.70780958e-02, -6.59908950e-02, -1.59563407e-01],\n             [ 7.68346786e-02, -5.99950068e-02,  4.89747114e-02,\n              -1.13964356e-01,  1.05598472e-01, -1.60014555e-01],\n             [-6.14193045e-02, -3.82006206e-02, -1.20309927e-02,\n              -1.23834461e-01,  5.05906045e-02,  7.24875405e-02],\n             [-1.06973559e-01,  8.08707159e-03, -1.23331331e-01,\n              -1.08063765e-01, -1.03185058e-01, -3.48316692e-02],\n             [ 6.08417243e-02, -4.33301553e-02, -1.04966782e-01,\n               1.88854754e-01,  1.71531528e-01, -5.71143031e-02],\n             [ 9.81569365e-02,  8.75111893e-02, -3.55093963e-02,\n              -8.81588683e-02,  6.78883269e-02, -3.95975187e-02],\n             [-1.14663109e-01, -2.92997491e-02, -3.15966122e-02,\n              -1.22380674e-01,  8.99034813e-02,  1.63802188e-02],\n             [ 3.16538401e-02, -6.33720905e-02, -4.77643684e-02,\n               4.71230410e-02, -5.95972277e-02, -6.85580820e-02],\n             [-1.11370780e-01,  7.69485533e-02,  6.13692366e-02,\n              -1.19794346e-01, -3.69678587e-02,  2.02951998e-01],\n             [-1.42703339e-01, -9.72430259e-02,  5.04018134e-03,\n               9.75762010e-02, -1.17566228e-01, -6.01027608e-02],\n             [ 2.52913795e-02, -9.40641239e-02,  1.24967648e-02,\n               4.50270660e-02,  1.32457763e-01,  7.11266324e-02],\n             [ 9.89071187e-03,  1.75297096e-01,  1.67444736e-01,\n              -5.20650372e-02, -2.02201605e-01, -1.32647291e-01],\n             [ 1.95834171e-02, -4.79863398e-02,  8.09220150e-02,\n              -4.42066044e-02,  1.03196939e-02,  3.05320267e-02],\n             [-7.14446884e-04, -1.70157880e-01,  1.55897681e-02,\n               6.09404640e-03, -4.61501181e-02, -4.01810147e-02],\n             [ 4.73080166e-02,  6.60517737e-02,  1.94629580e-02,\n              -1.57093167e-01,  1.60925955e-01,  8.05986673e-02],\n             [ 1.32242423e-02,  4.45444472e-02, -1.91424321e-02,\n               7.04322681e-02, -5.56334481e-02, -8.34930688e-03],\n             [-1.23342010e-03, -5.61622642e-02,  1.43492743e-01,\n               1.76758990e-02,  2.00436749e-02,  1.86359212e-01],\n             [ 1.00787982e-01, -3.74127626e-02, -3.40652955e-03,\n              -9.75472911e-04,  1.91068314e-02,  7.66138658e-02],\n             [ 1.72837481e-01, -8.51704031e-02, -7.22087920e-02,\n              -1.04399674e-01,  3.81661765e-02,  9.03742909e-02],\n             [ 1.43618628e-01, -6.64510718e-03, -9.69889387e-02,\n              -1.72765687e-01,  3.55073884e-02, -1.57048166e-01],\n             [-1.47305995e-01, -1.86104029e-01,  9.22193155e-02,\n               1.63363367e-01,  7.20543638e-02, -4.27316390e-02],\n             [-1.12094469e-01,  1.28196077e-02, -1.28213912e-01,\n               6.21486828e-03,  5.30892797e-03, -1.25342995e-01],\n             [ 4.10727039e-02, -3.89916673e-02,  4.34093885e-02,\n              -1.07143447e-01,  3.15333344e-02, -9.51224566e-02],\n             [-2.78361049e-02,  1.71400849e-02,  1.73911244e-01,\n              -1.72006208e-02, -7.06825852e-02,  9.31587592e-02],\n             [ 1.39339387e-01,  4.11312655e-02,  2.05742978e-02,\n              -7.74497818e-03, -5.22259623e-03, -2.02417956e-03],\n             [ 1.09130338e-01, -5.21309339e-02, -1.04567930e-01,\n               1.72220338e-02,  1.58100620e-01,  1.40555933e-01],\n             [-1.99592412e-01, -2.40519606e-02, -1.57991603e-01,\n               1.02015443e-01, -9.98662040e-02,  7.14631677e-02],\n             [ 5.20380307e-03, -4.82758358e-02, -1.92286909e-01,\n              -7.96826743e-03, -1.54290631e-01,  1.85175255e-01],\n             [-1.98595673e-01, -2.65713446e-02,  1.93740590e-03,\n              -1.98984463e-02,  1.01096861e-01, -1.02575213e-01],\n             [ 1.66928127e-01,  1.20740876e-01, -7.15681240e-02,\n               6.38891086e-02,  5.51854298e-02, -1.22506432e-01],\n             [ 8.78814012e-02,  1.07822761e-01, -1.46974564e-01,\n               1.20047025e-01,  7.21886829e-02, -8.54675919e-02],\n             [ 2.29258146e-02, -1.72665089e-01,  3.57917808e-02,\n              -8.31400752e-02, -1.10140853e-01,  3.77841331e-02],\n             [ 1.36882365e-02, -3.07335649e-02, -1.62586360e-03,\n              -9.77346078e-02,  3.38865221e-02,  6.68704137e-02],\n             [ 1.18828140e-01, -1.47041306e-01,  5.20589352e-02,\n              -1.53026074e-01,  1.03887565e-01, -1.08588472e-01],\n             [-3.90928835e-02, -6.18510228e-03, -4.05132100e-02,\n               5.17346971e-02, -1.92173362e-01, -1.43100709e-01],\n             [-1.33428611e-02,  8.00306816e-03,  1.62952662e-01,\n              -6.57213926e-02,  5.07434793e-02, -5.92899099e-02],\n             [ 1.08613402e-01, -1.24999523e-01, -2.64579542e-02,\n               7.11960718e-02,  5.89826144e-02,  9.41445306e-02],\n             [ 2.08997540e-02,  1.74974009e-01,  1.66671813e-01,\n               4.79338653e-02,  8.78979554e-05,  2.67938152e-03],\n             [-9.45570320e-02, -1.23530865e-01,  3.68198715e-02,\n              -3.99812609e-02, -1.85414851e-01,  1.36107326e-01],\n             [ 1.53190598e-01,  2.04031289e-01, -7.17400387e-02,\n              -1.09849349e-01, -7.10915029e-02,  6.11329414e-02],\n             [ 1.07632719e-01, -6.84627448e-04, -2.09605265e-02,\n               9.12068710e-02, -1.35149017e-01, -2.50568949e-02],\n             [-1.18208840e-01, -2.58779433e-02,  9.36405584e-02,\n              -9.29577574e-02,  1.66403898e-03,  5.81755228e-02],\n             [-1.82792187e-01, -1.71266459e-02,  7.08144158e-02,\n              -4.03121598e-02,  1.73490494e-01,  8.20606723e-02],\n             [-5.80182523e-02, -1.58356905e-01,  1.55453954e-03,\n               3.33175063e-02,  3.05566993e-02, -5.09770252e-02],\n             [-9.79462788e-02,  1.12941926e-02,  3.29827517e-02,\n              -8.34175646e-02,  1.58567377e-03, -1.99254751e-01],\n             [ 4.25345935e-02,  2.99139717e-03, -9.77599397e-02,\n               1.25275636e-02,  2.49177702e-02,  5.07045388e-02],\n             [-2.04405174e-01,  1.21796027e-01, -1.64488241e-01,\n              -7.41155148e-02, -9.78912339e-02,  4.51627970e-02],\n             [ 1.64618686e-01, -1.68567866e-01, -5.03356345e-02,\n              -1.17939301e-01,  1.86417699e-01,  9.03297067e-02],\n             [ 1.34851336e-01, -8.32303539e-02, -2.33603716e-02,\n              -1.03158645e-01, -4.51769456e-02,  3.13593112e-02],\n             [-4.59918985e-03, -2.72297795e-04, -2.66443919e-02,\n              -9.28490087e-02, -1.95913151e-01,  7.13221431e-02],\n             [ 8.59930180e-04,  3.29423733e-02,  1.22036733e-01,\n               5.30130863e-02,  1.85941681e-01, -4.31221491e-03],\n             [ 8.20801780e-02, -4.31287326e-02,  2.93064527e-02,\n               3.96889783e-02, -8.63647982e-02,  2.53119059e-02],\n             [-1.53039083e-01,  9.30199102e-02,  3.22056469e-03,\n              -1.24914192e-01, -4.15421948e-02, -1.94892928e-01],\n             [-1.81372344e-01,  4.68530506e-03, -1.14344596e-03,\n              -8.26741755e-02,  5.44348992e-02,  1.65079758e-01],\n             [ 9.12437066e-02, -7.02710375e-02,  2.39162557e-02,\n              -1.94856316e-01,  2.31222976e-02,  1.75257191e-01],\n             [-3.32147554e-02,  1.42533733e-02, -9.78560820e-02,\n              -6.50117993e-02,  1.09827500e-02, -1.13609552e-01],\n             [ 7.81201050e-02,  1.05575085e-01, -7.44353980e-02,\n              -3.78257334e-02, -2.77477149e-02, -6.97395205e-02],\n             [-1.26582652e-01, -6.02407455e-02,  1.56459987e-01,\n               2.62732077e-02, -9.10715163e-02,  1.86322927e-02],\n             [-6.05976023e-03, -7.36052021e-02, -1.94375180e-02,\n              -6.14037318e-03, -9.34375674e-02,  1.59062594e-01],\n             [-8.38212222e-02, -1.07256345e-01,  3.92923839e-02,\n               6.44588396e-02,  5.82834482e-02, -4.43415083e-02],\n             [-3.70768644e-02, -1.57141834e-01,  4.57319804e-02,\n               1.47603378e-01, -1.90470919e-01,  8.95815566e-02],\n             [-1.41519820e-02,  1.02229277e-02, -1.52806088e-01,\n               7.36428723e-02, -1.53182924e-01, -6.77200109e-02],\n             [-1.24036811e-01,  6.76118815e-03,  5.35392836e-02,\n              -7.78794521e-03, -6.01940323e-03, -1.29643857e-01],\n             [ 3.38975042e-02,  1.99190766e-01,  2.09763810e-01,\n               1.06645174e-01, -1.48502052e-01, -1.69419963e-02]],            dtype=float32), DeviceArray([[-5.89450355e-09,  1.10020665e-06,  9.08776769e-08,\n              -3.22272911e-08, -3.63153418e-08, -1.08581332e-06],\n             [-1.78353936e-08, -2.19694940e-07,  7.57201164e-08,\n              -9.52954338e-08, -9.58301172e-08,  4.68956898e-07],\n             [ 1.39825218e-07,  3.54142844e-06,  6.78116862e-07,\n               1.57926394e-07,  1.00080001e-06, -3.30655712e-06],\n             [ 7.58904832e-07,  9.00165196e-06,  1.91860772e-06,\n               5.19962953e-07,  3.23478207e-06, -8.01583883e-06],\n             [ 7.41148881e-07,  9.07372851e-06,  1.93827600e-06,\n               5.23033350e-07,  3.20584559e-06, -8.05012587e-06],\n             [ 7.31363230e-07,  8.35094306e-06,  1.82408917e-06,\n               4.92660149e-07,  3.01473915e-06, -7.34671039e-06],\n             [-6.14870554e-09,  1.46902812e-06,  1.20125407e-07,\n              -4.29506990e-08, -4.43112640e-08, -1.45433694e-06],\n             [-2.39274023e-08, -2.94163669e-07,  1.01273521e-07,\n              -1.27209304e-07, -1.28118330e-07,  6.27614895e-07],\n             [ 1.87638690e-07,  4.63625565e-06,  8.93171091e-07,\n               2.08798184e-07,  1.31301329e-06, -4.32453771e-06],\n             [ 7.72408725e-07,  8.93784818e-06,  1.90030983e-06,\n               5.22363166e-07,  3.25587530e-06, -7.98687142e-06],\n             [ 7.40764676e-07,  9.06684909e-06,  1.93543951e-06,\n               5.27857878e-07,  3.20444451e-06, -8.04859064e-06],\n             [ 7.22921300e-07,  7.78584490e-06,  1.73262708e-06,\n               4.73720689e-07,  2.86539102e-06, -6.80177573e-06],\n             [-4.08872758e-09,  1.95721077e-06,  1.57139056e-07,\n              -5.70308281e-08, -4.90771583e-08, -1.94843437e-06],\n             [-3.21491669e-08, -3.93888541e-07,  1.35335583e-07,\n              -1.69406022e-07, -1.71084295e-07,  8.39658753e-07],\n             [ 2.52089052e-07,  5.95938036e-06,  1.16103763e-06,\n               2.73283263e-07,  1.69430473e-06, -5.54831695e-06],\n             [ 7.96222594e-07,  8.82393761e-06,  1.86779926e-06,\n               5.26646488e-07,  3.29270620e-06, -7.93437903e-06],\n             [ 7.40091991e-07,  9.05452816e-06,  1.93036249e-06,\n               5.36466530e-07,  3.20192976e-06, -8.04584488e-06],\n             [ 7.07165782e-07,  6.80477569e-06,  1.57277213e-06,\n               4.40408513e-07,  2.60538445e-06, -5.85572616e-06],\n             [ 4.29782787e-09,  2.59740773e-06,  2.01628808e-07,\n              -7.52185372e-08, -4.15186321e-08, -2.61137825e-06],\n             [-4.33095266e-08, -5.27453949e-07,  1.80578581e-07,\n              -2.24631520e-07, -2.27975477e-07,  1.12264138e-06],\n             [ 3.38904442e-07,  7.40378027e-06,  1.47343928e-06,\n               3.51184525e-07,  2.11997803e-06, -6.86727390e-06],\n             [ 8.37665198e-07,  8.62122488e-06,  1.81044447e-06,\n               5.34261801e-07,  3.35561526e-06, -7.83850282e-06],\n             [ 7.38935341e-07,  9.03247383e-06,  1.92128755e-06,\n               5.51770427e-07,  3.19740730e-06, -8.04093725e-06],\n             [ 6.77055141e-07,  5.14594376e-06,  1.29904197e-06,\n               3.82699909e-07,  2.16362355e-06, -4.25635881e-06],\n             [ 2.84734671e-08,  3.42271323e-06,  2.49386289e-07,\n              -9.79840351e-08,  7.15942861e-10, -3.50134815e-06],\n             [-5.85989852e-08, -7.06387311e-07,  2.40287903e-07,\n              -2.95550933e-07, -3.02619014e-07,  1.49929872e-06],\n             [ 4.54334412e-07,  8.61640001e-06,  1.78864775e-06,\n               4.36340798e-07,  2.50029757e-06, -7.92649553e-06],\n             [ 9.08039112e-07,  8.26257929e-06,  1.71055194e-06,\n               5.47710044e-07,  3.45858484e-06, -7.66110679e-06],\n             [ 7.37018695e-07,  8.99303996e-06,  1.90510173e-06,\n               5.78793617e-07,  3.18924913e-06, -8.03218063e-06],\n             [ 6.17813100e-07,  2.47749927e-06,  8.47588126e-07,\n               2.85440819e-07,  1.44720775e-06, -1.68546239e-06],\n             [ 8.95815475e-08,  4.45288197e-06,  2.86464058e-07,\n              -1.24693514e-07,  1.28736474e-07, -4.69444240e-06],\n             [-7.98249005e-08, -9.46196735e-07,  3.18169668e-07,\n              -3.83382172e-07, -3.98885334e-07,  1.99820715e-06],\n             [ 5.99228542e-07,  8.76413287e-06,  1.99479950e-06,\n               5.08822666e-07,  2.61243690e-06, -7.89131718e-06],\n             [ 1.02213471e-06,  7.63469507e-06,  1.54058000e-06,\n               5.71160342e-07,  3.61285970e-06, -7.32635317e-06],\n             [ 7.34072842e-07,  8.92267326e-06,  1.87635169e-06,\n               6.25928294e-07,  3.17445006e-06, -8.01660462e-06],\n             [ 4.98578515e-07, -1.41110866e-06,  1.53603480e-07,\n               1.29574914e-07,  3.90075598e-07,  2.04918433e-06],\n             [ 2.30933054e-07,  5.65936080e-06,  2.77691186e-07,\n              -1.51568671e-07,  4.48849903e-07, -6.27888267e-06],\n             [-1.09736106e-07, -1.26779742e-06,  4.17566611e-07,\n              -4.84438999e-07, -5.18943466e-07,  2.65304629e-06],\n             [ 7.44775662e-07,  6.37876201e-06,  1.86980697e-06,\n               5.22824337e-07,  2.04094908e-06, -5.28121518e-06],\n             [ 1.19083279e-06,  6.55613849e-06,  1.26347391e-06,\n               6.11066071e-07,  3.79858466e-06, -6.67799895e-06],\n             [ 7.30283148e-07,  8.79757408e-06,  1.82565918e-06,\n               7.06297158e-07,  3.14732688e-06, -7.98898145e-06],\n             [ 2.62097444e-07, -5.95620531e-06, -7.75638114e-07,\n              -9.77013954e-08, -8.62348543e-07,  6.34855905e-06],\n             [ 5.26827762e-07,  6.88779892e-06,  1.49560947e-07,\n              -1.67012857e-07,  1.14538557e-06, -8.30594763e-06],\n             [-1.52088518e-07, -1.69945258e-06,  5.39248504e-07,\n              -5.82383393e-07, -6.58500312e-07,  3.49752895e-06],\n             [ 7.63828155e-07,  1.78920914e-07,  1.12526754e-06,\n               3.98805383e-07,  3.93012499e-07,  1.20588595e-06],\n             [ 1.39363715e-06,  4.76657942e-06,  8.46491048e-07,\n               6.75670947e-07,  3.87500313e-06, -5.39017401e-06],\n             [ 7.27829558e-07,  8.57662417e-06,  1.73745775e-06,\n               8.37548328e-07,  3.09655707e-06, -7.93998333e-06],\n             [-1.52975304e-07, -8.48403215e-06, -1.69186501e-06,\n              -3.71675071e-07, -1.51615791e-06,  8.40440225e-06],\n             [ 1.05302070e-06,  7.71379564e-06, -2.14575977e-07,\n              -1.42296983e-07,  2.40270538e-06, -1.05793461e-05],\n             [-2.09548517e-07, -2.27917667e-06,  6.76103184e-07,\n              -6.33503987e-07, -7.95059464e-07,  4.54841620e-06],\n             [ 3.60403476e-07, -7.10164477e-06, -2.60149477e-07,\n               5.20091490e-08, -1.39066969e-06,  8.56870793e-06],\n             [ 1.51450240e-06,  1.98259454e-06,  3.09881301e-07,\n               7.69029270e-07,  3.36692756e-06, -2.82522319e-06],\n             [ 7.35002118e-07,  8.19092566e-06,  1.58768319e-06,\n               1.03413015e-06,  2.99751468e-06, -7.85149132e-06],\n             [-6.16005138e-07, -3.96131054e-06, -1.96598421e-06,\n              -5.70482428e-07,  9.77685346e-08,  2.75703906e-06],\n             [ 1.71396732e-06,  7.26016560e-06, -8.91745856e-07,\n              -1.98548022e-08,  3.90416744e-06, -1.20029299e-05],\n             [-2.72548334e-07, -3.05649314e-06,  8.02113561e-07,\n              -5.48290473e-07, -8.62919308e-07,  5.75925469e-06],\n             [-4.82296855e-07, -4.37789458e-06, -1.52266500e-06,\n              -4.78503694e-07,  1.08891720e-06,  4.00277941e-06],\n             [ 1.29106547e-06, -1.83849784e-06, -1.84190526e-07,\n               8.65605216e-07,  1.19882077e-06,  1.95082043e-06],\n             [ 7.73455099e-07,  7.53151107e-06,  1.34458696e-06,\n               1.27597616e-06,  2.78944913e-06, -7.67971960e-06],\n             [-3.73475871e-07,  6.33434820e-06, -1.04433207e-06,\n              -4.15061123e-07,  3.21438347e-06, -9.17595844e-06],\n             [ 1.91155050e-06,  4.24808513e-06, -1.62878257e-06,\n               2.64808847e-07,  3.52671350e-06, -9.45029115e-06],\n             [-2.84500231e-07, -4.08763117e-06,  8.54584982e-07,\n              -2.06916638e-07, -7.20116418e-07,  6.92017920e-06],\n             [-5.86249698e-08,  7.22480490e-06, -1.56320039e-06,\n              -6.33804234e-07,  6.07791662e-06, -1.14542972e-05],\n             [ 6.27224949e-07, -5.80912047e-06, -3.65499204e-07,\n               8.37135588e-07, -3.71978172e-06,  8.91575291e-06],\n             [ 8.75814635e-07,  6.44476586e-06,  9.82818392e-07,\n               1.42759950e-06,  2.30631395e-06, -7.28690748e-06],\n             [ 7.91641469e-07,  1.26343366e-06,  6.55591862e-07,\n               4.05362528e-07, -2.94489564e-06, -8.59181227e-07],\n             [ 1.12571422e-06, -1.86561249e-06, -1.68733015e-06,\n               5.70262159e-07, -2.80565246e-06,  8.75313731e-07],\n             [-9.61898081e-08, -5.40087012e-06,  7.23625078e-07,\n               3.54689064e-07, -2.08309700e-07,  7.55339488e-06],\n             [ 1.02153695e-06, -5.78017762e-06,  4.76029385e-07,\n               6.66310996e-07, -8.07298056e-06,  9.26087887e-06],\n             [ 5.83943233e-07, -7.10525410e-06, -4.68334918e-07,\n               3.76263955e-07, -8.86384896e-06,  1.29387236e-05],\n             [ 1.01500041e-06,  4.76382820e-06,  5.33049501e-07,\n               1.14710679e-06,  1.11709198e-06, -6.21370691e-06],\n             [-6.31725527e-07, -4.20461538e-06,  1.27212684e-06,\n               5.97546659e-07, -4.44876014e-06,  7.99392728e-06],\n             [ 1.27447595e-06, -7.42738530e-06, -1.20815002e-06,\n               8.20677712e-08, -1.29561167e-05,  1.26983887e-05],\n             [ 3.91787097e-07, -6.86304975e-06,  3.15919408e-07,\n               4.78149786e-07,  3.33411720e-07,  7.24351867e-06],\n             [-7.31582759e-07,  7.05648517e-06, -5.99249006e-07,\n              -3.65444635e-07,  6.58787803e-06, -1.00339112e-05],\n             [ 1.59395915e-06, -1.62996571e-06, -1.54220902e-06,\n              -4.67488718e-07, -5.08205630e-06,  3.86462261e-06],\n             [ 9.31843601e-07,  2.42041165e-06,  1.82153542e-07,\n               2.29804400e-07, -1.47138360e-06, -3.32397576e-06],\n             [-1.14857016e-07,  1.20156915e-06, -3.16910530e-07,\n              -7.66403843e-07,  1.39314898e-06, -2.99687235e-06],\n             [ 4.07920425e-07, -2.35097605e-06, -1.48380832e-06,\n              -1.10702365e-06, -8.48389391e-06,  6.41825045e-06],\n             [ 6.99555414e-07, -7.84252188e-06, -1.77876316e-07,\n              -5.36324023e-07, -5.12962743e-07,  6.85968735e-06],\n             [-4.43854134e-07, -1.19246545e-06,  1.71287593e-06,\n               6.58903559e-07,  3.85870317e-06, -1.14134968e-06],\n             [ 7.25131372e-07,  6.89900753e-06, -2.64183041e-06,\n               2.39554709e-07,  9.28923691e-06, -1.29141199e-05],\n             [ 3.82862510e-07, -4.08812895e-07,  2.96095322e-07,\n               5.14074827e-08, -4.79162736e-06,  2.19754088e-06],\n             [ 1.79455199e-07,  7.84833173e-06, -1.68732555e-07,\n              -1.58592755e-07,  1.22959445e-05, -1.30488133e-05],\n             [ 4.53154144e-07,  6.44779493e-06, -1.15898138e-06,\n               5.08535436e-07,  1.44323803e-05, -1.60667241e-05],\n             [ 2.64337359e-07, -6.91162859e-06, -3.40541533e-08,\n               2.95334388e-08, -1.31283150e-06,  6.42984196e-06],\n             [ 1.02554668e-06, -1.51715199e-06, -9.67659730e-07,\n               9.35940420e-07, -2.22751964e-06,  1.58044168e-06],\n             [ 2.94390293e-06, -1.08298093e-06, -1.40403404e-06,\n               8.85823624e-07, -4.19718754e-06,  3.36121474e-07],\n             [-1.27575959e-08, -3.47877312e-06,  9.04482306e-07,\n               1.38244604e-06, -5.05446133e-06,  7.17476451e-06],\n             [ 5.22484129e-07,  6.79360710e-06,  2.00455361e-07,\n               2.04385771e-08,  1.04309420e-05, -1.53854107e-05],\n             [-2.32110438e-07, -6.25013035e-06,  4.16154535e-07,\n              -4.29712628e-07, -8.74119905e-06,  1.33317380e-05],\n             [-1.72272749e-07, -2.85427996e-06,  9.17918385e-07,\n              -1.74608715e-07,  2.61466812e-06,  3.20356548e-07],\n             [-3.98590316e-07,  5.72777572e-06,  5.16953378e-08,\n              -2.30283320e-07,  9.72292401e-06, -1.01911537e-05],\n             [ 1.38026326e-06, -1.69508246e-06, -1.50022231e-06,\n              -3.24397718e-08,  2.64211110e-07,  1.73726607e-06],\n             [-2.57760775e-07, -6.49922958e-06,  1.07226242e-06,\n              -3.65794733e-07, -4.52577478e-06,  1.01539963e-05],\n             [ 1.58273042e-06, -3.64774905e-06, -1.09204620e-06,\n              -1.51300151e-06, -2.30509136e-06,  3.07344430e-06],\n             [ 5.98161819e-07,  6.63248647e-06,  1.05101094e-07,\n               4.95582071e-07,  5.25333189e-06, -1.07178275e-05],\n             [ 1.04858748e-06,  3.97416079e-06,  1.47306741e-06,\n               6.30566092e-07,  7.42636030e-06, -1.04624205e-05],\n             [-1.46166485e-06,  4.88427168e-06,  1.82387862e-06,\n               4.49339410e-07,  3.97414988e-06, -7.57285170e-06],\n             [-3.32953022e-07, -1.86433351e-06,  3.30018054e-07,\n              -7.58851115e-07, -3.58492161e-06,  5.42662883e-06],\n             [-1.41443184e-06, -5.73096713e-06,  7.34330115e-07,\n               1.69057145e-07,  9.80789991e-07,  7.30678676e-06],\n             [ 3.14442957e-07,  4.99395901e-06, -1.61032949e-06,\n               4.36894254e-07,  2.91608603e-06, -8.05610216e-06],\n             [ 1.56075430e-06,  3.49192919e-06, -3.83678184e-07,\n               7.36173774e-07, -1.48261358e-06, -2.87649664e-06],\n             [ 1.34733887e-06,  6.41233601e-06,  1.71123509e-06,\n               6.50678146e-07,  1.06393973e-05, -1.43433608e-05],\n             [-9.12618304e-07, -2.33978312e-06,  1.95254324e-06,\n               2.26706632e-07, -9.17776742e-06,  7.15697888e-06],\n             [ 7.15795125e-07,  6.30058730e-06, -9.77086074e-07,\n              -5.05545984e-07,  8.58714247e-06, -1.14812055e-05],\n             [-6.94097253e-07,  2.25981785e-06, -3.21830100e-08,\n               6.29661429e-07,  1.17689888e-05, -8.83968369e-06],\n             [ 4.45760975e-08, -7.20550042e-06, -7.32303988e-07,\n              -7.33645322e-07, -5.21131278e-06,  9.82466827e-06],\n             [-4.49712587e-07,  4.51746837e-06,  1.75300130e-07,\n              -1.04324442e-06,  9.77069340e-06, -7.88865873e-06],\n             [ 2.80753170e-06, -3.08079393e-06, -1.75416540e-06,\n               4.89795980e-07, -1.04929122e-05,  5.72613908e-06],\n             [-8.24183246e-07, -7.84531130e-06, -3.16042104e-07,\n              -5.25608527e-07, -1.09352886e-05,  1.42115850e-05],\n             [ 4.62080095e-07,  5.67532106e-06,  1.53317001e-06,\n               4.47872253e-07,  1.42140643e-05, -1.33855892e-05],\n             [ 1.61264234e-06,  5.61368461e-06, -2.16226613e-06,\n              -8.19383104e-07,  1.04630763e-05, -1.41299424e-05],\n             [ 5.60000785e-07,  1.69940745e-06, -2.29975785e-06,\n              -7.62130583e-07, -1.57408442e-06, -6.34730554e-07],\n             [ 1.29794250e-06, -2.95222549e-07, -1.44694434e-06,\n              -2.92114066e-07,  2.25114718e-06, -1.65837469e-06],\n             [-5.77794765e-07, -1.83486031e-06,  6.42491500e-08,\n              -7.89675028e-08, -3.92448237e-06,  5.13863097e-06],\n             [ 5.93625884e-07,  4.19205116e-06,  2.56499561e-07,\n              -8.98320280e-08,  1.01652313e-05, -1.07823216e-05],\n             [-3.87973358e-07, -7.36338461e-06, -2.81240659e-07,\n              -1.06455991e-06, -9.47545050e-06,  1.28567899e-05],\n             [-1.35564761e-07, -6.94711571e-06, -6.66803885e-07,\n               3.93332300e-07, -1.52133589e-05,  1.52696957e-05],\n             [-2.00091137e-07,  5.19159494e-06, -2.73080673e-06,\n              -1.00092529e-06,  3.84539180e-06, -8.02862451e-06]],            dtype=float32), DeviceArray([[2.12274364e-17, 1.16329784e-13, 1.48670884e-15,\n              1.35739756e-16, 9.56699235e-15, 1.11627599e-13],\n             [6.40458023e-16, 6.39256697e-15, 4.84248471e-16,\n              1.84344291e-15, 1.08346568e-15, 2.29769630e-14],\n             [3.40560926e-15, 1.07508685e-12, 2.48534239e-14,\n              4.89585848e-15, 6.50937041e-13, 1.80882588e-12],\n             [3.88176023e-14, 6.81023814e-12, 2.14583681e-13,\n              4.50841011e-14, 4.61792323e-12, 9.48895563e-12],\n             [3.71696048e-14, 6.92284988e-12, 2.19224080e-13,\n              4.49529228e-14, 4.64760695e-12, 9.59453610e-12],\n             [3.79026780e-14, 5.84377295e-12, 1.99977784e-13,\n              4.02084236e-14, 3.93055640e-12, 7.83793516e-12],\n             [3.70780632e-17, 2.07499789e-13, 2.63444463e-15,\n              2.40250016e-16, 1.69488328e-14, 1.99782234e-13],\n             [1.14560846e-15, 1.14565738e-14, 8.66404591e-16,\n              3.28540384e-15, 1.93570364e-15, 4.11436090e-14],\n             [5.81243199e-15, 1.84130163e-12, 4.30244998e-14,\n              8.45466190e-15, 1.10854143e-12, 3.08115556e-12],\n             [4.00559238e-14, 6.71291451e-12, 2.10304511e-13,\n              4.47433872e-14, 4.58902447e-12, 9.40107887e-12],\n             [3.70970378e-14, 6.91372003e-12, 2.18548175e-13,\n              4.45151864e-14, 4.64187759e-12, 9.58986709e-12],\n             [3.86563679e-14, 5.06533530e-12, 1.85505406e-13,\n              3.63224499e-14, 3.41559268e-12, 6.60209344e-12],\n             [6.49781208e-17, 3.68652727e-13, 4.62562956e-15,\n              4.20889163e-16, 2.97549122e-14, 3.57081498e-13],\n             [2.04563115e-15, 2.05276673e-14, 1.54775556e-15,\n              5.82812158e-15, 3.44899576e-15, 7.36078543e-14],\n             [9.54868627e-15, 3.03839267e-12, 7.24336498e-14,\n              1.41671249e-14, 1.81037043e-12, 5.03313970e-12],\n             [4.22862564e-14, 6.54081607e-12, 2.02801507e-13,\n              4.41454497e-14, 4.53797113e-12, 9.24457212e-12],\n             [3.69694374e-14, 6.89738240e-12, 2.17341038e-13,\n              4.37585115e-14, 4.63159068e-12, 9.58147450e-12],\n             [4.04843429e-14, 3.84650558e-12, 1.62578257e-13,\n              3.00309501e-14, 2.61262019e-12, 4.71493634e-12],\n             [1.20754685e-16, 6.50295464e-13, 7.98987799e-15,\n              7.23699974e-16, 5.14270294e-14, 6.36687589e-13],\n             [3.64142302e-15, 3.67666814e-14, 2.75727436e-15,\n              1.02526325e-14, 6.11574517e-15, 1.31477484e-13],\n             [1.46883712e-14, 4.67848763e-12, 1.15957062e-13,\n              2.24656465e-14, 2.73455130e-12, 7.60119051e-12],\n             [4.63087718e-14, 6.23982116e-12, 1.89877451e-13,\n              4.31139804e-14, 4.44875777e-12, 8.96729305e-12],\n             [3.67484838e-14, 6.86818267e-12, 2.15191919e-13,\n              4.24904049e-14, 4.61308899e-12, 9.56633904e-12],\n             [4.49987032e-14, 2.16799517e-12, 1.30208127e-13,\n              2.07182362e-14, 1.51246463e-12, 2.26127819e-12],\n             [2.78306360e-16, 1.13239799e-12, 1.34045929e-14,\n              1.20215998e-15, 8.66234303e-14, 1.13014014e-12],\n             [6.44631502e-15, 6.58058813e-14, 4.88761855e-15,\n              1.77658300e-14, 1.07509504e-14, 2.34171406e-13],\n             [2.03687097e-14, 6.30603642e-12, 1.69426539e-13,\n              3.21455373e-14, 3.55296543e-12, 9.84842329e-12],\n             [5.35468310e-14, 5.72386843e-12, 1.68327483e-13,\n              4.13906614e-14, 4.29568187e-12, 8.48082889e-12],\n             [3.63768430e-14, 6.81610931e-12, 2.11386084e-13,\n              4.04917663e-14, 4.57973243e-12, 9.53891566e-12],\n             [5.56252838e-14, 4.73318704e-13, 9.52897361e-14,\n              9.29542342e-15, 3.95390941e-13, 2.13340549e-13],\n             [9.35515915e-16, 1.92628942e-12, 2.13631145e-14,\n              1.87049989e-15, 1.40416541e-13, 1.98888128e-12],\n             [1.13000971e-14, 1.17632683e-13, 8.58685664e-15,\n              2.99529146e-14, 1.86081517e-14, 4.14942034e-13],\n             [2.54627692e-14, 6.45356294e-12, 2.09706479e-13,\n              3.77183670e-14, 3.37198347e-12, 9.18802274e-12],\n             [6.63559106e-14, 4.87193662e-12, 1.34495591e-13,\n              3.86816332e-14, 4.04036960e-12, 7.64292189e-12],\n             [3.57858613e-14, 6.72362123e-12, 2.04711804e-13,\n              3.77390651e-14, 4.51934757e-12, 9.48883940e-12],\n             [7.67252134e-14, 2.10885983e-13, 8.39426320e-14,\n              8.32536031e-16, 9.29799586e-14, 1.21318949e-12],\n             [4.09693785e-15, 3.13923020e-12, 3.11801129e-14,\n              2.55459483e-15, 2.18342678e-13, 3.44059503e-12],\n             [1.94676869e-14, 2.09800981e-13, 1.48453502e-14,\n              4.80272831e-14, 3.13235528e-14, 7.28540546e-13],\n             [3.77764057e-14, 3.30581699e-12, 1.94231187e-13,\n              2.93069334e-14, 1.42408145e-12, 3.41124416e-12],\n             [8.76719080e-14, 3.56295679e-12, 8.71891670e-14,\n              3.49178287e-14, 3.63008733e-12, 6.25143600e-12],\n             [3.49498804e-14, 6.56057370e-12, 1.93211332e-13,\n              3.51876494e-14, 4.40939733e-12, 9.39621471e-12],\n             [1.02543537e-13, 3.16919191e-12, 1.39068810e-13,\n              4.77598257e-15, 1.33027931e-12, 7.54485623e-12],\n             [1.81420312e-14, 4.72240042e-12, 3.97291317e-14,\n              2.61201519e-15, 3.41889630e-13, 5.73694521e-12],\n             [3.25354367e-14, 3.72662648e-13, 2.49319388e-14,\n              7.01126033e-14, 5.01605729e-14, 1.25834445e-12],\n             [9.15384169e-14, 1.18718401e-14, 1.31269900e-13,\n              8.53779903e-15, 6.22487291e-14, 7.74764096e-13],\n             [1.16988599e-13, 1.83090609e-12, 3.52818970e-14,\n              3.10080466e-14, 2.99111929e-12, 4.11762421e-12],\n             [3.40759864e-14, 6.27691692e-12, 1.74014394e-13,\n              3.67971407e-14, 4.20774656e-12, 9.22151851e-12],\n             [9.36341662e-14, 6.17177056e-12, 2.42255733e-13,\n              1.96068307e-14, 1.85940824e-12, 1.10231806e-11],\n             [6.66848982e-14, 6.08192923e-12, 4.49372459e-14,\n              1.16560627e-15, 6.08672292e-13, 8.77273774e-12],\n             [5.15875049e-14, 6.57081757e-13, 3.97287658e-14,\n              8.53035768e-14, 7.34619342e-14, 2.11091864e-12],\n             [1.77921778e-13, 4.60536470e-12, 1.63127460e-13,\n              4.93569105e-15, 1.53785545e-12, 1.22531889e-11],\n             [1.38032950e-13, 2.67059243e-13, 6.61249526e-15,\n              2.99198566e-14, 2.00648786e-12, 1.46605915e-12],\n             [3.40820376e-14, 5.79501724e-12, 1.43775454e-13,\n              5.25594108e-14, 3.83661939e-12, 8.88347902e-12],\n             [3.49875056e-14, 1.18555860e-12, 2.20762848e-13,\n              1.80839671e-14, 3.41978647e-15, 6.57213704e-13],\n             [1.63876562e-13, 5.63548036e-12, 6.40144201e-14,\n              1.46602970e-15, 1.11053614e-12, 1.07020252e-11],\n             [7.46754411e-14, 1.14307695e-12, 5.75161053e-14,\n              7.05393589e-14, 9.19605036e-14, 3.36333457e-12],\n             [1.23368235e-13, 1.50348213e-12, 2.65094641e-13,\n              1.14635084e-14, 3.09794779e-13, 1.62587337e-12],\n             [1.22416630e-13, 4.93287648e-13, 3.06068816e-14,\n              3.44190889e-14, 7.19466099e-13, 2.48978681e-13],\n             [3.73807871e-14, 5.01048465e-12, 1.00987252e-13,\n              9.43880933e-14, 3.16495354e-12, 8.21338744e-12],\n             [6.02890472e-14, 3.88085484e-12, 7.06748231e-14,\n              1.47200977e-14, 2.72828288e-12, 1.24074327e-11],\n             [1.90752335e-13, 2.11876241e-12, 1.29959222e-13,\n              2.35470077e-14, 9.55504664e-13, 6.53365122e-12],\n             [9.20150864e-14, 1.93947137e-12, 6.97596007e-14,\n              2.04688646e-14, 9.09242300e-14, 4.90398824e-12],\n             [2.05808169e-15, 5.04891658e-12, 1.81128468e-13,\n              4.12131673e-14, 4.92340066e-12, 1.45131598e-11],\n             [1.14951370e-13, 3.87743960e-12, 6.47504308e-14,\n              3.23113051e-14, 8.24780348e-13, 5.31654278e-12],\n             [4.77995227e-14, 3.82752250e-12, 5.15998987e-14,\n              1.38433712e-13, 2.04503905e-12, 6.88613749e-12],\n             [1.20617207e-13, 9.71850096e-14, 1.23844877e-13,\n              8.59225394e-15, 2.65607089e-12, 7.33658333e-14],\n             [1.25853551e-13, 2.87566656e-13, 1.53882739e-13,\n              7.24696250e-14, 4.75091646e-13, 5.21490674e-14],\n             [8.54805728e-14, 3.13701647e-12, 6.11158734e-14,\n              6.12061180e-15, 9.36028327e-14, 6.21522712e-12],\n             [9.97266302e-14, 3.38041726e-12, 8.78909846e-14,\n              4.59185607e-14, 9.32172829e-12, 9.51437887e-12],\n             [1.11211326e-13, 5.71903028e-12, 3.75462838e-14,\n              6.80075215e-15, 4.95495355e-12, 1.29487220e-11],\n             [6.28567565e-14, 2.27643035e-12, 1.36256668e-14,\n              8.73037604e-14, 6.21325582e-13, 4.45343328e-12],\n             [8.68259390e-14, 1.67289891e-12, 2.82116669e-13,\n              4.10208773e-14, 1.13631045e-12, 4.27477975e-12],\n             [1.16415571e-13, 5.41346828e-12, 7.09196360e-14,\n              1.98493260e-14, 1.17692210e-11, 1.43913518e-11],\n             [5.79508026e-14, 4.60745981e-12, 3.32559569e-14,\n              1.47494152e-14, 2.09347107e-13, 6.65009194e-12],\n             [5.57104818e-14, 4.67987671e-12, 5.37830922e-14,\n              1.96269104e-14, 4.35868095e-12, 1.08111844e-11],\n             [1.90275326e-13, 3.52346949e-13, 1.84113873e-13,\n              4.51164205e-14, 2.45487159e-12, 1.54192078e-12],\n             [5.02207698e-14, 7.26078973e-13, 1.37768617e-15,\n              6.04221212e-15, 1.28035766e-13, 1.15397243e-12],\n             [1.14944539e-13, 1.71613063e-13, 1.97055154e-13,\n              5.51698817e-14, 2.50777264e-13, 5.95267379e-13],\n             [9.44072599e-15, 4.50501344e-13, 1.11838151e-13,\n              7.85740288e-14, 4.68349881e-12, 3.17486879e-12],\n             [3.92286166e-14, 5.50404297e-12, 2.46053618e-14,\n              3.53296697e-14, 3.73009322e-13, 6.25861082e-12],\n             [2.28043174e-14, 1.19612599e-13, 2.32938397e-13,\n              4.81390271e-14, 1.67360819e-12, 1.21460079e-13],\n             [1.97282864e-13, 4.34959620e-12, 7.03280153e-13,\n              5.91790072e-15, 9.76983945e-12, 1.74895289e-11],\n             [1.88748299e-14, 1.15825657e-14, 4.84613722e-15,\n              3.96470373e-15, 2.14706681e-12, 3.81047705e-13],\n             [1.86247908e-14, 5.47145575e-12, 3.08681981e-14,\n              1.95181209e-14, 1.42232745e-11, 1.55279123e-11],\n             [1.47319358e-14, 4.47306948e-12, 8.58855630e-14,\n              1.98945576e-14, 2.14078043e-11, 2.42164240e-11],\n             [6.41345650e-15, 4.12341949e-12, 1.32246051e-14,\n              7.68807083e-15, 1.46188305e-13, 3.94888679e-12],\n             [2.92089188e-13, 1.49109485e-13, 8.26589231e-14,\n              1.19922667e-13, 2.40425475e-12, 5.56159827e-13],\n             [7.85756659e-13, 1.58019810e-13, 2.32656478e-13,\n              6.38192976e-14, 1.73062985e-12, 9.43298581e-14],\n             [6.03300233e-14, 1.22240499e-12, 5.52497704e-14,\n              1.59290726e-13, 3.79895160e-12, 5.32288666e-12],\n             [1.43017024e-14, 4.70550465e-12, 6.48424728e-14,\n              2.99140697e-15, 9.73449793e-12, 2.35858295e-11],\n             [3.81577582e-13, 3.85160046e-12, 2.91226461e-13,\n              5.28811681e-14, 1.03909190e-11, 2.15018003e-11],\n             [2.76621267e-15, 7.68774096e-13, 5.11400614e-14,\n              2.16139451e-15, 8.68445290e-13, 2.29089564e-14],\n             [2.64575335e-13, 3.82360289e-12, 2.98081546e-13,\n              4.35312000e-15, 1.30124167e-11, 1.05589808e-11],\n             [1.56517445e-13, 2.37491694e-13, 2.87678898e-13,\n              9.62362073e-15, 2.40470458e-13, 6.35600784e-13],\n             [5.63482773e-14, 4.52433013e-12, 8.64711270e-14,\n              4.89750215e-14, 4.26681867e-12, 1.21432421e-11],\n             [1.50155035e-13, 1.24448694e-12, 6.17540619e-14,\n              1.69198436e-13, 5.02130401e-13, 1.41153918e-12],\n             [5.85862739e-14, 4.50152329e-12, 1.89461836e-13,\n              1.48181367e-14, 2.23643239e-12, 9.65132141e-12],\n             [9.49037669e-14, 1.46891234e-12, 1.66119072e-13,\n              4.53771745e-14, 1.01376850e-11, 1.27502496e-11],\n             [1.80288076e-13, 2.70028552e-12, 2.77970220e-13,\n              8.52179926e-14, 1.04769871e-12, 3.74790823e-12],\n             [1.92593015e-14, 2.97783635e-13, 3.80525079e-14,\n              9.18365793e-14, 9.21701517e-13, 2.00620987e-12],\n             [1.35302888e-13, 3.38529552e-12, 3.35281458e-14,\n              2.97686024e-15, 1.25292341e-13, 6.21967538e-12],\n             [1.60756960e-14, 2.59312580e-12, 1.69482890e-13,\n              2.38856837e-14, 8.49733641e-13, 6.02466211e-12],\n             [1.51545362e-13, 9.35817071e-13, 1.04080278e-13,\n              5.19388237e-14, 3.67659597e-13, 1.10649088e-12],\n             [1.22576903e-13, 4.87884429e-12, 1.61775365e-13,\n              2.85257183e-14, 8.86557147e-12, 1.69115971e-11],\n             [4.03154615e-14, 4.77810771e-13, 3.02998349e-13,\n              2.23939799e-15, 5.53366988e-12, 3.11412354e-12],\n             [6.89227524e-14, 4.43917516e-12, 6.79811838e-14,\n              3.84796226e-14, 4.68237515e-12, 1.07565293e-11],\n             [2.56983596e-14, 5.48996178e-13, 3.19121221e-14,\n              3.39858045e-14, 9.86410259e-12, 6.85449180e-12],\n             [7.42128662e-14, 4.66550193e-12, 5.68850726e-14,\n              3.48152463e-14, 6.22827484e-12, 1.31156874e-11],\n             [1.60594719e-14, 1.71755524e-12, 1.06847158e-14,\n              1.22288952e-13, 7.70055200e-12, 5.33613865e-12],\n             [4.63322523e-13, 1.14702551e-12, 1.57314564e-13,\n              2.66102041e-14, 1.26338957e-11, 4.69107826e-12],\n             [4.32601410e-14, 5.60937147e-12, 1.38292283e-14,\n              2.03118992e-14, 7.40712960e-12, 1.63549868e-11],\n             [2.93910370e-14, 3.07944230e-12, 1.11994588e-13,\n              1.30717258e-14, 1.41837774e-11, 1.42507377e-11],\n             [1.22383210e-13, 3.66307982e-12, 2.05759360e-13,\n              2.79863480e-14, 8.02877122e-12, 1.67750067e-11],\n             [1.39484880e-14, 2.84715610e-13, 3.10231766e-13,\n              4.56258803e-14, 5.54326496e-13, 2.24314382e-14],\n             [1.79185510e-13, 2.01523046e-14, 2.22208496e-13,\n              2.34033899e-14, 2.82779713e-13, 1.60465798e-13],\n             [2.10014907e-14, 4.73251429e-13, 1.02406812e-13,\n              2.46776781e-14, 2.52360285e-12, 2.99892576e-12],\n             [1.51100188e-13, 2.05120122e-12, 8.49248616e-15,\n              1.26177179e-14, 8.98622583e-12, 1.23984642e-11],\n             [3.71367366e-14, 5.48269026e-12, 4.56275913e-14,\n              8.22980464e-14, 8.02551775e-12, 1.46907157e-11],\n             [3.55192357e-14, 4.99774224e-12, 5.69814039e-14,\n              3.79094881e-14, 1.74064999e-11, 1.94485643e-11],\n             [4.41634644e-14, 2.94970342e-12, 4.22736390e-13,\n              7.28969565e-14, 1.15712117e-12, 4.70383498e-12]],            dtype=float32)], [DeviceArray([ 3.0663330e-05, -6.0212583e-04,  2.3759302e-04,\n              2.5868230e-03, -2.6333719e-03, -6.5538799e-04],            dtype=float32), DeviceArray([ 1.8194061e-06,  3.8155737e-05,  7.0559526e-07,\n             -8.0754571e-06,  7.7572095e-06,  2.2165239e-05],            dtype=float32), DeviceArray([1.7689330e-12, 8.7941432e-11, 8.0609998e-14, 4.3580421e-12,\n             3.2801665e-12, 8.3839498e-11], dtype=float32)], [DeviceArray([[ 0.19147131, -0.08741151, -0.2899834 ,  0.62835646,\n               0.53938615,  0.40327066],\n             [-0.32089502,  0.20948973,  0.12690711, -0.25749823,\n               0.7558288 ,  0.26582175],\n             [-0.31148365,  0.6439382 , -0.63071334, -0.25127968,\n              -0.34781596,  0.8903316 ],\n             [ 0.01745997,  0.08365437, -0.36128715, -0.45783097,\n               0.05735766,  0.17224273],\n             [-0.02606061, -0.5212192 ,  0.4008201 , -0.77488416,\n              -0.3991631 ,  0.32255793],\n             [-0.47373995,  0.28477976, -0.8236459 , -0.6276811 ,\n              -0.6463785 , -0.45341846]], dtype=float32), DeviceArray([[ 2.0765101e-06,  5.2552758e-07,  1.5801023e-08,\n              -2.0083421e-06,  1.3381332e-06,  1.4539712e-06],\n             [-8.0421991e-10,  9.4330670e-07,  8.0152070e-08,\n              -5.2841326e-08, -2.0777220e-08,  3.6099043e-06],\n             [ 1.0393588e-06, -3.8457793e-07,  1.0858516e-09,\n              -1.2779847e-06,  4.7307202e-07,  4.9989808e-07],\n             [-1.4016753e-06,  6.8030022e-06,  1.8208079e-08,\n              -1.4558456e-08,  1.2815885e-06, -5.2809555e-06],\n             [ 3.7267771e-09,  8.3228333e-06,  4.8335327e-08,\n              -1.2781840e-07,  4.3923883e-08,  4.3804215e-05],\n             [-1.6840630e-08,  1.5559388e-05,  4.6703321e-09,\n              -7.2126383e-08,  7.0047875e-08,  3.1726337e-05]],            dtype=float32), DeviceArray([[3.1861910e-13, 1.9180499e-14, 2.2070947e-17, 2.0498165e-13,\n              8.7324645e-14, 1.1507772e-13],\n             [3.9682659e-19, 7.1927083e-14, 7.9724488e-16, 2.8024347e-16,\n              1.8258127e-16, 1.6315354e-12],\n             [8.7954194e-14, 4.7628832e-14, 1.1790737e-19, 8.0634026e-14,\n              1.1355760e-14, 3.2381348e-14],\n             [1.5629323e-12, 2.7163892e-12, 3.3153412e-17, 8.8922520e-18,\n              9.2106539e-14, 1.9771992e-12],\n             [7.5559045e-18, 3.8389882e-12, 1.7638762e-16, 1.5556227e-15,\n              1.1637820e-16, 1.7210415e-10],\n             [2.8360685e-17, 1.3631652e-11, 2.1812002e-18, 5.9393463e-16,\n              4.5445975e-16, 9.1455538e-11]], dtype=float32)], [DeviceArray([-0.00038439, -0.00253191,  0.        , -0.00024855,\n             -0.00149335, -0.0005279 ], dtype=float32), DeviceArray([ 9.0186906e-05,  2.0501162e-05,  0.0000000e+00,\n             -1.5605739e-05,  7.8567376e-05,  1.0919726e-05],            dtype=float32), DeviceArray([7.7366902e-10, 2.8706129e-11, 0.0000000e+00, 4.0496065e-11,\n             4.4863854e-10, 5.6050446e-11], dtype=float32)], [DeviceArray([[ 7.09596097e-01,  5.57229042e-01,  1.05266884e-01,\n               8.37738737e-02,  4.01628464e-01,  1.04380406e-01],\n             [ 1.63253590e-01,  2.98599415e-02, -6.43938184e-01,\n               8.21307227e-02,  6.04431272e-01, -6.21634185e-01],\n             [-1.62911832e-01, -2.34109700e-01, -1.84899673e-01,\n              -1.95914432e-01,  2.12329701e-01,  2.51049638e-01],\n             [ 6.75464392e-01, -2.52610683e-01, -1.87426805e-01,\n               2.40000457e-01, -5.11981368e-01, -4.66484815e-01],\n             [-3.82236729e-04,  4.80104297e-01, -6.04833551e-02,\n              -8.24540913e-01, -3.41470331e-01,  1.19573558e-02],\n             [ 1.48141757e-01, -1.05470764e-02, -2.78792083e-01,\n               1.00884447e-03,  7.85028264e-02,  6.96714282e-01]],            dtype=float32), DeviceArray([[-3.7734679e-08,  2.0914214e-07,  0.0000000e+00,\n               4.4612163e-09, -1.8412505e-08, -1.6384351e-07],\n             [ 2.5431862e-07,  5.8102722e-07,  0.0000000e+00,\n              -3.9158030e-07,  6.5379754e-07, -3.2795475e-07],\n             [-2.0847846e-08, -5.5632698e-09,  0.0000000e+00,\n               0.0000000e+00, -7.4284507e-09,  4.5505992e-08],\n             [-4.6340386e-07,  1.2311507e-06,  0.0000000e+00,\n              -8.9473851e-09, -4.2540904e-10,  3.4614908e-07],\n             [-4.7590757e-07,  1.1430946e-06,  0.0000000e+00,\n              -1.0665863e-09,  1.6385513e-08, -2.5949966e-08],\n             [ 6.1280770e-07,  2.1922931e-06,  0.0000000e+00,\n              -1.4440599e-07,  1.7412744e-06,  9.7419240e-07]],            dtype=float32), DeviceArray([[4.89823272e-16, 2.44132429e-15, 0.00000000e+00,\n              2.45463613e-18, 1.79699840e-16, 1.53430765e-15],\n             [4.67324764e-15, 2.39591468e-14, 0.00000000e+00,\n              1.44993897e-14, 2.56212982e-14, 6.71791317e-15],\n             [2.65992102e-17, 1.07641758e-17, 0.00000000e+00,\n              0.00000000e+00, 3.05696052e-18, 1.45802733e-16],\n             [1.06954859e-14, 8.00853592e-14, 0.00000000e+00,\n              1.86719016e-17, 3.72830896e-16, 7.56701373e-15],\n             [1.88198032e-14, 7.23305422e-14, 0.00000000e+00,\n              1.08007125e-19, 3.75352571e-15, 4.23270495e-15],\n             [9.34799587e-14, 2.79318913e-13, 0.00000000e+00,\n              1.72723126e-15, 1.98988395e-13, 9.09553601e-14]],            dtype=float32)], [DeviceArray([-0.00074821, -0.00081796,  0.00117048, -0.00178991,\n              0.00265362,  0.        ], dtype=float32), DeviceArray([ 1.1074917e-07,  2.4995950e-04,  1.8515976e-04,\n              3.9137376e-06, -6.2360908e-05,  0.0000000e+00],            dtype=float32), DeviceArray([2.8900531e-11, 7.9435143e-09, 4.1745047e-09, 3.8028711e-10,\n             2.4104238e-10, 0.0000000e+00], dtype=float32)], [DeviceArray([[-0.3056272 ,  0.4231755 , -0.04128612,  0.46694696,\n               0.16514003, -0.45115194],\n             [-0.02365104,  0.2639698 , -0.37985128,  0.8142715 ,\n               0.05928962, -0.27328548],\n             [-0.3441138 ,  0.3751544 ,  0.05917086,  0.4884692 ,\n               0.04338139,  0.2134114 ],\n             [-0.38789976, -0.38161513,  0.31562155,  0.10442398,\n              -0.00480278, -0.19789281],\n             [ 0.12581037,  0.37353298, -0.10928751,  0.503088  ,\n              -0.2655051 , -0.17124975],\n             [ 0.35595462,  0.18437892,  0.23258427, -0.85487086,\n              -0.2029693 , -0.33262372]], dtype=float32), DeviceArray([[ 1.4422523e-07,  1.0348479e-06, -2.1897273e-07,\n               2.1539340e-06, -9.7019038e-06,  0.0000000e+00],\n             [ 1.7801361e-07, -2.5086075e-07, -1.3055396e-07,\n               5.2816296e-07, -1.3365701e-06,  0.0000000e+00],\n             [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,\n               0.0000000e+00,  0.0000000e+00,  0.0000000e+00],\n             [ 1.8810088e-08,  8.0061959e-08,  5.1003372e-08,\n               3.1219592e-08, -2.3568070e-08,  0.0000000e+00],\n             [ 1.6152674e-07,  1.0853166e-06,  1.2019419e-06,\n               2.8108428e-07, -1.2651979e-07,  0.0000000e+00],\n             [ 3.1606952e-08,  2.8423892e-06,  5.0167625e-07,\n               1.1964460e-06, -5.3287968e-06,  0.0000000e+00]],            dtype=float32), DeviceArray([[1.1840859e-15, 7.9572782e-14, 3.1349770e-14, 2.4425308e-13,\n              4.6538823e-12, 0.0000000e+00],\n             [2.5218905e-15, 6.5746011e-15, 9.1246296e-16, 1.5102157e-14,\n              9.8801487e-14, 0.0000000e+00],\n             [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,\n              0.0000000e+00, 0.0000000e+00],\n             [3.2490592e-17, 8.4389236e-16, 5.7708000e-16, 1.1674825e-16,\n              3.7473287e-17, 0.0000000e+00],\n             [1.6931966e-15, 1.3198515e-13, 1.7659864e-13, 2.0507665e-14,\n              9.0556134e-16, 0.0000000e+00],\n             [2.1929211e-15, 7.6067006e-13, 2.3270353e-14, 7.4921501e-14,\n              1.4142405e-12, 0.0000000e+00]], dtype=float32)], [DeviceArray([ 0.        ,  0.00157121, -0.00141369, -0.00135325,\n             -0.00251697, -0.00212993], dtype=float32), DeviceArray([ 0.00000000e+00,  1.29414868e-04, -5.93711302e-05,\n              1.11967165e-04,  3.14280886e-04,  8.22721340e-05],            dtype=float32), DeviceArray([0.0000000e+00, 2.1965585e-08, 1.1964050e-09, 1.3787350e-09,\n             6.9818973e-09, 2.0284541e-09], dtype=float32)], [DeviceArray([[ 0.4786232 ,  0.07762085,  0.23054789,  0.24556062,\n              -0.00573604,  0.09755293],\n             [-0.7429878 ,  0.72179085,  0.18699315, -0.2683911 ,\n               0.57062393,  0.10477231],\n             [-0.5879746 ,  0.22517885,  0.58405346,  0.15380381,\n               0.5612985 , -0.374023  ],\n             [-0.13496533,  0.31007463, -0.91230136, -0.00988471,\n              -0.14309123,  0.64083695],\n             [-0.47050482, -0.09354866, -0.3502471 ,  0.02155269,\n              -0.23451555, -0.5596603 ],\n             [ 0.46003628, -0.26375806,  0.1943689 ,  0.04007643,\n              -0.29916492,  0.16186029]], dtype=float32), DeviceArray([[ 0.0000000e+00,  4.6203587e-07, -4.9911466e-07,\n               8.1376520e-08,  7.8805937e-07,  4.8135485e-07],\n             [ 0.0000000e+00, -3.8448088e-06,  2.5146437e-06,\n               8.5733035e-08,  4.4749559e-06,  9.3717990e-06],\n             [ 0.0000000e+00,  7.2909307e-07, -4.5917994e-07,\n               1.7394352e-07,  1.1992919e-06,  5.3257600e-07],\n             [ 0.0000000e+00, -7.2328521e-07,  3.5359469e-07,\n               5.5574421e-09,  1.1884179e-06,  2.0294008e-06],\n             [ 0.0000000e+00, -3.7733372e-07,  2.5784374e-07,\n               1.0839074e-07,  3.5430759e-07,  9.1234665e-07],\n             [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,\n               0.0000000e+00,  0.0000000e+00,  0.0000000e+00]],            dtype=float32), DeviceArray([[0.00000000e+00, 3.25262346e-14, 3.59916768e-14,\n              3.68739144e-16, 5.14987562e-14, 1.75674311e-14],\n             [0.00000000e+00, 7.44021056e-13, 3.08127899e-13,\n              4.13313856e-16, 1.09750945e-12, 4.47324686e-12],\n             [0.00000000e+00, 1.40464599e-13, 5.10655903e-14,\n              2.17324683e-15, 1.08902335e-13, 1.50668662e-14],\n             [0.00000000e+00, 3.73186691e-14, 6.29945264e-15,\n              3.80917003e-18, 7.84495827e-14, 2.20110930e-13],\n             [0.00000000e+00, 1.75216490e-14, 3.24122409e-15,\n              1.44898835e-15, 6.33568871e-15, 4.17144380e-14],\n             [0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n              0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],            dtype=float32)], [DeviceArray([0.00180455, 0.00180623, 0.00180595, 0.00259012], dtype=float32), DeviceArray([-0.00044176, -0.00043326, -0.00036839, -0.00124654], dtype=float32), DeviceArray([1.95022345e-08, 1.87587244e-08, 1.35600775e-08,\n             8.94127368e-08], dtype=float32)], [DeviceArray([[-0.02883793, -0.06974681,  0.24037988, -0.45696107],\n             [-0.66388506, -0.35703114, -0.34853068,  0.36051708],\n             [ 0.661668  , -0.03455266,  0.73796755, -0.4630023 ],\n             [-0.11637773, -0.829776  , -0.0157351 , -0.2001823 ],\n             [ 0.2375262 , -0.39262995, -0.61489004, -0.27574617],\n             [ 0.9198022 ,  0.37472868,  0.05519994, -0.7439775 ]],            dtype=float32), DeviceArray([[ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,\n               0.0000000e+00],\n             [-2.2636502e-06, -2.0661346e-06, -1.2783025e-06,\n              -1.4003868e-05],\n             [-5.6834824e-07, -5.8828175e-07, -5.8287685e-07,\n              -1.2936071e-06],\n             [-2.1617128e-11, -2.1906333e-11, -2.1303336e-11,\n              -1.0811599e-07],\n             [-8.0856603e-07, -6.9933321e-07, -3.0290084e-07,\n              -7.9441652e-06],\n             [-1.5500837e-07, -8.2882934e-08,  1.3715700e-07,\n              -2.3562773e-06]], dtype=float32), DeviceArray([[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n             [5.1146001e-13, 4.2600835e-13, 1.6287263e-13, 1.0453084e-11],\n             [3.2286870e-14, 3.4591704e-14, 3.3959343e-14, 9.9721216e-14],\n             [5.6415981e-23, 5.7961612e-23, 5.4847972e-23, 1.3727235e-15],\n             [6.5160449e-14, 4.8715518e-14, 9.0941515e-15, 3.3455365e-12],\n             [2.3930796e-15, 6.8169646e-16, 1.8897782e-15, 2.9122923e-13]],            dtype=float32)]), tree_def=PyTreeDef(CustomNode(<class 'flax.core.frozen_dict.FrozenDict'>[()], [{'params': {'fc0': {'bias': *, 'kernel': *}, 'fc1': {'bias': *, 'kernel': *}, 'fc2': {'bias': *, 'kernel': *}, 'fc3': {'bias': *, 'kernel': *}, 'fc4': {'bias': *, 'kernel': *}, 'fc_last': {'bias': *, 'kernel': *}}}])), subtree_defs=(PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *))))"},"metadata":{}}]},{"cell_type":"code","source":"# access param by layer name\nget_params(opt_state)['params']['fc0']","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"model.get_variable(name='fc_last')","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:40:47.867354Z","iopub.execute_input":"2021-07-04T06:40:47.868012Z","iopub.status.idle":"2021-07-04T06:40:47.892970Z","shell.execute_reply.started":"2021-07-04T06:40:47.867958Z","shell.execute_reply":"2021-07-04T06:40:47.891635Z"},"trusted":true},"execution_count":109,"outputs":[{"traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)","\u001b[0;32m<ipython-input-109-eab0686523b2>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_variable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'fc_last'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m","\u001b[0;31mTypeError\u001b[0m: get_variable() missing 1 required positional argument: 'col'"],"ename":"TypeError","evalue":"get_variable() missing 1 required positional argument: 'col'","output_type":"error"}]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]}]}