{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.7.10","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"## THINGS TO DO??\n\n### Existing NeRF\n1. Rewrite the model into FLAX (otherwise cant work with Flax-CLIP?) (Haiku seems like functional, but Flax is OOP?)\n2. Rewrite training loop into FLAX (does FLAX provide abstraction for writing training loop?? current training loop is pretty low-level)\n3. Refactor the notebook code --> module code\n - render scene, visualizing, animation ... etc.\n4. Dataloading for our concerned dataset\n5. consolidate all controllabe params in a class (e.g. `Config`)\n\n\n### NeRF --> DietNeRF\n1. Change sampling to 8 samples only\n2. Add CLIP into the training loop for a new loss function\n3. Check if DietNeRF can get comparable result to NeRF\n\n### Optional?\n1. Understand what the hell each operations are doing? (e.g. `get_rays` ... etc)\n2. Add W&B for visualization?\n3. Super large scale NeRF(ssssss) training --> get huge samples of scene for POC\n4. Optimize bottleneck operations by `jax.vmap`, `jax.pmap`","metadata":{}},{"cell_type":"code","source":"# enable showing live \"loss plot\" inside notebook\n!pip install livelossplot","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:25:20.105372Z","iopub.execute_input":"2021-07-04T06:25:20.105956Z","iopub.status.idle":"2021-07-04T06:25:27.213206Z","shell.execute_reply.started":"2021-07-04T06:25:20.105901Z","shell.execute_reply":"2021-07-04T06:25:27.211937Z"},"trusted":true},"execution_count":63,"outputs":[{"name":"stdout","text":"Requirement already satisfied: livelossplot in /opt/conda/lib/python3.7/site-packages (0.5.4)\nRequirement already satisfied: matplotlib in /opt/conda/lib/python3.7/site-packages (from livelossplot) (3.4.1)\nRequirement already satisfied: ipython in /opt/conda/lib/python3.7/site-packages (from livelossplot) (7.22.0)\nRequirement already satisfied: bokeh in /opt/conda/lib/python3.7/site-packages (from livelossplot) (2.3.1)\nRequirement already satisfied: typing-extensions>=3.7.4 in /opt/conda/lib/python3.7/site-packages (from bokeh->livelossplot) (3.7.4.3)\nRequirement already satisfied: packaging>=16.8 in /opt/conda/lib/python3.7/site-packages (from bokeh->livelossplot) (20.9)\nRequirement already satisfied: PyYAML>=3.10 in /opt/conda/lib/python3.7/site-packages (from bokeh->livelossplot) (5.3.1)\nRequirement already satisfied: pillow>=7.1.0 in /opt/conda/lib/python3.7/site-packages (from bokeh->livelossplot) (7.2.0)\nRequirement already satisfied: tornado>=5.1 in /opt/conda/lib/python3.7/site-packages (from bokeh->livelossplot) (6.1)\nRequirement already satisfied: numpy>=1.11.3 in /opt/conda/lib/python3.7/site-packages (from bokeh->livelossplot) (1.19.5)\nRequirement already satisfied: python-dateutil>=2.1 in /opt/conda/lib/python3.7/site-packages (from bokeh->livelossplot) (2.8.1)\nRequirement already satisfied: Jinja2>=2.7 in /opt/conda/lib/python3.7/site-packages (from bokeh->livelossplot) (2.11.3)\nRequirement already satisfied: MarkupSafe>=0.23 in /opt/conda/lib/python3.7/site-packages (from Jinja2>=2.7->bokeh->livelossplot) (1.1.1)\nRequirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging>=16.8->bokeh->livelossplot) (2.4.7)\nRequirement already satisfied: six>=1.5 in /opt/conda/lib/python3.7/site-packages (from python-dateutil>=2.1->bokeh->livelossplot) (1.15.0)\nRequirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (3.0.18)\nRequirement already satisfied: decorator in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (4.4.2)\nRequirement already satisfied: pygments in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (2.8.1)\nRequirement already satisfied: pickleshare in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (0.7.5)\nRequirement already satisfied: jedi>=0.16 in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (0.18.0)\nRequirement already satisfied: pexpect>4.3 in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (4.8.0)\nRequirement already satisfied: traitlets>=4.2 in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (5.0.5)\nRequirement already satisfied: setuptools>=18.5 in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (49.6.0.post20210108)\nRequirement already satisfied: backcall in /opt/conda/lib/python3.7/site-packages (from ipython->livelossplot) (0.2.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /opt/conda/lib/python3.7/site-packages (from jedi>=0.16->ipython->livelossplot) (0.8.1)\nRequirement already satisfied: ptyprocess>=0.5 in /opt/conda/lib/python3.7/site-packages (from pexpect>4.3->ipython->livelossplot) (0.7.0)\nRequirement already satisfied: wcwidth in /opt/conda/lib/python3.7/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython->livelossplot) (0.2.5)\nRequirement already satisfied: ipython-genutils in /opt/conda/lib/python3.7/site-packages (from traitlets>=4.2->ipython->livelossplot) (0.2.0)\nRequirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/lib/python3.7/site-packages (from matplotlib->livelossplot) (1.3.1)\nRequirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.7/site-packages (from matplotlib->livelossplot) (0.10.0)\n","output_type":"stream"}]},{"cell_type":"code","source":"%%capture\n!conda install -y -c conda-forge jax jaxlib flax optax datasets transformers\n!conda install -y importlib-metadata\n!pip install -U dm-haiku","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","execution":{"iopub.status.busy":"2021-07-04T06:25:27.215245Z","iopub.execute_input":"2021-07-04T06:25:27.215562Z","iopub.status.idle":"2021-07-04T06:27:11.574413Z","shell.execute_reply.started":"2021-07-04T06:25:27.215530Z","shell.execute_reply":"2021-07-04T06:27:11.572485Z"},"trusted":true},"execution_count":64,"outputs":[]},{"cell_type":"markdown","source":"","metadata":{}},{"cell_type":"code","source":"# TPU setup\nimport os\nif 'TPU_NAME' in os.environ:\n import requests\n if 'TPU_DRIVER_MODE' not in globals():\n url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'\n resp = requests.post(url)\n TPU_DRIVER_MODE = 1\n\n from jax.config import config\n config.FLAGS.jax_xla_backend = \"tpu_driver\"\n config.FLAGS.jax_backend_target = os.environ['TPU_NAME']\n print('Registered TPU:', config.FLAGS.jax_backend_target)\nelse:\n print('No TPU detected. Can be changed under \"Runtime/Change runtime type\".')\n\n# Module check\nimport jax\nimport flax\nimport haiku as hk\n\nfor _m in (jax, flax, hk):\n print(f'{_m.__name__}: {_m.__version__}')\njax.local_devices()","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.578625Z","iopub.execute_input":"2021-07-04T06:27:11.579247Z","iopub.status.idle":"2021-07-04T06:27:11.600702Z","shell.execute_reply.started":"2021-07-04T06:27:11.579197Z","shell.execute_reply":"2021-07-04T06:27:11.598533Z"},"trusted":true},"execution_count":65,"outputs":[{"name":"stdout","text":"No TPU detected. Can be changed under \"Runtime/Change runtime type\".\njax: 0.2.16\nflax: 0.3.4\nhaiku: 0.0.4\n","output_type":"stream"},{"execution_count":65,"output_type":"execute_result","data":{"text/plain":"[CpuDevice(id=0)]"},"metadata":{}}]},{"cell_type":"code","source":"from functools import partial\n\nimport jax\nfrom jax import random, grad, jit, vmap, flatten_util, nn\nfrom jax.experimental import optimizers # change due to version difference\nfrom jax.config import config\nimport jax.numpy as np\n\nimport haiku as hk\nfrom haiku._src import utils\n\nfrom livelossplot import PlotLosses\nimport matplotlib.pyplot as plt\nfrom tqdm.notebook import tqdm as tqdm\nimport cv2\nimport imageio\nimport glob\nfrom IPython.display import clear_output\nimport pickle\nfrom skimage.metrics import structural_similarity as ssim_fn\n\nrng = jax.random.PRNGKey(42)","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.603322Z","iopub.execute_input":"2021-07-04T06:27:11.603678Z","iopub.status.idle":"2021-07-04T06:27:11.619997Z","shell.execute_reply.started":"2021-07-04T06:27:11.603646Z","shell.execute_reply":"2021-07-04T06:27:11.618654Z"},"trusted":true},"execution_count":66,"outputs":[]},{"cell_type":"code","source":"#ls ../input/pull-phototourism-images/sacre_coeur/dense/images\nDATASET = 'sacre'\nposedir = f'../input/phototourism/phototourism/sacre' # Directory condtains [bds.npy, c2w_mats.npy, kinv_mats.npy, res_mats.npy]\nimgdir = f'../input/pull-phototourism-images/sacre_coeur/dense/images' # Directory of images","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.621386Z","iopub.execute_input":"2021-07-04T06:27:11.621825Z","iopub.status.idle":"2021-07-04T06:27:11.630402Z","shell.execute_reply.started":"2021-07-04T06:27:11.621792Z","shell.execute_reply":"2021-07-04T06:27:11.629270Z"},"trusted":true},"execution_count":67,"outputs":[]},{"cell_type":"markdown","source":"### 1. Helper Functions for Loading Data","metadata":{}},{"cell_type":"code","source":"posedata = {}\nfor f in os.listdir(posedir):\n if '.npy' not in f:\n continue\n z = np.load(os.path.join(posedir, f))\n posedata[f.split('.')[0]] = z\nprint('Pose data loaded - ', posedata.keys())\n\nimgfiles = sorted(glob.glob(imgdir + '/*.jpg'))\nprint(f'{len(imgfiles)} images')","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.633311Z","iopub.execute_input":"2021-07-04T06:27:11.633646Z","iopub.status.idle":"2021-07-04T06:27:11.676300Z","shell.execute_reply.started":"2021-07-04T06:27:11.633614Z","shell.execute_reply":"2021-07-04T06:27:11.674797Z"},"trusted":true},"execution_count":68,"outputs":[{"name":"stdout","text":"Pose data loaded - dict_keys(['kinv_mats', 'res_mats', 'c2w_mats', 'bds'])\n1179 images\n","output_type":"stream"}]},{"cell_type":"code","source":"@jit\ndef get_rays(c2w, kinv, i, j):\n# i, j = np.meshgrid(np.arange(W), np.arange(H), indexing='xy')\n pixco = np.stack([i, j, np.ones_like(i)], -1)\n dirs = pixco @ kinv.T\n# dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1)\n rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)\n rays_o = np.broadcast_to(c2w[:3,-1], rays_d.shape)\n return np.stack([rays_o, rays_d], 0)\n\n\ndef normalize(x):\n return x / np.linalg.norm(x)\n\n\ndef viewmatrix(z, up, pos):\n vec2 = normalize(z)\n vec1_avg = up\n vec0 = normalize(np.cross(vec1_avg, vec2))\n vec1 = normalize(np.cross(vec2, vec0))\n m = np.stack([vec0, vec1, vec2, pos], 1)\n return m\n\n\ndef ptstocam(pts, c2w):\n tt = np.matmul(c2w[:3,:3].T, (pts-c2w[:3,3])[...,np.newaxis])[...,0]\n return tt\n\n\ndef poses_avg(poses):\n center = poses[:, :3, 3].mean(0)\n vec2 = normalize(poses[:, :3, 2].sum(0))\n up = poses[:, :3, 1].sum(0)\n return viewmatrix(vec2, up, center)\n\n\ndef render_path_spiral(c2w, up, rads, focal, zrate, rots, N):\n \"\"\"\n enumerate list of poses around a spiral\n used for test set visualization\n \"\"\"\n render_poses = []\n rads = np.array(list(rads) + [1.])\n for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:\n c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) \n z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.])))\n render_poses.append(viewmatrix(z, up, c))\n return render_poses","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.678295Z","iopub.execute_input":"2021-07-04T06:27:11.678750Z","iopub.status.idle":"2021-07-04T06:27:11.699145Z","shell.execute_reply.started":"2021-07-04T06:27:11.678692Z","shell.execute_reply":"2021-07-04T06:27:11.698223Z"},"trusted":true},"execution_count":69,"outputs":[]},{"cell_type":"code","source":"def get_example(img_idx, split='train', downsample=4):\n sc = .05\n \n # first 20 are test, next 5 are validation, the rest are training:\n # https://github.com/tancik/learnit/issues/3\n if 'train' in split:\n img_idx = img_idx + 25\n if 'val' in split:\n img_idx = img_idx + 20\n \n # uint8 --> float\n img = imageio.imread(imgfiles[img_idx])[...,:3]/255.\n \n # WHAT DO THESE MATRICES MEAN???\n # (4, 4)\n c2w = posedata['c2w_mats'][img_idx]\n # (3, 3)\n kinv = posedata['kinv_mats'][img_idx]\n c2w = np.concatenate([c2w[:3,:3], c2w[:3,3:4]*sc], -1)\n # (2, )\n bds = posedata['bds'][img_idx] * np.array([.9, 1.2]) * sc\n H, W = img.shape[:2]\n \n # (0, 4, 8, ..., H)\n # WHAT ARE THE PURPOSES OF THIS MATRIX???\n i, j = np.meshgrid(np.arange(0,W,downsample), np.arange(0,H,downsample), indexing='xy')\n \n test_images = img[j, i]\n test_rays = get_rays(c2w, kinv, i, j)\n return test_images, test_rays, bds","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.702928Z","iopub.execute_input":"2021-07-04T06:27:11.703761Z","iopub.status.idle":"2021-07-04T06:27:11.717429Z","shell.execute_reply.started":"2021-07-04T06:27:11.703650Z","shell.execute_reply":"2021-07-04T06:27:11.715895Z"},"trusted":true},"execution_count":70,"outputs":[]},{"cell_type":"markdown","source":"### 2. NeRF Renderer","metadata":{}},{"cell_type":"code","source":"def render_rays(\n rnd_input, model, params, \n bvals, rays, near, far, \n N_samples, rand=False, allret=False\n ):\n rays_o, rays_d = rays\n\n # Compute 3D query points\n z_vals = np.linspace(near, far, N_samples) \n if rand:\n z_vals += random.uniform(rnd_input, shape=list(rays_o.shape[:-1]) + [N_samples]) * (far-near)/N_samples\n # r(t) = o + t*d\n pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]\n \n # Run network\n pts_flat = np.reshape(pts, [-1,3])\n if bvals is not None:\n pts_flat = np.concatenate([np.sin(pts_flat @ bvals.T), \n np.cos(pts_flat @ bvals.T)], axis=-1)\n \n raw = model.apply(params, pts_flat)\n raw = np.reshape(raw, list(pts.shape[:-1]) + [4])\n \n # Compute opacities and colors\n rgb, sigma_a = raw[...,:3], raw[...,3]\n sigma_a = jax.nn.relu(sigma_a)\n rgb = jax.nn.sigmoid(rgb) \n \n # Do volume rendering\n dists = np.concatenate([z_vals[..., 1:] - z_vals[..., :-1], np.broadcast_to([1e10], z_vals[...,:1].shape)], -1) \n alpha = 1. - np.exp(-sigma_a * dists)\n trans = np.minimum(1., 1. - alpha + 1e-10)\n trans = np.concatenate([np.ones_like(trans[...,:1]), trans[...,:-1]], -1) \n weights = alpha * np.cumprod(trans, -1)\n \n rgb_map = np.sum(weights[...,None] * rgb, -2) \n acc_map = np.sum(weights, -1)\n \n if False:\n rgb_map = rgb_map + (1.-acc_map[..., None])\n \n if not allret:\n return rgb_map\n \n depth_map = np.sum(weights * z_vals, -1) \n\n return rgb_map, depth_map, acc_map\n\n\ndef render_fn_inner(rnd_input, model, params, bvals, rays, near, far, rand, allret, N_samples):\n return render_rays(rnd_input, model, params, bvals, rays, near, far, \n N_samples=N_samples, rand=rand, allret=allret)\n\n# optimize render_fn_inner by JIT (func in, func out)\nrender_fn_inner = jit(render_fn_inner, static_argnums=(1, 7, 8, 9))\n\n\ndef render_fn(rnd_input, model, params, bvals, rays, near, far, N_samples, rand):\n chunk = 5\n for i in range(0, rays.shape[1], chunk):\n out = render_fn_inner(rnd_input, model, params, bvals, rays[:,i:i+chunk], near, far, rand, True, N_samples)\n if i==0:\n rets = out\n else:\n rets = [np.concatenate([a, b], 0) for a, b in zip(rets, out)]\n return rets","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.719345Z","iopub.execute_input":"2021-07-04T06:27:11.719675Z","iopub.status.idle":"2021-07-04T06:27:11.741107Z","shell.execute_reply.started":"2021-07-04T06:27:11.719638Z","shell.execute_reply":"2021-07-04T06:27:11.740228Z"},"trusted":true},"execution_count":71,"outputs":[]},{"cell_type":"markdown","source":"### 3. NeRF Model Architecture","metadata":{}},{"cell_type":"code","source":"class Model(hk.Module):\n def __init__(self):\n super().__init__()\n self.width = 256\n self.depth = 6\n self.use_viewdirs = False\n \n def __call__(self, coords, view_dirs=None):\n sh = coords.shape\n if self.use_viewdirs:\n viewdirs = None\n viewdirs = np.repeat(viewdirs[...,None,:], coords.shape[-2], axis=-2)\n viewdirs /= np.linalg.norm(viewdirs, axis=-1, keepdims=True)\n viewdirs = np.reshape(viewdirs, (-1,3))\n viewdirs = hk.Linear(output_size=self.width//2)(viewdirs)\n viewdirs = jax.nn.relu(viewdirs)\n coords = np.reshape(coords, [-1,3])\n \n # positional encoding\n x = np.concatenate([np.concatenate([np.sin(coords*(2**i)), np.cos(coords*(2**i))], axis=-1) for i in np.linspace(0,8,20)], axis=-1)\n\n for _ in range(self.depth-1):\n x = hk.Linear(output_size=self.width)(x)\n x = jax.nn.relu(x)\n \n if self.use_viewdirs:\n density = hk.Linear(output_size=1)(x)\n x = np.concatenate([x,viewdirs], axis=-1)\n x = hk.Linear(output_size=self.width)(x)\n x = jax.nn.relu(x)\n rgb = hk.Linear(output_size=3)(x)\n out = np.concatenate([density, rgb], axis=-1)\n else:\n out = hk.Linear(output_size=4)(x)\n out = np.reshape(out, list(sh[:-1]) + [4])\n return out","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.742202Z","iopub.execute_input":"2021-07-04T06:27:11.742739Z","iopub.status.idle":"2021-07-04T06:27:11.761243Z","shell.execute_reply.started":"2021-07-04T06:27:11.742684Z","shell.execute_reply":"2021-07-04T06:27:11.760296Z"},"trusted":true},"execution_count":72,"outputs":[]},{"cell_type":"markdown","source":"### 4. Training Loop","metadata":{}},{"cell_type":"code","source":"batch_size = 64\nN_samples = 128\ninner_step_size = 1\n\nmodel = hk.without_apply_rng(hk.transform(lambda x, y=None: Model()(x, y)))\n\nmse_fn = jit(lambda x, y: np.mean((x - y)**2))\npsnr_fn = jit(lambda x, y: -10 * np.log10(mse_fn(x, y)))\n\n@jit\ndef single_step(rng, image, rays, params, bds):\n def sgd(param, update):\n return param - inner_step_size * update\n \n rng, rng_inputs = jax.random.split(rng)\n def loss_model(params):\n g = render_rays(rng_inputs, model, params, None, rays, bds[0], bds[1], N_samples, rand=True)\n return mse_fn(g, image)\n \n model_loss, grad = jax.value_and_grad(loss_model)(params)\n new_params = jax.tree_multimap(sgd, params, grad)\n return rng, new_params, model_loss\n\ndef update_network_weights(rng, images, rays, params, inner_steps, bds):\n for _ in range(inner_steps):\n rng, rng_input = random.split(rng)\n idx = random.randint(rng_input, shape=(batch_size,), minval=0, maxval=images.shape[0])\n image_sub = images[idx,:]\n rays_sub = rays[:,idx,:]\n \n rng, params, loss = single_step(rng, image_sub, rays_sub, params, bds)\n return rng, params, loss","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.762522Z","iopub.execute_input":"2021-07-04T06:27:11.763042Z","iopub.status.idle":"2021-07-04T06:27:11.781650Z","shell.execute_reply.started":"2021-07-04T06:27:11.763004Z","shell.execute_reply":"2021-07-04T06:27:11.780574Z"},"trusted":true},"execution_count":73,"outputs":[]},{"cell_type":"code","source":"plt_groups = {'Train PSNR':[], 'Test PSNR':[]}\nplotlosses_model = PlotLosses(groups=plt_groups)","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.782906Z","iopub.execute_input":"2021-07-04T06:27:11.783199Z","iopub.status.idle":"2021-07-04T06:27:11.797320Z","shell.execute_reply.started":"2021-07-04T06:27:11.783172Z","shell.execute_reply":"2021-07-04T06:27:11.796476Z"},"trusted":true},"execution_count":74,"outputs":[]},{"cell_type":"code","source":"max_iters = 150000\n\ninner_update_steps = 64\nlr = 5e-4\n\nexp_name = f'{DATASET}_ius_{inner_update_steps}_ilr_{inner_step_size}_olr_{lr}_bs_{batch_size}'\nexp_dir = f'checkpoint/phototourism_checkpoints/{exp_name}/'\n\nif not os.path.exists(exp_dir):\n os.makedirs(exp_dir)\n\nparams = model.init(rng, np.ones((1,3)))\n\nopt = optimizers.adam(lr)\nopt_state = opt.init_fun(params)\n\ntest_inner_steps = 64\n\n\ndef update_model(rng, params, opt_state, image, rays, bds):\n rng, new_params, model_loss = update_network_weights(rng, image, rays, params, inner_update_steps, bds)\n \n def calc_grad(params, new_params):\n return params - new_params\n model_grad = jax.tree_multimap(calc_grad, params, new_params)\n \n updates, opt_state = opt.update(model_grad, opt_state)\n params = optimizers.apply_updates(params, updates)\n return rng, params, opt_state, model_loss\n\n@jit\ndef update_model_single(rng, params, opt_state, image, rays, bds):\n rng, new_params, model_loss = single_step(rng, image, rays, params, bds)\n \n def calc_grad(params, new_params):\n return params - new_params\n model_grad = jax.tree_multimap(calc_grad, params, new_params)\n \n updates, opt_state = opt.update(model_grad, opt_state)\n params = optimizers.apply_updates(params, updates)\n return rng, params, opt_state, model_loss\n\n\n\nplt_groups['Train PSNR'].append(exp_name+f'_train')\nplt_groups['Test PSNR'].append(exp_name+f'_test')\nstep = 0\n\ntrain_psnrs = []\nrng = jax.random.PRNGKey(0)\n\ntrain_steps = []\ntrain_psnrs_all = []\ntest_steps = []\ntest_psnrs_all = []\nfor step in tqdm(range(max_iters)):\n try:\n rng, rng_input = jax.random.split(rng)\n img_idx = random.randint(rng_input, shape=(), minval=0, maxval=len(imgfiles)-25) \n images, rays, bds = get_example(img_idx, downsample=1)\n except:\n print('data loading error')\n raise\n continue\n \n\n images = np.reshape(images, (-1,3))\n rays = np.reshape(rays, (2,-1,3))\n\n if inner_update_steps == 1:\n rng, rng_input = random.split(rng)\n idx = random.randint(rng_input, shape=(batch_size,), minval=0, maxval=images.shape[0])\n rng, params, opt_state, loss = update_model_single(rng, params, opt_state, \n images[idx,:], rays[:,idx,:], bds)\n else:\n rng, params, opt_state, loss = update_model(rng, params, opt_state, \n images, rays, bds)\n train_psnrs.append(-10 * np.log10(loss))\n \n # track model loss\n if step % 250 == 0:\n plotlosses_model.update({exp_name+'_train':np.mean(np.array(train_psnrs))}, current_step=step)\n train_steps.append(step)\n train_psnrs_all.append(np.mean(np.array(train_psnrs)))\n train_psnrs = []\n \n # run validation\n if step % 500 == 0 and step != 0:\n test_psnr = []\n for ti in range(5):\n test_images, test_rays, bds = get_example(ti, split='val', downsample=2)\n\n test_images, test_holdout_images = np.split(test_images, [test_images.shape[1]//2], axis=1)\n test_rays, test_holdout_rays = np.split(test_rays, [test_rays.shape[2]//2], axis=2)\n\n test_images_flat = np.reshape(test_images, (-1,3))\n test_rays = np.reshape(test_rays, (2,-1,3))\n\n rng, test_params, test_inner_loss = update_network_weights(rng, test_images_flat, test_rays, params, test_inner_steps, bds)\n\n test_result = np.clip(render_fn(rng, model, test_params, None, test_holdout_rays, bds[0], bds[1], N_samples, rand=False)[0], 0, 1)\n test_psnr.append(psnr_fn(test_holdout_images, test_result))\n test_psnr = np.mean(np.array(test_psnr))\n\n test_steps.append(step)\n test_psnrs_all.append(test_psnr)\n \n plotlosses_model.update({exp_name+'_test':test_psnr}, current_step=step)\n plotlosses_model.send()\n\n plt.figure(figsize=(15,5)) \n plt.subplot(1,3, 1)\n plt.imshow(test_images)\n plt.subplot(1,3, 2)\n plt.imshow(test_holdout_images)\n plt.subplot(1,3, 3)\n plt.imshow(test_result)\n plt.show()\n \n # save model checkpoint + render sample view on test set for model check\n if step % 10000 == 0 and step != 0:\n test_images, test_rays, bds = get_example(0, split='test')\n test_images_flat = np.reshape(test_images, (-1,3))\n test_rays = np.reshape(test_rays, (2,-1,3))\n rng, test_params_1, test_inner_loss = update_network_weights(rng, test_images_flat, test_rays, params, test_inner_steps, bds)\n\n test_images, test_rays, bds = get_example(1, split='test')\n test_images_flat = np.reshape(test_images, (-1,3))\n test_rays = np.reshape(test_rays, (2,-1,3))\n rng, test_params_2, test_inner_loss = update_network_weights(rng, test_images_flat, test_rays, params, test_inner_steps, bds)\n \n poses = posedata['c2w_mats']\n c2w = poses_avg(poses)\n focal = .8\n render_poses = render_path_spiral(c2w, c2w[:3,1], [.1, .1, .05], focal, zrate=.5, rots=2, N=120)\n \n bds = np.array([5., 25.]) * .05\n H = 128\n W = H*3//2\n f = H * 1.\n kinv = np.array(\n [1./f, 0, -W*.5/f,\n 0, -1./f, H*.5/f,\n 0, 0, -1.]\n ).reshape([3,3])\n i, j = np.meshgrid(np.arange(0,W), np.arange(0,H), indexing='xy')\n renders = []\n for p, c2w in enumerate(tqdm(render_poses)):\n rays = get_rays(c2w, kinv, i, j)\n interp = p / len(render_poses)\n interp_params = jax.tree_multimap(lambda x, y: y*p/len(render_poses) + x*(1-p/len(render_poses)), test_params_1, test_params_2)\n result = render_fn(rng, model, interp_params, None, rays, bds[0], bds[1], N_samples, rand=False)[0]\n renders.append(result)\n \n renders = (np.clip(np.array(renders), 0, 1)*255).astype(np.uint8)\n imageio.mimwrite(f'{exp_dir}render_sprial_{step}.mp4', renders, fps=30, quality=8)\n \n plt.plot(train_steps, train_psnrs_all)\n plt.savefig(f'{exp_dir}train_curve_{step}.png')\n \n plt.plot(test_steps, test_psnrs_all)\n plt.savefig(f'{exp_dir}test_curve_{step}.png')\n \n with open(f'{exp_dir}checkpount_{step}.pkl', 'wb') as file:\n pickle.dump(params, file)","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.798560Z","iopub.execute_input":"2021-07-04T06:27:11.799038Z","iopub.status.idle":"2021-07-04T06:27:11.942383Z","shell.execute_reply.started":"2021-07-04T06:27:11.798994Z","shell.execute_reply":"2021-07-04T06:27:11.940947Z"},"trusted":true},"execution_count":75,"outputs":[{"traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)","\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mopt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptimizers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mopt_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mtest_inner_steps\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m64\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mAttributeError\u001b[0m: 'Optimizer' object has no attribute 'init_fun'"],"ename":"AttributeError","evalue":"'Optimizer' object has no attribute 'init_fun'","output_type":"error"}]},{"cell_type":"markdown","source":"### 1. PLAYGROUND\n- optimizers API: https://jax.readthedocs.io/en/latest/_modules/jax/experimental/optimizers.html#adam","metadata":{}},{"cell_type":"code","source":"from flax import linen as nn","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:17.457842Z","iopub.execute_input":"2021-07-04T06:27:17.458453Z","iopub.status.idle":"2021-07-04T06:27:17.462662Z","shell.execute_reply.started":"2021-07-04T06:27:17.458415Z","shell.execute_reply":"2021-07-04T06:27:17.461847Z"},"trusted":true},"execution_count":76,"outputs":[]},{"cell_type":"code","source":"nn.Dense?","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:28:32.651658Z","iopub.execute_input":"2021-07-04T06:28:32.652070Z","iopub.status.idle":"2021-07-04T06:28:32.728399Z","shell.execute_reply.started":"2021-07-04T06:28:32.652032Z","shell.execute_reply":"2021-07-04T06:28:32.726998Z"},"trusted":true},"execution_count":79,"outputs":[{"output_type":"display_data","data":{"text/plain":"\u001b[0;31mInit signature:\u001b[0m\n\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDense\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m \u001b[0muse_bias\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m<\u001b[0m\u001b[0;32mclass\u001b[0m \u001b[0;34m'jax._src.numpy.lax_numpy.float32'\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m \u001b[0mprecision\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m \u001b[0mkernel_init\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mCallable\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIterable\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m<\u001b[0m\u001b[0mfunction\u001b[0m \u001b[0mvariance_scaling\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m<\u001b[0m\u001b[0mlocals\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m \u001b[0mat\u001b[0m \u001b[0;36m0x7f66d9aac050\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m \u001b[0mbias_init\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mCallable\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIterable\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m<\u001b[0m\u001b[0mfunction\u001b[0m \u001b[0mzeros\u001b[0m \u001b[0mat\u001b[0m \u001b[0;36m0x7f66e16b1050\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m \u001b[0mparent\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mType\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mForwardRef\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Module'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mType\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mForwardRef\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Scope'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mType\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mForwardRef\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'_Sentinel'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m<\u001b[0m\u001b[0mflax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_Sentinel\u001b[0m \u001b[0mobject\u001b[0m \u001b[0mat\u001b[0m \u001b[0;36m0x7f66d9abf8d0\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;31mDocstring:\u001b[0m \nA linear transformation applied over the last dimension of the input.\n\nAttributes:\n features: the number of output features.\n use_bias: whether to add a bias to the output (default: True).\n dtype: the dtype of the computation (default: float32).\n precision: numerical precision of the computation see `jax.lax.Precision`\n for details.\n kernel_init: initializer function for the weight matrix.\n bias_init: initializer function for the bias.\n\u001b[0;31mFile:\u001b[0m /opt/conda/lib/python3.7/site-packages/flax/linen/linear.py\n\u001b[0;31mType:\u001b[0m type\n\u001b[0;31mSubclasses:\u001b[0m \n"},"metadata":{}}]},{"cell_type":"code","source":"class ModelFlax(nn.Module):\n width = 256\n depth = 6\n \n @nn.compact\n def __call__(self, coords):\n sh = coords.shape\n coords = np.reshape(coords, [-1,3])\n \n # positional encoding\n x = np.concatenate([np.concatenate([np.sin(coords*(2**i)), np.cos(coords*(2**i))], axis=-1) for i in np.linspace(0,8,20)], axis=-1)\n\n for idx in range(self.depth-1):\n #x = hk.Linear(output_size=self.width)(x)\n x = nn.Dense(self.depth, name=f'fc{idx}')(x)\n x = nn.relu(x)\n\n #out = hk.Linear(output_size=4)(x)\n out = nn.Dense(4, name='fc_last')(x)\n out = np.reshape(out, list(sh[:-1]) + [4])\n return out","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:37:54.690404Z","iopub.execute_input":"2021-07-04T06:37:54.690877Z","iopub.status.idle":"2021-07-04T06:37:54.702764Z","shell.execute_reply.started":"2021-07-04T06:37:54.690841Z","shell.execute_reply":"2021-07-04T06:37:54.701461Z"},"trusted":true},"execution_count":95,"outputs":[]},{"cell_type":"code","source":"model = ModelFlax()\nkey1, key2 = random.split(jax.random.PRNGKey(0))\ndummy_x = random.normal(key1, (1, 3))\nparams = model.init(key2, dummy_x)","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:37:55.558848Z","iopub.execute_input":"2021-07-04T06:37:55.559294Z","iopub.status.idle":"2021-07-04T06:37:56.881656Z","shell.execute_reply.started":"2021-07-04T06:37:55.559253Z","shell.execute_reply":"2021-07-04T06:37:56.878424Z"},"trusted":true},"execution_count":96,"outputs":[]},{"cell_type":"code","source":"dummy_x = \nModelFlax.init","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:33:01.022645Z","iopub.execute_input":"2021-07-04T06:33:01.023024Z","iopub.status.idle":"2021-07-04T06:33:01.032804Z","shell.execute_reply.started":"2021-07-04T06:33:01.022993Z","shell.execute_reply":"2021-07-04T06:33:01.031934Z"},"trusted":true},"execution_count":85,"outputs":[{"execution_count":85,"output_type":"execute_result","data":{"text/plain":" flax.core.frozen_dict.FrozenDict[str, typing.Mapping[str, typing.Any]]>"},"metadata":{}}]},{"cell_type":"code","source":"model = hk.without_apply_rng(hk.transform(lambda x, y=None: ModelHaiku()(x, y)))\nparams = model.init(rng, np.ones((1,3)))\n","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"class ModelHaiku(hk.Module):\n def __init__(self):\n super().__init__()\n self.width = 256\n self.depth = 6\n \n def __call__(self, coords):\n sh = coords.shape\n coords = np.reshape(coords, [-1,3])\n \n # positional encoding\n x = np.concatenate([np.concatenate([np.sin(coords*(2**i)), np.cos(coords*(2**i))], axis=-1) for i in np.linspace(0,8,20)], axis=-1)\n\n for _ in range(self.depth-1):\n x = hk.Linear(output_size=self.width)(x)\n x = jax.nn.relu(x)\n\n out = hk.Linear(output_size=4)(x)\n out = np.reshape(out, list(sh[:-1]) + [4])\n return out","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:20.392396Z","iopub.execute_input":"2021-07-04T06:27:20.393211Z","iopub.status.idle":"2021-07-04T06:27:20.403900Z","shell.execute_reply.started":"2021-07-04T06:27:20.393166Z","shell.execute_reply":"2021-07-04T06:27:20.402592Z"},"trusted":true},"execution_count":77,"outputs":[]},{"cell_type":"code","source":"## OLD JAX + HAIKU\ndef single_step(rng, image, rays, params, bds):\n def sgd(param, update):\n return param - inner_step_size * update\n \n rng, rng_inputs = jax.random.split(rng)\n def loss_model(params):\n g = render_rays(rng_inputs, model, params, None, rays, bds[0], bds[1], N_samples, rand=True)\n return mse_fn(g, image)\n \n model_loss, grad = jax.value_and_grad(loss_model)(params)\n new_params = jax.tree_multimap(sgd, params, grad)\n return rng, new_params, model_loss\n\n\nmodel = hk.without_apply_rng(hk.transform(lambda x, y=None: ModelHaiku()(x, y)))\nparams = model.init(rng, np.ones((1,3)))\nopt = optimizers.adam(lr)\nopt_state = opt.init_fun(params)\nupdates, opt_state = opt.update(model_grad, opt_state)\nparams = optimizers.apply_updates(params, updates)","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:21.016922Z","iopub.execute_input":"2021-07-04T06:27:21.017295Z","iopub.status.idle":"2021-07-04T06:27:21.065052Z","shell.execute_reply.started":"2021-07-04T06:27:21.017263Z","shell.execute_reply":"2021-07-04T06:27:21.062768Z"},"trusted":true},"execution_count":78,"outputs":[{"traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)","\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwithout_apply_rng\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mModelHaiku\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0mopt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptimizers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0mopt_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/opt/conda/lib/python3.7/site-packages/haiku/_src/transform.py\u001b[0m in \u001b[0;36minit_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minit_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 111\u001b[0;31m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 112\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m raise ValueError(\"If your transformed function uses `hk.{get,set}_state` \"\n","\u001b[0;32m/opt/conda/lib/python3.7/site-packages/haiku/_src/transform.py\u001b[0m in \u001b[0;36minit_fn\u001b[0;34m(rng, *args, **kwargs)\u001b[0m\n\u001b[1;32m 295\u001b[0m \u001b[0mrng\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mto_prng_sequence\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merr_msg\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mINIT_RNG_ERROR\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 296\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mbase\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 297\u001b[0;31m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 298\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollect_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollect_initial_state\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 299\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m(x, y)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwithout_apply_rng\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mModelHaiku\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0mopt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptimizers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/opt/conda/lib/python3.7/site-packages/haiku/_src/module.py\u001b[0m in \u001b[0;36mwrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 426\u001b[0m \u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstateful\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnamed_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlocal_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 427\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 428\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 429\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 430\u001b[0m \u001b[0;31m# Module names are set in the constructor. If `f` is the constructor then\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/opt/conda/lib/python3.7/site-packages/haiku/_src/module.py\u001b[0m in \u001b[0;36mrun_interceptors\u001b[0;34m(bound_method, method_name, self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 277\u001b[0m \u001b[0;34m\"\"\"Runs any method interceptors or the original method.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 278\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0minterceptor_stack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 279\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mbound_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 280\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 281\u001b[0m ctx = MethodContext(module=self,\n","\u001b[0;31mTypeError\u001b[0m: __call__() takes 2 positional arguments but 3 were given"],"ename":"TypeError","evalue":"__call__() takes 2 positional arguments but 3 were given","output_type":"error"}]},{"cell_type":"code","source":"## NEW JAX + HAIKU\nlr = 1e-3\nnum_steps = 3\n\nmodel = hk.without_apply_rng(hk.transform(lambda x: ModelHaiku()(x)))\nparams = model.init(rng, np.ones((1,3)))\nopt_init, opt_update, get_params = optimizers.adam(lr)\nopt_state = opt_init(params)\n\n\ndef single_step_v2(step, rng, image, rays, bds, opt_state):\n def loss_model(params):\n g = render_rays(rng_inputs, model, params,\n None, rays, bds[0], bds[1], \n N_samples, rand=True)\n return mse_fn(g, image)\n rng, rng_inputs = jax.random.split(rng)\n value, grads = jax.value_and_grad(loss_model)(get_params(opt_state))\n opt_state = opt_update(step, grads, opt_state)\n return value, opt_state","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.949773Z","iopub.status.idle":"2021-07-04T06:27:11.950196Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"rng = jax.random.PRNGKey(0)\n\nfor istep in range(num_steps):\n rng, rng_input = jax.random.split(rng)\n img_idx = random.randint(rng_input, shape=(), minval=0, maxval=len(imgfiles)-25)\n images, rays, bds = get_example(img_idx, downsample=1)\n images = np.reshape(images, (-1,3))\n rays = np.reshape(rays, (2,-1,3))\n rng, rng_input = random.split(rng)\n idx = random.randint(rng_input, shape=(batch_size,), minval=0, maxval=images.shape[0])\n loss, opt_state = single_step_v2(istep, rng, images[idx,:], rays[:,idx,:], bds, opt_state)","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:27:11.951115Z","iopub.status.idle":"2021-07-04T06:27:11.951524Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"## NEW JAX + FLAX!\nlr = 1e-3\nnum_steps = 3\n\n# model = hk.without_apply_rng(hk.transform(lambda x: ModelHaiku()(x)))\n# params = model.init(rng, np.ones((1,3)))\n\nmodel = ModelFlax()\nkey1, key2 = random.split(jax.random.PRNGKey(0))\ndummy_x = random.normal(key1, (1, 3))\nparams = model.init(key2, dummy_x)\n\nopt_init, opt_update, get_params = optimizers.adam(lr)\nopt_state = opt_init(params)\n\n\ndef single_step_v2(step, rng, image, rays, bds, opt_state):\n def loss_model(params):\n g = render_rays(rng_inputs, model, params,\n None, rays, bds[0], bds[1], \n N_samples, rand=True)\n return mse_fn(g, image)\n rng, rng_inputs = jax.random.split(rng)\n value, grads = jax.value_and_grad(loss_model)(get_params(opt_state))\n opt_state = opt_update(step, grads, opt_state)\n return value, opt_state","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:39:14.752065Z","iopub.execute_input":"2021-07-04T06:39:14.752535Z","iopub.status.idle":"2021-07-04T06:39:14.838259Z","shell.execute_reply.started":"2021-07-04T06:39:14.752490Z","shell.execute_reply":"2021-07-04T06:39:14.837384Z"},"trusted":true},"execution_count":100,"outputs":[]},{"cell_type":"code","source":"opt_state","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:39:18.950560Z","iopub.execute_input":"2021-07-04T06:39:18.951230Z","iopub.status.idle":"2021-07-04T06:39:18.998534Z","shell.execute_reply.started":"2021-07-04T06:39:18.951187Z","shell.execute_reply":"2021-07-04T06:39:18.997174Z"},"collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"execution_count":101,"outputs":[{"execution_count":101,"output_type":"execute_result","data":{"text/plain":"OptimizerState(packed_state=([DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32)], [DeviceArray([[ 0.12091248, 0.03934956, 0.09956338, -0.03747796,\n -0.04041335, 0.02077009],\n [ 0.05760278, -0.0443754 , 0.05104939, -0.14502399,\n 0.0776381 , 0.1672416 ],\n [ 0.11950634, -0.01816947, 0.01361674, 0.03426957,\n -0.04762232, -0.01502447],\n [-0.07393871, 0.04445767, 0.02080443, 0.00191198,\n 0.06556749, 0.04923652],\n [ 0.0131189 , 0.02381099, -0.01838387, 0.12599164,\n -0.06083291, -0.05262698],\n [-0.1097906 , 0.09892927, -0.13584226, 0.03626684,\n 0.02823985, -0.10594288],\n [ 0.0871003 , -0.09209569, -0.0871902 , 0.0108161 ,\n -0.05293119, -0.06841837],\n [ 0.04326692, -0.13521153, 0.20045498, -0.10092928,\n 0.16232044, -0.05342549],\n [ 0.0112839 , 0.11713059, 0.04803858, 0.06951734,\n 0.0195748 , 0.06725274],\n [ 0.03751812, -0.11262533, -0.01647667, -0.03335583,\n -0.13196915, 0.01376943],\n [-0.04743343, 0.06179082, 0.04837865, 0.07638828,\n -0.09461062, -0.19005549],\n [ 0.08255602, -0.13530254, 0.06071713, -0.04774788,\n -0.03129638, 0.0693047 ],\n [ 0.04571046, -0.03062774, 0.0852575 , -0.10863639,\n 0.04058903, 0.04844888],\n [-0.04248823, -0.05062192, -0.01092482, -0.02707409,\n -0.13995013, -0.10880554],\n [-0.11670938, 0.10568915, -0.16793151, -0.01668308,\n -0.06549685, 0.04782696],\n [ 0.06555521, 0.15952255, -0.02444034, -0.14221518,\n 0.00795548, 0.00058907],\n [ 0.06769255, -0.10001128, 0.01939742, -0.16975056,\n 0.08136977, 0.05386211],\n [-0.08528083, 0.13880436, -0.12550238, 0.07701986,\n -0.04409956, -0.13809827],\n [-0.04213446, -0.1537718 , 0.07943241, -0.17456312,\n 0.15739301, -0.02497932],\n [ 0.05776352, -0.06588574, 0.07887202, 0.11499051,\n -0.03748027, -0.01763058],\n [ 0.00971414, -0.02444758, -0.04719648, -0.04523701,\n 0.11060363, 0.07518162],\n [ 0.17120014, -0.02850895, 0.03745553, 0.01270946,\n -0.03092966, 0.04966141],\n [-0.0771074 , 0.00305039, -0.1426532 , -0.00744342,\n -0.08475614, 0.07027218],\n [-0.1148666 , 0.0939114 , -0.20275104, 0.05386702,\n -0.03729506, -0.09220184],\n [-0.08134414, -0.11703359, -0.0546924 , 0.07569566,\n -0.03793426, 0.14144461],\n [ 0.00797677, 0.09448453, 0.06531649, 0.00926861,\n 0.07197642, 0.047566 ],\n [ 0.01645362, -0.01201729, 0.11911698, -0.05343648,\n 0.0041722 , -0.04309576],\n [ 0.07480554, -0.15580009, 0.06402453, 0.14318122,\n 0.05626502, -0.05839388],\n [-0.07621136, 0.08332268, 0.10026689, -0.09124082,\n 0.0228574 , -0.02363835],\n [ 0.0955759 , -0.18413888, 0.0389773 , 0.04142505,\n 0.06638285, -0.0562977 ],\n [ 0.06386134, 0.02720288, 0.00071549, 0.11486595,\n 0.11699079, 0.0286552 ],\n [-0.08813303, -0.10803533, -0.05704414, 0.09067526,\n -0.05155005, -0.07570118],\n [-0.03661836, -0.01282759, -0.11591093, -0.15659498,\n 0.19731197, -0.08506151],\n [ 0.04415998, -0.16286144, -0.00652541, 0.1599725 ,\n 0.0077349 , 0.12481923],\n [ 0.15215434, -0.00402213, 0.01797182, 0.13577195,\n 0.07287586, 0.07491174],\n [ 0.14486 , -0.03073188, 0.0474685 , -0.08027478,\n 0.08148502, -0.15101376],\n [ 0.08231065, -0.04066302, 0.00606986, -0.05671879,\n -0.03834728, 0.03751288],\n [ 0.05405528, -0.12826225, 0.13363709, 0.06226205,\n 0.1394112 , -0.00298018],\n [-0.03758067, 0.17012277, 0.11220542, 0.09133592,\n -0.07800542, -0.06042796],\n [-0.05766695, 0.06635129, 0.01898533, 0.02079507,\n -0.04781189, 0.01584896],\n [ 0.12000843, 0.16353324, -0.09740052, 0.10700001,\n 0.10240921, 0.1396231 ],\n [ 0.13632071, -0.09265424, -0.03013684, 0.01697874,\n -0.0997249 , 0.04033921],\n [-0.06027675, -0.08162045, 0.0226552 , 0.03845311,\n 0.15729518, 0.0500581 ],\n [ 0.00548866, -0.10702777, 0.1363892 , 0.0008856 ,\n 0.04104868, 0.00985958],\n [-0.0602491 , -0.0130259 , -0.06303586, -0.15844673,\n 0.06022697, -0.16976918],\n [-0.03691809, -0.02184887, -0.00489147, -0.07071362,\n 0.01866116, 0.01260312],\n [ 0.09207677, -0.0602571 , 0.06722448, 0.03493843,\n 0.08584649, 0.04340586],\n [ 0.08255793, -0.07338887, 0.04268095, -0.04148448,\n 0.02685298, 0.05295678],\n [-0.05688781, 0.0053965 , 0.05855345, 0.05604774,\n 0.11289045, 0.0337536 ],\n [-0.03694104, -0.04049438, 0.11850937, 0.17993085,\n 0.03200574, 0.10255481],\n [-0.0245451 , -0.1777949 , 0.02939972, -0.09398863,\n 0.18782113, -0.1360264 ],\n [-0.15326034, -0.03229709, -0.0839939 , 0.09428229,\n 0.0673226 , -0.1048438 ],\n [ 0.03231962, -0.19225913, 0.05750696, -0.12309711,\n -0.0318327 , 0.05704029],\n [ 0.07545964, 0.17980802, 0.14941423, 0.03874794,\n 0.03335564, 0.13724373],\n [ 0.13273065, 0.06468265, 0.01514143, -0.03719363,\n 0.11870434, 0.02415203],\n [-0.10081078, -0.01774029, 0.04868399, 0.02396935,\n 0.0357847 , -0.07437636],\n [ 0.06230159, -0.08518022, -0.13052619, -0.0298177 ,\n -0.06633851, -0.15825692],\n [ 0.07715177, -0.05978196, 0.05070818, -0.11113419,\n 0.10514048, -0.15949102],\n [-0.06095733, -0.03677897, -0.00935357, -0.12114109,\n 0.05016354, 0.07169061],\n [-0.10700291, 0.00947172, -0.12596504, -0.11015648,\n -0.10353563, -0.03361167],\n [ 0.0635042 , -0.04348218, -0.10771246, 0.18733755,\n 0.17300211, -0.05960783],\n [ 0.09768728, 0.08765952, -0.03345449, -0.08652668,\n 0.06814633, -0.04087353],\n [-0.11300852, -0.02791403, -0.03409228, -0.12401181,\n 0.08971363, 0.0175717 ],\n [ 0.0301298 , -0.06322609, -0.04609533, 0.04999686,\n -0.06116783, -0.06802417],\n [-0.11089841, 0.07835439, 0.06409204, -0.11715482,\n -0.03738579, 0.20215245],\n [-0.14400601, -0.09568423, 0.00715016, 0.10031349,\n -0.11718752, -0.0624514 ],\n [ 0.02698691, -0.09550516, 0.0097266 , 0.04360985,\n 0.13089257, 0.07150275],\n [ 0.00952834, 0.17542061, 0.1686825 , -0.04919006,\n -0.2015965 , -0.13389407],\n [ 0.01955793, -0.04790407, 0.08299995, -0.04268588,\n 0.01052913, 0.02939646],\n [-0.0022939 , -0.17001595, 0.01649912, 0.00899479,\n -0.04769804, -0.03968188],\n [ 0.0478185 , 0.06605493, 0.02228497, -0.1544845 ,\n 0.1604688 , 0.0804143 ],\n [ 0.01103472, 0.04315194, -0.01691535, 0.07182854,\n -0.05721959, -0.00783077],\n [ 0.00145147, -0.05753785, 0.14071816, 0.01597903,\n 0.01853715, 0.18673742],\n [ 0.10080386, -0.03879032, -0.00257544, 0.00165138,\n 0.01999713, 0.07541166],\n [ 0.17055124, -0.08377179, -0.07121725, -0.1030369 ,\n 0.03812077, 0.08947916],\n [ 0.14384921, -0.00645927, -0.09719041, -0.17143993,\n 0.0341155 , -0.15642157],\n [-0.14672083, -0.18627107, 0.09514766, 0.1655882 ,\n 0.07048287, -0.04141303],\n [-0.11038305, 0.01266968, -0.13012522, 0.00432192,\n 0.00661376, -0.12586735],\n [ 0.04380185, -0.04043896, 0.04071361, -0.1095993 ,\n 0.02999796, -0.09467401],\n [-0.02747682, 0.01572298, 0.17412968, -0.01638444,\n -0.07010888, 0.09199664],\n [ 0.14056218, 0.03970063, 0.02137369, -0.00897393,\n -0.00536025, -0.00209912],\n [ 0.10748938, -0.05072176, -0.10428328, 0.01942103,\n 0.15795839, 0.14172362],\n [-0.19731991, -0.02555625, -0.15654613, 0.10174139,\n -0.10126485, 0.07406058],\n [ 0.0051182 , -0.04685609, -0.19060855, -0.00625047,\n -0.15289049, 0.18272123],\n [-0.19823857, -0.02669578, 0.0014769 , -0.01741426,\n 0.10098754, -0.10494607],\n [ 0.1671394 , 0.11931118, -0.0711092 , 0.06434302,\n 0.05375719, -0.11997687],\n [ 0.08855402, 0.10628866, -0.14936283, 0.118646 ,\n 0.07264227, -0.08480131],\n [ 0.02289091, -0.17247637, 0.03616032, -0.08103012,\n -0.11150467, 0.0383321 ],\n [ 0.01550049, -0.03067526, -0.0012579 , -0.09759074,\n 0.034106 , 0.06797335],\n [ 0.12155779, -0.14713879, 0.05044763, -0.15492183,\n 0.10530532, -0.10844996],\n [-0.0395255 , -0.00755156, -0.03988497, 0.0533135 ,\n -0.19197764, -0.14226356],\n [-0.01343509, 0.00661296, 0.16345409, -0.06798036,\n 0.05056882, -0.05704824],\n [ 0.11013287, -0.12514994, -0.02816967, 0.07051858,\n 0.0587835 , 0.09308016],\n [ 0.0209928 , 0.17353372, 0.16734958, 0.04956442,\n 0.00093992, 0.00340455],\n [-0.09287037, -0.12340804, 0.03708031, -0.04078699,\n -0.18511374, 0.13498192],\n [ 0.15583818, 0.20263459, -0.07443506, -0.11255585,\n -0.07247806, 0.0619547 ],\n [ 0.10758242, -0.00078523, -0.02034423, 0.09368771,\n -0.1336972 , -0.02547406],\n [-0.11606696, -0.0244782 , 0.09383436, -0.09345954,\n 0.00139141, 0.05733661],\n [-0.18299046, -0.01726742, 0.07100719, -0.04187343,\n 0.17501818, 0.08150984],\n [-0.06030202, -0.15978408, 0.00370273, 0.03474844,\n 0.02906018, -0.05044353],\n [-0.10054277, 0.01139852, 0.03567186, -0.08161601,\n 0.00288773, -0.20045884],\n [ 0.04335637, 0.00288314, -0.09950867, 0.01127051,\n 0.02589588, 0.05039444],\n [-0.20394403, 0.12326696, -0.1643659 , -0.0723019 ,\n -0.09912831, 0.04431969],\n [ 0.16497536, -0.16872397, -0.04751958, -0.11541699,\n 0.18788281, 0.08790423],\n [ 0.13208571, -0.08465189, -0.02066516, -0.10026854,\n -0.04670858, 0.03192634],\n [-0.00228765, -0.00040981, -0.02937566, -0.0937202 ,\n -0.194368 , 0.06886023],\n [-0.00191272, 0.03282534, 0.12399007, 0.05509276,\n 0.18744221, -0.00469013],\n [ 0.08384778, -0.04454415, 0.02690851, 0.03728205,\n -0.08604282, 0.02411885],\n [-0.15516514, 0.09445769, 0.00470266, -0.12434138,\n -0.04008782, -0.19726186],\n [-0.17874427, 0.0048454 , -0.00387846, -0.08311484,\n 0.05458146, 0.16586116],\n [ 0.08855603, -0.07168159, 0.02372633, -0.19512348,\n 0.02156699, 0.17569917],\n [-0.03421487, 0.01564328, -0.09506329, -0.06229905,\n 0.01248883, -0.11405812],\n [ 0.08093395, 0.10542554, -0.07732578, -0.04079597,\n -0.02625835, -0.07219328],\n [-0.12371734, -0.05892467, 0.1537514 , 0.02370239,\n -0.09075558, 0.01606907],\n [-0.00653294, -0.07476543, -0.01906146, -0.00472945,\n -0.09184022, 0.15842594],\n [-0.08626378, -0.1070536 , 0.03994581, 0.06633023,\n 0.05852774, -0.04346308],\n [-0.03864894, -0.157292 , 0.0475421 , 0.1494269 ,\n -0.18904321, 0.09033928],\n [-0.01302009, 0.01030955, -0.15401638, 0.07121224,\n -0.15460083, -0.06733674],\n [-0.12611501, 0.00686998, 0.05387065, -0.00674056,\n -0.00749963, -0.12718946],\n [ 0.03201933, 0.19906111, 0.20728587, 0.10627076,\n -0.14703809, -0.0194511 ]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.]], dtype=float32)], [DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32)], [DeviceArray([[ 0.19187102, -0.08696733, -0.2890502 , 0.625623 ,\n 0.5421405 , 0.4058794 ],\n [-0.32072243, 0.21094166, 0.12682122, -0.25806317,\n 0.7562236 , 0.26565063],\n [-0.31112608, 0.64321524, -0.6304671 , -0.2540195 ,\n -0.34511346, 0.8914479 ],\n [ 0.0170626 , 0.08604698, -0.3607038 , -0.46026543,\n 0.05988517, 0.16992189],\n [-0.02520008, -0.5196175 , 0.40215695, -0.7760595 ,\n -0.3977024 , 0.32397616],\n [-0.47431922, 0.28736246, -0.82317984, -0.62767595,\n -0.6443609 , -0.45300403]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.]], dtype=float32)], [DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32)], [DeviceArray([[ 7.09373713e-01, 5.58225453e-01, 1.05266884e-01,\n 8.47791284e-02, 4.02058363e-01, 1.03561655e-01],\n [ 1.63567051e-01, 3.24267261e-02, -6.43938184e-01,\n 8.08098242e-02, 6.05995297e-01, -6.22118413e-01],\n [-1.64262623e-01, -2.33843014e-01, -1.84899673e-01,\n -1.95914432e-01, 2.11128980e-01, 2.52425253e-01],\n [ 6.72683001e-01, -2.49966606e-01, -1.87426805e-01,\n 2.40193129e-01, -5.11264145e-01, -4.63818789e-01],\n [-2.92124599e-03, 4.82702762e-01, -6.04833551e-02,\n -8.24802756e-01, -3.40679497e-01, 1.23547213e-02],\n [ 1.46633625e-01, -7.92004820e-03, -2.78792083e-01,\n -3.65616666e-04, 8.00321922e-02, 6.98535085e-01]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.]], dtype=float32)], [DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32)], [DeviceArray([[-0.30406743, 0.42582163, -0.0433044 , 0.4696062 ,\n 0.162392 , -0.45115194],\n [-0.02224123, 0.26482254, -0.3821817 , 0.8167653 ,\n 0.0565933 , -0.27328548],\n [-0.3441138 , 0.3751544 , 0.05917086, 0.4884692 ,\n 0.04338139, 0.2134114 ],\n [-0.3870027 , -0.3817356 , 0.31533888, 0.10569409,\n -0.00611069, -0.19789281],\n [ 0.12731338, 0.3734065 , -0.10944682, 0.5042247 ,\n -0.26706368, -0.17124975],\n [ 0.35682514, 0.1869185 , 0.23269117, -0.8522264 ,\n -0.20571367, -0.33262372]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.]], dtype=float32)], [DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32)], [DeviceArray([[ 0.4786232 , 0.07740223, 0.23075195, 0.24710925,\n -0.00430446, 0.09901427],\n [-0.7429878 , 0.7190723 , 0.18975034, -0.26684535,\n 0.57330495, 0.10749137],\n [-0.5879746 , 0.22385383, 0.5853615 , 0.15528546,\n 0.563753 , -0.3713637 ],\n [-0.13496533, 0.30754456, -0.9095284 , -0.00883 ,\n -0.14046153, 0.64349794],\n [-0.47050482, -0.0958715 , -0.3475464 , 0.02285524,\n -0.23184529, -0.55695194],\n [ 0.46003628, -0.26375806, 0.1943689 , 0.04007643,\n -0.29916492, 0.16186029]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.]], dtype=float32), DeviceArray([[0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.],\n [0., 0., 0., 0., 0., 0.]], dtype=float32)], [DeviceArray([0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0.], dtype=float32)], [DeviceArray([[-0.02883793, -0.06974681, 0.24037988, -0.45696107],\n [-0.66506857, -0.35821897, -0.34971425, 0.35783553],\n [ 0.66095954, -0.03526216, 0.7372601 , -0.4655306 ],\n [-0.11639789, -0.82979643, -0.01575499, -0.20149845],\n [ 0.23644614, -0.3937142 , -0.61596936, -0.27844813],\n [ 0.9190065 , 0.37393552, 0.05566693, -0.7466602 ]], dtype=float32), DeviceArray([[0., 0., 0., 0.],\n [0., 0., 0., 0.],\n [0., 0., 0., 0.],\n [0., 0., 0., 0.],\n [0., 0., 0., 0.],\n [0., 0., 0., 0.]], dtype=float32), DeviceArray([[0., 0., 0., 0.],\n [0., 0., 0., 0.],\n [0., 0., 0., 0.],\n [0., 0., 0., 0.],\n [0., 0., 0., 0.],\n [0., 0., 0., 0.]], dtype=float32)]), tree_def=PyTreeDef(CustomNode([()], [{'params': {'fc0': {'bias': *, 'kernel': *}, 'fc1': {'bias': *, 'kernel': *}, 'fc2': {'bias': *, 'kernel': *}, 'fc3': {'bias': *, 'kernel': *}, 'fc4': {'bias': *, 'kernel': *}, 'fc_last': {'bias': *, 'kernel': *}}}])), subtree_defs=(PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *))))"},"metadata":{}}]},{"cell_type":"code","source":"rng = jax.random.PRNGKey(0)\n\nfor istep in range(num_steps):\n rng, rng_input = jax.random.split(rng)\n img_idx = random.randint(rng_input, shape=(), minval=0, maxval=len(imgfiles)-25)\n images, rays, bds = get_example(img_idx, downsample=1)\n images = np.reshape(images, (-1,3))\n rays = np.reshape(rays, (2,-1,3))\n rng, rng_input = random.split(rng)\n idx = random.randint(rng_input, shape=(batch_size,), minval=0, maxval=images.shape[0])\n loss, opt_state = single_step_v2(istep, rng, images[idx,:], rays[:,idx,:], bds, opt_state)","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:39:32.502213Z","iopub.execute_input":"2021-07-04T06:39:32.502642Z","iopub.status.idle":"2021-07-04T06:39:33.647639Z","shell.execute_reply.started":"2021-07-04T06:39:32.502606Z","shell.execute_reply":"2021-07-04T06:39:33.646620Z"},"trusted":true},"execution_count":102,"outputs":[]},{"cell_type":"code","source":"opt_state","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:39:35.648041Z","iopub.execute_input":"2021-07-04T06:39:35.648460Z","iopub.status.idle":"2021-07-04T06:39:35.708276Z","shell.execute_reply.started":"2021-07-04T06:39:35.648417Z","shell.execute_reply":"2021-07-04T06:39:35.706959Z"},"collapsed":true,"jupyter":{"outputs_hidden":true},"trusted":true},"execution_count":103,"outputs":[{"execution_count":103,"output_type":"execute_result","data":{"text/plain":"OptimizerState(packed_state=([DeviceArray([-0.00045145, -0.00143954, -0.00263059, -0.00230609,\n 0.00044354, 0.0007933 ], dtype=float32), DeviceArray([ 7.4164245e-07, 9.0824096e-06, 1.9418590e-06,\n 5.1692371e-07, 3.2076082e-06, -8.0520622e-06], dtype=float32), DeviceArray([3.7262568e-14, 6.9343763e-12, 2.2007938e-13, 4.5521316e-14,\n 4.6548199e-12, 9.6004038e-12], dtype=float32)], [DeviceArray([[ 1.19515203e-01, 3.79956812e-02, 1.00867495e-01,\n -3.65395024e-02, -3.95988673e-02, 2.29738709e-02],\n [ 5.77125698e-02, -4.45387252e-02, 4.90781777e-02,\n -1.46169439e-01, 7.75758103e-02, 1.68276310e-01],\n [ 1.20900065e-01, -1.95997935e-02, 1.10316165e-02,\n 3.20435129e-02, -4.71294001e-02, -1.43090524e-02],\n [-7.43934736e-02, 4.30180728e-02, 1.81723312e-02,\n -3.99712473e-04, 6.60072789e-02, 5.00307828e-02],\n [ 1.26672154e-02, 2.23715417e-02, -2.10146792e-02,\n 1.23676464e-01, -6.03894331e-02, -5.18337041e-02],\n [-1.10233277e-01, 9.74887088e-02, -1.38467059e-01,\n 3.39527391e-02, 2.86765136e-02, -1.05141982e-01],\n [ 8.55626091e-02, -9.34591144e-02, -8.58496949e-02,\n 1.18011115e-02, -5.21217994e-02, -6.61588833e-02],\n [ 4.34344709e-02, -1.35381013e-01, 1.98374435e-01,\n -1.02123044e-01, 1.62237614e-01, -5.23203351e-02],\n [ 1.26959765e-02, 1.15699135e-01, 4.54383790e-02,\n 6.72767386e-02, 2.00665258e-02, 6.79813847e-02],\n [ 3.70608568e-02, -1.14064977e-01, -1.91099793e-02,\n -3.56719680e-02, -1.31532311e-01, 1.45644369e-02],\n [-4.78853099e-02, 6.03514388e-02, 4.57476750e-02,\n 7.40659088e-02, -9.41671953e-02, -1.89262241e-01],\n [ 8.21234509e-02, -1.36744007e-01, 5.80971800e-02,\n -5.00687174e-02, -3.08657531e-02, 7.01127574e-02],\n [ 4.40107286e-02, -3.19981165e-02, 8.66279900e-02,\n -1.07607514e-01, 4.13880087e-02, 5.07526249e-02],\n [-4.22691666e-02, -5.07959910e-02, -1.31016057e-02,\n -2.83072051e-02, -1.40049130e-01, -1.07643224e-01],\n [-1.15292378e-01, 1.04256600e-01, -1.70544341e-01,\n -1.89380869e-02, -6.50074780e-02, 4.85670343e-02],\n [ 6.50937259e-02, 1.58082798e-01, -2.70758662e-02,\n -1.44539312e-01, 8.38710368e-03, 1.38538575e-03],\n [ 6.72403350e-02, -1.01450533e-01, 1.67661272e-02,\n -1.72085881e-01, 8.18130970e-02, 5.46553172e-02],\n [-8.56901556e-02, 1.37361005e-01, -1.28114462e-01,\n 7.46859163e-02, -4.36814055e-02, -1.37273654e-01],\n [-4.40381952e-02, -1.55146882e-01, 8.08292851e-02,\n -1.73489004e-01, 1.58172131e-01, -2.26417482e-02],\n [ 5.80274314e-02, -6.60630912e-02, 7.66124576e-02,\n 1.13725699e-01, -3.75916548e-02, -1.64222773e-02],\n [ 1.09880436e-02, -2.58814096e-02, -4.98207957e-02,\n -4.75093797e-02, 1.11088477e-01, 7.59330243e-02],\n [ 1.70731843e-01, -2.99488902e-02, 3.48158926e-02,\n 1.03710145e-02, -3.05071846e-02, 5.04600182e-02],\n [-7.75601938e-02, 1.61137013e-03, -1.45285055e-01,\n -9.80186835e-03, -8.43129978e-02, 7.10653290e-02],\n [-1.15224957e-01, 9.24633294e-02, -2.05374062e-01,\n 5.15054911e-02, -3.69068421e-02, -9.13216770e-02],\n [-8.35028067e-02, -1.18411213e-01, -5.32682054e-02,\n 7.68247694e-02, -3.71921435e-02, 1.43807575e-01],\n [ 8.27940181e-03, 9.43050608e-02, 6.29868209e-02,\n 7.97823165e-03, 7.18559921e-02, 4.88115028e-02],\n [ 1.66943446e-02, -1.34530757e-02, 1.16480984e-01,\n -5.57347909e-02, 4.64809174e-03, -4.23300602e-02],\n [ 7.43269697e-02, -1.57240435e-01, 6.13770597e-02,\n 1.40817121e-01, 5.66717610e-02, -5.75914197e-02],\n [-7.66651481e-02, 8.18840712e-02, 9.76340249e-02,\n -9.36404243e-02, 2.33002231e-02, -2.28453223e-02],\n [ 9.53125134e-02, -1.85607061e-01, 3.80081609e-02,\n 3.89923938e-02, 6.66604489e-02, -5.37296869e-02],\n [ 6.14767075e-02, 2.58252937e-02, 2.17701704e-03,\n 1.16079114e-01, 1.17664330e-01, 3.10384557e-02],\n [-8.77963677e-02, -1.08215772e-01, -5.94331436e-02,\n 8.93637240e-02, -5.16765937e-02, -7.44252503e-02],\n [-3.69373262e-02, -1.42672677e-02, -1.18558399e-01,\n -1.58940330e-01, 1.97768405e-01, -8.42699409e-02],\n [ 4.36676294e-02, -1.64302766e-01, -9.18847229e-03,\n 1.57562688e-01, 8.11561663e-03, 1.25627786e-01],\n [ 1.51698932e-01, -5.45999967e-03, 1.53371440e-02,\n 1.33300990e-01, 7.33180791e-02, 7.57046118e-02],\n [ 1.44726440e-01, -3.08332909e-02, 4.71221283e-02,\n -8.30720440e-02, 8.04841295e-02, -1.49666801e-01],\n [ 7.98062310e-02, -4.20358926e-02, 7.60007044e-03,\n -5.53373694e-02, -3.77982035e-02, 3.99156027e-02],\n [ 5.44234477e-02, -1.28442332e-01, 1.31196484e-01,\n 6.09315857e-02, 1.39281034e-01, -1.68151711e-03],\n [-3.80312428e-02, 1.68671548e-01, 1.09560236e-01,\n 8.88865143e-02, -7.76072890e-02, -5.95312268e-02],\n [-5.81748076e-02, 6.49072379e-02, 1.62887760e-02,\n 1.83060467e-02, -4.74709012e-02, 1.66655071e-02],\n [ 1.19550623e-01, 1.62096694e-01, -1.00038491e-01,\n 1.04416534e-01, 1.02850296e-01, 1.40415817e-01],\n [ 1.34725481e-01, -9.12372246e-02, -2.81765349e-02,\n 1.86341237e-02, -1.00317590e-01, 4.11429219e-02],\n [-6.28308132e-02, -8.29493925e-02, 2.43374854e-02,\n 4.02909331e-02, 1.57642350e-01, 5.24831302e-02],\n [ 5.88842249e-03, -1.07205682e-01, 1.33899868e-01,\n -4.64881101e-04, 4.09163609e-02, 1.11551182e-02],\n [-6.04856983e-02, -1.40968747e-02, -6.44786209e-02,\n -1.61199972e-01, 5.91667481e-02, -1.68288842e-01],\n [-3.74374464e-02, -2.33025309e-02, -7.66711216e-03,\n -7.33268559e-02, 1.89521369e-02, 1.34219714e-02],\n [ 9.16161388e-02, -6.16912395e-02, 6.45804331e-02,\n 3.22281457e-02, 8.62854049e-02, 4.41986509e-02],\n [ 8.09125751e-02, -7.19564557e-02, 4.51078340e-02,\n -3.91592197e-02, 2.63338201e-02, 5.22068962e-02],\n [-5.94698898e-02, 5.49785653e-03, 6.05785921e-02,\n 5.87637201e-02, 1.12998255e-01, 3.62044908e-02],\n [-3.65077220e-02, -4.06673774e-02, 1.15966111e-01,\n 1.78552181e-01, 3.18685621e-02, 1.03732467e-01],\n [-2.45138742e-02, -1.76385358e-01, 2.93305293e-02,\n -9.44927782e-02, 1.87303945e-01, -1.34810388e-01],\n [-1.53767079e-01, -3.38131189e-02, -8.66767094e-02,\n 9.15280879e-02, 6.75857961e-02, -1.04090497e-01],\n [ 3.18573974e-02, -1.93688884e-01, 5.48515618e-02,\n -1.25846893e-01, -3.13979946e-02, 5.78340515e-02],\n [ 7.54026771e-02, 1.81285068e-01, 1.52018696e-01,\n 4.14612181e-02, 3.37518901e-02, 1.34640053e-01],\n [ 1.30115807e-01, 6.48052245e-02, 1.76690277e-02,\n -3.51731479e-02, 1.17213652e-01, 2.66280770e-02],\n [-1.00345813e-01, -1.79041419e-02, 4.60862778e-02,\n 2.25334708e-02, 3.56246531e-02, -7.34259188e-02],\n [ 6.08137362e-02, -8.37168917e-02, -1.28189042e-01,\n -2.70780958e-02, -6.59908950e-02, -1.59563407e-01],\n [ 7.68346786e-02, -5.99950068e-02, 4.89747114e-02,\n -1.13964356e-01, 1.05598472e-01, -1.60014555e-01],\n [-6.14193045e-02, -3.82006206e-02, -1.20309927e-02,\n -1.23834461e-01, 5.05906045e-02, 7.24875405e-02],\n [-1.06973559e-01, 8.08707159e-03, -1.23331331e-01,\n -1.08063765e-01, -1.03185058e-01, -3.48316692e-02],\n [ 6.08417243e-02, -4.33301553e-02, -1.04966782e-01,\n 1.88854754e-01, 1.71531528e-01, -5.71143031e-02],\n [ 9.81569365e-02, 8.75111893e-02, -3.55093963e-02,\n -8.81588683e-02, 6.78883269e-02, -3.95975187e-02],\n [-1.14663109e-01, -2.92997491e-02, -3.15966122e-02,\n -1.22380674e-01, 8.99034813e-02, 1.63802188e-02],\n [ 3.16538401e-02, -6.33720905e-02, -4.77643684e-02,\n 4.71230410e-02, -5.95972277e-02, -6.85580820e-02],\n [-1.11370780e-01, 7.69485533e-02, 6.13692366e-02,\n -1.19794346e-01, -3.69678587e-02, 2.02951998e-01],\n [-1.42703339e-01, -9.72430259e-02, 5.04018134e-03,\n 9.75762010e-02, -1.17566228e-01, -6.01027608e-02],\n [ 2.52913795e-02, -9.40641239e-02, 1.24967648e-02,\n 4.50270660e-02, 1.32457763e-01, 7.11266324e-02],\n [ 9.89071187e-03, 1.75297096e-01, 1.67444736e-01,\n -5.20650372e-02, -2.02201605e-01, -1.32647291e-01],\n [ 1.95834171e-02, -4.79863398e-02, 8.09220150e-02,\n -4.42066044e-02, 1.03196939e-02, 3.05320267e-02],\n [-7.14446884e-04, -1.70157880e-01, 1.55897681e-02,\n 6.09404640e-03, -4.61501181e-02, -4.01810147e-02],\n [ 4.73080166e-02, 6.60517737e-02, 1.94629580e-02,\n -1.57093167e-01, 1.60925955e-01, 8.05986673e-02],\n [ 1.32242423e-02, 4.45444472e-02, -1.91424321e-02,\n 7.04322681e-02, -5.56334481e-02, -8.34930688e-03],\n [-1.23342010e-03, -5.61622642e-02, 1.43492743e-01,\n 1.76758990e-02, 2.00436749e-02, 1.86359212e-01],\n [ 1.00787982e-01, -3.74127626e-02, -3.40652955e-03,\n -9.75472911e-04, 1.91068314e-02, 7.66138658e-02],\n [ 1.72837481e-01, -8.51704031e-02, -7.22087920e-02,\n -1.04399674e-01, 3.81661765e-02, 9.03742909e-02],\n [ 1.43618628e-01, -6.64510718e-03, -9.69889387e-02,\n -1.72765687e-01, 3.55073884e-02, -1.57048166e-01],\n [-1.47305995e-01, -1.86104029e-01, 9.22193155e-02,\n 1.63363367e-01, 7.20543638e-02, -4.27316390e-02],\n [-1.12094469e-01, 1.28196077e-02, -1.28213912e-01,\n 6.21486828e-03, 5.30892797e-03, -1.25342995e-01],\n [ 4.10727039e-02, -3.89916673e-02, 4.34093885e-02,\n -1.07143447e-01, 3.15333344e-02, -9.51224566e-02],\n [-2.78361049e-02, 1.71400849e-02, 1.73911244e-01,\n -1.72006208e-02, -7.06825852e-02, 9.31587592e-02],\n [ 1.39339387e-01, 4.11312655e-02, 2.05742978e-02,\n -7.74497818e-03, -5.22259623e-03, -2.02417956e-03],\n [ 1.09130338e-01, -5.21309339e-02, -1.04567930e-01,\n 1.72220338e-02, 1.58100620e-01, 1.40555933e-01],\n [-1.99592412e-01, -2.40519606e-02, -1.57991603e-01,\n 1.02015443e-01, -9.98662040e-02, 7.14631677e-02],\n [ 5.20380307e-03, -4.82758358e-02, -1.92286909e-01,\n -7.96826743e-03, -1.54290631e-01, 1.85175255e-01],\n [-1.98595673e-01, -2.65713446e-02, 1.93740590e-03,\n -1.98984463e-02, 1.01096861e-01, -1.02575213e-01],\n [ 1.66928127e-01, 1.20740876e-01, -7.15681240e-02,\n 6.38891086e-02, 5.51854298e-02, -1.22506432e-01],\n [ 8.78814012e-02, 1.07822761e-01, -1.46974564e-01,\n 1.20047025e-01, 7.21886829e-02, -8.54675919e-02],\n [ 2.29258146e-02, -1.72665089e-01, 3.57917808e-02,\n -8.31400752e-02, -1.10140853e-01, 3.77841331e-02],\n [ 1.36882365e-02, -3.07335649e-02, -1.62586360e-03,\n -9.77346078e-02, 3.38865221e-02, 6.68704137e-02],\n [ 1.18828140e-01, -1.47041306e-01, 5.20589352e-02,\n -1.53026074e-01, 1.03887565e-01, -1.08588472e-01],\n [-3.90928835e-02, -6.18510228e-03, -4.05132100e-02,\n 5.17346971e-02, -1.92173362e-01, -1.43100709e-01],\n [-1.33428611e-02, 8.00306816e-03, 1.62952662e-01,\n -6.57213926e-02, 5.07434793e-02, -5.92899099e-02],\n [ 1.08613402e-01, -1.24999523e-01, -2.64579542e-02,\n 7.11960718e-02, 5.89826144e-02, 9.41445306e-02],\n [ 2.08997540e-02, 1.74974009e-01, 1.66671813e-01,\n 4.79338653e-02, 8.78979554e-05, 2.67938152e-03],\n [-9.45570320e-02, -1.23530865e-01, 3.68198715e-02,\n -3.99812609e-02, -1.85414851e-01, 1.36107326e-01],\n [ 1.53190598e-01, 2.04031289e-01, -7.17400387e-02,\n -1.09849349e-01, -7.10915029e-02, 6.11329414e-02],\n [ 1.07632719e-01, -6.84627448e-04, -2.09605265e-02,\n 9.12068710e-02, -1.35149017e-01, -2.50568949e-02],\n [-1.18208840e-01, -2.58779433e-02, 9.36405584e-02,\n -9.29577574e-02, 1.66403898e-03, 5.81755228e-02],\n [-1.82792187e-01, -1.71266459e-02, 7.08144158e-02,\n -4.03121598e-02, 1.73490494e-01, 8.20606723e-02],\n [-5.80182523e-02, -1.58356905e-01, 1.55453954e-03,\n 3.33175063e-02, 3.05566993e-02, -5.09770252e-02],\n [-9.79462788e-02, 1.12941926e-02, 3.29827517e-02,\n -8.34175646e-02, 1.58567377e-03, -1.99254751e-01],\n [ 4.25345935e-02, 2.99139717e-03, -9.77599397e-02,\n 1.25275636e-02, 2.49177702e-02, 5.07045388e-02],\n [-2.04405174e-01, 1.21796027e-01, -1.64488241e-01,\n -7.41155148e-02, -9.78912339e-02, 4.51627970e-02],\n [ 1.64618686e-01, -1.68567866e-01, -5.03356345e-02,\n -1.17939301e-01, 1.86417699e-01, 9.03297067e-02],\n [ 1.34851336e-01, -8.32303539e-02, -2.33603716e-02,\n -1.03158645e-01, -4.51769456e-02, 3.13593112e-02],\n [-4.59918985e-03, -2.72297795e-04, -2.66443919e-02,\n -9.28490087e-02, -1.95913151e-01, 7.13221431e-02],\n [ 8.59930180e-04, 3.29423733e-02, 1.22036733e-01,\n 5.30130863e-02, 1.85941681e-01, -4.31221491e-03],\n [ 8.20801780e-02, -4.31287326e-02, 2.93064527e-02,\n 3.96889783e-02, -8.63647982e-02, 2.53119059e-02],\n [-1.53039083e-01, 9.30199102e-02, 3.22056469e-03,\n -1.24914192e-01, -4.15421948e-02, -1.94892928e-01],\n [-1.81372344e-01, 4.68530506e-03, -1.14344596e-03,\n -8.26741755e-02, 5.44348992e-02, 1.65079758e-01],\n [ 9.12437066e-02, -7.02710375e-02, 2.39162557e-02,\n -1.94856316e-01, 2.31222976e-02, 1.75257191e-01],\n [-3.32147554e-02, 1.42533733e-02, -9.78560820e-02,\n -6.50117993e-02, 1.09827500e-02, -1.13609552e-01],\n [ 7.81201050e-02, 1.05575085e-01, -7.44353980e-02,\n -3.78257334e-02, -2.77477149e-02, -6.97395205e-02],\n [-1.26582652e-01, -6.02407455e-02, 1.56459987e-01,\n 2.62732077e-02, -9.10715163e-02, 1.86322927e-02],\n [-6.05976023e-03, -7.36052021e-02, -1.94375180e-02,\n -6.14037318e-03, -9.34375674e-02, 1.59062594e-01],\n [-8.38212222e-02, -1.07256345e-01, 3.92923839e-02,\n 6.44588396e-02, 5.82834482e-02, -4.43415083e-02],\n [-3.70768644e-02, -1.57141834e-01, 4.57319804e-02,\n 1.47603378e-01, -1.90470919e-01, 8.95815566e-02],\n [-1.41519820e-02, 1.02229277e-02, -1.52806088e-01,\n 7.36428723e-02, -1.53182924e-01, -6.77200109e-02],\n [-1.24036811e-01, 6.76118815e-03, 5.35392836e-02,\n -7.78794521e-03, -6.01940323e-03, -1.29643857e-01],\n [ 3.38975042e-02, 1.99190766e-01, 2.09763810e-01,\n 1.06645174e-01, -1.48502052e-01, -1.69419963e-02]], dtype=float32), DeviceArray([[-5.89450355e-09, 1.10020665e-06, 9.08776769e-08,\n -3.22272911e-08, -3.63153418e-08, -1.08581332e-06],\n [-1.78353936e-08, -2.19694940e-07, 7.57201164e-08,\n -9.52954338e-08, -9.58301172e-08, 4.68956898e-07],\n [ 1.39825218e-07, 3.54142844e-06, 6.78116862e-07,\n 1.57926394e-07, 1.00080001e-06, -3.30655712e-06],\n [ 7.58904832e-07, 9.00165196e-06, 1.91860772e-06,\n 5.19962953e-07, 3.23478207e-06, -8.01583883e-06],\n [ 7.41148881e-07, 9.07372851e-06, 1.93827600e-06,\n 5.23033350e-07, 3.20584559e-06, -8.05012587e-06],\n [ 7.31363230e-07, 8.35094306e-06, 1.82408917e-06,\n 4.92660149e-07, 3.01473915e-06, -7.34671039e-06],\n [-6.14870554e-09, 1.46902812e-06, 1.20125407e-07,\n -4.29506990e-08, -4.43112640e-08, -1.45433694e-06],\n [-2.39274023e-08, -2.94163669e-07, 1.01273521e-07,\n -1.27209304e-07, -1.28118330e-07, 6.27614895e-07],\n [ 1.87638690e-07, 4.63625565e-06, 8.93171091e-07,\n 2.08798184e-07, 1.31301329e-06, -4.32453771e-06],\n [ 7.72408725e-07, 8.93784818e-06, 1.90030983e-06,\n 5.22363166e-07, 3.25587530e-06, -7.98687142e-06],\n [ 7.40764676e-07, 9.06684909e-06, 1.93543951e-06,\n 5.27857878e-07, 3.20444451e-06, -8.04859064e-06],\n [ 7.22921300e-07, 7.78584490e-06, 1.73262708e-06,\n 4.73720689e-07, 2.86539102e-06, -6.80177573e-06],\n [-4.08872758e-09, 1.95721077e-06, 1.57139056e-07,\n -5.70308281e-08, -4.90771583e-08, -1.94843437e-06],\n [-3.21491669e-08, -3.93888541e-07, 1.35335583e-07,\n -1.69406022e-07, -1.71084295e-07, 8.39658753e-07],\n [ 2.52089052e-07, 5.95938036e-06, 1.16103763e-06,\n 2.73283263e-07, 1.69430473e-06, -5.54831695e-06],\n [ 7.96222594e-07, 8.82393761e-06, 1.86779926e-06,\n 5.26646488e-07, 3.29270620e-06, -7.93437903e-06],\n [ 7.40091991e-07, 9.05452816e-06, 1.93036249e-06,\n 5.36466530e-07, 3.20192976e-06, -8.04584488e-06],\n [ 7.07165782e-07, 6.80477569e-06, 1.57277213e-06,\n 4.40408513e-07, 2.60538445e-06, -5.85572616e-06],\n [ 4.29782787e-09, 2.59740773e-06, 2.01628808e-07,\n -7.52185372e-08, -4.15186321e-08, -2.61137825e-06],\n [-4.33095266e-08, -5.27453949e-07, 1.80578581e-07,\n -2.24631520e-07, -2.27975477e-07, 1.12264138e-06],\n [ 3.38904442e-07, 7.40378027e-06, 1.47343928e-06,\n 3.51184525e-07, 2.11997803e-06, -6.86727390e-06],\n [ 8.37665198e-07, 8.62122488e-06, 1.81044447e-06,\n 5.34261801e-07, 3.35561526e-06, -7.83850282e-06],\n [ 7.38935341e-07, 9.03247383e-06, 1.92128755e-06,\n 5.51770427e-07, 3.19740730e-06, -8.04093725e-06],\n [ 6.77055141e-07, 5.14594376e-06, 1.29904197e-06,\n 3.82699909e-07, 2.16362355e-06, -4.25635881e-06],\n [ 2.84734671e-08, 3.42271323e-06, 2.49386289e-07,\n -9.79840351e-08, 7.15942861e-10, -3.50134815e-06],\n [-5.85989852e-08, -7.06387311e-07, 2.40287903e-07,\n -2.95550933e-07, -3.02619014e-07, 1.49929872e-06],\n [ 4.54334412e-07, 8.61640001e-06, 1.78864775e-06,\n 4.36340798e-07, 2.50029757e-06, -7.92649553e-06],\n [ 9.08039112e-07, 8.26257929e-06, 1.71055194e-06,\n 5.47710044e-07, 3.45858484e-06, -7.66110679e-06],\n [ 7.37018695e-07, 8.99303996e-06, 1.90510173e-06,\n 5.78793617e-07, 3.18924913e-06, -8.03218063e-06],\n [ 6.17813100e-07, 2.47749927e-06, 8.47588126e-07,\n 2.85440819e-07, 1.44720775e-06, -1.68546239e-06],\n [ 8.95815475e-08, 4.45288197e-06, 2.86464058e-07,\n -1.24693514e-07, 1.28736474e-07, -4.69444240e-06],\n [-7.98249005e-08, -9.46196735e-07, 3.18169668e-07,\n -3.83382172e-07, -3.98885334e-07, 1.99820715e-06],\n [ 5.99228542e-07, 8.76413287e-06, 1.99479950e-06,\n 5.08822666e-07, 2.61243690e-06, -7.89131718e-06],\n [ 1.02213471e-06, 7.63469507e-06, 1.54058000e-06,\n 5.71160342e-07, 3.61285970e-06, -7.32635317e-06],\n [ 7.34072842e-07, 8.92267326e-06, 1.87635169e-06,\n 6.25928294e-07, 3.17445006e-06, -8.01660462e-06],\n [ 4.98578515e-07, -1.41110866e-06, 1.53603480e-07,\n 1.29574914e-07, 3.90075598e-07, 2.04918433e-06],\n [ 2.30933054e-07, 5.65936080e-06, 2.77691186e-07,\n -1.51568671e-07, 4.48849903e-07, -6.27888267e-06],\n [-1.09736106e-07, -1.26779742e-06, 4.17566611e-07,\n -4.84438999e-07, -5.18943466e-07, 2.65304629e-06],\n [ 7.44775662e-07, 6.37876201e-06, 1.86980697e-06,\n 5.22824337e-07, 2.04094908e-06, -5.28121518e-06],\n [ 1.19083279e-06, 6.55613849e-06, 1.26347391e-06,\n 6.11066071e-07, 3.79858466e-06, -6.67799895e-06],\n [ 7.30283148e-07, 8.79757408e-06, 1.82565918e-06,\n 7.06297158e-07, 3.14732688e-06, -7.98898145e-06],\n [ 2.62097444e-07, -5.95620531e-06, -7.75638114e-07,\n -9.77013954e-08, -8.62348543e-07, 6.34855905e-06],\n [ 5.26827762e-07, 6.88779892e-06, 1.49560947e-07,\n -1.67012857e-07, 1.14538557e-06, -8.30594763e-06],\n [-1.52088518e-07, -1.69945258e-06, 5.39248504e-07,\n -5.82383393e-07, -6.58500312e-07, 3.49752895e-06],\n [ 7.63828155e-07, 1.78920914e-07, 1.12526754e-06,\n 3.98805383e-07, 3.93012499e-07, 1.20588595e-06],\n [ 1.39363715e-06, 4.76657942e-06, 8.46491048e-07,\n 6.75670947e-07, 3.87500313e-06, -5.39017401e-06],\n [ 7.27829558e-07, 8.57662417e-06, 1.73745775e-06,\n 8.37548328e-07, 3.09655707e-06, -7.93998333e-06],\n [-1.52975304e-07, -8.48403215e-06, -1.69186501e-06,\n -3.71675071e-07, -1.51615791e-06, 8.40440225e-06],\n [ 1.05302070e-06, 7.71379564e-06, -2.14575977e-07,\n -1.42296983e-07, 2.40270538e-06, -1.05793461e-05],\n [-2.09548517e-07, -2.27917667e-06, 6.76103184e-07,\n -6.33503987e-07, -7.95059464e-07, 4.54841620e-06],\n [ 3.60403476e-07, -7.10164477e-06, -2.60149477e-07,\n 5.20091490e-08, -1.39066969e-06, 8.56870793e-06],\n [ 1.51450240e-06, 1.98259454e-06, 3.09881301e-07,\n 7.69029270e-07, 3.36692756e-06, -2.82522319e-06],\n [ 7.35002118e-07, 8.19092566e-06, 1.58768319e-06,\n 1.03413015e-06, 2.99751468e-06, -7.85149132e-06],\n [-6.16005138e-07, -3.96131054e-06, -1.96598421e-06,\n -5.70482428e-07, 9.77685346e-08, 2.75703906e-06],\n [ 1.71396732e-06, 7.26016560e-06, -8.91745856e-07,\n -1.98548022e-08, 3.90416744e-06, -1.20029299e-05],\n [-2.72548334e-07, -3.05649314e-06, 8.02113561e-07,\n -5.48290473e-07, -8.62919308e-07, 5.75925469e-06],\n [-4.82296855e-07, -4.37789458e-06, -1.52266500e-06,\n -4.78503694e-07, 1.08891720e-06, 4.00277941e-06],\n [ 1.29106547e-06, -1.83849784e-06, -1.84190526e-07,\n 8.65605216e-07, 1.19882077e-06, 1.95082043e-06],\n [ 7.73455099e-07, 7.53151107e-06, 1.34458696e-06,\n 1.27597616e-06, 2.78944913e-06, -7.67971960e-06],\n [-3.73475871e-07, 6.33434820e-06, -1.04433207e-06,\n -4.15061123e-07, 3.21438347e-06, -9.17595844e-06],\n [ 1.91155050e-06, 4.24808513e-06, -1.62878257e-06,\n 2.64808847e-07, 3.52671350e-06, -9.45029115e-06],\n [-2.84500231e-07, -4.08763117e-06, 8.54584982e-07,\n -2.06916638e-07, -7.20116418e-07, 6.92017920e-06],\n [-5.86249698e-08, 7.22480490e-06, -1.56320039e-06,\n -6.33804234e-07, 6.07791662e-06, -1.14542972e-05],\n [ 6.27224949e-07, -5.80912047e-06, -3.65499204e-07,\n 8.37135588e-07, -3.71978172e-06, 8.91575291e-06],\n [ 8.75814635e-07, 6.44476586e-06, 9.82818392e-07,\n 1.42759950e-06, 2.30631395e-06, -7.28690748e-06],\n [ 7.91641469e-07, 1.26343366e-06, 6.55591862e-07,\n 4.05362528e-07, -2.94489564e-06, -8.59181227e-07],\n [ 1.12571422e-06, -1.86561249e-06, -1.68733015e-06,\n 5.70262159e-07, -2.80565246e-06, 8.75313731e-07],\n [-9.61898081e-08, -5.40087012e-06, 7.23625078e-07,\n 3.54689064e-07, -2.08309700e-07, 7.55339488e-06],\n [ 1.02153695e-06, -5.78017762e-06, 4.76029385e-07,\n 6.66310996e-07, -8.07298056e-06, 9.26087887e-06],\n [ 5.83943233e-07, -7.10525410e-06, -4.68334918e-07,\n 3.76263955e-07, -8.86384896e-06, 1.29387236e-05],\n [ 1.01500041e-06, 4.76382820e-06, 5.33049501e-07,\n 1.14710679e-06, 1.11709198e-06, -6.21370691e-06],\n [-6.31725527e-07, -4.20461538e-06, 1.27212684e-06,\n 5.97546659e-07, -4.44876014e-06, 7.99392728e-06],\n [ 1.27447595e-06, -7.42738530e-06, -1.20815002e-06,\n 8.20677712e-08, -1.29561167e-05, 1.26983887e-05],\n [ 3.91787097e-07, -6.86304975e-06, 3.15919408e-07,\n 4.78149786e-07, 3.33411720e-07, 7.24351867e-06],\n [-7.31582759e-07, 7.05648517e-06, -5.99249006e-07,\n -3.65444635e-07, 6.58787803e-06, -1.00339112e-05],\n [ 1.59395915e-06, -1.62996571e-06, -1.54220902e-06,\n -4.67488718e-07, -5.08205630e-06, 3.86462261e-06],\n [ 9.31843601e-07, 2.42041165e-06, 1.82153542e-07,\n 2.29804400e-07, -1.47138360e-06, -3.32397576e-06],\n [-1.14857016e-07, 1.20156915e-06, -3.16910530e-07,\n -7.66403843e-07, 1.39314898e-06, -2.99687235e-06],\n [ 4.07920425e-07, -2.35097605e-06, -1.48380832e-06,\n -1.10702365e-06, -8.48389391e-06, 6.41825045e-06],\n [ 6.99555414e-07, -7.84252188e-06, -1.77876316e-07,\n -5.36324023e-07, -5.12962743e-07, 6.85968735e-06],\n [-4.43854134e-07, -1.19246545e-06, 1.71287593e-06,\n 6.58903559e-07, 3.85870317e-06, -1.14134968e-06],\n [ 7.25131372e-07, 6.89900753e-06, -2.64183041e-06,\n 2.39554709e-07, 9.28923691e-06, -1.29141199e-05],\n [ 3.82862510e-07, -4.08812895e-07, 2.96095322e-07,\n 5.14074827e-08, -4.79162736e-06, 2.19754088e-06],\n [ 1.79455199e-07, 7.84833173e-06, -1.68732555e-07,\n -1.58592755e-07, 1.22959445e-05, -1.30488133e-05],\n [ 4.53154144e-07, 6.44779493e-06, -1.15898138e-06,\n 5.08535436e-07, 1.44323803e-05, -1.60667241e-05],\n [ 2.64337359e-07, -6.91162859e-06, -3.40541533e-08,\n 2.95334388e-08, -1.31283150e-06, 6.42984196e-06],\n [ 1.02554668e-06, -1.51715199e-06, -9.67659730e-07,\n 9.35940420e-07, -2.22751964e-06, 1.58044168e-06],\n [ 2.94390293e-06, -1.08298093e-06, -1.40403404e-06,\n 8.85823624e-07, -4.19718754e-06, 3.36121474e-07],\n [-1.27575959e-08, -3.47877312e-06, 9.04482306e-07,\n 1.38244604e-06, -5.05446133e-06, 7.17476451e-06],\n [ 5.22484129e-07, 6.79360710e-06, 2.00455361e-07,\n 2.04385771e-08, 1.04309420e-05, -1.53854107e-05],\n [-2.32110438e-07, -6.25013035e-06, 4.16154535e-07,\n -4.29712628e-07, -8.74119905e-06, 1.33317380e-05],\n [-1.72272749e-07, -2.85427996e-06, 9.17918385e-07,\n -1.74608715e-07, 2.61466812e-06, 3.20356548e-07],\n [-3.98590316e-07, 5.72777572e-06, 5.16953378e-08,\n -2.30283320e-07, 9.72292401e-06, -1.01911537e-05],\n [ 1.38026326e-06, -1.69508246e-06, -1.50022231e-06,\n -3.24397718e-08, 2.64211110e-07, 1.73726607e-06],\n [-2.57760775e-07, -6.49922958e-06, 1.07226242e-06,\n -3.65794733e-07, -4.52577478e-06, 1.01539963e-05],\n [ 1.58273042e-06, -3.64774905e-06, -1.09204620e-06,\n -1.51300151e-06, -2.30509136e-06, 3.07344430e-06],\n [ 5.98161819e-07, 6.63248647e-06, 1.05101094e-07,\n 4.95582071e-07, 5.25333189e-06, -1.07178275e-05],\n [ 1.04858748e-06, 3.97416079e-06, 1.47306741e-06,\n 6.30566092e-07, 7.42636030e-06, -1.04624205e-05],\n [-1.46166485e-06, 4.88427168e-06, 1.82387862e-06,\n 4.49339410e-07, 3.97414988e-06, -7.57285170e-06],\n [-3.32953022e-07, -1.86433351e-06, 3.30018054e-07,\n -7.58851115e-07, -3.58492161e-06, 5.42662883e-06],\n [-1.41443184e-06, -5.73096713e-06, 7.34330115e-07,\n 1.69057145e-07, 9.80789991e-07, 7.30678676e-06],\n [ 3.14442957e-07, 4.99395901e-06, -1.61032949e-06,\n 4.36894254e-07, 2.91608603e-06, -8.05610216e-06],\n [ 1.56075430e-06, 3.49192919e-06, -3.83678184e-07,\n 7.36173774e-07, -1.48261358e-06, -2.87649664e-06],\n [ 1.34733887e-06, 6.41233601e-06, 1.71123509e-06,\n 6.50678146e-07, 1.06393973e-05, -1.43433608e-05],\n [-9.12618304e-07, -2.33978312e-06, 1.95254324e-06,\n 2.26706632e-07, -9.17776742e-06, 7.15697888e-06],\n [ 7.15795125e-07, 6.30058730e-06, -9.77086074e-07,\n -5.05545984e-07, 8.58714247e-06, -1.14812055e-05],\n [-6.94097253e-07, 2.25981785e-06, -3.21830100e-08,\n 6.29661429e-07, 1.17689888e-05, -8.83968369e-06],\n [ 4.45760975e-08, -7.20550042e-06, -7.32303988e-07,\n -7.33645322e-07, -5.21131278e-06, 9.82466827e-06],\n [-4.49712587e-07, 4.51746837e-06, 1.75300130e-07,\n -1.04324442e-06, 9.77069340e-06, -7.88865873e-06],\n [ 2.80753170e-06, -3.08079393e-06, -1.75416540e-06,\n 4.89795980e-07, -1.04929122e-05, 5.72613908e-06],\n [-8.24183246e-07, -7.84531130e-06, -3.16042104e-07,\n -5.25608527e-07, -1.09352886e-05, 1.42115850e-05],\n [ 4.62080095e-07, 5.67532106e-06, 1.53317001e-06,\n 4.47872253e-07, 1.42140643e-05, -1.33855892e-05],\n [ 1.61264234e-06, 5.61368461e-06, -2.16226613e-06,\n -8.19383104e-07, 1.04630763e-05, -1.41299424e-05],\n [ 5.60000785e-07, 1.69940745e-06, -2.29975785e-06,\n -7.62130583e-07, -1.57408442e-06, -6.34730554e-07],\n [ 1.29794250e-06, -2.95222549e-07, -1.44694434e-06,\n -2.92114066e-07, 2.25114718e-06, -1.65837469e-06],\n [-5.77794765e-07, -1.83486031e-06, 6.42491500e-08,\n -7.89675028e-08, -3.92448237e-06, 5.13863097e-06],\n [ 5.93625884e-07, 4.19205116e-06, 2.56499561e-07,\n -8.98320280e-08, 1.01652313e-05, -1.07823216e-05],\n [-3.87973358e-07, -7.36338461e-06, -2.81240659e-07,\n -1.06455991e-06, -9.47545050e-06, 1.28567899e-05],\n [-1.35564761e-07, -6.94711571e-06, -6.66803885e-07,\n 3.93332300e-07, -1.52133589e-05, 1.52696957e-05],\n [-2.00091137e-07, 5.19159494e-06, -2.73080673e-06,\n -1.00092529e-06, 3.84539180e-06, -8.02862451e-06]], dtype=float32), DeviceArray([[2.12274364e-17, 1.16329784e-13, 1.48670884e-15,\n 1.35739756e-16, 9.56699235e-15, 1.11627599e-13],\n [6.40458023e-16, 6.39256697e-15, 4.84248471e-16,\n 1.84344291e-15, 1.08346568e-15, 2.29769630e-14],\n [3.40560926e-15, 1.07508685e-12, 2.48534239e-14,\n 4.89585848e-15, 6.50937041e-13, 1.80882588e-12],\n [3.88176023e-14, 6.81023814e-12, 2.14583681e-13,\n 4.50841011e-14, 4.61792323e-12, 9.48895563e-12],\n [3.71696048e-14, 6.92284988e-12, 2.19224080e-13,\n 4.49529228e-14, 4.64760695e-12, 9.59453610e-12],\n [3.79026780e-14, 5.84377295e-12, 1.99977784e-13,\n 4.02084236e-14, 3.93055640e-12, 7.83793516e-12],\n [3.70780632e-17, 2.07499789e-13, 2.63444463e-15,\n 2.40250016e-16, 1.69488328e-14, 1.99782234e-13],\n [1.14560846e-15, 1.14565738e-14, 8.66404591e-16,\n 3.28540384e-15, 1.93570364e-15, 4.11436090e-14],\n [5.81243199e-15, 1.84130163e-12, 4.30244998e-14,\n 8.45466190e-15, 1.10854143e-12, 3.08115556e-12],\n [4.00559238e-14, 6.71291451e-12, 2.10304511e-13,\n 4.47433872e-14, 4.58902447e-12, 9.40107887e-12],\n [3.70970378e-14, 6.91372003e-12, 2.18548175e-13,\n 4.45151864e-14, 4.64187759e-12, 9.58986709e-12],\n [3.86563679e-14, 5.06533530e-12, 1.85505406e-13,\n 3.63224499e-14, 3.41559268e-12, 6.60209344e-12],\n [6.49781208e-17, 3.68652727e-13, 4.62562956e-15,\n 4.20889163e-16, 2.97549122e-14, 3.57081498e-13],\n [2.04563115e-15, 2.05276673e-14, 1.54775556e-15,\n 5.82812158e-15, 3.44899576e-15, 7.36078543e-14],\n [9.54868627e-15, 3.03839267e-12, 7.24336498e-14,\n 1.41671249e-14, 1.81037043e-12, 5.03313970e-12],\n [4.22862564e-14, 6.54081607e-12, 2.02801507e-13,\n 4.41454497e-14, 4.53797113e-12, 9.24457212e-12],\n [3.69694374e-14, 6.89738240e-12, 2.17341038e-13,\n 4.37585115e-14, 4.63159068e-12, 9.58147450e-12],\n [4.04843429e-14, 3.84650558e-12, 1.62578257e-13,\n 3.00309501e-14, 2.61262019e-12, 4.71493634e-12],\n [1.20754685e-16, 6.50295464e-13, 7.98987799e-15,\n 7.23699974e-16, 5.14270294e-14, 6.36687589e-13],\n [3.64142302e-15, 3.67666814e-14, 2.75727436e-15,\n 1.02526325e-14, 6.11574517e-15, 1.31477484e-13],\n [1.46883712e-14, 4.67848763e-12, 1.15957062e-13,\n 2.24656465e-14, 2.73455130e-12, 7.60119051e-12],\n [4.63087718e-14, 6.23982116e-12, 1.89877451e-13,\n 4.31139804e-14, 4.44875777e-12, 8.96729305e-12],\n [3.67484838e-14, 6.86818267e-12, 2.15191919e-13,\n 4.24904049e-14, 4.61308899e-12, 9.56633904e-12],\n [4.49987032e-14, 2.16799517e-12, 1.30208127e-13,\n 2.07182362e-14, 1.51246463e-12, 2.26127819e-12],\n [2.78306360e-16, 1.13239799e-12, 1.34045929e-14,\n 1.20215998e-15, 8.66234303e-14, 1.13014014e-12],\n [6.44631502e-15, 6.58058813e-14, 4.88761855e-15,\n 1.77658300e-14, 1.07509504e-14, 2.34171406e-13],\n [2.03687097e-14, 6.30603642e-12, 1.69426539e-13,\n 3.21455373e-14, 3.55296543e-12, 9.84842329e-12],\n [5.35468310e-14, 5.72386843e-12, 1.68327483e-13,\n 4.13906614e-14, 4.29568187e-12, 8.48082889e-12],\n [3.63768430e-14, 6.81610931e-12, 2.11386084e-13,\n 4.04917663e-14, 4.57973243e-12, 9.53891566e-12],\n [5.56252838e-14, 4.73318704e-13, 9.52897361e-14,\n 9.29542342e-15, 3.95390941e-13, 2.13340549e-13],\n [9.35515915e-16, 1.92628942e-12, 2.13631145e-14,\n 1.87049989e-15, 1.40416541e-13, 1.98888128e-12],\n [1.13000971e-14, 1.17632683e-13, 8.58685664e-15,\n 2.99529146e-14, 1.86081517e-14, 4.14942034e-13],\n [2.54627692e-14, 6.45356294e-12, 2.09706479e-13,\n 3.77183670e-14, 3.37198347e-12, 9.18802274e-12],\n [6.63559106e-14, 4.87193662e-12, 1.34495591e-13,\n 3.86816332e-14, 4.04036960e-12, 7.64292189e-12],\n [3.57858613e-14, 6.72362123e-12, 2.04711804e-13,\n 3.77390651e-14, 4.51934757e-12, 9.48883940e-12],\n [7.67252134e-14, 2.10885983e-13, 8.39426320e-14,\n 8.32536031e-16, 9.29799586e-14, 1.21318949e-12],\n [4.09693785e-15, 3.13923020e-12, 3.11801129e-14,\n 2.55459483e-15, 2.18342678e-13, 3.44059503e-12],\n [1.94676869e-14, 2.09800981e-13, 1.48453502e-14,\n 4.80272831e-14, 3.13235528e-14, 7.28540546e-13],\n [3.77764057e-14, 3.30581699e-12, 1.94231187e-13,\n 2.93069334e-14, 1.42408145e-12, 3.41124416e-12],\n [8.76719080e-14, 3.56295679e-12, 8.71891670e-14,\n 3.49178287e-14, 3.63008733e-12, 6.25143600e-12],\n [3.49498804e-14, 6.56057370e-12, 1.93211332e-13,\n 3.51876494e-14, 4.40939733e-12, 9.39621471e-12],\n [1.02543537e-13, 3.16919191e-12, 1.39068810e-13,\n 4.77598257e-15, 1.33027931e-12, 7.54485623e-12],\n [1.81420312e-14, 4.72240042e-12, 3.97291317e-14,\n 2.61201519e-15, 3.41889630e-13, 5.73694521e-12],\n [3.25354367e-14, 3.72662648e-13, 2.49319388e-14,\n 7.01126033e-14, 5.01605729e-14, 1.25834445e-12],\n [9.15384169e-14, 1.18718401e-14, 1.31269900e-13,\n 8.53779903e-15, 6.22487291e-14, 7.74764096e-13],\n [1.16988599e-13, 1.83090609e-12, 3.52818970e-14,\n 3.10080466e-14, 2.99111929e-12, 4.11762421e-12],\n [3.40759864e-14, 6.27691692e-12, 1.74014394e-13,\n 3.67971407e-14, 4.20774656e-12, 9.22151851e-12],\n [9.36341662e-14, 6.17177056e-12, 2.42255733e-13,\n 1.96068307e-14, 1.85940824e-12, 1.10231806e-11],\n [6.66848982e-14, 6.08192923e-12, 4.49372459e-14,\n 1.16560627e-15, 6.08672292e-13, 8.77273774e-12],\n [5.15875049e-14, 6.57081757e-13, 3.97287658e-14,\n 8.53035768e-14, 7.34619342e-14, 2.11091864e-12],\n [1.77921778e-13, 4.60536470e-12, 1.63127460e-13,\n 4.93569105e-15, 1.53785545e-12, 1.22531889e-11],\n [1.38032950e-13, 2.67059243e-13, 6.61249526e-15,\n 2.99198566e-14, 2.00648786e-12, 1.46605915e-12],\n [3.40820376e-14, 5.79501724e-12, 1.43775454e-13,\n 5.25594108e-14, 3.83661939e-12, 8.88347902e-12],\n [3.49875056e-14, 1.18555860e-12, 2.20762848e-13,\n 1.80839671e-14, 3.41978647e-15, 6.57213704e-13],\n [1.63876562e-13, 5.63548036e-12, 6.40144201e-14,\n 1.46602970e-15, 1.11053614e-12, 1.07020252e-11],\n [7.46754411e-14, 1.14307695e-12, 5.75161053e-14,\n 7.05393589e-14, 9.19605036e-14, 3.36333457e-12],\n [1.23368235e-13, 1.50348213e-12, 2.65094641e-13,\n 1.14635084e-14, 3.09794779e-13, 1.62587337e-12],\n [1.22416630e-13, 4.93287648e-13, 3.06068816e-14,\n 3.44190889e-14, 7.19466099e-13, 2.48978681e-13],\n [3.73807871e-14, 5.01048465e-12, 1.00987252e-13,\n 9.43880933e-14, 3.16495354e-12, 8.21338744e-12],\n [6.02890472e-14, 3.88085484e-12, 7.06748231e-14,\n 1.47200977e-14, 2.72828288e-12, 1.24074327e-11],\n [1.90752335e-13, 2.11876241e-12, 1.29959222e-13,\n 2.35470077e-14, 9.55504664e-13, 6.53365122e-12],\n [9.20150864e-14, 1.93947137e-12, 6.97596007e-14,\n 2.04688646e-14, 9.09242300e-14, 4.90398824e-12],\n [2.05808169e-15, 5.04891658e-12, 1.81128468e-13,\n 4.12131673e-14, 4.92340066e-12, 1.45131598e-11],\n [1.14951370e-13, 3.87743960e-12, 6.47504308e-14,\n 3.23113051e-14, 8.24780348e-13, 5.31654278e-12],\n [4.77995227e-14, 3.82752250e-12, 5.15998987e-14,\n 1.38433712e-13, 2.04503905e-12, 6.88613749e-12],\n [1.20617207e-13, 9.71850096e-14, 1.23844877e-13,\n 8.59225394e-15, 2.65607089e-12, 7.33658333e-14],\n [1.25853551e-13, 2.87566656e-13, 1.53882739e-13,\n 7.24696250e-14, 4.75091646e-13, 5.21490674e-14],\n [8.54805728e-14, 3.13701647e-12, 6.11158734e-14,\n 6.12061180e-15, 9.36028327e-14, 6.21522712e-12],\n [9.97266302e-14, 3.38041726e-12, 8.78909846e-14,\n 4.59185607e-14, 9.32172829e-12, 9.51437887e-12],\n [1.11211326e-13, 5.71903028e-12, 3.75462838e-14,\n 6.80075215e-15, 4.95495355e-12, 1.29487220e-11],\n [6.28567565e-14, 2.27643035e-12, 1.36256668e-14,\n 8.73037604e-14, 6.21325582e-13, 4.45343328e-12],\n [8.68259390e-14, 1.67289891e-12, 2.82116669e-13,\n 4.10208773e-14, 1.13631045e-12, 4.27477975e-12],\n [1.16415571e-13, 5.41346828e-12, 7.09196360e-14,\n 1.98493260e-14, 1.17692210e-11, 1.43913518e-11],\n [5.79508026e-14, 4.60745981e-12, 3.32559569e-14,\n 1.47494152e-14, 2.09347107e-13, 6.65009194e-12],\n [5.57104818e-14, 4.67987671e-12, 5.37830922e-14,\n 1.96269104e-14, 4.35868095e-12, 1.08111844e-11],\n [1.90275326e-13, 3.52346949e-13, 1.84113873e-13,\n 4.51164205e-14, 2.45487159e-12, 1.54192078e-12],\n [5.02207698e-14, 7.26078973e-13, 1.37768617e-15,\n 6.04221212e-15, 1.28035766e-13, 1.15397243e-12],\n [1.14944539e-13, 1.71613063e-13, 1.97055154e-13,\n 5.51698817e-14, 2.50777264e-13, 5.95267379e-13],\n [9.44072599e-15, 4.50501344e-13, 1.11838151e-13,\n 7.85740288e-14, 4.68349881e-12, 3.17486879e-12],\n [3.92286166e-14, 5.50404297e-12, 2.46053618e-14,\n 3.53296697e-14, 3.73009322e-13, 6.25861082e-12],\n [2.28043174e-14, 1.19612599e-13, 2.32938397e-13,\n 4.81390271e-14, 1.67360819e-12, 1.21460079e-13],\n [1.97282864e-13, 4.34959620e-12, 7.03280153e-13,\n 5.91790072e-15, 9.76983945e-12, 1.74895289e-11],\n [1.88748299e-14, 1.15825657e-14, 4.84613722e-15,\n 3.96470373e-15, 2.14706681e-12, 3.81047705e-13],\n [1.86247908e-14, 5.47145575e-12, 3.08681981e-14,\n 1.95181209e-14, 1.42232745e-11, 1.55279123e-11],\n [1.47319358e-14, 4.47306948e-12, 8.58855630e-14,\n 1.98945576e-14, 2.14078043e-11, 2.42164240e-11],\n [6.41345650e-15, 4.12341949e-12, 1.32246051e-14,\n 7.68807083e-15, 1.46188305e-13, 3.94888679e-12],\n [2.92089188e-13, 1.49109485e-13, 8.26589231e-14,\n 1.19922667e-13, 2.40425475e-12, 5.56159827e-13],\n [7.85756659e-13, 1.58019810e-13, 2.32656478e-13,\n 6.38192976e-14, 1.73062985e-12, 9.43298581e-14],\n [6.03300233e-14, 1.22240499e-12, 5.52497704e-14,\n 1.59290726e-13, 3.79895160e-12, 5.32288666e-12],\n [1.43017024e-14, 4.70550465e-12, 6.48424728e-14,\n 2.99140697e-15, 9.73449793e-12, 2.35858295e-11],\n [3.81577582e-13, 3.85160046e-12, 2.91226461e-13,\n 5.28811681e-14, 1.03909190e-11, 2.15018003e-11],\n [2.76621267e-15, 7.68774096e-13, 5.11400614e-14,\n 2.16139451e-15, 8.68445290e-13, 2.29089564e-14],\n [2.64575335e-13, 3.82360289e-12, 2.98081546e-13,\n 4.35312000e-15, 1.30124167e-11, 1.05589808e-11],\n [1.56517445e-13, 2.37491694e-13, 2.87678898e-13,\n 9.62362073e-15, 2.40470458e-13, 6.35600784e-13],\n [5.63482773e-14, 4.52433013e-12, 8.64711270e-14,\n 4.89750215e-14, 4.26681867e-12, 1.21432421e-11],\n [1.50155035e-13, 1.24448694e-12, 6.17540619e-14,\n 1.69198436e-13, 5.02130401e-13, 1.41153918e-12],\n [5.85862739e-14, 4.50152329e-12, 1.89461836e-13,\n 1.48181367e-14, 2.23643239e-12, 9.65132141e-12],\n [9.49037669e-14, 1.46891234e-12, 1.66119072e-13,\n 4.53771745e-14, 1.01376850e-11, 1.27502496e-11],\n [1.80288076e-13, 2.70028552e-12, 2.77970220e-13,\n 8.52179926e-14, 1.04769871e-12, 3.74790823e-12],\n [1.92593015e-14, 2.97783635e-13, 3.80525079e-14,\n 9.18365793e-14, 9.21701517e-13, 2.00620987e-12],\n [1.35302888e-13, 3.38529552e-12, 3.35281458e-14,\n 2.97686024e-15, 1.25292341e-13, 6.21967538e-12],\n [1.60756960e-14, 2.59312580e-12, 1.69482890e-13,\n 2.38856837e-14, 8.49733641e-13, 6.02466211e-12],\n [1.51545362e-13, 9.35817071e-13, 1.04080278e-13,\n 5.19388237e-14, 3.67659597e-13, 1.10649088e-12],\n [1.22576903e-13, 4.87884429e-12, 1.61775365e-13,\n 2.85257183e-14, 8.86557147e-12, 1.69115971e-11],\n [4.03154615e-14, 4.77810771e-13, 3.02998349e-13,\n 2.23939799e-15, 5.53366988e-12, 3.11412354e-12],\n [6.89227524e-14, 4.43917516e-12, 6.79811838e-14,\n 3.84796226e-14, 4.68237515e-12, 1.07565293e-11],\n [2.56983596e-14, 5.48996178e-13, 3.19121221e-14,\n 3.39858045e-14, 9.86410259e-12, 6.85449180e-12],\n [7.42128662e-14, 4.66550193e-12, 5.68850726e-14,\n 3.48152463e-14, 6.22827484e-12, 1.31156874e-11],\n [1.60594719e-14, 1.71755524e-12, 1.06847158e-14,\n 1.22288952e-13, 7.70055200e-12, 5.33613865e-12],\n [4.63322523e-13, 1.14702551e-12, 1.57314564e-13,\n 2.66102041e-14, 1.26338957e-11, 4.69107826e-12],\n [4.32601410e-14, 5.60937147e-12, 1.38292283e-14,\n 2.03118992e-14, 7.40712960e-12, 1.63549868e-11],\n [2.93910370e-14, 3.07944230e-12, 1.11994588e-13,\n 1.30717258e-14, 1.41837774e-11, 1.42507377e-11],\n [1.22383210e-13, 3.66307982e-12, 2.05759360e-13,\n 2.79863480e-14, 8.02877122e-12, 1.67750067e-11],\n [1.39484880e-14, 2.84715610e-13, 3.10231766e-13,\n 4.56258803e-14, 5.54326496e-13, 2.24314382e-14],\n [1.79185510e-13, 2.01523046e-14, 2.22208496e-13,\n 2.34033899e-14, 2.82779713e-13, 1.60465798e-13],\n [2.10014907e-14, 4.73251429e-13, 1.02406812e-13,\n 2.46776781e-14, 2.52360285e-12, 2.99892576e-12],\n [1.51100188e-13, 2.05120122e-12, 8.49248616e-15,\n 1.26177179e-14, 8.98622583e-12, 1.23984642e-11],\n [3.71367366e-14, 5.48269026e-12, 4.56275913e-14,\n 8.22980464e-14, 8.02551775e-12, 1.46907157e-11],\n [3.55192357e-14, 4.99774224e-12, 5.69814039e-14,\n 3.79094881e-14, 1.74064999e-11, 1.94485643e-11],\n [4.41634644e-14, 2.94970342e-12, 4.22736390e-13,\n 7.28969565e-14, 1.15712117e-12, 4.70383498e-12]], dtype=float32)], [DeviceArray([ 3.0663330e-05, -6.0212583e-04, 2.3759302e-04,\n 2.5868230e-03, -2.6333719e-03, -6.5538799e-04], dtype=float32), DeviceArray([ 1.8194061e-06, 3.8155737e-05, 7.0559526e-07,\n -8.0754571e-06, 7.7572095e-06, 2.2165239e-05], dtype=float32), DeviceArray([1.7689330e-12, 8.7941432e-11, 8.0609998e-14, 4.3580421e-12,\n 3.2801665e-12, 8.3839498e-11], dtype=float32)], [DeviceArray([[ 0.19147131, -0.08741151, -0.2899834 , 0.62835646,\n 0.53938615, 0.40327066],\n [-0.32089502, 0.20948973, 0.12690711, -0.25749823,\n 0.7558288 , 0.26582175],\n [-0.31148365, 0.6439382 , -0.63071334, -0.25127968,\n -0.34781596, 0.8903316 ],\n [ 0.01745997, 0.08365437, -0.36128715, -0.45783097,\n 0.05735766, 0.17224273],\n [-0.02606061, -0.5212192 , 0.4008201 , -0.77488416,\n -0.3991631 , 0.32255793],\n [-0.47373995, 0.28477976, -0.8236459 , -0.6276811 ,\n -0.6463785 , -0.45341846]], dtype=float32), DeviceArray([[ 2.0765101e-06, 5.2552758e-07, 1.5801023e-08,\n -2.0083421e-06, 1.3381332e-06, 1.4539712e-06],\n [-8.0421991e-10, 9.4330670e-07, 8.0152070e-08,\n -5.2841326e-08, -2.0777220e-08, 3.6099043e-06],\n [ 1.0393588e-06, -3.8457793e-07, 1.0858516e-09,\n -1.2779847e-06, 4.7307202e-07, 4.9989808e-07],\n [-1.4016753e-06, 6.8030022e-06, 1.8208079e-08,\n -1.4558456e-08, 1.2815885e-06, -5.2809555e-06],\n [ 3.7267771e-09, 8.3228333e-06, 4.8335327e-08,\n -1.2781840e-07, 4.3923883e-08, 4.3804215e-05],\n [-1.6840630e-08, 1.5559388e-05, 4.6703321e-09,\n -7.2126383e-08, 7.0047875e-08, 3.1726337e-05]], dtype=float32), DeviceArray([[3.1861910e-13, 1.9180499e-14, 2.2070947e-17, 2.0498165e-13,\n 8.7324645e-14, 1.1507772e-13],\n [3.9682659e-19, 7.1927083e-14, 7.9724488e-16, 2.8024347e-16,\n 1.8258127e-16, 1.6315354e-12],\n [8.7954194e-14, 4.7628832e-14, 1.1790737e-19, 8.0634026e-14,\n 1.1355760e-14, 3.2381348e-14],\n [1.5629323e-12, 2.7163892e-12, 3.3153412e-17, 8.8922520e-18,\n 9.2106539e-14, 1.9771992e-12],\n [7.5559045e-18, 3.8389882e-12, 1.7638762e-16, 1.5556227e-15,\n 1.1637820e-16, 1.7210415e-10],\n [2.8360685e-17, 1.3631652e-11, 2.1812002e-18, 5.9393463e-16,\n 4.5445975e-16, 9.1455538e-11]], dtype=float32)], [DeviceArray([-0.00038439, -0.00253191, 0. , -0.00024855,\n -0.00149335, -0.0005279 ], dtype=float32), DeviceArray([ 9.0186906e-05, 2.0501162e-05, 0.0000000e+00,\n -1.5605739e-05, 7.8567376e-05, 1.0919726e-05], dtype=float32), DeviceArray([7.7366902e-10, 2.8706129e-11, 0.0000000e+00, 4.0496065e-11,\n 4.4863854e-10, 5.6050446e-11], dtype=float32)], [DeviceArray([[ 7.09596097e-01, 5.57229042e-01, 1.05266884e-01,\n 8.37738737e-02, 4.01628464e-01, 1.04380406e-01],\n [ 1.63253590e-01, 2.98599415e-02, -6.43938184e-01,\n 8.21307227e-02, 6.04431272e-01, -6.21634185e-01],\n [-1.62911832e-01, -2.34109700e-01, -1.84899673e-01,\n -1.95914432e-01, 2.12329701e-01, 2.51049638e-01],\n [ 6.75464392e-01, -2.52610683e-01, -1.87426805e-01,\n 2.40000457e-01, -5.11981368e-01, -4.66484815e-01],\n [-3.82236729e-04, 4.80104297e-01, -6.04833551e-02,\n -8.24540913e-01, -3.41470331e-01, 1.19573558e-02],\n [ 1.48141757e-01, -1.05470764e-02, -2.78792083e-01,\n 1.00884447e-03, 7.85028264e-02, 6.96714282e-01]], dtype=float32), DeviceArray([[-3.7734679e-08, 2.0914214e-07, 0.0000000e+00,\n 4.4612163e-09, -1.8412505e-08, -1.6384351e-07],\n [ 2.5431862e-07, 5.8102722e-07, 0.0000000e+00,\n -3.9158030e-07, 6.5379754e-07, -3.2795475e-07],\n [-2.0847846e-08, -5.5632698e-09, 0.0000000e+00,\n 0.0000000e+00, -7.4284507e-09, 4.5505992e-08],\n [-4.6340386e-07, 1.2311507e-06, 0.0000000e+00,\n -8.9473851e-09, -4.2540904e-10, 3.4614908e-07],\n [-4.7590757e-07, 1.1430946e-06, 0.0000000e+00,\n -1.0665863e-09, 1.6385513e-08, -2.5949966e-08],\n [ 6.1280770e-07, 2.1922931e-06, 0.0000000e+00,\n -1.4440599e-07, 1.7412744e-06, 9.7419240e-07]], dtype=float32), DeviceArray([[4.89823272e-16, 2.44132429e-15, 0.00000000e+00,\n 2.45463613e-18, 1.79699840e-16, 1.53430765e-15],\n [4.67324764e-15, 2.39591468e-14, 0.00000000e+00,\n 1.44993897e-14, 2.56212982e-14, 6.71791317e-15],\n [2.65992102e-17, 1.07641758e-17, 0.00000000e+00,\n 0.00000000e+00, 3.05696052e-18, 1.45802733e-16],\n [1.06954859e-14, 8.00853592e-14, 0.00000000e+00,\n 1.86719016e-17, 3.72830896e-16, 7.56701373e-15],\n [1.88198032e-14, 7.23305422e-14, 0.00000000e+00,\n 1.08007125e-19, 3.75352571e-15, 4.23270495e-15],\n [9.34799587e-14, 2.79318913e-13, 0.00000000e+00,\n 1.72723126e-15, 1.98988395e-13, 9.09553601e-14]], dtype=float32)], [DeviceArray([-0.00074821, -0.00081796, 0.00117048, -0.00178991,\n 0.00265362, 0. ], dtype=float32), DeviceArray([ 1.1074917e-07, 2.4995950e-04, 1.8515976e-04,\n 3.9137376e-06, -6.2360908e-05, 0.0000000e+00], dtype=float32), DeviceArray([2.8900531e-11, 7.9435143e-09, 4.1745047e-09, 3.8028711e-10,\n 2.4104238e-10, 0.0000000e+00], dtype=float32)], [DeviceArray([[-0.3056272 , 0.4231755 , -0.04128612, 0.46694696,\n 0.16514003, -0.45115194],\n [-0.02365104, 0.2639698 , -0.37985128, 0.8142715 ,\n 0.05928962, -0.27328548],\n [-0.3441138 , 0.3751544 , 0.05917086, 0.4884692 ,\n 0.04338139, 0.2134114 ],\n [-0.38789976, -0.38161513, 0.31562155, 0.10442398,\n -0.00480278, -0.19789281],\n [ 0.12581037, 0.37353298, -0.10928751, 0.503088 ,\n -0.2655051 , -0.17124975],\n [ 0.35595462, 0.18437892, 0.23258427, -0.85487086,\n -0.2029693 , -0.33262372]], dtype=float32), DeviceArray([[ 1.4422523e-07, 1.0348479e-06, -2.1897273e-07,\n 2.1539340e-06, -9.7019038e-06, 0.0000000e+00],\n [ 1.7801361e-07, -2.5086075e-07, -1.3055396e-07,\n 5.2816296e-07, -1.3365701e-06, 0.0000000e+00],\n [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,\n 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n [ 1.8810088e-08, 8.0061959e-08, 5.1003372e-08,\n 3.1219592e-08, -2.3568070e-08, 0.0000000e+00],\n [ 1.6152674e-07, 1.0853166e-06, 1.2019419e-06,\n 2.8108428e-07, -1.2651979e-07, 0.0000000e+00],\n [ 3.1606952e-08, 2.8423892e-06, 5.0167625e-07,\n 1.1964460e-06, -5.3287968e-06, 0.0000000e+00]], dtype=float32), DeviceArray([[1.1840859e-15, 7.9572782e-14, 3.1349770e-14, 2.4425308e-13,\n 4.6538823e-12, 0.0000000e+00],\n [2.5218905e-15, 6.5746011e-15, 9.1246296e-16, 1.5102157e-14,\n 9.8801487e-14, 0.0000000e+00],\n [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,\n 0.0000000e+00, 0.0000000e+00],\n [3.2490592e-17, 8.4389236e-16, 5.7708000e-16, 1.1674825e-16,\n 3.7473287e-17, 0.0000000e+00],\n [1.6931966e-15, 1.3198515e-13, 1.7659864e-13, 2.0507665e-14,\n 9.0556134e-16, 0.0000000e+00],\n [2.1929211e-15, 7.6067006e-13, 2.3270353e-14, 7.4921501e-14,\n 1.4142405e-12, 0.0000000e+00]], dtype=float32)], [DeviceArray([ 0. , 0.00157121, -0.00141369, -0.00135325,\n -0.00251697, -0.00212993], dtype=float32), DeviceArray([ 0.00000000e+00, 1.29414868e-04, -5.93711302e-05,\n 1.11967165e-04, 3.14280886e-04, 8.22721340e-05], dtype=float32), DeviceArray([0.0000000e+00, 2.1965585e-08, 1.1964050e-09, 1.3787350e-09,\n 6.9818973e-09, 2.0284541e-09], dtype=float32)], [DeviceArray([[ 0.4786232 , 0.07762085, 0.23054789, 0.24556062,\n -0.00573604, 0.09755293],\n [-0.7429878 , 0.72179085, 0.18699315, -0.2683911 ,\n 0.57062393, 0.10477231],\n [-0.5879746 , 0.22517885, 0.58405346, 0.15380381,\n 0.5612985 , -0.374023 ],\n [-0.13496533, 0.31007463, -0.91230136, -0.00988471,\n -0.14309123, 0.64083695],\n [-0.47050482, -0.09354866, -0.3502471 , 0.02155269,\n -0.23451555, -0.5596603 ],\n [ 0.46003628, -0.26375806, 0.1943689 , 0.04007643,\n -0.29916492, 0.16186029]], dtype=float32), DeviceArray([[ 0.0000000e+00, 4.6203587e-07, -4.9911466e-07,\n 8.1376520e-08, 7.8805937e-07, 4.8135485e-07],\n [ 0.0000000e+00, -3.8448088e-06, 2.5146437e-06,\n 8.5733035e-08, 4.4749559e-06, 9.3717990e-06],\n [ 0.0000000e+00, 7.2909307e-07, -4.5917994e-07,\n 1.7394352e-07, 1.1992919e-06, 5.3257600e-07],\n [ 0.0000000e+00, -7.2328521e-07, 3.5359469e-07,\n 5.5574421e-09, 1.1884179e-06, 2.0294008e-06],\n [ 0.0000000e+00, -3.7733372e-07, 2.5784374e-07,\n 1.0839074e-07, 3.5430759e-07, 9.1234665e-07],\n [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,\n 0.0000000e+00, 0.0000000e+00, 0.0000000e+00]], dtype=float32), DeviceArray([[0.00000000e+00, 3.25262346e-14, 3.59916768e-14,\n 3.68739144e-16, 5.14987562e-14, 1.75674311e-14],\n [0.00000000e+00, 7.44021056e-13, 3.08127899e-13,\n 4.13313856e-16, 1.09750945e-12, 4.47324686e-12],\n [0.00000000e+00, 1.40464599e-13, 5.10655903e-14,\n 2.17324683e-15, 1.08902335e-13, 1.50668662e-14],\n [0.00000000e+00, 3.73186691e-14, 6.29945264e-15,\n 3.80917003e-18, 7.84495827e-14, 2.20110930e-13],\n [0.00000000e+00, 1.75216490e-14, 3.24122409e-15,\n 1.44898835e-15, 6.33568871e-15, 4.17144380e-14],\n [0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]], dtype=float32)], [DeviceArray([0.00180455, 0.00180623, 0.00180595, 0.00259012], dtype=float32), DeviceArray([-0.00044176, -0.00043326, -0.00036839, -0.00124654], dtype=float32), DeviceArray([1.95022345e-08, 1.87587244e-08, 1.35600775e-08,\n 8.94127368e-08], dtype=float32)], [DeviceArray([[-0.02883793, -0.06974681, 0.24037988, -0.45696107],\n [-0.66388506, -0.35703114, -0.34853068, 0.36051708],\n [ 0.661668 , -0.03455266, 0.73796755, -0.4630023 ],\n [-0.11637773, -0.829776 , -0.0157351 , -0.2001823 ],\n [ 0.2375262 , -0.39262995, -0.61489004, -0.27574617],\n [ 0.9198022 , 0.37472868, 0.05519994, -0.7439775 ]], dtype=float32), DeviceArray([[ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,\n 0.0000000e+00],\n [-2.2636502e-06, -2.0661346e-06, -1.2783025e-06,\n -1.4003868e-05],\n [-5.6834824e-07, -5.8828175e-07, -5.8287685e-07,\n -1.2936071e-06],\n [-2.1617128e-11, -2.1906333e-11, -2.1303336e-11,\n -1.0811599e-07],\n [-8.0856603e-07, -6.9933321e-07, -3.0290084e-07,\n -7.9441652e-06],\n [-1.5500837e-07, -8.2882934e-08, 1.3715700e-07,\n -2.3562773e-06]], dtype=float32), DeviceArray([[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n [5.1146001e-13, 4.2600835e-13, 1.6287263e-13, 1.0453084e-11],\n [3.2286870e-14, 3.4591704e-14, 3.3959343e-14, 9.9721216e-14],\n [5.6415981e-23, 5.7961612e-23, 5.4847972e-23, 1.3727235e-15],\n [6.5160449e-14, 4.8715518e-14, 9.0941515e-15, 3.3455365e-12],\n [2.3930796e-15, 6.8169646e-16, 1.8897782e-15, 2.9122923e-13]], dtype=float32)]), tree_def=PyTreeDef(CustomNode([()], [{'params': {'fc0': {'bias': *, 'kernel': *}, 'fc1': {'bias': *, 'kernel': *}, 'fc2': {'bias': *, 'kernel': *}, 'fc3': {'bias': *, 'kernel': *}, 'fc4': {'bias': *, 'kernel': *}, 'fc_last': {'bias': *, 'kernel': *}}}])), subtree_defs=(PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *)), PyTreeDef((*, *, *))))"},"metadata":{}}]},{"cell_type":"code","source":"# access param by layer name\nget_params(opt_state)['params']['fc0']","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"model.get_variable(name='fc_last')","metadata":{"execution":{"iopub.status.busy":"2021-07-04T06:40:47.867354Z","iopub.execute_input":"2021-07-04T06:40:47.868012Z","iopub.status.idle":"2021-07-04T06:40:47.892970Z","shell.execute_reply.started":"2021-07-04T06:40:47.867958Z","shell.execute_reply":"2021-07-04T06:40:47.891635Z"},"trusted":true},"execution_count":109,"outputs":[{"traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)","\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_variable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'fc_last'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m","\u001b[0;31mTypeError\u001b[0m: get_variable() missing 1 required positional argument: 'col'"],"ename":"TypeError","evalue":"get_variable() missing 1 required positional argument: 'col'","output_type":"error"}]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]}]}