diff --git "a/notebooks/.ipynb_checkpoints/Normalizing_flows_TEST-Copy2-checkpoint.ipynb" "b/notebooks/.ipynb_checkpoints/Normalizing_flows_TEST-Copy2-checkpoint.ipynb" new file mode 100644--- /dev/null +++ "b/notebooks/.ipynb_checkpoints/Normalizing_flows_TEST-Copy2-checkpoint.ipynb" @@ -0,0 +1,2402 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "38f65ab0-9aa9-4de1-bffb-2a70a5ebe152", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.distributions as D\n", + "\n", + "from torch.utils.data import DataLoader, dataset, TensorDataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "25375e3d-9f96-48fb-84fe-093d14f386ea", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import sys\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "sys.path.append('../insight')\n", + "from archive import archive " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a22153c3-c83b-4f3e-b25a-3a8a034c17fc", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "parent_dir = '/data/astro/scratch/lcabayol/Euclid/NNphotozs/Euclid_EXT_MER_PHZ_DC2_v1.5'\n", + "\n", + "photoz_archive = archive(path = parent_dir, Qz_cut=0.5)\n", + "f, ferr, specz, specqz = photoz_archive.get_training_data()\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1a316118-9704-4ea9-a35e-7959e0b83ac1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "dset = TensorDataset(torch.Tensor(f),torch.Tensor(specz))\n", + "loader = DataLoader(dset, batch_size=32, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "f6f0feae-ce85-4840-ba5d-639c48d0e496", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "class InvertedPlanarFlow(nn.Module):\n", + " \"\"\"Implementation of the invertible transformation used in planar flow:\n", + " f(z) = z + u * h(dot(w.T, z) + b)\n", + " See Section 4.1 in https://arxiv.org/pdf/1505.05770.pdf. \n", + " \"\"\"\n", + " def __init__(self, n_dims):\n", + " super(InvertedPlanarFlow, self).__init__()\n", + " self.n_dims=n_dims\n", + " \n", + " def _assign_params_flow(self, flow_params):\n", + " \n", + " # Extract the u, w, and b components from flow_params\n", + " u = flow_params[:, :self.n_dims]\n", + " w = flow_params[:, self.n_dims:(2* self.n_dims)]\n", + " b = flow_params[:, -1].unsqueeze(1)\n", + " \n", + " # Update the parameters with the new values\n", + " self.u = nn.Parameter(u)\n", + " self.w = nn.Parameter(w)\n", + " self.b = nn.Parameter(b)\n", + "\n", + " #def forward(self, y, flow_params):\n", + " # self._assign_params_flow(flow_params)\n", + " # activation = torch.tanh(torch.einsum('ij,ij->i', y, self.w) + self.b)\n", + " # flow = y + torch.mm(activation, self.u)\n", + " # det_jacobian = self._determinant_jacobian(y,self._jacobian(y))\n", + " \n", + " # return flow, det_jacobian\n", + " \n", + " def forward(self, z):\n", + " activation = F.linear(z, self.weight, self.bias)\n", + " scale = self.get_scale(self.scale, self.weight)\n", + " return z + self.u * nn.Tanh()(activation)\n", + "\n", + " \n", + " def inverse(self, z, flow_params):\n", + " self._assign_params_flow(flow_params)\n", + " activation = torch.tanh(torch.einsum('ij,ij->i', z, self.w) + self.b)\n", + " flow = z - torch.mm(activation, self.u)\n", + " det_jacobian = self._determinant_jacobian(z,self._jacobian_inverse(z))\n", + " flow =flow*det_jacobian\n", + "\n", + " return flow, det_jacobian\n", + " \n", + " def _jacobian(self, y):\n", + " h = torch.einsum('ij,ij->i', y, self.w)\n", + " activation = torch.tanh(h + self.b)\n", + " dactivation = 1/torch.cosh(h[:,None] + self.b)**2\n", + " jacobian = dactivation[:,:,None] * torch.einsum('ij,ij->i', self.u,self.w)[:,None,None] + torch.eye(self.n_dims)[None,:,:]\n", + " return jacobian\n", + " \n", + " def _jacobian_inverse(self, x):\n", + " h = torch.einsum('ij,ij->i', x, self.w)\n", + " activation = torch.tanh(h + self.b)\n", + " dactivation = 1/torch.cosh(h[:,None] + self.b)**2\n", + " jacobian = -dactivation[:,:,None] * torch.einsum('ij,ij->i', self.u,self.w)[:,None,None] + torch.eye(self.n_dims)[None,:,:]\n", + " return jacobian\n", + "\n", + " def _determinant_jacobian(self, z, jacobian):\n", + " determinant_jacobian = torch.det(jacobian)\n", + " return determinant_jacobian\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 125, + "id": "8c2eb937-f1c5-4f11-aafa-0d45a4d2fdbf", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "30" + ] + }, + "execution_count": 125, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "10*(2+1)" + ] + }, + { + "cell_type": "code", + "execution_count": 126, + "id": "8ad6f89b-517c-47bb-a955-c94d80a5e4c9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "class ConditionalNormalizingFlow(nn.Module):\n", + " def __init__(self, n_dims, n_context, n_flows):\n", + " super(ConditionalNormalizingFlow, self).__init__()\n", + " self.n_dims = n_dims\n", + " self.n_flows = n_flows\n", + " self.n_context = n_context\n", + " self.base_distribution=D.Normal(torch.zeros(self.n_dims), torch.ones(self.n_dims))\n", + "\n", + " # Create the flow layers\n", + " self.flows = nn.ModuleList([InvertedPlanarFlow(self.n_dims) for _ in range(n_flows)])\n", + "\n", + " # Context encoder network (maps context to flow parameters)\n", + " self.encoder = nn.Sequential(\n", + " nn.Linear(self.n_context, 32),\n", + " nn.ReLU(),\n", + " nn.Linear(32, 64),\n", + " nn.ReLU(),\n", + " nn.Linear(64, n_flows * (2 * n_dims + 1)) # u, w, b for each flow\n", + " )\n", + "\n", + " def forward(self, y, context_data):\n", + " # Compute flow parameters from the context\n", + " flow_params = self.encoder(context_data).view(-1, self.n_flows, 2 * self.n_dims + 1)\n", + "\n", + " # Apply each forward flow transformation\n", + " detjac_sum = 0\n", + " for i in range(self.n_flows):\n", + " if i==0:\n", + " z = y\n", + " z, det_jacobian = self.flows[i](z,flow_params[:, i, :])\n", + " \n", + " detjac_sum += torch.log(det_jacobian)\n", + " return z, detjac_sum\n", + " \n", + " def predict(self,context_data, zs):\n", + " flow_params = self.encoder(context_data).view(-1, self.n_flows, 2 * self.n_dims + 1)\n", + " pred_z = []\n", + " \n", + " for z in zs:\n", + " for i in range(self.n_flows):\n", + " z, det_invjac = self.flows[i].inverse(z,flow_params[:, i, :])\n", + " pred_z.append(z.detach().cpu().numpy()[0])\n", + " return np.array(pred_z) \n", + " \n" + ] + }, + { + "cell_type": "code", + "execution_count": 127, + "id": "59b89bbe-aab9-424d-be61-b49ecabbb589", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Example usage\n", + "n_dims = 1\n", + "n_context=6\n", + "n_flows = 10\n", + "batch_size = 32\n", + "nepochs=100\n", + "# Define the dimensionality of z\n", + "z_dim = 1\n", + "epsilon = 0.1\n", + "\n", + "base_distribution = D.Normal(torch.zeros(z_dim), torch.ones(z_dim))\n", + "base_distribution_perturbation = D.Normal(torch.zeros(z_dim), epsilon**2*torch.ones(z_dim))\n", + "\n", + "#data = torch.randn(batch_size, n_context)\n", + "#y = torch.randn(batch_size,1)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 129, + "id": "6567aace-64e6-4ad6-ac75-44b4c8ed63a4", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 0\n", + "epoch 1\n", + "epoch 2\n", + "epoch 3\n", + "epoch 4\n", + "epoch 5\n", + "epoch 6\n", + "epoch 7\n", + "epoch 8\n", + "epoch 9\n", + "epoch 10\n", + "epoch 11\n", + "epoch 12\n", + "epoch 13\n", + "epoch 14\n", + "epoch 15\n", + "epoch 16\n", + "epoch 17\n", + "epoch 18\n", + "epoch 19\n", + "epoch 20\n", + "epoch 21\n", + "epoch 22\n", + "epoch 23\n", + "epoch 24\n", + "epoch 25\n", + "epoch 26\n", + "epoch 27\n", + "epoch 28\n", + "epoch 29\n", + "epoch 30\n", + "epoch 31\n", + "epoch 32\n", + "epoch 33\n", + "epoch 34\n", + "epoch 35\n", + "epoch 36\n", + "epoch 37\n", + "epoch 38\n", + "epoch 39\n", + "epoch 40\n", + "epoch 41\n", + "epoch 42\n", + "epoch 43\n", + "epoch 44\n", + "epoch 45\n", + "epoch 46\n", + "epoch 47\n", + "epoch 48\n", + "epoch 49\n", + "epoch 50\n", + "epoch 51\n", + "epoch 52\n", + "epoch 53\n", + "epoch 54\n", + "epoch 55\n", + "epoch 56\n", + "epoch 57\n", + "epoch 58\n", + "epoch 59\n", + "epoch 60\n", + "epoch 61\n", + "epoch 62\n", + "epoch 63\n", + "epoch 64\n", + "epoch 65\n", + "epoch 66\n", + "epoch 67\n", + "epoch 68\n", + "epoch 69\n", + "epoch 70\n", + "epoch 71\n", + "epoch 72\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[129], line 12\u001b[0m\n\u001b[1;32m 9\u001b[0m logloss \u001b[38;5;241m=\u001b[39m base_distribution\u001b[38;5;241m.\u001b[39mlog_prob(z) \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m100\u001b[39m\u001b[38;5;241m*\u001b[39mlogdet_jacb \u001b[38;5;66;03m#+ base_distribution_perturbation.log_prob(v).mean(1)\u001b[39;00m\n\u001b[1;32m 10\u001b[0m logloss \u001b[38;5;241m=\u001b[39m logloss\u001b[38;5;241m.\u001b[39msum()\n\u001b[0;32m---> 12\u001b[0m \u001b[43mlogloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep() \n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/NFlow/lib/python3.10/site-packages/torch/_tensor.py:487\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 477\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 478\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 479\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 480\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 485\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 488\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 489\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/NFlow/lib/python3.10/site-packages/torch/autograd/__init__.py:200\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 195\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 197\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 198\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 199\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 200\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 201\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 202\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "flow_model = ConditionalNormalizingFlow(n_dims, n_context, n_flows)\n", + "optimizer = optim.Adam(flow_model.parameters(), lr=1e-3, weight_decay=1e-4)\n", + "\n", + "for e in range(nepochs):\n", + " print('epoch',e)\n", + " for x, y in loader:\n", + " #v = epsilon**2*torch.randn(size=x.size())\n", + " z, logdet_jacb = flow_model(y.unsqueeze(1),x)\n", + " logloss = base_distribution.log_prob(z) + 100*logdet_jacb #+ base_distribution_perturbation.log_prob(v).mean(1)\n", + " logloss = logloss.sum()\n", + "\n", + " logloss.backward()\n", + " optimizer.step() \n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 130, + "id": "1572b201-1427-4aba-93ef-11dcf7b6130e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "Ntest=10\n", + "z = base_distribution.sample(sample_shape=torch.Size([1000]))\n", + "ppz = np.zeros(shape=(Ntest,1000))\n", + "for ii in range(Ntest):\n", + " ypred = flow_model.predict(x[ii].unsqueeze(0),z.unsqueeze(1))\n", + " ppz[ii] = ypred.squeeze(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 131, + "id": "89a36a6d-70a0-403b-9877-212fafecd110", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.1227, 0.4276, 0.3734, 0.7514, 0.3263, 0.4419, 0.0294, 0.4950, 1.1767,\n", + " 2.5304])" + ] + }, + "execution_count": 131, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y[:Ntest]" + ] + }, + { + "cell_type": "code", + "execution_count": 132, + "id": "0759273d-e80a-413f-bd3c-7f6955419b6e", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.00618994, 0.00119137, 0.00665953, 0.00178801, 0.00802471,\n", + " 0.00547702, 0.00611226, 0.01258649, 0.00101891, 0.0086738 ])" + ] + }, + "execution_count": 132, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ppz.mean(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 124, + "id": "99ad0512-a947-4f84-a473-3e2401fb9281", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 1., 1., 1., 2., 0., 4., 2., 3.,\n", + " 2., 1., 3., 0., 0., 3., 2., 8., 5., 4., 4., 10., 4.,\n", + " 8., 7., 10., 10., 12., 19., 16., 14., 17., 24., 27., 36., 35.,\n", + " 26., 54., 32., 40., 44., 49., 45., 58., 54., 43., 31., 34., 22.,\n", + " 29., 19., 16., 12., 19., 15., 12., 11., 9., 2., 5., 6., 2.,\n", + " 0., 4., 1., 3., 0., 0., 3., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 1., 0., 1.]),\n", + " array([-2.11697197, -2.07361054, -2.03024936, -1.98688793, -1.94352663,\n", + " -1.90016532, -1.85680389, -1.81344259, -1.77008128, -1.72671998,\n", + " -1.68335855, -1.63999724, -1.59663594, -1.55327463, -1.50991321,\n", + " -1.4665519 , -1.42319059, -1.37982929, -1.33646786, -1.29310656,\n", + " -1.24974525, -1.20638394, -1.16302252, -1.11966121, -1.07629991,\n", + " -1.03293848, -0.98957717, -0.94621587, -0.9028545 , -0.8594932 ,\n", + " -0.81613183, -0.77277052, -0.72940916, -0.68604785, -0.64268649,\n", + " -0.59932518, -0.55596381, -0.51260251, -0.46924114, -0.42587981,\n", + " -0.38251847, -0.33915713, -0.2957958 , -0.25243446, -0.20907313,\n", + " -0.16571179, -0.12235046, -0.07898912, -0.03562779, 0.00773355,\n", + " 0.05109489, 0.09445623, 0.13781756, 0.1811789 , 0.22454023,\n", + " 0.26790157, 0.31126291, 0.35462424, 0.39798558, 0.44134691,\n", + " 0.48470825, 0.52806962, 0.57143092, 0.61479229, 0.65815359,\n", + " 0.70151496, 0.74487627, 0.78823763, 0.83159894, 0.8749603 ,\n", + " 0.91832161, 0.96168298, 1.00504434, 1.04840565, 1.09176695,\n", + " 1.13512826, 1.17848969, 1.22185099, 1.2652123 , 1.30857372,\n", + " 1.35193503, 1.39529634, 1.43865764, 1.48201907, 1.52538037,\n", + " 1.56874168, 1.61210299, 1.65546441, 1.69882572, 1.74218702,\n", + " 1.78554833, 1.82890975, 1.87227106, 1.91563237, 1.95899367,\n", + " 2.0023551 , 2.04571629, 2.08907771, 2.13243914, 2.17580032,\n", + " 2.21916175]),\n", + " )" + ] + }, + "execution_count": 124, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhsAAAGcCAYAAABwemJAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaR0lEQVR4nO3df2zc913H8ZfjuUahtat2axtnWR0LJdsqTYC60Kw1ii0kEAUTMVRCWfwPqI1Uddqi/WgkYCoEUvYHaxQGKEWo6ioo+1HLQx3LJORhR9wfQa3kgNaJqHZw3TZaaXWuaOs5veMPFKtOnNV279PzOY+H9JXmy/nr93Tz/NTne5/7ttXr9XoAAArZ1OwBAICNTWwAAEWJDQCgKLEBABQlNgCAosQGAFCU2AAAinpfswdIklqtlhdeeCHXXHNN2tramj0OALAC9Xo9r732Wnp6erJp0+XXL9ZFbLzwwgvZtm1bs8cAANZgZmYmH/zgBy/772uOjR/84AcZGRnJtm3bsnfv3lxzzTVrPdXi987MzKSrq2vN5wEA3jtzc3PZtm3bOzbAmmLjq1/9ap544ok88cQT2bp1a5Jkamoqf/7nf56f//mfz8mTJ/Mnf/Inufnmm1d0vguXTrq6usQGALSYd3oLxKpj48knn8wf//Ef5z/+4z/ygQ98IMn/v+diaGgoR48ezeDgYLZv3559+/alUqmsbWoAYMNY1W6U8+fP5zOf+Uw+97nPLYZGkpw4cSJnzpxJf39/kmRwcDCTk5M5depUY6cFAFrOqmJjfHw8MzMz+eEPf5i9e/fmIx/5SJ544olUKpX09fWlo6MjSdLe3p6+vr6MjY0te575+fnMzc0tOQCAjWlVl1EmJydz7bXX5stf/nKuu+66fPe7383Q0FAGBgYuea9Fd3d3Zmdnlz3PkSNH8uCDD659agCgZaxqZeONN97IRz7ykVx33XVJkl/5lV/JjTfemJMnTy6ualxQq9VSq9WWPc+hQ4dSrVYXj5mZmTWODwCsd6uKjZtuuin/+7//u+SxD37wg/nCF75wyaWQarW6uFPlYp2dnYs7T+xAAYCNbVWxcfvtt2d6ejrnz59ffOzNN99M8v9bX+v1epJkYWEh09PTGRgYaOCoAEArWlVs7NixIz/7sz+b733ve0mSV155JS+//HI+//nPp6enJxMTE0n+/42kvb292bVrV+MnBgBayqo/Z+NrX/taPve5z2VycjJTU1P5xje+kc2bN2d0dDSHDx/O6dOnU6lUMjIy4j4nAEDa6heufTTR3Nxcuru7U61WvX8DAFrESv9+u8U8AFCU2AAAihIbAEBRYgMAKEpsAABFrXrrK8BG0/vAU5c8Nv3QnU2YBDYmKxsAQFFiAwAoSmwAAEWJDQCgKLEBABQlNgCAosQGAFCU2AAAihIbAEBRYgMAKEpsAABFiQ0AoCixAQAUJTYAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICixAYAUNT7mj0AwHJ6H3jqksemH7qzaT//vfzZsNFY2QAAihIbAEBRYgMAKEpsAABFiQ0AoCixAQAUJTYAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICixAYAUJTYAACKEhsAQFHvKjZeeOGFRs0BAGxQq4qNer2eHTt2pK2tLW1tbfnUpz6VJJmamsqBAwdy/PjxDA8P5+zZs0WGBQBaz/tW8+R//ud/zqc//encdtttSZLt27enVqtlaGgoR48ezeDgYLZv3559+/alUqkUGRgAaC2rWtn4y7/8y7S3t+eGG27Irbfemuuvvz4nTpzImTNn0t/fnyQZHBzM5ORkTp06VWRgAKC1rDg2XnvttczPz+cP/uAPsn379tx///2p1+upVCrp6+tLR0dHkqS9vT19fX0ZGxu77Lnm5+czNze35AAANqYVX0a55ppr8i//8i9ZWFjI3/zN3+Szn/1sfuZnfibnzp1LV1fXkud2d3dndnb2suc6cuRIHnzwwbVPDQC0jFXvRuno6Mj999+fBx54IH//93+fjo6OxVWNC2q1Wmq12mXPcejQoVSr1cVjZmZm9ZMDAC1hzVtff+M3fiPVajVbtmy55DJItVrN1q1bL/u9nZ2d6erqWnIAABvTmmPjrbfeys6dO7Nnz55MTU2lXq8nSRYWFjI9PZ2BgYGGDQkAtK4Vx8bExEQef/zxxag4fvx4Pv/5z2f37t3p6enJxMREkmR8fDy9vb3ZtWtXmYkBgJay4jeIzszM5DOf+Uz+4R/+IbfddluGh4dzxx13JElGR0dz+PDhnD59OpVKJSMjI2lrays2NADQOlYcG3fffXfuvvvuZf9tx44deeyxx5Ik9913X2MmAwA2hFV9gihAq+l94KlLHpt+6M4mTAJXLnd9BQCKEhsAQFFiAwAoSmwAAEWJDQCgKLEBABQlNgCAosQGAFCU2AAAihIbAEBRYgMAKEpsAABFiQ0AoCixAQAUJTYAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICi3tfsAQDea70PPNXsEeCKYmUDAChKbAAARYkNAKAosQEAFCU2AICi7EYBNhQ7TWD9sbIBABQlNgCAosQGAFCU2AAAihIbAEBRdqMArMByu1ymH7qzCZNA67GyAQAUJTYAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICixAYAUJTYAACKEhsAQFFrio3XX389H/3oRzM9PZ0kmZqayoEDB3L8+PEMDw/n7NmzjZwRAGhha7rF/LFjx/KDH/wgSVKr1TI0NJSjR49mcHAw27dvz759+1KpVBo6KADQmla9sjE6OpqBgYHFr0+cOJEzZ86kv78/STI4OJjJycmcOnWqcVMCAC1rVbHx3//933nxxReza9euxccqlUr6+vrS0dGRJGlvb09fX1/GxsYue575+fnMzc0tOQCAjWnFl1HeeuutPPLII3nwwQeXPH7u3Ll0dXUteay7uzuzs7OXPdeRI0cuOQ/AavU+8FSzRwBWYMUrG1/96ldz7733ZtOmpd/S0dGxuKpxQa1WS61Wu+y5Dh06lGq1unjMzMyscmwAoFWseGXj2LFj+cIXvrDksZ07d6ZWq+WWW25Z8ni1Ws3WrVsve67Ozs50dnauclQAoBWtODb+67/+a8nXbW1t+eEPf5jZ2dn86q/+aur1etra2rKwsJDp6eklbyIFAK5c7/pDvXbv3p2enp5MTEwkScbHx9Pb27vkTaQAwJVrTZ+z8XabNm3K6OhoDh8+nNOnT6dSqWRkZCRtbW2NmA8AaHFrjo16vb74n3fs2JHHHnssSXLfffe9+6kAgA3DvVEAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICixAYAUJTYAACKEhsAQFFiAwAoSmwAAEWJDQCgKLEBABQlNgCAot7X7AGAja/3gaeWfD390J1NmgRoBisbAEBRYgMAKEpsAABFiQ0AoCixAQAUZTcK8J67eHdKYocKbGRWNgCAosQGAFCU2AAAihIbAEBRYgMAKEpsAABF2foK0CC29MLyrGwAAEWJDQCgKLEBABQlNgCAosQGAFCU2AAAihIbAEBRYgMAKEpsAABFiQ0AoCixAQAU5d4oQMtY7t4jwPpnZQMAKEpsAABFiQ0AoCixAQAUJTYAgKLsRgFYI7tjYGWsbAAARa06Np555pnccccdue666/JLv/RLefnll5MkU1NTOXDgQI4fP57h4eGcPXu24cMCAK1nVbHx5ptv5lvf+la+973vZWZmJq+//nr+4i/+IrVaLUNDQ7nrrrtyzz33ZP/+/dm3b1+pmQGAFrKq2KhWq/mjP/qjbN68OT/90z+d/v7+bNq0KSdOnMiZM2fS39+fJBkcHMzk5GROnTpVZGgAoHWsKjZuvPHGXHXVVUmSH//4x3nppZfy2c9+NpVKJX19feno6EiStLe3p6+vL2NjY8ueZ35+PnNzc0sOAGBjWtMbRL/zne/ktttuy9jYWP7zP/8z586dS1dX15LndHd3Z3Z2dtnvP3LkSLq7uxePbdu2rWUMAKAFrCk2fvmXfznf/OY384lPfCKf+tSn0tHRsbiqcUGtVkutVlv2+w8dOpRqtbp4zMzMrGUMAKAFrOlzNi5cJvm7v/u7XH/99fnABz5wyaWQarWarVu3Lvv9nZ2d6ezsXMuPBgBazLv6nI3Nmzfn/e9/f/bs2ZOpqanU6/UkycLCQqanpzMwMNCQIQGA1rWq2Pif//mf/NM//dNiVPzrv/5r9u/fn/7+/vT09GRiYiJJMj4+nt7e3uzatavxEwMALWVVl1Gmpqby+7//+9m5c2d+67d+K1dffXX+9E//NG1tbRkdHc3hw4dz+vTpVCqVjIyMpK2trdTcAECLWFVs3HrrrTl37tyy/7Zjx4489thjSZL77rvv3U8GsAEtdz+V6YfubMIk8N5xbxQAoCixAQAUJTYAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICiVnVvFIB3sty9P4Arm5UNAKAosQEAFCU2AICixAYAUJTYAACKshsFWBfsYoGNy8oGAFCU2AAAihIbAEBRYgMAKEpsAABFiQ0AoCixAQAUJTYAgKLEBgBQlNgAAIoSGwBAUe6NArwr7mkCvBMrGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICixAYAUJTYAACKEhsAQFFiAwAoSmwAAEW5NwpAQe4dA1Y2AIDCxAYAUJTYAACKEhsAQFFiAwAoSmwAAEWJDQCgqFXFxre//e3s3LkzXV1d+eQnP5lXXnklSTI1NZUDBw7k+PHjGR4eztmzZ4sMCwC0nhXHxnPPPZennnoqTz75ZB599NF8//vfzxe/+MXUarUMDQ3lrrvuyj333JP9+/dn3759JWcGAFrIij9B9OTJkzl27Fiuuuqq3HLLLZmcnMw3vvGNnDhxImfOnEl/f3+SZHBwMHv37s2pU6fy8Y9/vNjgAEBrWPHKxvDwcK666qrFr2+88cZ86EMfSqVSSV9fXzo6OpIk7e3t6evry9jYWOOnBQBazprfIPr000/n3nvvzblz59LV1bXk37q7uzM7O3vZ752fn8/c3NySAwDYmNYUGy+++GLOnz+fvXv3pqOjY3FV44JarZZarXbZ7z9y5Ei6u7sXj23btq1lDACgBaw6Nt566608/PDDOXbsWJJky5Ytl6xMVKvVbN269bLnOHToUKrV6uIxMzOz2jEAgBax6tj4yle+koMHD+bqq69Oktxxxx2ZmppKvV5PkiwsLGR6ejoDAwOXPUdnZ2e6urqWHADAxrTi3ShJ8vDDD2fHjh159dVX8+qrr+a5557L+fPn09PTk4mJifziL/5ixsfH09vbm127dpWaGQBoISuOja9//es5ePDg4gpGkmzevDkvvfRSRkdHc/jw4Zw+fTqVSiUjIyNpa2srMjAA0Fra6m+vhyaZm5tLd3d3qtWqSyrQYnofeKrZI7S86YfubPYIsCYr/fu9qssowMZxcST4g9c8a3ktlos8ryHrlRuxAQBFiQ0AoCixAQAUJTYAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFOXeKMCKuekasBZWNgCAosQGAFCU2AAAihIbAEBRYgMAKMpuFCDJ8jtNph+6swmT4LVgo7GyAQAUJTYAgKLEBgBQlNgAAIoSGwBAUXajAGwQF+9isYOF9cLKBgBQlNgAAIoSGwBAUWIDAChKbAAARdmNAlzWcvfoAFgtKxsAQFFiAwAoSmwAAEWJDQCgKLEBABQlNgCAosQGAFCU2AAAihIbAEBRYgMAKEpsAABFtdXr9Xqzh5ibm0t3d3eq1Wq6urqaPQ5sOO5xwgXTD93Z7BHYQFb699vKBgBQlNgAAIoSGwBAUWIDAChKbAAARb2v2QMA8N65eGeS3Sm8F9a8svHGG2+kWq02chYAYANadWzUarU8+uij2bFjR5555pnFx6empnLgwIEcP348w8PDOXv2bEMHBQBa06ovo7z88ssZGBjI888/v/hYrVbL0NBQjh49msHBwWzfvj379u1LpVJp6LAAQOtZ9crGDTfckJtvvnnJYydOnMiZM2fS39+fJBkcHMzk5GROnTrVmCkBgJbVkN0olUolfX196ejoSJK0t7enr68vY2Njyz5/fn4+c3NzSw4AYGNqyG6Uc+fOXfKZ6N3d3ZmdnV32+UeOHMmDDz7YiB8NLWu5+5WsZWeA+54A611DVjY6OjoWVzUuqNVqqdVqyz7/0KFDqVari8fMzEwjxgAA1qGGrGxs2bIlJ0+eXPJYtVrN1q1bl31+Z2dnOjs7G/GjAYB1riErG3v27MnU1FQu3K1+YWEh09PTGRgYaMTpAYAWtqbYuPjyyO7du9PT05OJiYkkyfj4eHp7e7Nr1653PyEA0NJWfRnlRz/6UR555JEkyeOPP56bbropH/7whzM6OprDhw/n9OnTqVQqGRkZSVtbW8MHBgBaS1v9wrWPJpqbm0t3d3eq1eolu1pgo7IbhfXK/VJYqZX+/XbXVwCgKLEBABQlNgCAosQGAFCU2AAAimrIJ4jClaRRu0jW288CKMXKBgBQlNgAAIoSGwBAUWIDAChKbAAARdmNAi3GvVCAVmNlAwAoSmwAAEWJDQCgKLEBABQlNgCAouxGAWDVLt4V5Z49/CRWNgCAosQGAFCU2AAAihIbAEBRYgMAKEpsAABF2foKDWAbIBuZm//xblnZAACKEhsAQFFiAwAoSmwAAEWJDQCgKLtRaKqNuotjre/e965/NrLl/ve9UX7n+cmsbAAARYkNAKAosQEAFCU2AICixAYAUJTdKKx7a30H+1p2uni3PDSf392Nx8oGAFCU2AAAihIbAEBRYgMAKEpsAABF2Y1CMY16R3mjzr1W7lfClWYt/5u3G4SfxMoGAFCU2AAAihIbAEBRYgMAKEpsAABFbfjdKCXfIX2lfH5/K+zGaIUZgUv53V29lfzteS937K2ElQ0AoKiGxcbU1FQOHDiQ48ePZ3h4OGfPnm3UqQGAFtaQyyi1Wi1DQ0M5evRoBgcHs3379uzbty+VSqURpwcAWlhDVjZOnDiRM2fOpL+/P0kyODiYycnJnDp1qhGnBwBaWENWNiqVSvr6+tLR0ZEkaW9vT19fX8bGxvLxj3/8kufPz89nfn5+8etqtZokmZuba8Q4S9TmX7/ksUb9nIvPvZLzlpynlOVmXovl/ns26txr/fkXey/ngY3u4t+5tf5+bdT/b12rlfztWcvfp7W4cN56vf6Tn1hvgHvuuad+2223LXns9ttvr3/6059e9vlf+tKX6kkcDofD4XBsgGNmZuYndkJDVjY6OjoWVzUuqNVqqdVqyz7/0KFDOXjw4JLnvvLKK7n++uvT1tbWiJHWZG5uLtu2bcvMzEy6urqaNgfL8/qsb16f9c3rs7616utTr9fz2muvpaen5yc+ryGxsWXLlpw8eXLJY9VqNVu3bl32+Z2dnens7Fzy2LXXXtuIURqiq6urpV7sK43XZ33z+qxvXp/1rRVfn+7u7nd8TkPeILpnz55MTU0tXrNZWFjI9PR0BgYGGnF6AKCFNSQ2du/enZ6enkxMTCRJxsfH09vbm127djXi9ABAC2vIZZRNmzZldHQ0hw8fzunTp1OpVDIyMtLU91+sRWdnZ770pS9dcomH9cHrs755fdY3r8/6ttFfn7b6O+5XAQBYO/dGAQCKEhsAQFFiAwAoSmzQ8l566aVmjwAt4Y033li8PQS8l8TGZczNzeV3f/d3c+2116avry//+I//2OyRuMjU1FTuvvvu3H333c0e5Yo3NTWVAwcO5Pjx4xkeHs7Zs2ebPRJvU6vV8uijj2bHjh155plnmj0OF/n2t7+dnTt3pqurK5/85CfzyiuvNHukhhMbl/Fnf/Zn+Z3f+Z2Mj4/nE5/4RPbv35+pqalmj8Xb1Ov1XHfddZf9WHzeG7VaLUNDQ7nrrrtyzz33ZP/+/dm3b1+zx+JtXn755QwMDOT5559v9ihc5LnnnstTTz2VJ598Mo8++mi+//3v54tf/GKzx2o4sbGMhYWFfPSjH82v/dqv5WMf+1j+9m//Nps2bcq///u/N3s03qavry/vf//7mz3GFe/EiRM5c+ZM+vv7kySDg4OZnJzMqVOnmjwZF9xwww25+eabmz0Gyzh58mSOHTuWW265Jb/5m7+Z+++/P//2b//W7LEaTmwso6OjI8PDw4tf/9RP/VS6u7vzoQ99qIlTwfpUqVTS19e3eDPG9vb29PX1ZWxsrMmTwfo3PDycq666avHrG2+8cUP+rREbK/D8889n69at+YVf+IVmjwLrzrlz5y65cVR3d3dmZ2ebNBG0rqeffjr33ntvs8doOLGxAn/913+d48ePN3sMWJc6OjoWVzUuqNVq3ksDq/Tiiy/m/Pnz2bt3b7NHabiG3Bul1bz44ov5uZ/7ucv++2//9m/n6NGjSZKxsbF87GMfy6233vpejXfFW83rQ/Nt2bIlJ0+eXPJYtVrN1q1bmzQRtJ633norDz/8cI4dO9bsUYq4ImNjy5YtK/pshmeffTbPPfdcfu/3fi9Jcv78+bS3t7fcDeZazUpfH9aHPXv25Mtf/nLq9Xra2tqysLCQ6enpDAwMNHs0aBlf+cpXcvDgwVx99dVJkh//+MdL3svR6lxGuYyXXnopf/VXf5Xbb789zz77bCYnJ3PkyJFmj8VFLNU33+7du9PT05OJiYkkyfj4eHp7e7Nr164mT8bb+V1Zvx5++OHs2LEjr776ap599tl85zvfyXe/+91mj9VQV+TKxjt58803c+edd+bpp59esqT1h3/4h1Y11pHx8fGMjo5mdnY2Tz75ZH7913/9kvcOUN6mTZsyOjqaw4cP5/Tp06lUKhkZGfG7so786Ec/yiOPPJIkefzxx3PTTTflwx/+cJOnIkm+/vWv5+DBg3n7Ddg3b9684VZ33WIeACjKZRQAoCixAQAUJTYAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICi/g9g1ZIO9L0YCQAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(ypred, bins = 100)" + ] + }, + { + "cell_type": "code", + "execution_count": 113, + "id": "a1553335-9913-49da-98d8-a97a4011f52b", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.25293207" + ] + }, + "execution_count": 113, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ypred.mean()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f14fa5e5-408a-4f27-9be9-a518ba50d9c6", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df362e62-98dd-4c6e-b6a0-aa2d3a9e7b1c", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "767868c1-c425-4f69-ac21-46cf0f0d1abf", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e3081da-ea32-4c2e-8615-51ab626a8d64", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b3bd1ad-9c17-4846-8cbd-02b05a54a1ef", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de267aa7-5079-4c00-8a83-9399ca59070f", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "3d462c6f-db0c-4439-ac84-c41ff5704fb0", + "metadata": { + "tags": [] + }, + "source": [ + "## second part" + ] + }, + { + "cell_type": "code", + "execution_count": 206, + "id": "0e937f99-22c8-4d7b-8db2-a2bba14b6ecf", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.distributions as D\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 331, + "id": "e6fbfa63-3f7d-4619-a910-d39097ba69af", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "base_distribution = D.normal.Normal(loc=0, scale=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 333, + "id": "75c1b13e-bef8-4ad1-9af5-acd09be079c2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Log probability of the sample in the base distribution:\n", + "tensor([-90.7244])\n" + ] + } + ], + "source": [ + "# Sample from the base Gaussian distribution\n", + "z_sample = torch.Tensor([-13.4019])\n", + "\n", + "# Compute log probability for a given sample\n", + "log_prob_z = base_distribution.log_prob(z_sample)\n", + "print(\"Log probability of the sample in the base distribution:\")\n", + "print(log_prob_z)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 336, + "id": "b72d9238-ff08-4071-93e0-be175676dadb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 1.1044e+00, 5.3608e-01, 2.7223e-01, -4.0974e-01, -1.2630e+00,\n", + " 2.7436e-01, 4.5724e-01, 7.1562e-01, 2.5585e-01, -1.7401e+00,\n", + " 3.6557e-01, 1.9815e+00, 8.3127e-02, -2.8310e-02, 1.5944e-01,\n", + " -1.5677e+00, 8.7571e-01, -1.4026e+00, 9.6980e-01, -5.7682e-01,\n", + " 2.6239e-01, 5.7484e-01, -5.7243e-01, -2.1229e-01, 6.5120e-01,\n", + " -4.3072e-01, 4.2276e-01, 7.4024e-01, -5.8200e-01, -1.1932e+00,\n", + " -7.1932e-01, 2.9953e-01, 1.0649e+00, -1.2855e-01, 3.5041e-01,\n", + " 4.8927e-01, 6.0499e-01, 1.1458e+00, -6.1707e-01, 2.1834e+00,\n", + " 9.1538e-01, 7.9090e-01, 6.4550e-01, 9.1880e-01, 3.5349e-01,\n", + " 8.1434e-01, 1.1906e-01, -5.1368e-01, -8.0531e-01, 3.7218e-01,\n", + " -2.8462e-01, 5.1760e-01, 5.2333e-01, -1.1095e+00, -1.7090e-01,\n", + " -5.8244e-02, -3.0572e-01, 2.3676e-01, 4.0046e-01, 9.4080e-01,\n", + " 1.8561e+00, -5.2990e-01, 5.1445e-02, -6.2367e-02, 5.9578e-01,\n", + " 1.5714e-01, 6.5467e-01, 5.2255e-01, -5.2550e-01, -4.4779e-01,\n", + " 9.3960e-01, 8.7942e-01, 5.3973e-02, 5.6147e-01, -1.0643e+00,\n", + " -5.3023e-01, -1.5540e+00, -1.2863e-01, 9.1037e-01, -4.5574e-01,\n", + " -2.3140e-02, 1.2005e-01, 1.5731e+00, 1.6403e-01, 8.3374e-01,\n", + " -3.0596e-01, 6.6844e-01, -1.3816e+00, -4.3452e-01, -2.6365e-01,\n", + " -8.3384e-01, 2.4906e-01, -1.1455e+00, -1.9485e-01, -2.0116e+00,\n", + " 7.8016e-02, -9.5747e-01, 1.4937e-01, -2.5925e-01, -1.5145e+00,\n", + " 3.1507e-01, -6.2256e-01, 1.5290e+00, 6.3536e-01, 4.4222e-01,\n", + " -1.1336e+00, 8.3353e-01, -1.3868e-02, -1.4489e+00, -4.7428e-01,\n", + " 2.0545e+00, -4.0540e-01, -9.5925e-01, 1.4628e-01, 9.3416e-01,\n", + " 1.2972e-01, 2.8183e+00, -9.3750e-01, -1.3405e+00, 9.7499e-01,\n", + " 1.4499e+00, 2.3749e-01, -3.5307e-01, 1.7120e-01, 1.4509e+00,\n", + " -5.8411e-01, -1.6602e+00, -2.8575e-01, -8.8415e-01, 4.7089e-01,\n", + " -5.0467e-01, 2.1970e-01, -8.0355e-01, -6.7352e-01, -1.8423e+00,\n", + " 1.3197e+00, 7.5914e-01, -1.4738e+00, -2.4028e-01, 2.0500e+00,\n", + " 2.5433e-02, 1.3753e+00, 3.0881e-01, -7.2112e-01, -6.8736e-01,\n", + " -4.6874e-02, -1.4603e-01, 1.2329e+00, -4.2673e-01, -1.3867e+00,\n", + " 1.6686e+00, 4.1296e-01, -6.9377e-01, -1.2500e+00, -4.4740e-01,\n", + " -1.5826e+00, -4.4897e-01, 1.3707e+00, 8.3077e-01, 6.6048e-02,\n", + " 5.5191e-01, -2.4148e+00, 2.8505e-01, -2.8783e-01, 2.9958e-01,\n", + " 6.1308e-01, 1.2282e+00, 4.6869e-01, -1.7393e+00, -3.4317e-01,\n", + " -7.1529e-01, 1.1789e+00, -8.7934e-01, 5.3453e-01, 3.5537e-01,\n", + " -1.2870e+00, 1.2393e+00, 1.3256e+00, 4.3599e-01, 7.0790e-01,\n", + " -2.9851e-01, -2.5265e+00, 1.3229e+00, -3.3910e-01, -1.8691e-01,\n", + " -6.6734e-02, 2.7792e-01, -8.5344e-01, 6.6565e-01, -8.7678e-01,\n", + " -3.7833e-01, -4.5579e-01, -1.7119e+00, -4.4352e-01, -1.4852e+00,\n", + " -7.9120e-01, -7.5951e-01, -9.6410e-01, 9.1380e-01, -5.1296e-01,\n", + " -6.8361e-01, -1.2658e+00, 2.4770e+00, 2.1466e-01, -3.8021e-01,\n", + " -1.2697e+00, -2.3289e-01, -1.1961e+00, -3.6432e-03, -1.5040e-01,\n", + " 2.7548e-01, -1.7180e+00, 8.5279e-01, 1.6454e+00, 2.8954e-01,\n", + " 2.3629e+00, 1.7833e+00, -7.7760e-01, -5.9648e-03, 1.7025e-01,\n", + " -9.8024e-03, 2.1076e-01, -3.6466e-01, -4.6469e-01, 1.2334e+00,\n", + " -6.9017e-01, -9.6955e-01, -2.1257e+00, 8.3989e-01, -1.1027e-01,\n", + " -6.6110e-01, 4.2615e-01, -1.1948e-01, -1.8326e+00, -2.4196e+00,\n", + " -2.0621e-01, -4.6798e-01, -4.8596e-01, -1.6059e+00, -8.3020e-01,\n", + " -2.2515e+00, 4.5996e-01, 1.5545e+00, 9.9873e-01, 3.9866e-01,\n", + " 2.4323e-01, 1.3789e+00, 6.7689e-01, -2.6021e-02, 5.1618e-01,\n", + " -3.6143e-01, -9.7261e-01, 5.2108e-01, -8.0599e-01, 8.5165e-01,\n", + " 1.4695e+00, 1.1753e+00, -3.0389e-01, -2.4687e-01, 2.5467e-01,\n", + " -3.8649e-01, -1.3182e+00, 7.6464e-01, 1.5648e+00, 3.0884e-01,\n", + " 1.3561e-01, -9.0414e-01, -7.4026e-01, -4.4211e-01, 2.3025e-01,\n", + " -1.8918e-01, -4.2127e-01, 2.2499e-01, 5.1024e-01, -1.2259e+00,\n", + " -3.8215e-01, 3.7701e-01, -8.2513e-01, 1.2285e+00, -8.4012e-01,\n", + " -2.6810e-02, -4.3456e-01, 4.8574e-01, -1.0033e-01, -1.3786e+00,\n", + " 1.0564e+00, 4.2144e-01, 6.4922e-01, 1.0002e+00, -1.7159e+00,\n", + " 1.2216e+00, -8.9945e-03, -1.1934e+00, -2.8279e-01, -7.0490e-01,\n", + " -1.9081e-01, -9.7353e-01, -2.7479e+00, -4.5797e-01, 6.8582e-01,\n", + " -1.3553e+00, 2.7013e-01, -2.3826e-01, -1.6231e+00, 2.1371e-01,\n", + " 5.6384e-01, 2.1960e-01, 2.4150e-01, 6.2061e-01, -1.0489e-01,\n", + " 1.7517e+00, 6.7076e-01, -1.3978e-01, -8.2445e-02, 3.9530e-01,\n", + " 5.1329e-01, -1.9216e-01, 1.1939e+00, -2.8467e+00, 1.5854e+00,\n", + " 2.2636e-01, 1.0687e+00, -7.1430e-01, 1.1974e+00, 1.3638e-02,\n", + " 1.1439e+00, -1.5609e+00, -4.2655e-01, 3.8575e-01, 7.7837e-01,\n", + " -4.0150e-01, 5.4693e-01, 1.9081e+00, -7.4662e-01, -4.7376e-01,\n", + " 6.9946e-01, -1.9914e+00, -2.4258e-01, -9.3687e-01, 6.8541e-01,\n", + " 1.0665e+00, 2.1555e-01, -2.4295e-01, -1.1957e+00, 6.0446e-01,\n", + " 4.0213e-01, 1.4017e+00, 2.7869e-01, -9.5894e-01, -1.5638e+00,\n", + " -1.0110e-01, 4.6773e-01, 1.2827e+00, -9.1729e-01, 1.6297e+00,\n", + " 4.5764e-01, 8.8216e-01, 2.2360e-01, 1.0108e+00, 7.2386e-01,\n", + " 2.0961e+00, 2.4810e-01, -1.4284e+00, 5.6877e-01, 1.2502e+00,\n", + " -1.3409e-01, -3.0333e-01, -1.8817e-01, -2.2237e+00, -5.6682e-02,\n", + " -6.2232e-01, -4.6465e-01, -5.4288e-01, 7.8277e-01, 4.5549e-01,\n", + " 5.9486e-01, 6.6038e-02, 1.5156e+00, 8.7062e-01, -4.8555e-01,\n", + " 2.7363e-01, -2.2114e-01, 3.1821e-01, 6.6405e-01, 1.2378e+00,\n", + " 4.2984e-01, 2.2287e-01, -1.2077e+00, 9.1887e-01, 3.2677e-01,\n", + " -7.3290e-01, -2.2486e+00, 8.0229e-01, -8.8998e-01, -1.9330e+00,\n", + " 1.1855e+00, 6.3310e-01, -6.8345e-02, -1.1391e-01, 4.5725e-01,\n", + " 1.6477e+00, -2.6330e-01, 5.4432e-01, 1.3206e+00, 2.0895e+00,\n", + " 8.0402e-01, -1.7470e+00, 9.2834e-01, -8.7679e-01, 3.2405e-01,\n", + " 6.0144e-01, -1.2082e+00, 1.2006e+00, -6.3478e-01, 9.0314e-01,\n", + " -1.3630e+00, 1.3225e+00, -3.9263e-02, 2.1543e-01, -5.2079e-01,\n", + " 7.9209e-01, 2.1396e+00, 1.8184e+00, -1.5520e+00, 2.4373e-01,\n", + " 4.9069e-01, 3.7375e-01, 5.0519e-01, 8.2489e-02, -8.1836e-01,\n", + " -3.7315e-01, -6.3905e-01, 1.4168e+00, 1.0993e+00, -4.3926e-01,\n", + " 5.3011e-02, -2.4621e-01, 4.7379e-01, -3.1596e-01, 1.1466e+00,\n", + " 1.9753e-02, -4.1847e-01, -1.6455e+00, 6.2989e-01, 3.0642e-01,\n", + " -3.4217e-01, 3.7732e-01, 5.7856e-01, 3.5808e-01, -6.4435e-01,\n", + " -1.2664e+00, 8.9284e-01, 8.6703e-01, -1.1849e+00, 2.2854e-01,\n", + " -8.9665e-01, -7.4685e-01, -8.3032e-01, -7.6299e-01, 6.6947e-01,\n", + " -6.4320e-01, 1.2720e+00, -4.8314e-01, 1.0369e+00, -8.2754e-02,\n", + " -3.6780e-01, -8.5765e-01, -2.6584e-01, -8.4592e-01, 2.0900e+00,\n", + " -3.3571e-01, -2.7365e-01, 4.2045e-01, -5.0189e-01, 2.1774e+00,\n", + " 1.2684e+00, 1.0876e+00, 6.7469e-01, 1.8769e+00, 1.7561e+00,\n", + " 6.3241e-02, 1.1177e-01, -6.9672e-01, -1.3041e-01, -1.5295e-01,\n", + " 1.1734e+00, 2.0096e+00, 1.3440e+00, 4.1154e-01, -2.0337e+00,\n", + " 2.9413e-01, 6.3165e-01, 6.6015e-02, 7.9139e-01, -4.2177e-01,\n", + " -9.5036e-01, -9.5504e-01, -2.5653e-01, -9.4473e-01, -1.0939e+00,\n", + " 1.1041e-01, -5.3169e-01, -7.5699e-01, 5.3319e-01, -6.8033e-01,\n", + " 5.3278e-01, -1.2020e+00, -1.9659e-01, -7.2072e-01, 3.5778e-01,\n", + " 4.1554e-01, -1.4242e-01, -1.2204e+00, 1.3340e-01, 1.3910e+00,\n", + " 2.8282e-02, 3.6570e-01, 5.4003e-01, 2.2741e+00, 7.6066e-02,\n", + " -1.1488e+00, 6.1675e-01, -9.8228e-01, -6.6162e-01, 1.6759e+00,\n", + " -1.1585e+00, 9.2543e-01, -4.0777e-01, 5.8472e-01, 5.7255e-01,\n", + " 2.0956e+00, 1.6141e-01, 1.3108e+00, -4.2879e-01, -1.1836e-01,\n", + " 3.0701e-01, -2.3178e+00, 1.8452e-01, 1.0042e+00, -1.0290e+00,\n", + " 1.3379e+00, -7.2549e-01, -6.1016e-01, -1.3704e+00, -5.7965e-01,\n", + " -6.2787e-01, 4.3845e-02, 3.0263e-01, -2.3582e-01, -1.6874e+00,\n", + " -1.2794e-01, -6.3943e-01, -1.7925e+00, 4.9686e-01, 6.6856e-02,\n", + " -9.9491e-01, 1.2036e-01, -1.0225e+00, -4.4577e-01, 8.7954e-01,\n", + " -2.9278e-01, 1.1696e+00, 1.3101e-02, -2.4229e-01, 1.3554e+00,\n", + " -6.3189e-02, 5.8783e-01, -9.8299e-01, 9.6724e-01, 8.7076e-01,\n", + " -1.1199e+00, -1.5642e+00, -1.8559e+00, -8.2385e-01, 6.6140e-01,\n", + " -3.9254e-01, 6.1634e-01, 5.2854e-01, 1.0742e+00, -6.4902e-02,\n", + " 5.3602e-02, -1.5826e+00, 1.0997e+00, 2.4608e-01, -3.8086e-01,\n", + " 6.6185e-01, -9.8247e-01, 7.5218e-02, -4.1243e-02, -9.2047e-01,\n", + " -5.8027e-01, 6.8770e-01, -1.4893e+00, -9.4504e-01, 1.9702e-01,\n", + " 6.6918e-01, -1.1709e+00, -7.0232e-01, -4.1784e-01, 8.1003e-01,\n", + " -2.7804e-01, -2.0031e-01, 7.0527e-01, -9.6937e-01, 1.1689e+00,\n", + " 7.4006e-01, -2.5950e-01, 3.4501e-01, -2.3079e-01, -9.5328e-01,\n", + " 4.5314e-02, -1.5330e+00, -3.3205e-01, -9.7520e-02, -1.1451e+00,\n", + " 1.2243e+00, 4.0116e-01, 5.1238e-01, 1.0563e+00, 7.9822e-01,\n", + " 3.2488e-04, 4.1915e-01, -1.8082e-01, 9.7909e-01, -5.6570e-02,\n", + " 3.7949e-01, -4.4500e-02, 1.9252e-01, -9.0285e-01, -1.4165e+00,\n", + " 6.6683e-01, -1.2207e+00, 2.6244e-01, -1.0841e+00, -3.8274e-01,\n", + " 7.0959e-01, 9.6215e-01, -1.5444e+00, 4.9438e-03, 1.3579e+00,\n", + " 4.7280e-01, 1.2425e+00, 2.5537e-01, -5.1405e-01, 6.0571e-02,\n", + " -3.0876e-01, -1.6462e+00, -1.1509e-01, -4.7552e-02, -2.6791e-01,\n", + " -4.3740e-01, 3.1958e-01, -1.2730e+00, -1.2231e+00, -1.4263e+00,\n", + " 4.9323e-02, -1.3655e+00, 1.0031e+00, 2.3985e+00, 1.5285e+00,\n", + " -8.1026e-01, 3.5734e-01, -8.6429e-01, 1.7859e+00, -1.0850e+00,\n", + " -4.4293e-01, 1.5550e+00, -1.6774e-01, 1.4308e+00, -2.0135e+00,\n", + " -8.0758e-01, 6.4787e-01, -5.6576e-02, -3.9324e-01, 8.2594e-01,\n", + " 5.1615e-01, 5.9828e-01, 2.7394e+00, -1.4974e+00, 1.3907e+00,\n", + " 2.0605e+00, 1.7912e+00, -3.8680e-01, -1.8634e+00, 1.7960e-01,\n", + " -7.3744e-01, -3.6939e-01, -2.5540e-01, -3.5333e-01, 3.7214e-01,\n", + " 1.9065e+00, 1.7407e+00, 1.2818e+00, 4.7474e-01, -1.1279e+00,\n", + " -1.3318e-01, 1.0841e+00, 1.1111e+00, 2.5677e-01, 9.6003e-02,\n", + " 2.4728e+00, -6.9190e-01, -2.4013e+00, -1.9088e+00, -2.7976e-01,\n", + " 2.8647e-01, -1.0314e+00, 4.9384e-01, -9.7575e-01, 2.4125e-01,\n", + " -4.5776e-01, -2.1729e+00, 6.8295e-01, -5.9783e-01, -7.6568e-01,\n", + " -6.8171e-02, 7.2051e-01, 1.0234e+00, 1.3814e+00, 1.0481e+00,\n", + " 1.6249e+00, 1.4694e-02, 1.1664e+00, -8.8426e-01, 1.7891e-01,\n", + " 8.7546e-02, -9.6954e-01, 3.5897e-01, 8.3710e-01, -1.1977e-01,\n", + " 6.1722e-01, 2.1218e-01, -1.0700e+00, -1.0210e+00, 7.9014e-02,\n", + " 6.6533e-01, 5.0188e-01, 3.6993e-01, 6.7140e-01, -2.1729e-01,\n", + " 1.1855e+00, 9.7701e-01, 3.4555e-02, 5.2570e-02, 1.7804e+00,\n", + " -4.3345e-01, -9.1795e-01, 8.4220e-01, -5.2025e-01, -2.3512e+00,\n", + " -2.9133e-01, 1.8137e+00, -1.4506e-01, 7.7703e-01, -1.1247e+00,\n", + " 1.0406e+00, 1.4300e-01, -3.2812e-01, -5.9724e-01, -1.5874e+00,\n", + " 1.2625e+00, 2.5443e-01, 6.8958e-01, -1.5735e-01, -2.1421e+00,\n", + " -7.0841e-01, -4.5094e-01, -1.5603e+00, 7.7987e-01, -2.1264e+00,\n", + " 1.3828e+00, 6.0764e-01, -7.3796e-01, -2.4417e-01, -7.1150e-01,\n", + " 8.2985e-01, -7.7541e-02, 7.6857e-01, 3.1817e-01, -7.7757e-01,\n", + " -4.2241e-01, 6.5826e-01, -6.4465e-01, -5.9923e-01, -4.3871e-01,\n", + " 7.2439e-02, -3.3670e-01, 1.5531e-01, 5.0956e-01, 6.4318e-01,\n", + " 1.0634e-01, -5.2283e-02, 7.2015e-01, 1.1089e+00, 7.8364e-01,\n", + " -3.0195e-01, -3.6416e-01, 1.3103e+00, -4.6703e-01, 1.9210e+00,\n", + " 2.4408e-01, -1.4922e+00, 5.6181e-01, 7.4204e-01, 9.7397e-01,\n", + " 3.7613e-01, 1.2033e-03, -5.6261e-01, 1.3709e-01, 3.0908e-01,\n", + " 4.3589e-01, 9.6681e-01, 5.2447e-01, -6.7887e-01, 1.3580e+00,\n", + " 2.0728e-01, 1.3712e-01, -4.1442e-01, 9.4468e-02, 2.5556e-01,\n", + " -3.7138e-01, 1.5530e+00, 5.2251e-01, 4.6009e-01, 5.7644e-01,\n", + " 1.6343e+00, -9.3365e-01, -1.7162e+00, 5.0816e-01, 1.9417e-02,\n", + " -3.2217e-01, -1.3262e+00, -7.3292e-01, 1.6861e+00, -8.8239e-02,\n", + " -1.5955e+00, -9.8843e-01, 1.5626e+00, 2.0443e-01, -3.3774e-01,\n", + " -8.3594e-01, 2.5315e-01, 1.6415e+00, 1.0761e-01, 4.3909e-02,\n", + " -9.7492e-02, -6.0125e-01, -1.0249e+00, 4.4471e-01, 8.7049e-01,\n", + " -2.8564e-01, 3.5769e-01, -6.4562e-01, 2.4414e-01, 4.4689e-01,\n", + " 9.8379e-02, -2.0082e+00, 2.6912e+00, -2.9739e-02, 7.9058e-02,\n", + " -1.0467e+00, 1.2589e+00, 7.7871e-01, 1.6134e+00, -7.5795e-02,\n", + " 2.4459e-01, -9.5959e-01, -1.0011e+00, 1.5439e+00, 8.5851e-01,\n", + " 6.6582e-01, 8.2767e-02, 6.9187e-01, -8.3814e-02, 1.0908e+00,\n", + " 5.5916e-01, -3.6419e-01, 2.0466e-01, -8.0997e-01, 1.4612e-01,\n", + " -5.8160e-01, 3.4547e-01, -2.0894e-01, -2.3499e+00, -2.0449e-01,\n", + " 1.2959e-01, -1.0992e+00, -3.7099e-01, 5.7603e-01, -6.8101e-01,\n", + " 7.4702e-01, -9.2913e-01, -5.2333e-01, -1.1561e+00, 1.0236e+00,\n", + " 9.8274e-01, 1.8651e-01, 1.2910e-01, 7.5824e-01, -1.4217e+00,\n", + " 1.5712e+00, -1.5876e+00, -5.8153e-01, -9.3299e-01, -8.1509e-01,\n", + " 1.5932e+00, 1.6746e+00, -4.7422e-01, -7.5078e-01, 1.5711e+00,\n", + " -2.0033e-01, 1.1153e+00, 2.1060e+00, 5.4210e-01, -7.9347e-01,\n", + " 4.7698e-01, -1.0490e+00, -1.9694e-01, -5.3319e-02, 5.9719e-01,\n", + " -2.7117e-01, 5.3615e-01, 3.0759e-01, 6.2320e-01, 1.4930e-01,\n", + " 1.7076e+00, 9.3622e-02, -3.6664e-01, 1.7003e+00, -9.6456e-01,\n", + " 2.9677e-01, -9.8081e-02, 1.1619e+00, -5.0380e-01, -1.2547e+00,\n", + " -1.6950e+00, -1.7951e-01, -9.8662e-02, 6.6682e-01, 9.5207e-01,\n", + " -1.6040e-01, -1.6418e+00, -8.9001e-01, 8.7406e-02, -7.2148e-01,\n", + " 5.4667e-01, -1.5340e-01, 4.3919e-01, 2.4276e-01, -1.5593e+00,\n", + " 2.0327e+00, -1.7976e+00, -1.0149e+00, -1.3462e+00, -2.3772e+00,\n", + " 1.4590e+00, -3.0744e-01, 6.6394e-01, 4.7298e-02, -1.0030e+00,\n", + " 7.8366e-01, 1.8789e-01, 4.3368e-02, -6.8919e-01, -3.3976e-01,\n", + " -6.4345e-01, -3.7919e-01, 1.0008e+00, 9.6835e-01, -4.1755e-01,\n", + " -1.7020e-03, 2.9330e-01, 5.1436e-01, -2.6706e+00, -1.9252e-01,\n", + " -1.1766e-01, 3.9349e-01, 1.3534e+00, -4.5487e-01, -1.5176e+00,\n", + " 2.6356e-01, -2.7808e-01, -4.0173e-01, -7.6142e-01, -1.2565e+00,\n", + " 2.6196e-01, -3.3675e-01, 6.0729e-01, 9.9332e-01, -7.8205e-02,\n", + " -5.5831e-01, 6.3940e-01, -9.9503e-01, -1.8994e-01, 2.5763e-01,\n", + " -6.7560e-01, 6.0766e-02, 2.5248e-01, 1.2508e+00, 9.3228e-01])" + ] + }, + "execution_count": 336, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 252, + "id": "e66d6adc-0be2-45d0-8be1-adc1662c0374", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Object `log_prob_z` not found.\n" + ] + } + ], + "source": [ + "log_prob_z" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26503a10-7780-485b-907c-54b03d1ee868", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch import Tensor\n", + "\n", + "\n", + "class InvertedPlanarTransform(nn.Module):\n", + " \"\"\"Implementation of the invertible transformation used in planar flow:\n", + " f(z) = z + u * h(dot(w.T, z) + b)\n", + " See Section 4.1 in https://arxiv.org/pdf/1505.05770.pdf. \n", + " \"\"\"\n", + "\n", + " def __init__(self, dim):\n", + " \"\"\"Initialise weights and bias.\n", + " \n", + " Args:\n", + " dim: Dimensionality of the distribution to be estimated.\n", + " \"\"\"\n", + " super().__init__()\n", + " self.w = nn.Parameter(torch.randn(1, dim).normal_(0, 0.1))\n", + " self.b = nn.Parameter(torch.randn(1).normal_(0, 0.1))\n", + " self.u = nn.Parameter(torch.randn(1, dim).normal_(0, 0.1))\n", + "\n", + " def forward(self, z: Tensor) -> Tensor:\n", + " if torch.mm(self.u, self.w.T) < -1:\n", + " self.get_u_hat()\n", + "\n", + " return z + self.u * nn.Tanh()(torch.mm(z, self.w.T) + self.b)\n", + " \n", + " def log_det_J(self, z: Tensor) -> Tensor:\n", + " if torch.mm(self.u, self.w.T) < -1:\n", + " self.get_u_hat()\n", + " a = torch.mm(z, self.w.T) + self.b\n", + " psi = (1 - nn.Tanh()(a) ** 2) * self.w\n", + " abs_det = (1 + torch.mm(self.u, psi.T)).abs()\n", + " log_det = torch.log(1e-4 + abs_det)\n", + "\n", + " return log_det\n", + "\n", + " def get_u_hat(self) -> None:\n", + " \"\"\"Enforce w^T u >= -1. When using h(.) = tanh(.), this is a sufficient condition \n", + " for invertibility of the transformation f(z). See Appendix A.1.\n", + " \"\"\"\n", + " wtu = torch.mm(self.u, self.w.T)\n", + " m_wtu = -1 + torch.log(1 + torch.exp(wtu))\n", + " self.u.data = (\n", + " self.u + (m_wtu - wtu) * self.w / torch.norm(self.w, p=2, dim=1) ** 2\n", + " )\n", + "\n", + " def inverse(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Compute the inverse transformation.\n", + "\n", + " Args:\n", + " x: Input tensor.\n", + "\n", + " Returns:\n", + " Inverse transformed tensor.\n", + " \"\"\"\n", + " # Compute inverse of tanh using atanh (inverse hyperbolic tangent)\n", + " inverse_h = torch.atanh((x - self.u) / torch.mm(self.w, self.w.T))\n", + "\n", + " # Compute the inverse transformation\n", + " inverse_z = torch.mm(inverse_h - self.b, torch.inverse(self.w))\n", + "\n", + " return inverse_z\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89a9e463-5c22-428c-8548-08182fb77c5c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "\n", + "class ConditionalPlanarFlowModel(nn.Module):\n", + " def __init__(self, input_dim=6, target_dim=1, num_flows=1):\n", + " super(ConditionalPlanarFlowModel, self).__init__()\n", + " self.input_dim = input_dim\n", + " self.target_dim = target_dim\n", + " self.num_flows = num_flows\n", + "\n", + " # Create a list to hold the InvertedPlanarTransform instances\n", + " self.flow_list = nn.ModuleList([InvertedPlanarTransform(dim=input_dim) for _ in range(num_flows)])\n", + "\n", + " # Linear layer to predict the CPD for the target variable\n", + " self.linear_layer = nn.Linear(input_dim, target_dim)\n", + "\n", + " def forward(self, input_features):\n", + " # Apply all the InvertedPlanarTransform instances sequentially\n", + " transformed_features = input_features\n", + " for flow in self.flow_list:\n", + " transformed_features = flow(transformed_features)\n", + "\n", + " # Predict the CPD for the target variable\n", + " predicted_cpd = self.linear_layer(transformed_features)\n", + "\n", + " return predicted_cpd\n", + "\n", + "\n", + "def train_model(model, input_data, target_data, num_epochs=100, learning_rate=0.001):\n", + " optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n", + "\n", + " for epoch in range(num_epochs):\n", + " optimizer.zero_grad()\n", + " predicted_cpd = model(input_data)\n", + " loss = criterion(predicted_cpd, target_data)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + "# Step 5: Inference\n", + "def predict_cpd(model, input_features):\n", + " with torch.no_grad():\n", + " model.eval()\n", + " transformed_features = model.feature_transform.inverse(input_features)\n", + " predicted_cpd = model.linear_layer(transformed_features)\n", + " return predicted_cpd\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5dbcc9d-8661-41e6-a3e8-c31441e0491f", + "metadata": {}, + "outputs": [], + "source": [ + "# Step 3: Define your loss function (e.g., negative log-likelihood)\n", + "criterion = nn.NLLLoss() # Use a suitable loss function based on the nature of your CPD\n", + "\n", + "# Step 4: Training the model" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e92ade13-b05c-44d2-a5fb-7cbe61dc63e0", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "num_samples=10\n", + "input_dim=6\n", + "target_dim=1" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "6aa8e26a-007e-461c-8fb9-e4cd103e363a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Assuming you have your input_data and target_data as torch tensors:\n", + "input_data = torch.randn(num_samples, input_dim) # Replace with your actual data\n", + "target_data = torch.randn(num_samples, target_dim) # Replace with your actual data" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "802c6a15-5109-4ad5-ad42-bec6dc3dd838", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "0D or 1D target tensor expected, multi-target not supported", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_data\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[3], line 35\u001b[0m, in \u001b[0;36mtrain_model\u001b[0;34m(model, input_data, target_data, num_epochs, learning_rate)\u001b[0m\n\u001b[1;32m 33\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m 34\u001b[0m predicted_cpd \u001b[38;5;241m=\u001b[39m model(input_data)\n\u001b[0;32m---> 35\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mcriterion\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpredicted_cpd\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_data\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 36\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m 37\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/NFlow/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/NFlow/lib/python3.10/site-packages/torch/nn/modules/loss.py:216\u001b[0m, in \u001b[0;36mNLLLoss.forward\u001b[0;34m(self, input, target)\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor, target: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 216\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnll_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_index\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mignore_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreduction\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreduction\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/NFlow/lib/python3.10/site-packages/torch/nn/functional.py:2704\u001b[0m, in \u001b[0;36mnll_loss\u001b[0;34m(input, target, weight, size_average, ignore_index, reduce, reduction)\u001b[0m\n\u001b[1;32m 2702\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m size_average \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m reduce \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 2703\u001b[0m reduction \u001b[38;5;241m=\u001b[39m _Reduction\u001b[38;5;241m.\u001b[39mlegacy_get_string(size_average, reduce)\n\u001b[0;32m-> 2704\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_nn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnll_loss_nd\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_Reduction\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_enum\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreduction\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_index\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mRuntimeError\u001b[0m: 0D or 1D target tensor expected, multi-target not supported" + ] + } + ], + "source": [ + "train_model(model, input_data, target_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "7296cbf5-10f7-4dc6-a30a-96bf70b7e577", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "ename": "TypeError", + "evalue": "randn(): argument 'size' must be tuple of ints, but found element of type Tensor at pos 2", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[38], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m target_data \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn(num_samples, target_dim) \u001b[38;5;66;03m# Replace with your actual data\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;66;03m# Step 2: Create the model\u001b[39;00m\n\u001b[0;32m----> 6\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mConditionalPlanarFlowModel\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_dim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_dim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtarget_data\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;66;03m# Step 4: Train the model\u001b[39;00m\n\u001b[1;32m 9\u001b[0m train_model(model, input_data, target_data)\n", + "Cell \u001b[0;32mIn[35], line 13\u001b[0m, in \u001b[0;36mConditionalPlanarFlowModel.__init__\u001b[0;34m(self, input_dim, target_dim)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtarget_dim \u001b[38;5;241m=\u001b[39m target_dim\n\u001b[1;32m 12\u001b[0m \u001b[38;5;66;03m# Feature transformation using InvertedPlanarTransform\u001b[39;00m\n\u001b[0;32m---> 13\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfeature_transform \u001b[38;5;241m=\u001b[39m \u001b[43mInvertedPlanarTransform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_dim\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;66;03m# Linear layer to predict the CPD for the target variable\u001b[39;00m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlinear_layer \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mLinear(input_dim, target_dim)\n", + "Cell \u001b[0;32mIn[34], line 19\u001b[0m, in \u001b[0;36mInvertedPlanarTransform.__init__\u001b[0;34m(self, dim)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Initialise weights and bias.\u001b[39;00m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;124;03m\u001b[39;00m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;124;03m dim: Dimensionality of the distribution to be estimated.\u001b[39;00m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m()\n\u001b[0;32m---> 19\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mw \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mParameter(\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrandn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mnormal_(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0.1\u001b[39m))\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mb \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mParameter(torch\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mnormal_(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0.1\u001b[39m))\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mu \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mParameter(torch\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m1\u001b[39m, dim)\u001b[38;5;241m.\u001b[39mnormal_(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0.1\u001b[39m))\n", + "\u001b[0;31mTypeError\u001b[0m: randn(): argument 'size' must be tuple of ints, but found element of type Tensor at pos 2" + ] + } + ], + "source": [ + "# Assuming you have your input_data and target_data as torch tensors:\n", + "input_data = torch.randn(num_samples, input_dim) # Replace with your actual data\n", + "target_data = torch.randn(num_samples, target_dim) # Replace with your actual data\n", + "\n", + "# Step 2: Create the model\n", + "\n", + "# Step 4: Train the model\n", + "\n", + "\n", + "# Step 5: Predict the CPD for new input features\n", + "new_input_features = torch.randn(num_samples, input_dim) # Replace with your new data\n", + "predicted_cpd = predict_cpd(model, new_input_features)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "fe0b227d-9aa5-4c99-91a4-0e558795328c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\u001b[0;31mDocstring:\u001b[0m\n", + "randn(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor\n", + "\n", + "\n", + "Returns a tensor filled with random numbers from a normal distribution\n", + "with mean `0` and variance `1` (also called the standard normal\n", + "distribution).\n", + "\n", + ".. math::\n", + " \\text{out}_{i} \\sim \\mathcal{N}(0, 1)\n", + "\n", + "The shape of the tensor is defined by the variable argument :attr:`size`.\n", + "\n", + "Args:\n", + " size (int...): a sequence of integers defining the shape of the output tensor.\n", + " Can be a variable number of arguments or a collection like a list or tuple.\n", + "\n", + "Keyword args:\n", + " generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling\n", + " out (Tensor, optional): the output tensor.\n", + " dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.\n", + " Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).\n", + " layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.\n", + " Default: ``torch.strided``.\n", + " device (:class:`torch.device`, optional): the desired device of returned tensor.\n", + " Default: if ``None``, uses the current device for the default tensor type\n", + " (see :func:`torch.set_default_tensor_type`). :attr:`device` will be the CPU\n", + " for CPU tensor types and the current CUDA device for CUDA tensor types.\n", + " requires_grad (bool, optional): If autograd should record operations on the\n", + " returned tensor. Default: ``False``.\n", + " pin_memory (bool, optional): If set, returned tensor would be allocated in\n", + " the pinned memory. Works only for CPU tensors. Default: ``False``.\n", + "\n", + "Example::\n", + "\n", + " >>> torch.randn(4)\n", + " tensor([-2.1436, 0.9966, 2.3426, -0.6366])\n", + " >>> torch.randn(2, 3)\n", + " tensor([[ 1.5954, 2.8929, -1.0923],\n", + " [ 1.1719, -0.4709, -0.1996]])\n", + "\u001b[0;31mType:\u001b[0m builtin_function_or_method" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "torch.randn?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "160e45c5-79fa-43bc-b4a5-3bb633880230", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cfa9af6a-e0e9-4171-b46f-ca21255e253d", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "918bd133-21b6-40b1-a264-c1ba88cc8b49", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "3f8c8a08-854d-44f7-83eb-23a6a9e0cf93", + "metadata": {}, + "outputs": [], + "source": [ + "class PlanarFlow(nn.Module):\n", + " def __init__(self, dim: int = 2, K: int = 6):\n", + " \"\"\"Make a planar flow by stacking planar transformations in sequence.\n", + "\n", + " Args:\n", + " dim: Dimensionality of the distribution to be estimated.\n", + " K: Number of transformations in the flow. \n", + " \"\"\"\n", + " super().__init__()\n", + " self.layers = [InvertedPlanarTransform(dim) for _ in range(K)]\n", + " self.model = nn.Sequential(*self.layers)\n", + "\n", + " def forward(self, z: Tensor) -> Tuple[Tensor, float]:\n", + " log_det_J = 0\n", + "\n", + " for layer in self.layers:\n", + " log_det_J += layer.log_det_J(z)\n", + " z = layer(z)\n", + "\n", + " return z, log_det_J" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "789351d3-262f-43b8-8132-89d053681ed7", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "model = PlanarFlow(dim=2, K=2)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "7cf722c0-6d93-437e-9679-1a4704831d73", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "PlanarFlow(\n", + " (model): Sequential(\n", + " (0): InvertedPlanarTransform()\n", + " (1): InvertedPlanarTransform()\n", + " )\n", + ")" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "48a9b31f-f3e3-49d7-8b1a-6cfe426bbbe3", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "ename": "TypeError", + "evalue": "object.__init__() takes exactly one argument (the instance to initialize)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[3], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m identity_flow \u001b[38;5;241m=\u001b[39m \u001b[43mIdentityFlow\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mzeros\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_dims\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/EUCLID/INSIGHT/NF/flows/IdentityFlow.py:14\u001b[0m, in \u001b[0;36mIdentityFlow.__init__\u001b[0;34m(self, params, n_dims, name)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, params, n_dims, name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mIdentityFlow\u001b[39m\u001b[38;5;124m'\u001b[39m):\n\u001b[1;32m 10\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;124;03m :param params: shape (?, 1), this will become alpha and define the slow of ReLU for x < 0\u001b[39;00m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;124;03m :param n_dims: Dimension of the distribution that's being transformed\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 14\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mIdentityFlow\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 15\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_dims\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 16\u001b[0m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mTypeError\u001b[0m: object.__init__() takes exactly one argument (the instance to initialize)" + ] + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "b2d50e1e-5adc-4bdd-87cb-32c78c883b68", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "planar_flow = InvertedPlanarFlow(params=torch.zeros((1, 2 * 1 + 1)), n_dims=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "cb981f0f-113e-45f2-a027-2e003065bdd8", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "InvertedPlanarFlow()" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "planar_flow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74688079-9a3f-4f87-b93f-4a191325e6ac", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce29cd37-0b15-4cf9-a55c-da73f38d67ea", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80b9fbda-6377-4df2-9caa-3fe741ec3bb6", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b1d73f2-51a8-4775-850a-da4981fd7d15", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "2a1bd67f-8253-4ebd-b69e-16a1418c6e31", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-08-03 07:26:44.817779: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2023-08-03 07:26:47.725659: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2023-08-03 07:27:10.847442: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", + "2023-08-03 07:27:53.690091: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1960] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n", + "Skipping registering GPU devices...\n" + ] + } + ], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_probability as tfp\n", + "import numpy as np\n", + "\n", + "tfd = tfp.distributions\n", + "\n", + "\n", + "class BaseEstimator(tf.keras.Sequential):\n", + " x_noise_std = tf.Variable(initial_value=0.0, dtype=tf.float32, trainable=False)\n", + " y_noise_std = tf.Variable(initial_value=0.0, dtype=tf.float32, trainable=False)\n", + "\n", + " def __init__(self, layers, noise_fn_type=\"fixed_rate\", noise_scale_factor=0.0, random_seed=22):\n", + " tf.random.set_seed(random_seed)\n", + " self.noise_fn_type = noise_fn_type\n", + " self.noise_scale_factor = noise_scale_factor\n", + "\n", + " super().__init__(layers)\n", + "\n", + " def fit(self, x, y, batch_size=None, epochs=None, verbose=1, **kwargs):\n", + " self._assign_data_normalization(x, y)\n", + " assert len(x.shape) == len(y.shape) == 2, \"Please pass a matrix not a vector\"\n", + " self._assign_noise_regularisation(n_dims=x.shape[1] + y.shape[1], n_datapoints=x.shape[0])\n", + " super().fit(\n", + " x=x,\n", + " y=y,\n", + " batch_size=batch_size,\n", + " epochs=epochs,\n", + " verbose=verbose,\n", + " callbacks=[tf.keras.callbacks.TerminateOnNaN()],\n", + " **kwargs,\n", + " )\n", + "\n", + " def _assign_noise_regularisation(self, n_dims, n_datapoints):\n", + " assert self.noise_fn_type in [\"rule_of_thumb\", \"fixed_rate\"]\n", + " if self.noise_fn_type == \"rule_of_thumb\":\n", + " noise_std = self.noise_scale_factor * (n_datapoints + 1) ** (-1 / (4 + n_dims))\n", + " self.x_noise_std.assign(noise_std)\n", + " self.y_noise_std.assign(noise_std)\n", + " elif self.noise_fn_type == \"fixed_rate\":\n", + " self.x_noise_std.assign(self.noise_scale_factor)\n", + " self.y_noise_std.assign(self.noise_scale_factor)\n", + "\n", + " def score(self, x_data, y_data):\n", + " x_data = x_data.astype(np.float32)\n", + " y_data = y_data.astype(np.float32)\n", + " nll = self._get_neg_log_likelihood()\n", + " return -nll(y_data, self.call(x_data, training=False)).numpy().mean()\n", + "\n", + " def _assign_data_normalization(self, x, y):\n", + " self.x_mean = np.mean(x, axis=0, dtype=np.float32)\n", + " self.y_mean = np.mean(y, axis=0, dtype=np.float32)\n", + " self.x_std = np.std(x, axis=0, dtype=np.float32)\n", + " self.y_std = np.std(y, axis=0, dtype=np.float32)\n", + "\n", + " def _get_neg_log_likelihood(self):\n", + " y_input_model = self._get_input_model()\n", + " return lambda y, p_y: -p_y.log_prob(y_input_model(y)) + tf.reduce_sum(\n", + " tf.math.log(self.y_std)\n", + " )\n", + "\n", + " def _get_input_model(self):\n", + " y_input_model = tf.keras.Sequential()\n", + " # add data normalization layer\n", + " y_input_model.add(\n", + " tf.keras.layers.Lambda(lambda y: (y - tf.ones_like(y) * self.y_mean) / self.y_std)\n", + " )\n", + " # noise will be switched on during training and switched off otherwise automatically\n", + " y_input_model.add(tf.keras.layers.GaussianNoise(self.y_noise_std))\n", + " return y_input_model\n", + "\n", + " def pdf(self, x, y):\n", + " assert x.shape == y.shape\n", + " output = self(x)\n", + " y_circ = (y - tf.ones_like(y) * self.y_mean) / self.y_std\n", + " return output.prob(y_circ) / tf.reduce_prod(self.y_std)\n", + "\n", + " def log_pdf(self, x, y):\n", + " x = x.astype(np.float32)\n", + " y = y.astype(np.float32)\n", + " assert x.shape == y.shape\n", + "\n", + " output = self(x)\n", + " assert output.event_shape == y.shape[-1]\n", + "\n", + " y_circ = (y - tf.ones_like(y) * self.y_mean) / self.y_std\n", + " return output.log_prob(y_circ) - tf.reduce_sum(tf.math.log(self.y_std))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "84797c3a-3973-40d8-991d-cf375a88a416", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_probability as tfp\n", + "from tensorflow.python import tf2\n", + "import numpy as np\n", + "from sklearn.cluster import KMeans\n", + "from sklearn.metrics.pairwise import cosine_distances\n", + "\n", + "if not tf2.enabled():\n", + " import tensorflow.compat.v2 as tf\n", + "\n", + " tf.enable_v2_behavior()\n", + " assert tf2.enabled()\n", + "tfd = tfp.distributions\n", + "\n", + "\n", + "class MeanFieldLayer(tfp.layers.DistributionLambda):\n", + " def __init__(self, n_dims, scale=None, map_mode=False, dtype=None):\n", + " \"\"\"\n", + " A subclass of Distribution Lambda. A layer that uses it's input to parametrize n_dims-many indepentent normal\n", + " distributions (aka mean field)\n", + " Requires input size n_dims for fixed scale, 2*n_dims for trainable scale\n", + " Mean Field also works for scalars\n", + " The input tensors for this layer should be initialized to Zero for a standard normal distribution\n", + " :param n_dims: Dimension of the distribution that's being output by the Layer\n", + " :param scale: (float) None if scale should be trainable. If not None, specifies the fixed scale of the\n", + " independent normals. If map mode is activated, this is ignored and set to 1.0\n", + " \"\"\"\n", + " self.n_dims = n_dims\n", + " self.scale = scale\n", + "\n", + " if map_mode:\n", + " self.scale = 1.0\n", + " convert_ttf = tfd.Distribution.mean if map_mode else tfd.Distribution.sample\n", + "\n", + " make_dist_fn = self._get_distribution_fn(self.n_dims, self.scale)\n", + "\n", + " super().__init__(\n", + " make_distribution_fn=make_dist_fn, convert_to_tensor_fn=convert_ttf, dtype=dtype\n", + " )\n", + "\n", + " @staticmethod\n", + " def _get_distribution_fn(n_dims, scale=None):\n", + " if scale is None:\n", + "\n", + " def dist_fn(t):\n", + " assert t.shape[-1] == 2 * n_dims\n", + " return tfd.Independent(\n", + " tfd.Normal(\n", + " loc=t[..., 0:n_dims],\n", + " scale=1e-3\n", + " + tf.nn.softplus(\n", + " tf.math.log(tf.math.expm1(1.0)) + 0.05 * t[..., n_dims : 2 * n_dims]\n", + " ),\n", + " ),\n", + " reinterpreted_batch_ndims=1,\n", + " )\n", + "\n", + " else:\n", + " assert scale > 0.0\n", + "\n", + " def dist_fn(t):\n", + " assert t.shape[-1] == n_dims\n", + " return tfd.Independent(\n", + " tfd.Normal(loc=t[..., 0:n_dims], scale=scale), reinterpreted_batch_ndims=1\n", + " )\n", + "\n", + " return dist_fn\n", + "\n", + " def get_total_param_size(self):\n", + " return 2 * self.n_dims if self.scale is None else self.n_dims" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "af8bca53-9191-4f66-aa78-64747187df98", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import tensorflow_probability as tfp\n", + "\n", + "\n", + "class AffineFlow(tfp.experimental.joint_distribution_layers.Affine):\n", + " def __init__(self, t, n_dims, name=\"AffineFlow\"):\n", + " assert t.shape[-1] == 2 * n_dims\n", + " super(AffineFlow, self).__init__(\n", + " shift=t[..., 0:n_dims], scale_diag=1.0 + t[..., n_dims : 2 * n_dims], name=name\n", + " )\n", + "\n", + " @staticmethod\n", + " def get_param_size(n_dims):\n", + " \"\"\"\n", + " :param n_dims: The dimension of the distribution to be transformed by the flow\n", + " :return: (int) The dimension of the parameter space for the flow\n", + " \"\"\"\n", + " return 2 * n_dims" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "5561b844-a362-4b0c-8ad2-002c68cff427", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "ename": "TypeError", + "evalue": "missing a required argument: 'out_units'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[45], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtest_radial\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[36], line 46\u001b[0m, in \u001b[0;36mtest_radial\u001b[0;34m()\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtest_radial\u001b[39m():\n\u001b[0;32m---> 46\u001b[0m \u001b[43mflow_dimension_testing\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mradial\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[36], line 17\u001b[0m, in \u001b[0;36mflow_dimension_testing\u001b[0;34m(flow_name)\u001b[0m\n\u001b[1;32m 14\u001b[0m flow \u001b[38;5;241m=\u001b[39m flow_class(tf\u001b[38;5;241m.\u001b[39mones((batch_size, flow_class\u001b[38;5;241m.\u001b[39mget_param_size(dim) \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m)), dim)\n\u001b[1;32m 16\u001b[0m flow \u001b[38;5;241m=\u001b[39m flow_class(tf\u001b[38;5;241m.\u001b[39mones((batch_size, flow_class\u001b[38;5;241m.\u001b[39mget_param_size(dim))), dim)\n\u001b[0;32m---> 17\u001b[0m reference \u001b[38;5;241m=\u001b[39m \u001b[43mAffineFlow\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mones\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mAffineFlow\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_param_size\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdim\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 19\u001b[0m test_tensors \u001b[38;5;241m=\u001b[39m [[[\u001b[38;5;241m0.0\u001b[39m] \u001b[38;5;241m*\u001b[39m dim], [[\u001b[38;5;241m1.0\u001b[39m] \u001b[38;5;241m*\u001b[39m dim] \u001b[38;5;241m*\u001b[39m batch_size]\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m flow\u001b[38;5;241m.\u001b[39mforward_min_event_ndims \u001b[38;5;241m==\u001b[39m reference\u001b[38;5;241m.\u001b[39mforward_min_event_ndims\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/NFlow/lib/python3.10/site-packages/decorator.py:232\u001b[0m, in \u001b[0;36mdecorate..fun\u001b[0;34m(*args, **kw)\u001b[0m\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m kwsyntax:\n\u001b[1;32m 231\u001b[0m args, kw \u001b[38;5;241m=\u001b[39m fix(args, kw, sig)\n\u001b[0;32m--> 232\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcaller\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mextras\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkw\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/NFlow/lib/python3.10/site-packages/tensorflow_probability/python/distributions/distribution.py:342\u001b[0m, in \u001b[0;36m_DistributionMeta.__new__..wrapped_init\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[38;5;66;03m# Note: if we ever want to have things set in `self` before `__init__` is\u001b[39;00m\n\u001b[1;32m 340\u001b[0m \u001b[38;5;66;03m# called, here is the place to do it.\u001b[39;00m\n\u001b[1;32m 341\u001b[0m self_\u001b[38;5;241m.\u001b[39m_parameters \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 342\u001b[0m \u001b[43mdefault_init\u001b[49m\u001b[43m(\u001b[49m\u001b[43mself_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 343\u001b[0m \u001b[38;5;66;03m# Note: if we ever want to override things set in `self` by subclass\u001b[39;00m\n\u001b[1;32m 344\u001b[0m \u001b[38;5;66;03m# `__init__`, here is the place to do it.\u001b[39;00m\n\u001b[1;32m 345\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m self_\u001b[38;5;241m.\u001b[39m_parameters \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 346\u001b[0m \u001b[38;5;66;03m# We prefer subclasses will set `parameters = dict(locals())` because\u001b[39;00m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;66;03m# this has nearly zero overhead. However, failing to do this, we will\u001b[39;00m\n\u001b[1;32m 348\u001b[0m \u001b[38;5;66;03m# resolve the input arguments dynamically and only when needed.\u001b[39;00m\n", + "Cell \u001b[0;32mIn[44], line 7\u001b[0m, in \u001b[0;36mAffineFlow.__init__\u001b[0;34m(self, t, n_dims, name)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, t, n_dims, name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAffineFlow\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m t\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m n_dims\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mAffineFlow\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mshift\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mt\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43mn_dims\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscale_diag\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1.0\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_dims\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mn_dims\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/NFlow/lib/python3.10/site-packages/decorator.py:231\u001b[0m, in \u001b[0;36mdecorate..fun\u001b[0;34m(*args, **kw)\u001b[0m\n\u001b[1;32m 229\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfun\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkw):\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m kwsyntax:\n\u001b[0;32m--> 231\u001b[0m args, kw \u001b[38;5;241m=\u001b[39m \u001b[43mfix\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkw\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msig\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 232\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m caller(func, \u001b[38;5;241m*\u001b[39m(extras \u001b[38;5;241m+\u001b[39m args), \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkw)\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/NFlow/lib/python3.10/site-packages/decorator.py:203\u001b[0m, in \u001b[0;36mfix\u001b[0;34m(args, kwargs, sig)\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfix\u001b[39m(args, kwargs, sig):\n\u001b[1;32m 200\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 201\u001b[0m \u001b[38;5;124;03m Fix args and kwargs to be consistent with the signature\u001b[39;00m\n\u001b[1;32m 202\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 203\u001b[0m ba \u001b[38;5;241m=\u001b[39m \u001b[43msig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbind\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 204\u001b[0m ba\u001b[38;5;241m.\u001b[39mapply_defaults() \u001b[38;5;66;03m# needed for test_dan_schult\u001b[39;00m\n\u001b[1;32m 205\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ba\u001b[38;5;241m.\u001b[39margs, ba\u001b[38;5;241m.\u001b[39mkwargs\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/NFlow/lib/python3.10/inspect.py:3185\u001b[0m, in \u001b[0;36mSignature.bind\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3180\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbind\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m/\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 3181\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Get a BoundArguments object, that maps the passed `args`\u001b[39;00m\n\u001b[1;32m 3182\u001b[0m \u001b[38;5;124;03m and `kwargs` to the function's signature. Raises `TypeError`\u001b[39;00m\n\u001b[1;32m 3183\u001b[0m \u001b[38;5;124;03m if the passed arguments can not be bound.\u001b[39;00m\n\u001b[1;32m 3184\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 3185\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_bind\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/NFlow/lib/python3.10/inspect.py:3100\u001b[0m, in \u001b[0;36mSignature._bind\u001b[0;34m(self, args, kwargs, partial)\u001b[0m\n\u001b[1;32m 3098\u001b[0m msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmissing a required argument: \u001b[39m\u001b[38;5;132;01m{arg!r}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 3099\u001b[0m msg \u001b[38;5;241m=\u001b[39m msg\u001b[38;5;241m.\u001b[39mformat(arg\u001b[38;5;241m=\u001b[39mparam\u001b[38;5;241m.\u001b[39mname)\n\u001b[0;32m-> 3100\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(msg) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 3101\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3102\u001b[0m \u001b[38;5;66;03m# We have a positional argument to process\u001b[39;00m\n\u001b[1;32m 3103\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n", + "\u001b[0;31mTypeError\u001b[0m: missing a required argument: 'out_units'" + ] + } + ], + "source": [ + "test_radial()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "727a821f-0fc9-4ec9-8ba1-9451c3e2bda7", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_probability as tfp\n", + "\n", + "\n", + "class PlanarFlow(tfp.bijectors.Bijector):\n", + " \"\"\"\n", + " Implements a bijector x = y + u * tanh(w_t * y + b)\n", + "\n", + " Args:\n", + " params: Tensor shape (?, 2*n_dims+1). This will be split into the parameters\n", + " u (?, n_dims), w (?, n_dims), b (?, 1).\n", + " Furthermore u will be constrained to assure the invertability of the flow\n", + " n_dims: The dimension of the distribution that will be transformed\n", + " name: The name to give this particular flow\n", + "\n", + " \"\"\"\n", + "\n", + " _u, _w, _b = None, None, None\n", + "\n", + " def __init__(self, t, n_dims, name=\"Inverted_Planar_Flow\"):\n", + " super().__init__(validate_args=False, name=name, inverse_min_event_ndims=1)\n", + " assert t.shape[-1] == 2 * n_dims + 1\n", + " u, w, b = (\n", + " t[..., 0:n_dims],\n", + " # initialize w to 1.0\n", + " t[..., n_dims : 2 * n_dims] + 1,\n", + " t[..., 2 * n_dims : 2 * n_dims + 1],\n", + " )\n", + "\n", + " # constrain u before assigning it\n", + " self._u = self._u_circ(u, w)\n", + " self._w = w\n", + " self._b = b\n", + "\n", + " @staticmethod\n", + " def get_param_size(n_dims):\n", + " \"\"\"\n", + " :param n_dims: The dimension of the distribution to be transformed by the flow\n", + " :return: (int array) The dimension of the parameter space for this flow, n_dims + n_dims + 1\n", + " \"\"\"\n", + " return n_dims + n_dims + 1\n", + "\n", + " @staticmethod\n", + " def _u_circ(u, w):\n", + " \"\"\"\n", + " To ensure invertibility of the flow, the following condition needs to hold: w_t * u >= -1\n", + " :return: The transformed u\n", + " \"\"\"\n", + " wtu = tf.math.reduce_sum(w * u, 1, keepdims=True)\n", + " # add constant to make it more numerically stable\n", + " m_wtu = -1.0 + tf.nn.softplus(wtu) + 1e-5\n", + " norm_w_squared = tf.math.reduce_sum(w ** 2, 1, keepdims=True) + 1e-9\n", + " return u + (m_wtu - wtu) * (w / norm_w_squared)\n", + "\n", + " def _wzb(self, z):\n", + " \"\"\"\n", + " Computes w_t * z + b\n", + " \"\"\"\n", + " return tf.math.reduce_sum(self._w * z, 1, keepdims=True) + self._b\n", + "\n", + " @staticmethod\n", + " def _der_tanh(z):\n", + " \"\"\"\n", + " Computes the derivative of hyperbolic tangent\n", + " \"\"\"\n", + " return 1.0 - tf.math.tanh(z) ** 2\n", + "\n", + " def _forward(self, z):\n", + " \"\"\"\n", + " Runs a forward pass through the bijector\n", + " \"\"\"\n", + " return z + self._u * tf.math.tanh(self._wzb(z))\n", + "\n", + " def _forward_log_det_jacobian(self, z):\n", + " \"\"\"\n", + " Computes the ln of the absolute determinant of the jacobian\n", + " \"\"\"\n", + " psi = self._der_tanh(self._wzb(z)) * self._w\n", + " det_grad = 1.0 + tf.math.reduce_sum(self._u * psi, 1)\n", + " return tf.math.log(tf.math.abs(det_grad))" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "b3aa80cb-8d33-429a-bf44-17228f745289", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_probability as tfp\n", + "\n", + "\n", + "class RadialFlow(tfp.bijectors.Bijector):\n", + " \"\"\"\n", + " Implements a bijector x = y + (alpha * beta * (y - y_0)) / (alpha + abs(y - y_0)).\n", + " Args:\n", + " params: Tensor shape (?, n_dims+2). This will be split into the parameters\n", + " alpha (?, 1), beta (?, 1), gamma (?, n_dims).\n", + " Furthermore alpha will be constrained to assure the invertability of the flow\n", + " n_dims: The dimension of the distribution that will be transformed\n", + " name: The name to give this particular flow\n", + " \"\"\"\n", + "\n", + " _alpha = None\n", + " _beta = None\n", + " _gamma = None\n", + "\n", + " def __init__(self, t, n_dims, name='RadialFlow'):\n", + " super().__init__(validate_args=False, name=name, inverse_min_event_ndims=1)\n", + "\n", + " assert t.shape[-1] == n_dims + 2\n", + " alpha = t[..., 0:1]\n", + " beta = t[..., 1:2]\n", + " gamma = t[..., 2 : n_dims + 2]\n", + "\n", + " # constraining the parameters before they are assigned to ensure invertibility.\n", + " # slightly shift alpha, softplus(zero centered input - 2) = small\n", + " self._alpha = self._alpha_circ(0.3 * alpha - 2.0)\n", + " # slightly shift beta, softplus(zero centered input + ln(e - 1)) = 0\n", + " self._beta = self._beta_circ(0.1 * beta + tf.math.log(tf.math.expm1(1.0)))\n", + " self._gamma = gamma\n", + " self.n_dims = n_dims\n", + "\n", + " @staticmethod\n", + " def get_param_size(n_dims):\n", + " \"\"\"\n", + " :param n_dims: The dimension of the distribution to be transformed by the flow\n", + " :return: (int) The dimension of the parameter space for the flow\n", + " \"\"\"\n", + " return 1 + 1 + n_dims\n", + "\n", + " def _r(self, z):\n", + " return tf.math.reduce_sum(tf.abs(z - self._gamma), 1, keepdims=True)\n", + "\n", + " def _h(self, r):\n", + " return 1.0 / (self._alpha + r)\n", + "\n", + " def _forward(self, z):\n", + " \"\"\"\n", + " Runs a forward pass through the bijector\n", + " \"\"\"\n", + " r = self._r(z)\n", + " h = self._h(r)\n", + " return z + (self._alpha * self._beta * h) * (z - self._gamma)\n", + "\n", + " def _forward_log_det_jacobian(self, z):\n", + " \"\"\"\n", + " Computes the ln of the absolute determinant of the jacobian\n", + " \"\"\"\n", + " r = self._r(z)\n", + " with tf.GradientTape() as g:\n", + " g.watch(r)\n", + " h = self._h(r)\n", + " der_h = g.gradient(h, r)\n", + " ab = self._alpha * self._beta\n", + " det = (1.0 + ab * h) ** (self.n_dims - 1) * (1.0 + ab * h + ab * der_h * r)\n", + " det = tf.squeeze(det, axis=-1)\n", + " return tf.math.log(det)\n", + "\n", + " @staticmethod\n", + " def _alpha_circ(alpha):\n", + " \"\"\"\n", + " Method for constraining the alpha parameter to meet the invertibility requirements\n", + " \"\"\"\n", + " return tf.nn.softplus(alpha)\n", + "\n", + " @staticmethod\n", + " def _beta_circ(beta):\n", + " \"\"\"\n", + " Method for constraining the beta parameter to meet the invertibility requirements\n", + " \"\"\"\n", + " return tf.nn.softplus(beta) - 1.0" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "f74ef9a1-83b6-45ba-ac25-5f0f5ec5665c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import tensorflow as tf\n", + "import tensorflow_probability as tfp\n", + "\n", + "tfd = tfp.distributions\n", + "\n", + "\n", + "class BayesianNNEstimator(BaseEstimator):\n", + " def __init__(\n", + " self,\n", + " dist_layer,\n", + " kl_weight_scale,\n", + " kl_use_exact=True,\n", + " hidden_sizes=(10,),\n", + " activation=\"tanh\",\n", + " learning_rate=3e-2,\n", + " noise_reg=(\"fixed_rate\", 0.0),\n", + " trainable_prior=False,\n", + " map_mode=False,\n", + " prior_scale=1.0,\n", + " random_seed=22,\n", + " ):\n", + " \"\"\"\n", + " A bayesian net parametrizing a normalizing flow distribution\n", + " :param dist_layer: A Tfp Distribution Lambda Layer that converts the neural net output into a distribution\n", + " :param kl_weight_scale: Scales how much KL(posterior|prior) influences the loss\n", + " :param hidden_sizes: size and depth of net\n", + " :param noise_reg: Tuple with (type_of_reg, scale_factor)\n", + " :param trainable_prior: empirical bayes\n", + " :param map_mode: If true, will use the mean of the posterior instead of a sample. Default False\n", + " :param prior_scale: The scale of the zero centered priors\n", + "\n", + " A note on kl_weight_scale: Keras calculates the loss per sample and not for the full dataset. Therefore,\n", + " we need to scale the KL(q||p) loss down to a single sample, which means setting kl_weight_scale = 1/n_datapoints\n", + " \"\"\"\n", + " self.x_noise_std = tf.Variable(initial_value=0.0, dtype=tf.float32, trainable=False)\n", + " self.y_noise_std = tf.Variable(initial_value=0.0, dtype=tf.float32, trainable=False)\n", + " self.map_mode = map_mode\n", + "\n", + " posterior = self._get_posterior_fn(map_mode=map_mode)\n", + " prior = self._get_prior_fn(trainable_prior, prior_scale)\n", + " dense_layers = self._get_dense_layers(\n", + " hidden_sizes=hidden_sizes,\n", + " output_size=dist_layer.get_total_param_size(),\n", + " posterior=posterior,\n", + " prior=prior,\n", + " kl_weight_scale=kl_weight_scale,\n", + " kl_use_exact=kl_use_exact,\n", + " activation=activation,\n", + " )\n", + "\n", + " super().__init__(\n", + " dense_layers + [dist_layer],\n", + " noise_fn_type=noise_reg[0],\n", + " noise_scale_factor=noise_reg[1],\n", + " random_seed=random_seed,\n", + " )\n", + "\n", + " self.compile(\n", + " optimizer=tf.keras.optimizers.Adam(learning_rate), loss=self._get_neg_log_likelihood()\n", + " )\n", + "\n", + " def score(self, x_data, y_data):\n", + " x_data = x_data.astype(np.float32)\n", + " y_data = y_data.astype(np.float32)\n", + "\n", + " scores = None\n", + " nll = self._get_neg_log_likelihood()\n", + " posterior_draws = 1 if self.map_mode else 50\n", + " for _ in range(posterior_draws):\n", + " res = tf.expand_dims(-nll(y_data, self.call(x_data, training=False)), axis=0)\n", + " scores = res if scores is None else tf.concat([scores, res], axis=0)\n", + " logsumexp = tf.math.reduce_logsumexp(scores, axis=0).numpy() - np.log(posterior_draws)\n", + " return logsumexp.mean()\n", + "\n", + " @staticmethod\n", + " def _get_prior_fn(trainable=False, prior_scale=1.0):\n", + " def prior_fn(kernel_size, bias_size=0, dtype=None):\n", + " size = kernel_size + bias_size\n", + " layers = [\n", + " tfp.layers.VariableLayer(\n", + " shape=size, initializer=\"zeros\", dtype=dtype, trainable=trainable\n", + " ),\n", + " MeanFieldLayer(size, scale=prior_scale, map_mode=False, dtype=dtype),\n", + " ]\n", + " return tf.keras.Sequential(layers)\n", + "\n", + " return prior_fn\n", + "\n", + " @staticmethod\n", + " def _get_posterior_fn(map_mode=False):\n", + " def posterior_fn(kernel_size, bias_size=0, dtype=None):\n", + " size = kernel_size + bias_size\n", + " layers = [\n", + " tfp.layers.VariableLayer(\n", + " size if map_mode else 2 * size,\n", + " initializer=\"normal\",\n", + " dtype=dtype,\n", + " trainable=True,\n", + " ),\n", + " MeanFieldLayer(size, scale=None, map_mode=map_mode, dtype=dtype),\n", + " ]\n", + " return tf.keras.Sequential(layers)\n", + "\n", + " return posterior_fn\n", + "\n", + " def _get_dense_layers(\n", + " self,\n", + " hidden_sizes,\n", + " output_size,\n", + " posterior,\n", + " prior,\n", + " kl_weight_scale=1.0,\n", + " kl_use_exact=True,\n", + " activation=\"relu\",\n", + " ):\n", + " assert type(hidden_sizes) == tuple or type(hidden_sizes) == list\n", + " assert kl_weight_scale <= 1.0\n", + "\n", + " # these values are assigned once fit is called\n", + " normalization = [tf.keras.layers.Lambda(lambda x: (x - self.x_mean) / (self.x_std + 1e-8))]\n", + " noise_reg = [tf.keras.layers.GaussianNoise(self.x_noise_std)]\n", + " hidden = [\n", + " tfp.layers.DenseVariational(\n", + " units=size,\n", + " make_posterior_fn=posterior,\n", + " make_prior_fn=prior,\n", + " kl_weight=kl_weight_scale,\n", + " kl_use_exact=kl_use_exact,\n", + " activation=activation,\n", + " )\n", + " for size in hidden_sizes\n", + " ]\n", + " output = [\n", + " tfp.layers.DenseVariational(\n", + " units=output_size,\n", + " make_posterior_fn=posterior,\n", + " make_prior_fn=prior,\n", + " kl_weight=kl_weight_scale,\n", + " kl_use_exact=kl_use_exact,\n", + " activation=\"linear\",\n", + " )\n", + " ]\n", + " return normalization + noise_reg + hidden + output" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "881c0b13-71fc-4a27-88a6-590de016cdea", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "class InverseNormalizingFlowLayer(tfp.layers.DistributionLambda):\n", + " _flow_types = None\n", + " _trainable_base_dist = None\n", + " _n_dims = None\n", + "\n", + " def __init__(self, flow_types, n_dims, trainable_base_dist=False):\n", + " \"\"\"\n", + " Subclass of a DistributionLambda. A layer that uses it's input to parametrize a normalizing flow\n", + " that transforms a base normal distribution. The Normalizing flows are inverted to enable fast likelihood\n", + " calculation of externally provided data. This is useful for density estimation.\n", + " As a result, sampling from this layer is not possible.\n", + " This layer does not work for scalars!\n", + " :param flow_types: Types of flows to use, applied in order from base_dist -> transformed_dist\n", + " :param n_dims: dimension of the underlying distribution being transformed\n", + " :param trainable_base_dist: whether the base normal distribution should have trainable loc and scale diag\n", + " \"\"\"\n", + " assert all([flow_type in FLOWS for flow_type in flow_types])\n", + "\n", + " self._flow_types = flow_types\n", + " self._trainable_base_dist = trainable_base_dist\n", + " self._n_dims = n_dims\n", + "\n", + " # as keras transforms tensors, this layer needs to have an tensor-like output\n", + " # therefore a function needs to be provided that transforms a distribution into a tensor\n", + " # per default the .sample() function is used, but our reversed flows cannot perform that operation\n", + " convert_ttfn = lambda d: tf.constant([0.0])\n", + " make_flow_dist = self._get_distribution_fn(n_dims, flow_types, trainable_base_dist)\n", + " super().__init__(make_distribution_fn=make_flow_dist, convert_to_tensor_fn=convert_ttfn)\n", + "\n", + " @staticmethod\n", + " def _get_distribution_fn(n_dims, flow_types, trainable_base_dist):\n", + " return lambda t: tfd.TransformedDistribution(\n", + " distribution=InverseNormalizingFlowLayer._get_base_dist(\n", + " t, n_dims, trainable_base_dist\n", + " ),\n", + " bijector=tfp.bijectors.Invert(\n", + " InverseNormalizingFlowLayer._get_bijector(\n", + " (t[..., 2 * n_dims :] if trainable_base_dist else t), flow_types, n_dims\n", + " )\n", + " ),\n", + " )\n", + "\n", + " def get_total_param_size(self):\n", + " \"\"\"\n", + " :return: The total number of parameters to specify this distribution\n", + " \"\"\"\n", + " num_flow_params = sum(\n", + " [FLOWS[flow_type].get_param_size(self._n_dims) for flow_type in self._flow_types]\n", + " )\n", + " base_dist_params = 2 * self._n_dims if self._trainable_base_dist else 0\n", + " return num_flow_params + base_dist_params\n", + "\n", + " @staticmethod\n", + " def _get_bijector(t, flow_types, n_dims):\n", + " # intuitively, we want to flows to go from base_dist -> transformed dist\n", + " flow_types = list(reversed(flow_types))\n", + " param_sizes = [FLOWS[flow_type].get_param_size(n_dims) for flow_type in flow_types]\n", + " assert sum(param_sizes) == t.shape[-1]\n", + " split_beginnings = [sum(param_sizes[0:i]) for i in range(len(param_sizes))]\n", + " chain = [\n", + " FLOWS[flow_type](t[..., begin : begin + size], n_dims)\n", + " for begin, size, flow_type in zip(split_beginnings, param_sizes, flow_types)\n", + " ]\n", + " return tfp.bijectors.Chain(chain)\n", + "\n", + " @staticmethod\n", + " def _get_base_dist(t, n_dims, trainable):\n", + " if trainable:\n", + " return tfd.MultivariateNormalDiag(\n", + " loc=t[..., 0:n_dims],\n", + " scale_diag=1e-3\n", + " + tf.math.softplus(\n", + " tf.math.log(tf.math.expm1(1.0)) + 0.1 * t[..., n_dims : 2 * n_dims]\n", + " ),\n", + " )\n", + " else:\n", + " # we still need to know the batch size, therefore we need t for reference\n", + " return tfd.MultivariateNormalDiag(\n", + " loc=tf.zeros_like(t[..., 0:n_dims]), scale_diag=tf.ones_like(t[..., 0:n_dims])\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "4b62f5ad-f94a-4a4d-955a-76f75adaee76", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_probability as tfp\n", + "\n", + "\n", + "tfd = tfp.distributions\n", + "\n", + "\n", + "class BayesNormalizingFlowNetwork(BayesianNNEstimator):\n", + " def __init__(self, n_dims, kl_weight_scale, n_flows=2, trainable_base_dist=True, **kwargs):\n", + " \"\"\"\n", + " A bayesian net parametrizing a normalizing flow distribution\n", + " :param n_dims: The dimension of the output distribution\n", + " :param kl_weight_scale: Scales how much KL(posterior|prior) influences the loss\n", + " :param n_flows: The number of flows to use\n", + " :param hidden_sizes: size and depth of net\n", + " :param trainable_base_dist: whether to train the base normal dist\n", + " :param noise_reg: Tuple with (type_of_reg, scale_factor)\n", + " :param trainable_prior: empirical bayes\n", + " :param map_mode: If true, will use the mean of the posterior instead of a sample. Default False\n", + " :param prior_scale: The scale of the zero centered priors\n", + "\n", + " A note on kl_weight_scale: Keras calculates the loss per sample and not for the full dataset. Therefore,\n", + " we need to scale the KL(q||p) loss down to a single sample, which means setting kl_weight_scale = 1/n_datapoints\n", + " \"\"\"\n", + " dist_layer = InverseNormalizingFlowLayer(\n", + " flow_types=[\"radial\"] * n_flows, n_dims=n_dims, trainable_base_dist=trainable_base_dist\n", + " )\n", + " super().__init__(dist_layer, kl_weight_scale, **kwargs)\n", + "\n", + " @staticmethod\n", + " def build_function(\n", + " n_dims,\n", + " kl_weight_scale,\n", + " n_flows=2,\n", + " trainable_base_dist=True,\n", + " kl_use_exact=True,\n", + " hidden_sizes=(10,),\n", + " activation=\"tanh\",\n", + " noise_reg=(\"fixed_rate\", 0.0),\n", + " learning_rate=2e-2,\n", + " trainable_prior=False,\n", + " map_mode=False,\n", + " prior_scale=1.0,\n", + " ):\n", + " # this is necessary, else there'll be processes hanging around hogging memory\n", + " tf.keras.backend.clear_session()\n", + " return BayesNormalizingFlowNetwork(\n", + " n_dims=n_dims,\n", + " kl_weight_scale=kl_weight_scale,\n", + " n_flows=n_flows,\n", + " trainable_base_dist=trainable_base_dist,\n", + " kl_use_exact=kl_use_exact,\n", + " hidden_sizes=hidden_sizes,\n", + " activation=activation,\n", + " noise_reg=noise_reg,\n", + " learning_rate=learning_rate,\n", + " trainable_prior=trainable_prior,\n", + " map_mode=map_mode,\n", + " prior_scale=prior_scale,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "65992896-1ca4-4c63-8775-7485fe274321", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "class MaximumLikelihoodNNEstimator(BaseEstimator):\n", + " def __init__(\n", + " self,\n", + " dist_layer,\n", + " hidden_sizes=(16, 16),\n", + " noise_reg=(\"fixed_rate\", 0.0),\n", + " learning_rate=3e-3,\n", + " activation=\"relu\",\n", + " random_seed=22,\n", + " ):\n", + " assert len(noise_reg) == 2\n", + "\n", + " dense_layers = self._get_dense_layers(\n", + " hidden_sizes=hidden_sizes,\n", + " output_size=dist_layer.get_total_param_size(),\n", + " activation=activation,\n", + " )\n", + "\n", + " super().__init__(\n", + " dense_layers + [dist_layer],\n", + " noise_fn_type=noise_reg[0],\n", + " noise_scale_factor=noise_reg[1],\n", + " random_seed=random_seed,\n", + " )\n", + "\n", + " self.compile(\n", + " optimizer=tf.keras.optimizers.Adam(learning_rate), loss=self._get_neg_log_likelihood()\n", + " )\n", + "\n", + " def _get_dense_layers(self, hidden_sizes, output_size, activation):\n", + " assert type(hidden_sizes) == tuple or type(hidden_sizes) == list\n", + " # the data normalization values are assigned once fit is called\n", + " normalization = [tf.keras.layers.Lambda(lambda x: (x - self.x_mean) / (self.x_std + 1e-8))]\n", + " noise_reg = [tf.keras.layers.GaussianNoise(self.x_noise_std)]\n", + " hidden = [tf.keras.layers.Dense(size, activation=activation) for size in hidden_sizes]\n", + " output = [tf.keras.layers.Dense(output_size, activation=\"linear\")]\n", + " return normalization + noise_reg + hidden + output" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "a97995fc-130a-4b21-a8dd-98b39f940647", + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_probability as tfp\n", + "\n", + "tfd = tfp.distributions\n", + "\n", + "\n", + "class NormalizingFlowNetwork(MaximumLikelihoodNNEstimator):\n", + " def __init__(self, n_dims, n_flows=10, trainable_base_dist=True, **kwargs):\n", + " \"\"\"\n", + " :param n_dims: Dimensionsion of Y. The dimension of X is automatically inferred from the data\n", + " :param n_flows: The number of radial flows to use.\n", + " :param trainable_base_dist: Whether the standard normal base dist has trainable mean + diag. convariance\n", + " \"\"\"\n", + " dist_layer = InverseNormalizingFlowLayer(\n", + " flow_types=[\"radial\"] * n_flows, n_dims=n_dims, trainable_base_dist=trainable_base_dist\n", + " )\n", + " super().__init__(dist_layer, **kwargs)\n", + "\n", + " @staticmethod\n", + " def build_function(\n", + " n_dims=1,\n", + " n_flows=3,\n", + " hidden_sizes=(16, 16),\n", + " trainable_base_dist=True,\n", + " noise_reg=(\"fixed_rate\", 0.0),\n", + " learning_rate=3e-3,\n", + " activation=\"tanh\",\n", + " ):\n", + " # this is necessary, else there'll be processes hanging around hogging memory\n", + " tf.keras.backend.clear_session()\n", + " return NormalizingFlowNetwork(\n", + " n_dims=n_dims,\n", + " n_flows=n_flows,\n", + " hidden_sizes=hidden_sizes,\n", + " trainable_base_dist=trainable_base_dist,\n", + " noise_reg=noise_reg,\n", + " learning_rate=learning_rate,\n", + " activation=activation,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a024be78-1775-48f0-a82b-c35059acab44", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e94580e6-639e-4e81-8e00-5d05fdf42067", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5d66b5d-ffc9-42e6-82e3-909245a0907b", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "640b9386-fbe4-488a-a15c-7cf1596d369f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import tensorflow_probability as tfp\n", + "tfd = tfp.distributions\n", + "import pytest\n", + "\n", + "\n", + "def flow_dimension_testing(flow_name):\n", + " FLOWS = {\"planar\": PlanarFlow, \"radial\": RadialFlow, \"affine\": AffineFlow}\n", + " # all are tested against Affine Flow, since that's the reference implementation\n", + " batch_size = 10\n", + " for dim in [1, 4]:\n", + " flow_class = FLOWS[flow_name]\n", + " # test dimension of parameter space\n", + " with pytest.raises(AssertionError):\n", + " flow = flow_class(tf.ones((batch_size, flow_class.get_param_size(dim) + 1)), dim)\n", + "\n", + " flow = flow_class(tf.ones((batch_size, flow_class.get_param_size(dim))), dim)\n", + " reference = AffineFlow(tf.ones((batch_size, AffineFlow.get_param_size(dim))), dim)\n", + "\n", + " test_tensors = [[[0.0] * dim], [[1.0] * dim] * batch_size]\n", + " assert flow.forward_min_event_ndims == reference.forward_min_event_ndims\n", + " for tensor in test_tensors:\n", + " assert flow.forward(tensor).shape == reference.forward(tensor).shape\n", + " assert (\n", + " flow._forward_log_det_jacobian(tensor).shape\n", + " == reference._forward_log_det_jacobian(tensor).shape\n", + " )\n", + "\n", + " tensor = [[1.0] * dim] + ([[0.0] * dim] * (batch_size - 2)) + [[1.0] * dim]\n", + " res = flow.forward(tensor).numpy()\n", + " assert res[0] == pytest.approx(res[-1], rel=1e-5)\n", + " assert res[1] == pytest.approx(res[-2], rel=1e-5)\n", + " assert not all(res[0] == res[1])\n", + "\n", + " tensor = [[1.0] * dim] + ([[0.0] * dim] * (batch_size - 2)) + [[1.0] * dim]\n", + " res = flow._forward_log_det_jacobian(tensor).numpy()\n", + " assert res[0] == pytest.approx(res[-1], rel=1e-5)\n", + " assert res[1] == pytest.approx(res[-2], rel=1e-5)\n", + " assert not res[0] == pytest.approx(res[1])\n", + "\n", + "\n", + "def test_planar():\n", + " flow_dimension_testing(\"planar\")\n", + "\n", + "\n", + "def test_radial():\n", + " flow_dimension_testing(\"radial\")" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "7d08e49a-50a5-4f01-8af9-c61d07a7630a", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "ename": "TypeError", + "evalue": "Affine.__init__() got an unexpected keyword argument 'shift'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[37], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtest_radial\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[36], line 46\u001b[0m, in \u001b[0;36mtest_radial\u001b[0;34m()\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtest_radial\u001b[39m():\n\u001b[0;32m---> 46\u001b[0m \u001b[43mflow_dimension_testing\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mradial\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[36], line 17\u001b[0m, in \u001b[0;36mflow_dimension_testing\u001b[0;34m(flow_name)\u001b[0m\n\u001b[1;32m 14\u001b[0m flow \u001b[38;5;241m=\u001b[39m flow_class(tf\u001b[38;5;241m.\u001b[39mones((batch_size, flow_class\u001b[38;5;241m.\u001b[39mget_param_size(dim) \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m)), dim)\n\u001b[1;32m 16\u001b[0m flow \u001b[38;5;241m=\u001b[39m flow_class(tf\u001b[38;5;241m.\u001b[39mones((batch_size, flow_class\u001b[38;5;241m.\u001b[39mget_param_size(dim))), dim)\n\u001b[0;32m---> 17\u001b[0m reference \u001b[38;5;241m=\u001b[39m \u001b[43mAffineFlow\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mones\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mAffineFlow\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_param_size\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdim\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 19\u001b[0m test_tensors \u001b[38;5;241m=\u001b[39m [[[\u001b[38;5;241m0.0\u001b[39m] \u001b[38;5;241m*\u001b[39m dim], [[\u001b[38;5;241m1.0\u001b[39m] \u001b[38;5;241m*\u001b[39m dim] \u001b[38;5;241m*\u001b[39m batch_size]\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m flow\u001b[38;5;241m.\u001b[39mforward_min_event_ndims \u001b[38;5;241m==\u001b[39m reference\u001b[38;5;241m.\u001b[39mforward_min_event_ndims\n", + "Cell \u001b[0;32mIn[9], line 7\u001b[0m, in \u001b[0;36mAffineFlow.__init__\u001b[0;34m(self, t, n_dims, name)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, t, n_dims, name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAffineFlow\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m t\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m n_dims\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mAffineFlow\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mshift\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mt\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43mn_dims\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscale_diag\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1.0\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_dims\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mn_dims\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mTypeError\u001b[0m: Affine.__init__() got an unexpected keyword argument 'shift'" + ] + } + ], + "source": [ + "test_radial()" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "97946900-bc4f-4bca-8faf-e7cfbf238d61", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [ + { + "ename": "ValueError", + "evalue": "name for name_scope must be a string.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[32], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mflow_dimension_testing\u001b[49m\u001b[43m(\u001b[49m\u001b[43mPlanarFlow\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_dims\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mndims\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[27], line 12\u001b[0m, in \u001b[0;36mflow_dimension_testing\u001b[0;34m(flow_class)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m dim \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m4\u001b[39m]:\n\u001b[1;32m 10\u001b[0m \u001b[38;5;66;03m# test dimension of parameter space\u001b[39;00m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m pytest\u001b[38;5;241m.\u001b[39mraises(\u001b[38;5;167;01mAssertionError\u001b[39;00m):\n\u001b[0;32m---> 12\u001b[0m flow \u001b[38;5;241m=\u001b[39m \u001b[43mflow_class\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mones\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflow_class\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_param_size\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdim\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m flow \u001b[38;5;241m=\u001b[39m flow_class(tf\u001b[38;5;241m.\u001b[39mones((batch_size, flow_class\u001b[38;5;241m.\u001b[39mget_param_size(dim))), dim)\n\u001b[1;32m 15\u001b[0m reference \u001b[38;5;241m=\u001b[39m AffineFlow(tf\u001b[38;5;241m.\u001b[39mones((batch_size, AffineFlow\u001b[38;5;241m.\u001b[39mget_param_size(dim))), dim)\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/NFlow/lib/python3.10/site-packages/tensorflow_probability/python/bijectors/bijector.py:913\u001b[0m, in \u001b[0;36mBijector.__call__\u001b[0;34m(self, value, name, **kwargs)\u001b[0m\n\u001b[1;32m 910\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(value, Bijector):\n\u001b[1;32m 911\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m chain\u001b[38;5;241m.\u001b[39mChain([\u001b[38;5;28mself\u001b[39m, value], name\u001b[38;5;241m=\u001b[39mname, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 913\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mforward\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/NFlow/lib/python3.10/site-packages/tensorflow_probability/python/bijectors/bijector.py:1326\u001b[0m, in \u001b[0;36mBijector.forward\u001b[0;34m(self, x, name, **kwargs)\u001b[0m\n\u001b[1;32m 1310\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mforward\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 1311\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Returns the forward `Bijector` evaluation, i.e., X = g(Y).\u001b[39;00m\n\u001b[1;32m 1312\u001b[0m \n\u001b[1;32m 1313\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1324\u001b[0m \u001b[38;5;124;03m NotImplementedError: if `_forward` is not implemented.\u001b[39;00m\n\u001b[1;32m 1325\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1326\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/NFlow/lib/python3.10/site-packages/tensorflow_probability/python/bijectors/bijector.py:1300\u001b[0m, in \u001b[0;36mBijector._call_forward\u001b[0;34m(self, x, name, **kwargs)\u001b[0m\n\u001b[1;32m 1298\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_call_forward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, name, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 1299\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Wraps call to _forward, allowing extra shared logic.\"\"\"\u001b[39;00m\n\u001b[0;32m-> 1300\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_name_and_control_scope(name):\n\u001b[1;32m 1301\u001b[0m dtype \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minverse_dtype(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1302\u001b[0m x \u001b[38;5;241m=\u001b[39m nest_util\u001b[38;5;241m.\u001b[39mconvert_to_nested_tensor(\n\u001b[1;32m 1303\u001b[0m x, name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m'\u001b[39m, dtype_hint\u001b[38;5;241m=\u001b[39mdtype,\n\u001b[1;32m 1304\u001b[0m dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01mif\u001b[39;00m SKIP_DTYPE_CHECKS \u001b[38;5;28;01melse\u001b[39;00m dtype,\n\u001b[1;32m 1305\u001b[0m allow_packing\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/NFlow/lib/python3.10/contextlib.py:135\u001b[0m, in \u001b[0;36m_GeneratorContextManager.__enter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mkwds, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfunc\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 135\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgen\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgenerator didn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt yield\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/NFlow/lib/python3.10/site-packages/tensorflow_probability/python/bijectors/bijector.py:1816\u001b[0m, in \u001b[0;36mBijector._name_and_control_scope\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1812\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Helper function to standardize op scope.\"\"\"\u001b[39;00m\n\u001b[1;32m 1813\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m name_util\u001b[38;5;241m.\u001b[39minstance_scope(\n\u001b[1;32m 1814\u001b[0m instance_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname,\n\u001b[1;32m 1815\u001b[0m constructor_name_scope\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_constructor_name_scope):\n\u001b[0;32m-> 1816\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43mtf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname_scope\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m name_scope:\n\u001b[1;32m 1817\u001b[0m deps \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 1818\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_defer_all_assertions:\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/NFlow/lib/python3.10/site-packages/tensorflow/python/framework/ops.py:6423\u001b[0m, in \u001b[0;36mname_scope_v2.__init__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 6414\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Initialize the context manager.\u001b[39;00m\n\u001b[1;32m 6415\u001b[0m \n\u001b[1;32m 6416\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 6420\u001b[0m \u001b[38;5;124;03m ValueError: If name is not a string.\u001b[39;00m\n\u001b[1;32m 6421\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 6422\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(name, \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m-> 6423\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname for name_scope must be a string.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 6424\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_name \u001b[38;5;241m=\u001b[39m name\n\u001b[1;32m 6425\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_exit_fns \u001b[38;5;241m=\u001b[39m []\n", + "\u001b[0;31mValueError\u001b[0m: name for name_scope must be a string." + ] + } + ], + "source": [ + "flow_dimension_testing(PlanarFlow(t=t, n_dims=ndims))" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "c58f187a-4d63-45e9-bd80-8349e94edd63", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "t = tf.random.normal(shape=(1, 5))\n", + "ndims = 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41869a9e-26ea-4729-a19f-3463eca5f4a8", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fcc6c787-4142-4992-8a39-4f358aa43bd6", + "metadata": {}, + "outputs": [], + "source": [ + " \"\"\"\n", + " Implements a bijector x = y + (alpha * beta * (y - y_0)) / (alpha + abs(y - y_0)).\n", + " Args:\n", + " params: Tensor shape (?, n_dims+2). This will be split into the parameters\n", + " alpha (?, 1), beta (?, 1), gamma (?, n_dims).\n", + " Furthermore alpha will be constrained to assure the invertability of the flow\n", + " n_dims: The dimension of the distribution that will be transformed\n", + " name: The name to give this particular flow\n", + " \"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b215fb29-b609-4def-afa5-a610b188fadf", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "NFlows", + "language": "python", + "name": "nflows" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}