diff --git "a/notebooks/toy_test.ipynb" "b/notebooks/toy_test.ipynb" new file mode 100644--- /dev/null +++ "b/notebooks/toy_test.ipynb" @@ -0,0 +1,2354 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "b89d9358-ef03-4b85-82c3-7eb6366cc7cc", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.distributions as D\n", + "\n", + "from torch.utils.data import DataLoader, dataset, TensorDataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9e29e664-80af-4102-a465-77b73e84f867", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import sys\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "sys.path.append('../insight')\n", + "from archive import archive \n", + "\n", + "import torch.nn.functional as F" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ac8eedcf-ec15-4b95-8977-c3affc94f4de", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import torch.nn.functional as F\n", + "\n", + "class RandomToNormalNN(nn.Module):\n", + " def __init__(self, input_dim, hidden_dim, output_dim, num_hidden_layers):\n", + " super(RandomToNormalNN, self).__init__()\n", + " self.input_layer = nn.Linear(input_dim, hidden_dim)\n", + " self.hidden_layers = nn.ModuleList([\n", + " nn.Linear(hidden_dim, hidden_dim) for _ in range(num_hidden_layers)\n", + " ])\n", + " self.output_layer = nn.Linear(hidden_dim, output_dim)\n", + " self.activation = nn.ReLU() # You can experiment with other activation functions\n", + " \n", + " def forward(self, x):\n", + " x = self.activation(self.input_layer(x))\n", + " for hidden_layer in self.hidden_layers:\n", + " x = self.activation(hidden_layer(x))\n", + " x = self.output_layer(x)\n", + " return x\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "617194de-d0d0-4d44-86ef-75185d1b5c0f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "input_dim = 1\n", + "hidden_dim = 128\n", + "output_dim = 1\n", + "num_hidden_layers = 3\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "\n", + "model = RandomToNormalNN(input_dim, hidden_dim, output_dim, num_hidden_layers).to(device)\n", + "\n", + "nepochs=100\n", + "bs = 1000\n", + "z_dim=1\n", + "epsilon = 0.5\n", + "\n", + "# Create an instance of the neural network\n", + "base_distribution = D.Normal(torch.zeros(z_dim), torch.ones(z_dim))\n", + "\n", + "optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "706bf267-c870-4fe6-a1de-758aac722678", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "random_input = torch.rand(10000, input_dim)\n", + "dset = TensorDataset(random_input)\n", + "loader = DataLoader(dset, batch_size=100, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "825e88d5-c4b2-493c-8597-3c5859b8bf8a", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 0 0.918938398361206\n", + "epoch 1 0.918938398361206\n", + "epoch 2 0.918938398361206\n", + "epoch 3 0.918938398361206\n", + "epoch 4 0.918938398361206\n", + "epoch 5 0.918938398361206\n", + "epoch 6 0.918938398361206\n", + "epoch 7 0.918938398361206\n", + "epoch 8 0.918938398361206\n", + "epoch 9 0.918938398361206\n", + "epoch 10 0.918938398361206\n", + "epoch 11 0.918938398361206\n", + "epoch 12 0.918938398361206\n", + "epoch 13 0.918938398361206\n", + "epoch 14 0.918938398361206\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[42], line 13\u001b[0m\n\u001b[1;32m 11\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39mloss\u001b[38;5;241m.\u001b[39mmean()\n\u001b[1;32m 12\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[0;32m---> 13\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m \n\u001b[1;32m 15\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mepoch \u001b[39m\u001b[38;5;124m'\u001b[39m,e, loss\u001b[38;5;241m.\u001b[39mitem())\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/DLenv2/lib/python3.9/site-packages/torch/optim/optimizer.py:88\u001b[0m, in \u001b[0;36mOptimizer._hook_for_profile..profile_hook_step..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 86\u001b[0m profile_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOptimizer.step#\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m.step\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(obj\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m)\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mrecord_function(profile_name):\n\u001b[0;32m---> 88\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/DLenv2/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27\u001b[0m, in \u001b[0;36m_DecoratorContextManager.__call__..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclone():\n\u001b[0;32m---> 27\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/DLenv2/lib/python3.9/site-packages/torch/optim/adam.py:141\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;66;03m# record the step after step update\u001b[39;00m\n\u001b[1;32m 139\u001b[0m state_steps\u001b[38;5;241m.\u001b[39mappend(state[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstep\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[0;32m--> 141\u001b[0m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madam\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams_with_grad\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 142\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 143\u001b[0m \u001b[43m \u001b[49m\u001b[43mexp_avgs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 144\u001b[0m \u001b[43m \u001b[49m\u001b[43mexp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_exp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 146\u001b[0m \u001b[43m \u001b[49m\u001b[43mstate_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 147\u001b[0m \u001b[43m \u001b[49m\u001b[43mamsgrad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mamsgrad\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 148\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta1\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 149\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta2\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 150\u001b[0m \u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mlr\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[43m \u001b[49m\u001b[43mweight_decay\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mweight_decay\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 152\u001b[0m \u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43meps\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 153\u001b[0m \u001b[43m \u001b[49m\u001b[43mmaximize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmaximize\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/DLenv2/lib/python3.9/site-packages/torch/optim/_functional.py:94\u001b[0m, in \u001b[0;36madam\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize)\u001b[0m\n\u001b[1;32m 91\u001b[0m bias_correction2 \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m beta2 \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m step\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m weight_decay \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m---> 94\u001b[0m grad \u001b[38;5;241m=\u001b[39m \u001b[43mgrad\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparam\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mweight_decay\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 96\u001b[0m \u001b[38;5;66;03m# Decay the first and second moment running average coefficient\u001b[39;00m\n\u001b[1;32m 97\u001b[0m exp_avg\u001b[38;5;241m.\u001b[39mmul_(beta1)\u001b[38;5;241m.\u001b[39madd_(grad, alpha\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m beta1)\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "base_distribution = D.Normal(torch.zeros(z_dim), torch.ones(z_dim))\n", + "\n", + "for e in range(nepochs):\n", + " for x in loader:\n", + " \n", + " optimizer.zero_grad()\n", + " \n", + " output = model(x[0].unsqueeze(1).to(device))\n", + " \n", + " loss = base_distribution.log_prob(output.cpu()).to(device)\n", + " loss = -loss.mean()\n", + " loss.backward()\n", + " optimizer.step() \n", + " \n", + " print('epoch ',e, loss.item())\n", + "\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "40b99ce9-7bb8-4f8c-b218-339ddafd0747", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.9507, grad_fn=)" + ] + }, + "execution_count": 79, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loss" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "id": "57deecda-8bd5-41d1-a176-359f130df5b7", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[0.9970],\n", + " [0.9916],\n", + " [1.0080],\n", + " [1.0096],\n", + " [0.9892],\n", + " [1.0066],\n", + " [0.9949],\n", + " [1.0210],\n", + " [1.0041],\n", + " [0.9981],\n", + " [1.0033],\n", + " [0.9885],\n", + " [1.0145],\n", + " [1.0075],\n", + " [1.0223],\n", + " [0.9913],\n", + " [1.0128],\n", + " [0.9889],\n", + " [1.0095],\n", + " [0.9927],\n", + " [0.9925],\n", + " [1.0221],\n", + " [1.0108],\n", + " [1.0112],\n", + " [1.0038],\n", + " [0.9891],\n", + " [0.9887],\n", + " [0.9949],\n", + " [0.9885],\n", + " [1.0030],\n", + " [1.0090],\n", + " [1.0156],\n", + " [0.9911],\n", + " [1.0166],\n", + " [0.9888],\n", + " [1.0066],\n", + " [1.0073],\n", + " [1.0067],\n", + " [0.9943],\n", + " [1.0096],\n", + " [0.9915],\n", + " [1.0128],\n", + " [0.9921],\n", + " [1.0151],\n", + " [1.0113],\n", + " [1.0126],\n", + " [1.0006],\n", + " [1.0071],\n", + " [1.0146],\n", + " [0.9968],\n", + " [1.0029],\n", + " [1.0010],\n", + " [1.0244],\n", + " [0.9951],\n", + " [0.9949],\n", + " [1.0102],\n", + " [0.9993],\n", + " [1.0032],\n", + " [0.9935],\n", + " [0.9942],\n", + " [1.0088],\n", + " [1.0008],\n", + " [1.0117],\n", + " [0.9902],\n", + " [1.0163],\n", + " [0.9884],\n", + " [1.0151],\n", + " [0.9895],\n", + " [0.9914],\n", + " [0.9977],\n", + " [0.9885],\n", + " [0.9928],\n", + " [0.9920],\n", + " [1.0260],\n", + " [1.0192],\n", + " [1.0039],\n", + " [0.9887],\n", + " [1.0172],\n", + " [1.0190],\n", + " [1.0116],\n", + " [1.0196],\n", + " [1.0089],\n", + " [0.9887],\n", + " [0.9902],\n", + " [0.9925],\n", + " [1.0005],\n", + " [0.9924],\n", + " [1.0215],\n", + " [1.0221],\n", + " [0.9885],\n", + " [1.0191],\n", + " [1.0151],\n", + " [1.0299],\n", + " [0.9932],\n", + " [0.9931],\n", + " [0.9895],\n", + " [1.0130],\n", + " [1.0238],\n", + " [1.0094],\n", + " [1.0108],\n", + " [0.9907],\n", + " [0.9910],\n", + " [0.9937],\n", + " [0.9900],\n", + " [0.9992],\n", + " [1.0072],\n", + " [0.9916],\n", + " [1.0194],\n", + " [1.0025],\n", + " [1.0133],\n", + " [1.0039],\n", + " [1.0015],\n", + " [1.0068],\n", + " [1.0010],\n", + " [1.0045],\n", + " [0.9930],\n", + " [1.0242],\n", + " [1.0062],\n", + " [1.0292],\n", + " [0.9945],\n", + " [0.9887],\n", + " [1.0163],\n", + " [1.0137],\n", + " [1.0147],\n", + " [1.0181],\n", + " [1.0111],\n", + " [1.0234],\n", + " [1.0001],\n", + " [1.0123],\n", + " [1.0025],\n", + " [1.0149],\n", + " [1.0201],\n", + " [1.0006],\n", + " [1.0049],\n", + " [0.9900],\n", + " [1.0222],\n", + " [1.0141],\n", + " [1.0128],\n", + " [1.0269],\n", + " [0.9897],\n", + " [1.0142],\n", + " [1.0018],\n", + " [1.0134],\n", + " [1.0003],\n", + " [0.9976],\n", + " [1.0122],\n", + " [1.0133],\n", + " [1.0175],\n", + " [1.0090],\n", + " [0.9900],\n", + " [0.9914],\n", + " [0.9954],\n", + " [1.0111],\n", + " [0.9926],\n", + " [1.0107],\n", + " [1.0039],\n", + " [1.0162],\n", + " [1.0009],\n", + " [0.9893],\n", + " [0.9886],\n", + " [1.0123],\n", + " [0.9892],\n", + " [1.0123],\n", + " [0.9942],\n", + " [1.0280],\n", + " [1.0013],\n", + " [1.0162],\n", + " [1.0118],\n", + " [0.9911],\n", + " [1.0116],\n", + " [1.0190],\n", + " [1.0093],\n", + " [1.0160],\n", + " [1.0304],\n", + " [0.9998],\n", + " [1.0076],\n", + " [0.9934],\n", + " [0.9984],\n", + " [1.0266],\n", + " [1.0161],\n", + " [1.0011],\n", + " [0.9920],\n", + " [0.9930],\n", + " [0.9954],\n", + " [1.0053],\n", + " [0.9911],\n", + " [1.0089],\n", + " [1.0089],\n", + " [1.0047],\n", + " [1.0155],\n", + " [1.0171],\n", + " [1.0071],\n", + " [1.0279],\n", + " [1.0241],\n", + " [1.0112],\n", + " [1.0052],\n", + " [0.9922],\n", + " [0.9921],\n", + " [0.9902],\n", + " [0.9894],\n", + " [1.0130],\n", + " [1.0140],\n", + " [1.0005],\n", + " [1.0161],\n", + " [1.0242],\n", + " [0.9938],\n", + " [1.0036],\n", + " [0.9885],\n", + " [1.0012],\n", + " [0.9917],\n", + " [0.9964],\n", + " [1.0080],\n", + " [0.9927],\n", + " [1.0123],\n", + " [0.9885],\n", + " [0.9915],\n", + " [0.9894],\n", + " [1.0004],\n", + " [1.0148],\n", + " [0.9885],\n", + " [0.9894],\n", + " [1.0306],\n", + " [0.9989],\n", + " [1.0092],\n", + " [1.0174],\n", + " [1.0064],\n", + " [1.0056],\n", + " [0.9907],\n", + " [1.0120],\n", + " [1.0086],\n", + " [0.9919],\n", + " [0.9958],\n", + " [0.9935],\n", + " [0.9890],\n", + " [1.0046],\n", + " [1.0200],\n", + " [1.0168],\n", + " [1.0193],\n", + " [1.0082],\n", + " [1.0081],\n", + " [1.0119],\n", + " [0.9885],\n", + " [1.0064],\n", + " [1.0245],\n", + " [1.0090],\n", + " [0.9966],\n", + " [1.0105],\n", + " [1.0213],\n", + " [1.0094],\n", + " [0.9895],\n", + " [1.0258],\n", + " [1.0028],\n", + " [1.0096],\n", + " [1.0165],\n", + " [1.0139],\n", + " [0.9919],\n", + " [1.0203],\n", + " [0.9902],\n", + " [1.0255],\n", + " [1.0126],\n", + " [1.0119],\n", + " [1.0091],\n", + " [1.0109],\n", + " [1.0093],\n", + " [1.0092],\n", + " [1.0114],\n", + " [0.9976],\n", + " [1.0145],\n", + " [1.0036],\n", + " [1.0166],\n", + " [1.0061],\n", + " [0.9901],\n", + " [1.0007],\n", + " [0.9888],\n", + " [0.9919],\n", + " [0.9905],\n", + " [1.0048],\n", + " [1.0268],\n", + " [0.9986],\n", + " [1.0288],\n", + " [1.0049],\n", + " [1.0252],\n", + " [0.9943],\n", + " [1.0019],\n", + " [0.9895],\n", + " [0.9956],\n", + " [1.0120],\n", + " [0.9917],\n", + " [0.9929],\n", + " [0.9885],\n", + " [1.0199],\n", + " [1.0069],\n", + " [1.0103],\n", + " [0.9887],\n", + " [1.0091],\n", + " [1.0149],\n", + " [1.0025],\n", + " [1.0291],\n", + " [0.9962],\n", + " [0.9968],\n", + " [1.0026],\n", + " [1.0129],\n", + " [0.9952],\n", + " [0.9994],\n", + " [0.9955],\n", + " [0.9959],\n", + " [0.9887],\n", + " [1.0178],\n", + " [1.0053],\n", + " [1.0167],\n", + " [1.0225],\n", + " [0.9982],\n", + " [0.9903],\n", + " [1.0007],\n", + " [1.0106],\n", + " [1.0052],\n", + " [1.0106],\n", + " [1.0097],\n", + " [1.0013],\n", + " [1.0160],\n", + " [0.9967],\n", + " [0.9931],\n", + " [1.0100],\n", + " [0.9893],\n", + " [1.0076],\n", + " [0.9893],\n", + " [1.0178],\n", + " [1.0060],\n", + " [0.9914],\n", + " [1.0051],\n", + " [1.0088],\n", + " [1.0092],\n", + " [1.0132],\n", + " [0.9895],\n", + " [1.0106],\n", + " [0.9893],\n", + " [1.0037],\n", + " [1.0116],\n", + " [1.0253],\n", + " [0.9914],\n", + " [1.0140],\n", + " [1.0188],\n", + " [1.0078],\n", + " [1.0077],\n", + " [1.0144],\n", + " [1.0066],\n", + " [0.9885],\n", + " [0.9998],\n", + " [1.0006],\n", + " [1.0017],\n", + " [1.0215],\n", + " [1.0127],\n", + " [1.0079],\n", + " [1.0105],\n", + " [0.9945],\n", + " [1.0002],\n", + " [1.0019],\n", + " [1.0253],\n", + " [0.9968],\n", + " [1.0103],\n", + " [1.0172],\n", + " [0.9893],\n", + " [1.0151],\n", + " [1.0161],\n", + " [0.9886],\n", + " [1.0117],\n", + " [0.9979],\n", + " [1.0238],\n", + " [1.0024],\n", + " [1.0097],\n", + " [0.9991],\n", + " [1.0186],\n", + " [1.0243],\n", + " [0.9965],\n", + " [0.9887],\n", + " [0.9888],\n", + " [1.0137],\n", + " [1.0066],\n", + " [0.9918],\n", + " [0.9967],\n", + " [1.0045],\n", + " [1.0295],\n", + " [1.0081],\n", + " [1.0296],\n", + " [0.9916],\n", + " [0.9885],\n", + " [1.0191],\n", + " [1.0212],\n", + " [1.0208],\n", + " [1.0129],\n", + " [1.0012],\n", + " [1.0065],\n", + " [1.0064],\n", + " [0.9978],\n", + " [0.9956],\n", + " [1.0156],\n", + " [1.0086],\n", + " [1.0067],\n", + " [1.0005],\n", + " [1.0160],\n", + " [1.0033],\n", + " [1.0070],\n", + " [1.0082],\n", + " [0.9993],\n", + " [1.0124],\n", + " [1.0223],\n", + " [1.0082],\n", + " [1.0116],\n", + " [1.0153],\n", + " [1.0167],\n", + " [1.0009],\n", + " [0.9939],\n", + " [0.9998],\n", + " [1.0113],\n", + " [1.0157],\n", + " [0.9887],\n", + " [1.0076],\n", + " [1.0024],\n", + " [0.9892],\n", + " [1.0115],\n", + " [1.0050],\n", + " [0.9885],\n", + " [0.9976],\n", + " [0.9901],\n", + " [0.9986],\n", + " [0.9953],\n", + " [0.9902],\n", + " [0.9895],\n", + " [1.0190],\n", + " [0.9924],\n", + " [0.9891],\n", + " [1.0197],\n", + " [0.9910],\n", + " [1.0096],\n", + " [1.0097],\n", + " [0.9893],\n", + " [1.0137],\n", + " [1.0060],\n", + " [1.0186],\n", + " [1.0189],\n", + " [1.0228],\n", + " [0.9887],\n", + " [1.0130],\n", + " [1.0138],\n", + " [0.9885],\n", + " [0.9901],\n", + " [0.9935],\n", + " [0.9937],\n", + " [0.9965],\n", + " [0.9899],\n", + " [0.9907],\n", + " [0.9891],\n", + " [1.0151],\n", + " [0.9926],\n", + " [0.9908],\n", + " [1.0120],\n", + " [1.0169],\n", + " [1.0109],\n", + " [1.0011],\n", + " [0.9950],\n", + " [1.0035],\n", + " [1.0086],\n", + " [0.9914],\n", + " [1.0056],\n", + " [1.0168],\n", + " [0.9897],\n", + " [0.9944],\n", + " [1.0207],\n", + " [1.0028],\n", + " [1.0000],\n", + " [1.0156],\n", + " [0.9914],\n", + " [1.0191],\n", + " [0.9995],\n", + " [1.0261],\n", + " [0.9915],\n", + " [1.0086],\n", + " [0.9916],\n", + " [0.9941],\n", + " [0.9896],\n", + " [0.9981],\n", + " [0.9934],\n", + " [1.0134],\n", + " [0.9928],\n", + " [0.9889],\n", + " [1.0098],\n", + " [0.9937],\n", + " [1.0140],\n", + " [1.0163],\n", + " [0.9960],\n", + " [0.9885],\n", + " [1.0008],\n", + " [1.0007],\n", + " [1.0204],\n", + " [0.9950],\n", + " [1.0292],\n", + " [1.0106],\n", + " [0.9886],\n", + " [0.9894],\n", + " [1.0085],\n", + " [0.9983],\n", + " [1.0120],\n", + " [1.0156],\n", + " [1.0171],\n", + " [0.9919],\n", + " [1.0134],\n", + " [1.0044],\n", + " [0.9990],\n", + " [1.0187],\n", + " [1.0171],\n", + " [0.9893],\n", + " [1.0013],\n", + " [0.9967],\n", + " [1.0267],\n", + " [0.9908],\n", + " [1.0025],\n", + " [1.0145],\n", + " [0.9886],\n", + " [1.0157],\n", + " [1.0136],\n", + " [1.0071],\n", + " [0.9936],\n", + " [1.0282],\n", + " [1.0228],\n", + " [0.9944],\n", + " [0.9901],\n", + " [1.0005],\n", + " [0.9962],\n", + " [0.9953],\n", + " [1.0040],\n", + " [0.9914],\n", + " [1.0087],\n", + " [1.0017],\n", + " [1.0055],\n", + " [0.9959],\n", + " [1.0153],\n", + " [0.9897],\n", + " [1.0072],\n", + " [0.9927],\n", + " [0.9885],\n", + " [1.0189],\n", + " [1.0147],\n", + " [1.0007],\n", + " [1.0035],\n", + " [1.0130],\n", + " [0.9953],\n", + " [0.9886],\n", + " [1.0103],\n", + " [1.0103],\n", + " [0.9921],\n", + " [0.9976],\n", + " [0.9971],\n", + " [1.0137],\n", + " [0.9908],\n", + " [1.0159],\n", + " [0.9933],\n", + " [0.9981],\n", + " [0.9887],\n", + " [1.0149],\n", + " [1.0147],\n", + " [1.0048],\n", + " [0.9892],\n", + " [1.0303],\n", + " [1.0129],\n", + " [1.0152],\n", + " [1.0134],\n", + " [1.0163],\n", + " [1.0111],\n", + " [0.9886],\n", + " [1.0111],\n", + " [1.0257],\n", + " [0.9974],\n", + " [1.0171],\n", + " [0.9921],\n", + " [1.0065],\n", + " [1.0025],\n", + " [1.0162],\n", + " [0.9992],\n", + " [0.9919],\n", + " [0.9920],\n", + " [0.9899],\n", + " [1.0135],\n", + " [0.9956],\n", + " [0.9951],\n", + " [1.0102],\n", + " [1.0075],\n", + " [1.0032],\n", + " [0.9915],\n", + " [0.9919],\n", + " [1.0170],\n", + " [0.9886],\n", + " [0.9920],\n", + " [0.9980],\n", + " [0.9950],\n", + " [1.0181],\n", + " [1.0089],\n", + " [1.0128],\n", + " [0.9886],\n", + " [1.0158],\n", + " [1.0190],\n", + " [1.0089],\n", + " [0.9886],\n", + " [0.9902],\n", + " [0.9901],\n", + " [1.0018],\n", + " [0.9928],\n", + " [0.9989],\n", + " [1.0079],\n", + " [1.0158],\n", + " [0.9933],\n", + " [1.0089],\n", + " [1.0011],\n", + " [0.9892],\n", + " [0.9938],\n", + " [1.0001],\n", + " [0.9907],\n", + " [0.9970],\n", + " [1.0192],\n", + " [1.0078],\n", + " [1.0137],\n", + " [1.0233],\n", + " [1.0036],\n", + " [0.9887],\n", + " [1.0061],\n", + " [1.0099],\n", + " [1.0042],\n", + " [0.9983],\n", + " [1.0021],\n", + " [1.0110],\n", + " [0.9984],\n", + " [1.0205],\n", + " [1.0062],\n", + " [1.0078],\n", + " [1.0014],\n", + " [0.9927],\n", + " [0.9895],\n", + " [1.0151],\n", + " [1.0113],\n", + " [0.9987],\n", + " [1.0060],\n", + " [1.0258],\n", + " [1.0050],\n", + " [1.0216],\n", + " [1.0195],\n", + " [1.0120],\n", + " [1.0046],\n", + " [1.0130],\n", + " [1.0131],\n", + " [0.9925],\n", + " [1.0105],\n", + " [0.9971],\n", + " [1.0256],\n", + " [0.9927],\n", + " [1.0154],\n", + " [1.0036],\n", + " [1.0176],\n", + " [0.9885],\n", + " [1.0279],\n", + " [0.9911],\n", + " [0.9931],\n", + " [1.0144],\n", + " [1.0088],\n", + " [0.9966],\n", + " [1.0016],\n", + " [1.0166],\n", + " [1.0174],\n", + " [0.9975],\n", + " [0.9918],\n", + " [0.9915],\n", + " [1.0046],\n", + " [0.9886],\n", + " [1.0147],\n", + " [1.0001],\n", + " [1.0019],\n", + " [0.9955],\n", + " [0.9970],\n", + " [0.9918],\n", + " [0.9957],\n", + " [0.9893],\n", + " [1.0132],\n", + " [1.0154],\n", + " [1.0027],\n", + " [1.0146],\n", + " [1.0203],\n", + " [1.0052],\n", + " [1.0010],\n", + " [1.0160],\n", + " [1.0121],\n", + " [0.9893],\n", + " [0.9886],\n", + " [0.9976],\n", + " [1.0003],\n", + " [0.9894],\n", + " [1.0237],\n", + " [1.0106],\n", + " [0.9987],\n", + " [0.9931],\n", + " [1.0118],\n", + " [0.9886],\n", + " [1.0004],\n", + " [0.9902],\n", + " [1.0198],\n", + " [0.9919],\n", + " [1.0073],\n", + " [1.0060],\n", + " [0.9897],\n", + " [0.9890],\n", + " [1.0018],\n", + " [1.0080],\n", + " [1.0166],\n", + " [0.9938],\n", + " [1.0027],\n", + " [1.0034],\n", + " [0.9914],\n", + " [0.9966],\n", + " [1.0029],\n", + " [1.0284],\n", + " [1.0191],\n", + " [1.0194],\n", + " [1.0211],\n", + " [1.0001],\n", + " [1.0191],\n", + " [1.0120],\n", + " [1.0018],\n", + " [1.0142],\n", + " [0.9898],\n", + " [0.9894],\n", + " [1.0251],\n", + " [0.9962],\n", + " [1.0180],\n", + " [1.0248],\n", + " [0.9895],\n", + " [0.9999],\n", + " [1.0104],\n", + " [0.9920],\n", + " [0.9984],\n", + " [1.0186],\n", + " [0.9909],\n", + " [1.0053],\n", + " [1.0025],\n", + " [1.0018],\n", + " [0.9935],\n", + " [0.9902],\n", + " [1.0165],\n", + " [0.9933],\n", + " [1.0121],\n", + " [0.9886],\n", + " [1.0088],\n", + " [1.0143],\n", + " [1.0159],\n", + " [0.9993],\n", + " [1.0196],\n", + " [0.9976],\n", + " [1.0055],\n", + " [0.9919],\n", + " [1.0172],\n", + " [0.9913],\n", + " [0.9911],\n", + " [0.9998],\n", + " [0.9886],\n", + " [0.9913],\n", + " [0.9995],\n", + " [1.0062],\n", + " [1.0155],\n", + " [0.9988],\n", + " [0.9909],\n", + " [0.9985],\n", + " [0.9887],\n", + " [0.9922],\n", + " [0.9896],\n", + " [1.0023],\n", + " [1.0051],\n", + " [0.9894],\n", + " [0.9916],\n", + " [0.9941],\n", + " [0.9898],\n", + " [1.0025],\n", + " [0.9894],\n", + " [1.0275],\n", + " [0.9901],\n", + " [1.0061],\n", + " [1.0181],\n", + " [1.0052],\n", + " [1.0079],\n", + " [1.0069],\n", + " [1.0138],\n", + " [0.9908],\n", + " [0.9918],\n", + " [0.9900],\n", + " [0.9915],\n", + " [1.0200],\n", + " [0.9967],\n", + " [0.9946],\n", + " [0.9977],\n", + " [1.0127],\n", + " [1.0114],\n", + " [0.9960],\n", + " [1.0024],\n", + " [1.0016],\n", + " [1.0221],\n", + " [1.0089],\n", + " [1.0177],\n", + " [0.9926],\n", + " [1.0214],\n", + " [1.0022],\n", + " [1.0048],\n", + " [0.9941],\n", + " [1.0168],\n", + " [1.0156],\n", + " [1.0033],\n", + " [1.0148],\n", + " [0.9950],\n", + " [0.9909],\n", + " [1.0190],\n", + " [1.0082],\n", + " [1.0075],\n", + " [1.0101],\n", + " [1.0111],\n", + " [1.0195],\n", + " [1.0041],\n", + " [1.0064],\n", + " [0.9930],\n", + " [1.0155],\n", + " [1.0103],\n", + " [1.0200],\n", + " [0.9993],\n", + " [0.9887],\n", + " [1.0086],\n", + " [0.9894],\n", + " [1.0022],\n", + " [0.9952],\n", + " [1.0106],\n", + " [0.9940],\n", + " [0.9894],\n", + " [1.0153],\n", + " [1.0005],\n", + " [1.0132],\n", + " [0.9937],\n", + " [1.0160],\n", + " [1.0131],\n", + " [1.0294],\n", + " [1.0165],\n", + " [1.0155],\n", + " [0.9922],\n", + " [0.9895],\n", + " [1.0063],\n", + " [0.9926],\n", + " [0.9944],\n", + " [1.0081],\n", + " [1.0137],\n", + " [1.0087],\n", + " [0.9918],\n", + " [0.9915],\n", + " [1.0092],\n", + " [1.0287],\n", + " [0.9906],\n", + " [0.9902],\n", + " [0.9896],\n", + " [0.9888],\n", + " [1.0291],\n", + " [1.0004],\n", + " [0.9894],\n", + " [0.9950],\n", + " [0.9961],\n", + " [1.0126],\n", + " [1.0084],\n", + " [1.0117],\n", + " [1.0202],\n", + " [1.0149],\n", + " [1.0147],\n", + " [0.9951],\n", + " [0.9934],\n", + " [1.0161],\n", + " [1.0140],\n", + " [1.0168],\n", + " [1.0188],\n", + " [1.0185],\n", + " [1.0100],\n", + " [1.0082],\n", + " [1.0204],\n", + " [0.9895],\n", + " [1.0106],\n", + " [1.0057],\n", + " [0.9895],\n", + " [0.9982],\n", + " [1.0010],\n", + " [1.0037],\n", + " [1.0080],\n", + " [0.9973],\n", + " [1.0180],\n", + " [1.0108],\n", + " [1.0078],\n", + " [0.9961],\n", + " [1.0000],\n", + " [1.0201],\n", + " [0.9893],\n", + " [0.9945],\n", + " [1.0174],\n", + " [1.0153],\n", + " [0.9899],\n", + " [1.0185],\n", + " [1.0305],\n", + " [0.9910],\n", + " [1.0147],\n", + " [1.0129],\n", + " [1.0104],\n", + " [1.0295],\n", + " [1.0130],\n", + " [0.9952],\n", + " [1.0135],\n", + " [0.9940],\n", + " [0.9901],\n", + " [1.0163],\n", + " [0.9893],\n", + " [1.0063],\n", + " [1.0091],\n", + " [1.0176],\n", + " [1.0226],\n", + " [0.9899],\n", + " [1.0023],\n", + " [1.0080],\n", + " [0.9926],\n", + " [1.0151],\n", + " [1.0185],\n", + " [1.0077],\n", + " [1.0113],\n", + " [0.9898],\n", + " [1.0042],\n", + " [0.9888],\n", + " [0.9945],\n", + " [1.0078],\n", + " [1.0082],\n", + " [0.9895],\n", + " [0.9886],\n", + " [1.0094],\n", + " [0.9896],\n", + " [1.0039],\n", + " [0.9896],\n", + " [0.9993],\n", + " [1.0217],\n", + " [0.9913],\n", + " [0.9887],\n", + " [0.9918],\n", + " [0.9922],\n", + " [0.9895],\n", + " [1.0186],\n", + " [0.9972],\n", + " [1.0097],\n", + " [1.0158],\n", + " [0.9996],\n", + " [0.9895],\n", + " [0.9907],\n", + " [1.0044],\n", + " [1.0167],\n", + " [0.9940],\n", + " [1.0066],\n", + " [1.0065],\n", + " [0.9943],\n", + " [0.9903],\n", + " [1.0094],\n", + " [1.0067],\n", + " [0.9895],\n", + " [1.0130],\n", + " [1.0171],\n", + " [0.9914],\n", + " [0.9992],\n", + " [1.0053],\n", + " [0.9964],\n", + " [1.0111],\n", + " [0.9917],\n", + " [0.9888],\n", + " [0.9960],\n", + " [0.9912],\n", + " [1.0021],\n", + " [0.9896],\n", + " [0.9932],\n", + " [0.9915],\n", + " [0.9922],\n", + " [0.9914],\n", + " [1.0132],\n", + " [0.9902],\n", + " [0.9904],\n", + " [1.0292],\n", + " [1.0082],\n", + " [1.0077],\n", + " [1.0159],\n", + " [1.0118],\n", + " [1.0162],\n", + " [1.0042],\n", + " [1.0155],\n", + " [1.0032],\n", + " [0.9900],\n", + " [1.0169],\n", + " [0.9894],\n", + " [1.0118],\n", + " [0.9975],\n", + " [0.9982],\n", + " [0.9926],\n", + " [1.0108],\n", + " [0.9944]], grad_fn=)\n" + ] + } + ], + "source": [ + "\n", + "# Generate random input\n", + "random_input = torch.rand(1000, input_dim) # You can adjust the batch size\n", + "output = model(random_input)\n", + "\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "6c104591-885b-4a5d-bbf0-114a623e2bfb", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0., 0., 0., 0., 0., 1000., 0., 0., 0.,\n", + " 0.]),\n", + " array([-4.9999651e-01, -3.9999652e-01, -2.9999653e-01, -1.9999652e-01,\n", + " -9.9996522e-02, 3.4766272e-06, 1.0000347e-01, 2.0000347e-01,\n", + " 3.0000347e-01, 4.0000346e-01, 5.0000346e-01], dtype=float32),\n", + " )" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(output.detach().cpu().numpy()[:,0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6bb688e-0fdf-44e5-9312-fd1521920e5d", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7d8ce915-82b3-40da-8037-995c60ec5653", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15a26ff5-fac9-4e9e-b122-ae81fff3ae0f", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40ecf26c-e478-4303-bff7-d83fbffa7bd7", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "fb032c0d-4448-4985-8b96-ddc4dfe96a3e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from torch import nn\n", + "\n", + "\n", + "class RBF(nn.Module):\n", + "\n", + " def __init__(self, n_kernels=5, mul_factor=2.0, bandwidth=None):\n", + " super().__init__()\n", + " self.bandwidth_multipliers = mul_factor ** (torch.arange(n_kernels) - n_kernels // 2)\n", + " self.bandwidth = bandwidth\n", + "\n", + " def get_bandwidth(self, L2_distances):\n", + " if self.bandwidth is None:\n", + " n_samples = L2_distances.shape[0]\n", + " return L2_distances.data.sum() / (n_samples ** 2 - n_samples)\n", + "\n", + " return self.bandwidth\n", + "\n", + " def forward(self, X):\n", + " L2_distances = torch.cdist(X, X) ** 2\n", + " return torch.exp(-L2_distances[None, ...] / (self.get_bandwidth(L2_distances) * self.bandwidth_multipliers)[:, None, None]).sum(dim=0)\n", + "\n", + "\n", + "class MMDLoss(nn.Module):\n", + "\n", + " def __init__(self, kernel=RBF()):\n", + " super().__init__()\n", + " self.kernel = kernel\n", + "\n", + " def forward(self, X, Y):\n", + " K = self.kernel(torch.vstack([X, Y]))\n", + "\n", + " X_size = X.shape[0]\n", + " XX = K[:X_size, :X_size].mean()\n", + " XY = K[:X_size, X_size:].mean()\n", + " YY = K[X_size:, X_size:].mean()\n", + " return XX - 2 * XY + YY" + ] + }, + { + "cell_type": "code", + "execution_count": 119, + "id": "b5fa6136-7df6-4cf0-bc58-1ffc2bd5b3ec", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "input_dim = 1\n", + "hidden_dim = 128\n", + "output_dim = 1\n", + "num_hidden_layers = 3\n", + "\n", + "model = RandomToNormalNN(input_dim, hidden_dim, output_dim, num_hidden_layers)\n", + "\n", + "nepochs=1000\n", + "bs = 1000\n", + "z_dim=1\n", + "epsilon = 0.5\n", + "\n", + "# Create an instance of the neural network\n", + "base_distribution = D.Normal(torch.zeros(z_dim), torch.ones(z_dim))\n", + "mmd = MMDLoss()\n", + "\n", + "optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 120, + "id": "752b4aac-3cd0-4d43-8abf-50a4a45d59a6", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[120], line 10\u001b[0m\n\u001b[1;32m 6\u001b[0m randomN_input \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn(bs, input_dim)\n\u001b[1;32m 8\u001b[0m output \u001b[38;5;241m=\u001b[39m model(random_input)\n\u001b[0;32m---> 10\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1000\u001b[39m\u001b[38;5;241m*\u001b[39m\u001b[43mmmd\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutput\u001b[49m\u001b[43m,\u001b[49m\u001b[43mrandomN_input\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m-\u001b[39mbase_distribution\u001b[38;5;241m.\u001b[39mlog_prob(output)\u001b[38;5;241m.\u001b[39mmean()\n\u001b[1;32m 11\u001b[0m \u001b[38;5;66;03m#loss = loss.mean()\u001b[39;00m\n\u001b[1;32m 12\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/DLenv2/lib/python3.9/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "Cell \u001b[0;32mIn[81], line 33\u001b[0m, in \u001b[0;36mMMDLoss.forward\u001b[0;34m(self, X, Y)\u001b[0m\n\u001b[1;32m 30\u001b[0m K \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mkernel(torch\u001b[38;5;241m.\u001b[39mvstack([X, Y]))\n\u001b[1;32m 32\u001b[0m X_size \u001b[38;5;241m=\u001b[39m X\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 33\u001b[0m XX \u001b[38;5;241m=\u001b[39m \u001b[43mK\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43mX_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43mX_size\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 34\u001b[0m XY \u001b[38;5;241m=\u001b[39m K[:X_size, X_size:]\u001b[38;5;241m.\u001b[39mmean()\n\u001b[1;32m 35\u001b[0m YY \u001b[38;5;241m=\u001b[39m K[X_size:, X_size:]\u001b[38;5;241m.\u001b[39mmean()\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "for e in range(nepochs):\n", + " \n", + " optimizer.zero_grad()\n", + " \n", + " random_input = torch.rand(bs, input_dim)\n", + " randomN_input = torch.randn(bs, input_dim)\n", + "\n", + " output = model(random_input)\n", + " \n", + " loss = mmd(output,randomN_input) #-base_distribution.log_prob(output).mean()\n", + " #loss = loss.mean()\n", + " loss.backward()\n", + " optimizer.step() \n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 116, + "id": "91e6f895-a161-46d1-80ec-c93ceaba765f", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 99., 102., 110., 89., 105., 100., 100., 101., 105., 89.]),\n", + " array([4.2331219e-04, 1.0035120e-01, 2.0027909e-01, 3.0020696e-01,\n", + " 4.0013486e-01, 5.0006270e-01, 5.9999061e-01, 6.9991851e-01,\n", + " 7.9984641e-01, 8.9977425e-01, 9.9970216e-01], dtype=float32),\n", + " )" + ] + }, + "execution_count": 116, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAGcCAYAAADknMuyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaU0lEQVR4nO3df2xdd33/8ZeTuo7Q4lu1AxIHU8caaYEVOkazdWlQbGkD0SrKNMEyfqRjaBDUTUMZo81atgYyUtjUNYtgUjqNioLU0UEUpgpSgbwlke4froqwK0GHqZ2lLliUdNddmzhO7v3+wTcWpklw4dofX/fxkK7Kvff4nDcfWb3PnnN9b1uj0WgEAKCAZaUHAABeuoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoJhLSg/w89Tr9Tz11FNZuXJl2traSo8DAMxBo9HIs88+m66urixbduHzHos+RJ566ql0d3eXHgMA+AUcP348r3rVqy74/KIPkZUrVyb5yf+Rzs7OwtMAAHMxOTmZ7u7umdfxC1n0IXLuckxnZ6cQAYAW8/PeVuHNqgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFDMJaUHYOnrue2h0iP8QsbuurH0CABLnjMiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUMwlpQcAmqfntodKj/Cijd11Y+kRWKT8Pr80OCMCABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAU45NVW0wrftIgAFzIL3xG5OTJk6nVas2cBQB4iXnRIVKv13Pfffdl3bp1+da3vjXz+OjoaLZv3579+/dn27ZtOXbs2JyeAwBeul70pZmnn346fX19efLJJ2ceq9fr2bx5c/bu3Zv+/v6sXbs2W7duTbVavehzACwMl3VZrF70GZFXvOIVufLKK2c9dujQoYyMjGTjxo1Jkv7+/gwNDWVwcPCizwEAL21NebNqtVpNb29v2tvbkyTLly9Pb29vBgYG8n//938XfO666657wb6mpqYyNTU1c39ycrIZIwIAi1BTQmRiYiKdnZ2zHqtUKhkfH8+pU6cu+Nz57NmzJ7t27WrGWEALcMkAXtqa8jki7e3tM2c8zqnX66nX6xd97nx27tyZWq02czt+/HgzRgQAFqGmhMjq1atfcAmlVqtlzZo1F33ufDo6OtLZ2TnrBgAsTU25NLNp06Z8+tOfTqPRSFtbW6anpzM2Npa+vr6cOXPmgs+V5pQwAJT1C50R+dnLKtdff326urpy5MiRJMnhw4fT09OT9evXX/Q5AOCl7UWfEfnRj36Ue++9N0nyhS98IatWrcrVV1+dgwcPZvfu3RkeHk61Ws2BAwfS1taWtra2Cz4HALy0tTUajUbpIS5mcnIylUoltVqt6e8XcWmGixm768bSI7xofqehrFb898Z8mevrt2/fBQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMVcUnoAWKx6bnuo9AhAi2nFf2+M3XVj0eM7IwIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKCYS5q5s8OHD+cb3/hGfvVXfzWDg4O5/fbbc/XVV2d0dDSf+tSn8qY3vSlHjx7NJz7xiVx55ZXNPDQA0IKaFiJnz57N+973vjz++OO55JJL8l//9V/5sz/7szz88MPZvHlz9u7dm/7+/qxduzZbt25NtVpt1qEBgBbVtEszJ06cyFNPPZXnn38+SVKpVPLMM8/k0KFDGRkZycaNG5Mk/f39GRoayuDgYLMODQC0qKaFyMtf/vL85m/+Zt7znvfkf//3f/NP//RP2bVrV6rVanp7e9Pe3p4kWb58eXp7ezMwMHDe/UxNTWVycnLWDQBYmpr6ZtUHH3wwIyMjWb16dd761rfmpptuysTERDo7O2dtV6lUMj4+ft597NmzJ5VKZebW3d3dzBEBgEWkqW9WnZiYyNvf/vYcO3YsN998cy677LK0t7fPnA05p16vp16vn3cfO3fuzI4dO2buT05OihEAWKKaFiLPP/983vOe9+SRRx7JihUrcscdd+T9739/PvShD+Xo0aOztq3ValmzZs1599PR0ZGOjo5mjQUALGJNuzTz2GOPZeXKlVmxYkWS5M4778yzzz6bG264IaOjo2k0GkmS6enpjI2Npa+vr1mHBgBaVNNC5Nd+7dcyPj6e5557Lkly+vTprFq1KjfccEO6urpy5MiRJD/5rJGenp6sX7++WYcGAFpU0y7NXH755bn33nvz53/+57nmmmvy5JNP5v7778/y5ctz8ODB7N69O8PDw6lWqzlw4EDa2tqadWgAoEW1Nc5dM1mkJicnU6lUUqvVXvDXN7+sntseaur+AKDVjN1147zsd66v375rBgAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABRzyXzs9Dvf+U4OHDiQ7u7ubNmyJStXrpyPwwAALa7pIfKZz3wmDzzwQB544IGsWbMmSTI6OppPfepTedOb3pSjR4/mE5/4RK688spmHxoAaDFNDZGvfOUr+fjHP57HHnssL3/5y5Mk9Xo9mzdvzt69e9Pf35+1a9dm69atqVarzTw0ANCCmvYekTNnzuTDH/5wPvKRj8xESJIcOnQoIyMj2bhxY5Kkv78/Q0NDGRwcbNahAYAW1bQQOXz4cI4fP57HH388W7ZsyWtf+9o88MADqVar6e3tTXt7e5Jk+fLl6e3tzcDAwHn3MzU1lcnJyVk3AGBpatqlmaGhoVx22WX59Kc/ncsvvzxf//rXs3nz5vT19aWzs3PWtpVKJePj4+fdz549e7Jr165mjQUALGJNOyNy8uTJvPa1r83ll1+eJHnb296WV77ylTl69OjM2ZBz6vV66vX6efezc+fO1Gq1mdvx48ebNSIAsMg0LURWrVqV5557btZjr3rVq/LRj370BZdXarXazF/U/KyOjo50dnbOugEAS1PTQmTDhg0ZGxvLmTNnZh47depUkp/8+W6j0UiSTE9PZ2xsLH19fc06NADQopoWIuvWrcu1116bhx9+OEly4sSJPP300/mrv/qrdHV15ciRI0l+8qbWnp6erF+/vlmHBgBaVFM/R+T+++/PRz7ykQwNDWV0dDQPPvhgXvayl+XgwYPZvXt3hoeHU61Wc+DAgbS1tTXz0ABAC2prnLtmskhNTk6mUqmkVqs1/f0iPbc91NT9AUCrGbvrxnnZ71xfv33pHQBQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxTQ9RJ5//vm87nWvy9jYWJJkdHQ027dvz/79+7Nt27YcO3as2YcEAFrUJc3e4b59+/Kd73wnSVKv17N58+bs3bs3/f39Wbt2bbZu3ZpqtdrswwIALaipZ0QOHjyYvr6+mfuHDh3KyMhINm7cmCTp7+/P0NBQBgcHm3lYAKBFNS1E/ud//ic/+MEPsn79+pnHqtVqent7097eniRZvnx5ent7MzAwcMH9TE1NZXJyctYNAFiamhIiZ8+ezb333psPfOADsx6fmJhIZ2fnrMcqlUrGx8cvuK89e/akUqnM3Lq7u5sxIgCwCDUlRD7zmc/kgx/8YJYtm7279vb2mbMh59Tr9dTr9Qvua+fOnanVajO348ePN2NEAGARasqbVfft25ePfvSjsx676qqrUq/X8/rXv37W47VaLWvWrLngvjo6OtLR0dGMsQCARa4pIfK9731v1v22trY8/vjjGR8fz9vf/vY0Go20tbVleno6Y2Njs97QCgC8dM3rB5pdf/316erqypEjR5Ikhw8fTk9Pz6w3tAIAL11N/xyRn7Zs2bIcPHgwu3fvzvDwcKrVag4cOJC2trb5PCwA0CLmJUQajcbM/163bl0+//nPJ0luueWW+TgcANCifNcMAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFNC1EvvrVr+aqq65KZ2dn/uAP/iAnTpxIkoyOjmb79u3Zv39/tm3blmPHjjXrkABAi2tKiDzxxBN56KGH8pWvfCX33Xdf/vM//zO33npr6vV6Nm/enHe+8535wAc+kPe+973ZunVrMw4JACwBlzRjJ0ePHs2+ffty6aWX5vWvf32Ghoby4IMP5tChQxkZGcnGjRuTJP39/dmyZUsGBwdz3XXXNePQAEALa8oZkW3btuXSSy+duf/KV74yr371q1OtVtPb25v29vYkyfLly9Pb25uBgYEL7mtqaiqTk5OzbgDA0jQvb1Z99NFH88EPfjATExPp7Oyc9VylUsn4+PgFf3bPnj2pVCozt+7u7vkYEQBYBJoeIj/4wQ9y5syZbNmyJe3t7TNnQ86p1+up1+sX/PmdO3emVqvN3I4fP97sEQGARaIp7xE55+zZs7nnnnuyb9++JMnq1atz9OjRWdvUarWsWbPmgvvo6OhIR0dHM8cCABappp4R+cd//Mfs2LEjv/Irv5IkueGGGzI6OppGo5EkmZ6eztjYWPr6+pp5WACgRTXtjMg999yTdevW5ZlnnskzzzyTJ554ImfOnElXV1eOHDmSt7zlLTl8+HB6enqyfv36Zh0WAGhhTQmRL33pS9mxY8fMmY8kednLXpYf/vCHOXjwYHbv3p3h4eFUq9UcOHAgbW1tzTgsANDi2ho/XQ+L0OTkZCqVSmq12gv+AueX1XPbQ03dHwC0mrG7bpyX/c719dt3zQAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUMyChMjo6Gi2b9+e/fv3Z9u2bTl27NhCHBYAWOQume8D1Ov1bN68OXv37k1/f3/Wrl2brVu3plqtzvehAYBFbt7PiBw6dCgjIyPZuHFjkqS/vz9DQ0MZHByc70MDAIvcvJ8RqVar6e3tTXt7e5Jk+fLl6e3tzcDAQK677roXbD81NZWpqamZ+7VaLUkyOTnZ9NnqU883fZ8A0Erm4/X1p/fbaDQuut28h8jExEQ6OztnPVapVDI+Pn7e7ffs2ZNdu3a94PHu7u55mQ8AXsoq98zv/p999tlUKpULPj/vIdLe3j5zNuScer2eer1+3u137tyZHTt2zNr2xIkTueKKK9LW1ta0uSYnJ9Pd3Z3jx4+/IJRoHuu8cKz1wrDOC8M6L4z5XOdGo5Fnn302XV1dF91u3kNk9erVOXr06KzHarVa1qxZc97tOzo60tHRMeuxyy67bL7GS2dnp1/yBWCdF461XhjWeWFY54UxX+t8sTMh58z7m1U3bdqU0dHRmWtE09PTGRsbS19f33wfGgBY5OY9RK6//vp0dXXlyJEjSZLDhw+np6cn69evn+9DAwCL3Lxfmlm2bFkOHjyY3bt3Z3h4ONVqNQcOHGjq+z1+ER0dHfnbv/3bF1wGorms88Kx1gvDOi8M67wwFsM6tzV+3t/VAADME981AwAUI0QAgGKECABQzLy/WRVgKZuens4Xv/jFnDhxIhs2bMhv/dZvlR4JLujkyZM5ffr0nD7fY6Es+TMio6Oj2b59e/bv359t27bl2LFj593unnvuySc/+cncfvvtufvuuxd4ytY3l3WenJzMu9/97lx22WXp7e3Nv/3bvxWYtLXN9ff5nM997nP54z/+44UZbgmZ6zpPTExkw4YNOXXqVHbs2CFCfgFzWevp6enccccd+exnP5udO3fm4x//eIFJW1u9Xs99992XdevW5Vvf+tYFtyvyWthYws6ePdv49V//9cY3v/nNRqPRaDz88MON3/7t337Bdl/60pcaGzdunLn/O7/zO42vfe1rCzZnq5vrOt96662N//iP/2h8+9vfbrz73e9utLe3N5544omFHrdlzXWdz3n88ccbb3zjGxs333zzAk24NMx1nU+fPt1485vf3LjzzjsXesQlY65rfffddzf+/u//fub+pk2bGkeOHFmwOZeCiYmJxtjYWCNJY2Bg4LzblHotXNJnRA4dOpSRkZFs3LgxSdLf35+hoaEMDg7O2u4f/uEf8nu/93sz92+66abs27dvQWdtZXNZ5+np6bzuda/LTTfdlDe84Q35l3/5lyxbtiyPPPJIqbFbzlx/n5Pk9OnTeeCBB7Jly5YFnrL1zXWd//Vf/zXf+9738td//dclxlwS5rrWIyMjOXHixMz9SqWSZ555ZkFnbXWveMUrcuWVV150m1KvhUs6RKrVanp7e2e+dG/58uXp7e3NwMDAzDanT5/Oo48+mquuumrmsXXr1s3ahoubyzq3t7dn27ZtM/dXrFiRSqWSV7/61Qs+b6uayzqfs2/fvtxyyy0LPeKSMNd1/uIXv5jVq1dnx44dufbaa/P7v//7s14s+fnmutZbtmzJ3r1787WvfS2PPPJIzpw5k7e+9a0lRl6ySr4WLukQmZiYeMGX+FQqlYyPj8/c//GPf5wzZ87M2q5SqeTkyZOKe47mss4/68knn8yaNWtcU38R5rrODz/8cH7jN34jV1xxxUKOt2TMdZ2Hhobyjne8I/v27cvg4GCefvrp3HbbbQs5asub61r/7u/+bv7u7/4uN910U2655ZZ8+ctfzqWXXrqQoy55JV8Ll3SItLe3z5T2OfV6PfV6fdY2P/3Pc9v89D+5uLms88/653/+5+zfv3++R1tS5rLOExMTGR4eTn9//0KPt2TM9ff55MmTueGGG2Z+5uabb85DDz20YHMuBXNd60ajkR//+Mf55Cc/me9///vZvHlzpqenF3LUJa/ka+GSDpHVq1dncnJy1mO1Wi1r1qyZuX/FFVfk0ksvnbVdrVbLihUr/BflHM1lnX/awMBA3vCGN+TNb37zQoy3ZMxlnQ8dOpTbb789K1asyIoVK7J79+7cf//9WbFiRWq12kKP3JLm+vu8atWqPPfcczP3u7u7nUV9kea61nfffXcqlUpuvfXWPPLII3nsscfyuc99biFHXfJKvhYu6RDZtGlTRkdH0/j/X6czPT2dsbGx9PX1zWzT1taWt7zlLfn+978/89h///d/Z9OmTQs9bsuayzqf893vfjdPPPFE/vAP/zBJcubMmZmf4+Lmss7btm3LqVOnZm533HFH3vve9+bUqVOL6nMDFrO5/j5v2LAhIyMjM/dPnTqVnp6ehRy15c11rb/5zW/mmmuuSZL09PTkL/7iL/Ltb397weddykq+Fi7pELn++uvT1dWVI0eOJEkOHz6cnp6erF+/Prt27crw8HCS5H3ve9+sU6pf//rX8yd/8idFZm5Fc13nH/7wh/nsZz+bDRs25Lvf/W6GhoayZ8+ekqO3lLmuM7+cua7zn/7pn+bf//3fZ37u6NGjef/7319k5lY117V+4xvfmEcffXTm506ePJlrr722xMgt7XyXWBbDa+GS/mTVZcuW5eDBg9m9e3eGh4dTrVZz4MCBtLW15atf/WquueaaXHPNNXnXu96VsbGx/M3f/E3Onj2bt73tbXnHO95RevyWMZd1fs1rXpMbb7wxjz766Kw/B/vYxz6Wtra2gtO3jrn+PvPLmes69/X15Y/+6I/yoQ99KN3d3anX6/nwhz9cevyWMte1/tjHPpa//Mu/zJ133pmVK1fm7Nmzou9F+tGPfpR77703SfKFL3whq1atytVXX70oXgvbGs6LAwCFLOlLMwDA4iZEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFDM/wM7pWpKJrfwNwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(random_input.detach().cpu().numpy()[:,0])" + ] + }, + { + "cell_type": "code", + "execution_count": 121, + "id": "a9647567-57f9-4bde-8db7-2565939da253", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([20., 17., 13., 8., 10., 12., 9., 11., 15., 10., 15., 15., 18.,\n", + " 16., 24., 35., 30., 29., 32., 25., 36., 27., 38., 27., 32., 32.,\n", + " 24., 31., 33., 28., 28., 26., 27., 27., 24., 27., 17., 18., 10.,\n", + " 9., 12., 7., 11., 8., 15., 5., 14., 18., 13., 12.]),\n", + " array([-1.945116 , -1.8679345 , -1.7907529 , -1.7135713 , -1.6363897 ,\n", + " -1.5592082 , -1.4820266 , -1.404845 , -1.3276634 , -1.2504818 ,\n", + " -1.1733003 , -1.0961187 , -1.0189371 , -0.9417555 , -0.8645739 ,\n", + " -0.7873923 , -0.71021074, -0.63302916, -0.5558476 , -0.47866598,\n", + " -0.4014844 , -0.32430282, -0.24712123, -0.16993965, -0.09275807,\n", + " -0.01557648, 0.0616051 , 0.13878669, 0.21596827, 0.29314986,\n", + " 0.37033144, 0.447513 , 0.5246946 , 0.6018762 , 0.6790578 ,\n", + " 0.75623935, 0.83342093, 0.9106025 , 0.9877841 , 1.0649657 ,\n", + " 1.1421473 , 1.2193289 , 1.2965105 , 1.373692 , 1.4508736 ,\n", + " 1.5280552 , 1.6052368 , 1.6824183 , 1.7595999 , 1.8367815 ,\n", + " 1.9139631 ], dtype=float32),\n", + " )" + ] + }, + "execution_count": 121, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(output.detach().cpu().numpy()[:,0], bins = 50)" + ] + }, + { + "cell_type": "markdown", + "id": "5319a1e6-2724-45d7-ae33-e2c6798948a7", + "metadata": {}, + "source": [ + "## same with simple normalizing flow" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "5b5b786d-e36f-44c4-87cf-803cbad78805", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.distributions as D\n", + "import matplotlib.pyplot as plt\n", + "\n", + "class NormalizingFlow(nn.Module):\n", + " def __init__(self, input_dim, num_flows):\n", + " super(NormalizingFlow, self).__init__()\n", + " self.flows = nn.ModuleList([PlanarFlow(input_dim) for _ in range(num_flows)])\n", + "\n", + " def forward(self, z):\n", + " det_jacobian = torch.ones(z.size(0), 1)\n", + " for flow in self.flows:\n", + " z, jacobian = flow(z)\n", + " det_jacobian *= jacobian\n", + " return z, det_jacobian\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9deae1f2-d9bd-43f2-b79c-b119eabb57f2", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "class PlanarFlow(nn.Module):\n", + " def __init__(self, input_dim):\n", + " super(PlanarFlow, self).__init__()\n", + " self.input_dim = input_dim\n", + " self.weight = nn.Parameter(torch.randn(1, input_dim))\n", + " self.bias = nn.Parameter(torch.randn(1))\n", + " self.scale = nn.Parameter(torch.randn(1))\n", + "\n", + " def forward(self, z):\n", + " # Transformation function\n", + " z_flow = z + self.scale * torch.tanh(torch.mm(z, self.weight.t()) + self.bias)\n", + " # Absolute determinant of the Jacobian\n", + " psi = (1 - torch.tanh(torch.mm(z, self.weight.t()) + self.bias) ** 2) * self.weight\n", + " det_jacobian = torch.abs(1 + torch.mm(psi, self.weight.t()))\n", + " return z_flow, det_jacobian" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "318a11de-b859-41c8-8a5b-675d44c0795e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Define dimensions\n", + "input_dim = 1\n", + "\n", + "z_dim = input_dim\n", + "nepochs=100\n", + "bs=1000\n", + "num_flows=20\n", + "# Instantiate the PlanarFlow transformation\n", + "flow = NormalizingFlow(input_dim, num_flows)\n", + "base_distribution = D.Normal(torch.zeros(z_dim), torch.ones(z_dim))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "id": "f8d0d1d9-8693-4348-a089-87d3f84379e6", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 0 0.9581038951873779\n", + "epoch: 1 0.9572756886482239\n", + "epoch: 2 0.9593279361724854\n", + "epoch: 3 0.9568343758583069\n", + "epoch: 4 0.9572469592094421\n", + "epoch: 5 0.9553364515304565\n", + "epoch: 6 0.9575842022895813\n", + "epoch: 7 0.9560788869857788\n", + "epoch: 8 0.9587973356246948\n", + "epoch: 9 0.9580194354057312\n", + "epoch: 10 0.9589003920555115\n", + "epoch: 11 0.9569817781448364\n", + "epoch: 12 0.955956220626831\n", + "epoch: 13 0.958595871925354\n", + "epoch: 14 0.9574549198150635\n", + "epoch: 15 0.958065390586853\n", + "epoch: 16 0.9580419063568115\n", + "epoch: 17 0.9566380381584167\n", + "epoch: 18 0.957197904586792\n", + "epoch: 19 0.9565077424049377\n", + "epoch: 20 0.9570044875144958\n", + "epoch: 21 0.9565112590789795\n", + "epoch: 22 0.9574409127235413\n", + "epoch: 23 0.957895040512085\n", + "epoch: 24 0.9586047530174255\n", + "epoch: 25 0.9566912055015564\n", + "epoch: 26 0.958039402961731\n", + "epoch: 27 0.9593290686607361\n", + "epoch: 28 0.9575077891349792\n", + "epoch: 29 0.9567246437072754\n", + "epoch: 30 0.9582246541976929\n", + "epoch: 31 0.9556811451911926\n", + "epoch: 32 0.9571450352668762\n", + "epoch: 33 0.9564812183380127\n", + "epoch: 34 0.9580637216567993\n", + "epoch: 35 0.9575179815292358\n", + "epoch: 36 0.9590878486633301\n", + "epoch: 37 0.957419753074646\n", + "epoch: 38 0.9576168060302734\n", + "epoch: 39 0.9573911428451538\n", + "epoch: 40 0.9573774337768555\n", + "epoch: 41 0.9567323327064514\n", + "epoch: 42 0.959529459476471\n", + "epoch: 43 0.9581653475761414\n", + "epoch: 44 0.9583896398544312\n", + "epoch: 45 0.956044614315033\n", + "epoch: 46 0.9577047824859619\n", + "epoch: 47 0.958259642124176\n", + "epoch: 48 0.9573068022727966\n", + "epoch: 49 0.958161473274231\n", + "epoch: 50 0.957586407661438\n", + "epoch: 51 0.960275411605835\n", + "epoch: 52 0.9580078721046448\n", + "epoch: 53 0.9568662047386169\n", + "epoch: 54 0.9577730298042297\n", + "epoch: 55 0.9574154615402222\n", + "epoch: 56 0.9576438069343567\n", + "epoch: 57 0.9593215584754944\n", + "epoch: 58 0.9583775997161865\n", + "epoch: 59 0.9577406048774719\n", + "epoch: 60 0.9570046663284302\n", + "epoch: 61 0.9578263163566589\n", + "epoch: 62 0.9566665887832642\n", + "epoch: 63 0.9562719464302063\n", + "epoch: 64 0.9590163826942444\n", + "epoch: 65 0.9577448964118958\n", + "epoch: 66 0.9592074155807495\n", + "epoch: 67 0.9566142559051514\n", + "epoch: 68 0.9577342867851257\n", + "epoch: 69 0.9586719870567322\n", + "epoch: 70 0.9577504992485046\n", + "epoch: 71 0.9564600586891174\n", + "epoch: 72 0.9592216610908508\n", + "epoch: 73 0.9561935067176819\n", + "epoch: 74 0.9580640196800232\n", + "epoch: 75 0.9567023515701294\n", + "epoch: 76 0.9570690393447876\n", + "epoch: 77 0.9568561315536499\n", + "epoch: 78 0.9567307233810425\n", + "epoch: 79 0.9573178887367249\n", + "epoch: 80 0.9571306705474854\n", + "epoch: 81 0.9589489698410034\n", + "epoch: 82 0.957552433013916\n", + "epoch: 83 0.9554185271263123\n", + "epoch: 84 0.9571435451507568\n", + "epoch: 85 0.9579405784606934\n", + "epoch: 86 0.9575528502464294\n", + "epoch: 87 0.9572521448135376\n", + "epoch: 88 0.955950140953064\n", + "epoch: 89 0.9554963111877441\n", + "epoch: 90 0.9585280418395996\n", + "epoch: 91 0.9571262001991272\n", + "epoch: 92 0.9583413004875183\n", + "epoch: 93 0.9568629264831543\n", + "epoch: 94 0.9590401649475098\n", + "epoch: 95 0.9552928805351257\n", + "epoch: 96 0.9579176306724548\n", + "epoch: 97 0.9571013450622559\n", + "epoch: 98 0.9567534327507019\n", + "epoch: 99 0.9590332508087158\n" + ] + } + ], + "source": [ + "optimizer = optim.Adam(flow.parameters(), lr=1e-4, weight_decay=1e-4)\n", + "\n", + "for e in range(nepochs):\n", + " \n", + " \n", + " optimizer.zero_grad()\n", + " \n", + " random_input = torch.rand(bs, input_dim)\n", + " randomN_input = torch.randn(bs, input_dim)\n", + "\n", + " output, det_jacb = flow(random_input)\n", + " \n", + " loss = -base_distribution.log_prob(output) + torch.log(det_jacb) #+mmd(output,randomN_input)\n", + " loss = loss.mean()\n", + " loss.backward()\n", + " optimizer.step() \n", + " \n", + " print('epoch:',e, loss.item())\n", + " \n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "id": "3c8313d8-80f3-4f80-af2e-adee65f4b775", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([117., 90., 96., 93., 113., 87., 95., 104., 102., 103.]),\n", + " array([0.00128156, 0.10107996, 0.20087835, 0.30067676, 0.40047514,\n", + " 0.5002736 , 0.60007197, 0.69987035, 0.7996687 , 0.8994672 ,\n", + " 0.99926555], dtype=float32),\n", + " )" + ] + }, + "execution_count": 78, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(random_input.detach().cpu().numpy()[:,0])" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "f4bff66f-2790-45f3-bc22-5b5539c04403", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([25., 17., 29., 24., 22., 22., 21., 13., 13., 21., 24., 16., 22.,\n", + " 26., 8., 23., 16., 23., 16., 15., 18., 24., 24., 26., 21., 17.,\n", + " 17., 19., 17., 17., 26., 22., 14., 19., 14., 14., 22., 29., 16.,\n", + " 23., 22., 18., 20., 27., 15., 18., 16., 24., 21., 24.]),\n", + " array([-0.4638622 , -0.4453545 , -0.42684677, -0.40833908, -0.38983136,\n", + " -0.37132365, -0.35281593, -0.3343082 , -0.31580052, -0.2972928 ,\n", + " -0.27878508, -0.26027736, -0.24176966, -0.22326194, -0.20475423,\n", + " -0.18624651, -0.1677388 , -0.14923109, -0.13072337, -0.11221566,\n", + " -0.09370795, -0.07520024, -0.05669252, -0.03818481, -0.0196771 ,\n", + " -0.00116938, 0.01733833, 0.03584604, 0.05435375, 0.07286147,\n", + " 0.09136918, 0.10987689, 0.1283846 , 0.14689232, 0.16540003,\n", + " 0.18390775, 0.20241547, 0.22092317, 0.23943089, 0.2579386 ,\n", + " 0.2764463 , 0.29495403, 0.31346175, 0.33196944, 0.35047716,\n", + " 0.36898488, 0.3874926 , 0.40600032, 0.424508 , 0.44301572,\n", + " 0.46152344], dtype=float32),\n", + " )" + ] + }, + "execution_count": 79, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(output.detach().cpu().numpy()[:,0], bins = 50)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "826f3992-b2bd-4959-8282-a232a744bcaa", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "softplus = nn.Softplus()" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "32bcbe63-7665-4bad-9b12-8389cd3e773a", + "metadata": {}, + "outputs": [], + "source": [ + "class RadialFlow(nn.Module):\n", + " def __init__(self, input_dim):\n", + " super(RadialFlow, self).__init__()\n", + " self.input_dim = input_dim\n", + " self.alpha = nn.Parameter(torch.randn(input_dim))\n", + " self.beta = nn.Parameter(torch.randn(input_dim))\n", + " self.gamma = nn.Parameter(torch.randn(input_dim))\n", + "\n", + " def forward(self, z):\n", + " \n", + " \n", + " num = softplus(self.alpha) * (torch.exp(self.beta)-1) * (z - self.gamma) \n", + " den = softplus(self.alpha) + (z - self.gamma) \n", + " \n", + " z_flow = z + num / den\n", + " \n", + " r = torch.abs(z-self.gamma)\n", + " h = 1 / (softplus(self.alpha) + r)\n", + " \n", + " term1 = (1 + softplus(self.alpha) * (torch.exp(self.beta)-1) * h)**(self.input_dim -1)\n", + " term2 = 1 + softplus(self.alpha) * (torch.exp(self.beta)-1) * h - softplus(self.alpha) * (torch.exp(self.beta)-1) * r *h**2\n", + " \n", + " det_jacobian = term1*term2\n", + " \n", + " return z_flow, det_jacobian\n", + "\n", + "class NormalizingFlow(nn.Module):\n", + " def __init__(self, input_dim, num_flows):\n", + " super(NormalizingFlow, self).__init__()\n", + " self.flows = nn.ModuleList([RadialFlow(input_dim) for _ in range(num_flows)])\n", + "\n", + " def forward(self, z):\n", + " det_jacobian = torch.ones(z.size(0), 1).to(device)\n", + " for flow in self.flows:\n", + " z, jacobian = flow(z)\n", + " det_jacobian *= jacobian\n", + " return z, det_jacobian" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "id": "bf6d28e8-544a-4ab9-8174-2d3ec35d0891", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "def MMD(x, y, kernel):\n", + " \"\"\"Emprical maximum mean discrepancy. The lower the result\n", + " the more evidence that distributions are the same.\n", + "\n", + " Args:\n", + " x: first sample, distribution P\n", + " y: second sample, distribution Q\n", + " kernel: kernel type such as \"multiscale\" or \"rbf\"\n", + " \"\"\"\n", + " xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())\n", + " rx = (xx.diag().unsqueeze(0).expand_as(xx))\n", + " ry = (yy.diag().unsqueeze(0).expand_as(yy))\n", + "\n", + " dxx = rx.t() + rx - 2. * xx # Used for A in (1)\n", + " dyy = ry.t() + ry - 2. * yy # Used for B in (1)\n", + " dxy = rx.t() + ry - 2. * zz # Used for C in (1)\n", + "\n", + " XX, YY, XY = (torch.zeros(xx.shape).to(device),\n", + " torch.zeros(xx.shape).to(device),\n", + " torch.zeros(xx.shape).to(device))\n", + "\n", + " if kernel == \"multiscale\":\n", + "\n", + " bandwidth_range = [0.2, 0.5, 0.9, 1.3]\n", + " for a in bandwidth_range:\n", + " XX += a**2 * (a**2 + dxx)**-1\n", + " YY += a**2 * (a**2 + dyy)**-1\n", + " XY += a**2 * (a**2 + dxy)**-1\n", + "\n", + " if kernel == \"rbf\":\n", + "\n", + " bandwidth_range = [1, 1.5, 2.0, 5.0]\n", + " for a in bandwidth_range:\n", + " XX += torch.exp(-0.5*dxx/a)\n", + " YY += torch.exp(-0.5*dyy/a)\n", + " XY += torch.exp(-0.5*dxy/a)\n", + "\n", + "\n", + " return torch.mean(XX + YY - 2. * XY)" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "id": "478d8261-7e95-48fa-8b5c-6194ebaf4524", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "RadialFlow()" + ] + }, + "execution_count": 98, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "RadialFlow(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "c2f42513-ed84-4bb5-9097-f7b1b7862da9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Define dimensions\n", + "input_dim = 1\n", + "\n", + "z_dim = input_dim\n", + "nepochs=10000\n", + "\n", + "bs=1000\n", + "num_flows=100\n", + "# Instantiate the PlanarFlow transformation\n", + "flow = NormalizingFlow(input_dim, num_flows).to(device)\n", + "\n", + "base_distribution = D.Normal(torch.zeros(z_dim), torch.ones(z_dim))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "94928333-6043-4cbf-b247-b2904f0a9719", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "random_input = torch.rand(10000, input_dim)\n", + "dset = TensorDataset(random_input)\n", + "loader = DataLoader(dset, batch_size=1000, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "id": "18b1dcd2-7445-48af-bd62-46669d372ceb", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 0 28200.544921875\n", + "epoch: 1 1129.369384765625\n", + "epoch: 2 21.443578720092773\n", + "epoch: 3 112.31295776367188\n", + "epoch: 4 149.22096252441406\n", + "epoch: 5 33502.9375\n", + "epoch: 6 110.99540710449219\n", + "epoch: 7 271.1609191894531\n", + "epoch: 8 57.25232696533203\n", + "epoch: 9 73.51012420654297\n", + "epoch: 10 97.47554779052734\n", + "epoch: 11 123.52039337158203\n", + "epoch: 12 221.76556396484375\n", + "epoch: 13 225.6431427001953\n", + "epoch: 14 36.62788772583008\n", + "epoch: 15 788.5504150390625\n", + "epoch: 16 905.950927734375\n", + "epoch: 17 359.9490661621094\n", + "epoch: 18 94.8775863647461\n", + "epoch: 19 1562.370849609375\n", + "epoch: 20 68.43376159667969\n", + "epoch: 21 120.20706939697266\n", + "epoch: 22 117.14567565917969\n", + "epoch: 23 156.90737915039062\n", + "epoch: 24 64.27742767333984\n", + "epoch: 25 72.80873107910156\n", + "epoch: 26 299.83074951171875\n", + "epoch: 27 198.3571014404297\n", + "epoch: 28 67.56690979003906\n", + "epoch: 29 59.80242156982422\n", + "epoch: 30 152.69972229003906\n", + "epoch: 31 85.3028335571289\n", + "epoch: 32 2329.72314453125\n", + "epoch: 33 50.84280776977539\n", + "epoch: 34 25396.7109375\n", + "epoch: 35 1737.4556884765625\n", + "epoch: 36 101.53565216064453\n", + "epoch: 37 505.9686584472656\n", + "epoch: 38 972.81689453125\n", + "epoch: 39 48.345829010009766\n", + "epoch: 40 71052.6328125\n", + "epoch: 41 4206.85205078125\n", + "epoch: 42 168.01895141601562\n", + "epoch: 43 9705.7646484375\n", + "epoch: 44 337.1620788574219\n", + "epoch: 45 2976.151611328125\n", + "epoch: 46 2142248192.0\n", + "epoch: 47 18152.1484375\n", + "epoch: 48 26.833215713500977\n", + "epoch: 49 3471.46044921875\n", + "epoch: 50 583.5856323242188\n", + "epoch: 51 96.70631408691406\n", + "epoch: 52 170.61947631835938\n", + "epoch: 53 64.85731506347656\n", + "epoch: 54 914.636474609375\n", + "epoch: 55 153.5564727783203\n", + "epoch: 56 184.95700073242188\n", + "epoch: 57 45.27014923095703\n", + "epoch: 58 1541.3345947265625\n", + "epoch: 59 227.9755859375\n", + "epoch: 60 144.11129760742188\n", + "epoch: 61 55.70856475830078\n", + "epoch: 62 2118.783447265625\n", + "epoch: 63 161.55926513671875\n", + "epoch: 64 31.37447738647461\n", + "epoch: 65 1794.9423828125\n", + "epoch: 66 52.8399543762207\n", + "epoch: 67 51.3333854675293\n", + "epoch: 68 40.058067321777344\n", + "epoch: 69 51.833065032958984\n", + "epoch: 70 1669.7928466796875\n", + "epoch: 71 52.97251892089844\n", + "epoch: 72 56.90239715576172\n", + "epoch: 73 27612.794921875\n", + "epoch: 74 184.5417938232422\n", + "epoch: 75 70.6549072265625\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[95], line 14\u001b[0m\n\u001b[1;32m 12\u001b[0m loss \u001b[38;5;241m=\u001b[39m MMD(output,randomN_input, kernel\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mrbf\u001b[39m\u001b[38;5;124m'\u001b[39m)\u001b[38;5;241m-\u001b[39mbase_distribution\u001b[38;5;241m.\u001b[39mlog_prob(output\u001b[38;5;241m.\u001b[39mcpu())\u001b[38;5;241m.\u001b[39mto(device) \u001b[38;5;241m+\u001b[39m torch\u001b[38;5;241m.\u001b[39mlog(det_jacb) \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[1;32m 13\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mmean()\n\u001b[0;32m---> 14\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 15\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep() \n\u001b[1;32m 17\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mepoch:\u001b[39m\u001b[38;5;124m'\u001b[39m,e, loss\u001b[38;5;241m.\u001b[39mitem())\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/DLenv2/lib/python3.9/site-packages/torch/_tensor.py:363\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 355\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 356\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 357\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 361\u001b[0m create_graph\u001b[38;5;241m=\u001b[39mcreate_graph,\n\u001b[1;32m 362\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs)\n\u001b[0;32m--> 363\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/DLenv2/lib/python3.9/site-packages/torch/autograd/__init__.py:173\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 168\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 170\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 173\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 174\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "optimizer = optim.Adam(flow.parameters(), lr=1e-3, weight_decay=1e-4)\n", + "\n", + "for e in range(nepochs):\n", + " for x in loader:\n", + " #assert False\n", + " \n", + " optimizer.zero_grad()\n", + " randomN_input = torch.randn(len(x[0]), input_dim).to(device)\n", + "\n", + " output, det_jacb = flow(x[0].to(device))\n", + " \n", + " loss = MMD(output,randomN_input, kernel='rbf')-base_distribution.log_prob(output.cpu()).to(device) + torch.log(det_jacb) #\n", + " loss = loss.mean()\n", + " loss.backward()\n", + " optimizer.step() \n", + " \n", + " print('epoch:',e, loss.item())\n", + " \n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "id": "75356e3f-a4a8-4da3-98a4-1574c7bc3f68", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0., 0., 0., 0., 0., 1., 0., 0., 3., 1., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 2., 4., 2.,\n", + " 0., 0., 0., 112., 59., 70., 72., 74., 71., 99., 111.,\n", + " 127., 71., 10., 9., 11., 17., 6., 6., 4., 5., 1.,\n", + " 1., 4., 4., 7., 2., 3.]),\n", + " array([-10. , -9.6, -9.2, -8.8, -8.4, -8. , -7.6, -7.2, -6.8,\n", + " -6.4, -6. , -5.6, -5.2, -4.8, -4.4, -4. , -3.6, -3.2,\n", + " -2.8, -2.4, -2. , -1.6, -1.2, -0.8, -0.4, 0. , 0.4,\n", + " 0.8, 1.2, 1.6, 2. , 2.4, 2.8, 3.2, 3.6, 4. ,\n", + " 4.4, 4.8, 5.2, 5.6, 6. , 6.4, 6.8, 7.2, 7.6,\n", + " 8. , 8.4, 8.8, 9.2, 9.6, 10. ], dtype=float32),\n", + " )" + ] + }, + "execution_count": 97, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(randomN_input.detach().cpu().numpy()[:,0], bins = 50, range=(-10,10))\n", + "plt.hist(output.detach().cpu().numpy()[:,0], bins = 50, range=(-10,10))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 280, + "id": "72b15ab5-9031-444e-886d-b9f47859d01e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "def MMD(x, y, kernel):\n", + " \"\"\"Emprical maximum mean discrepancy. The lower the result\n", + " the more evidence that distributions are the same.\n", + "\n", + " Args:\n", + " x: first sample, distribution P\n", + " y: second sample, distribution Q\n", + " kernel: kernel type such as \"multiscale\" or \"rbf\"\n", + " \"\"\"\n", + " xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())\n", + " rx = (xx.diag().unsqueeze(0).expand_as(xx))\n", + " ry = (yy.diag().unsqueeze(0).expand_as(yy))\n", + "\n", + " dxx = rx.t() + rx - 2. * xx # Used for A in (1)\n", + " dyy = ry.t() + ry - 2. * yy # Used for B in (1)\n", + " dxy = rx.t() + ry - 2. * zz # Used for C in (1)\n", + "\n", + " XX, YY, XY = (torch.zeros(xx.shape).to(device),\n", + " torch.zeros(xx.shape).to(device),\n", + " torch.zeros(xx.shape).to(device))\n", + "\n", + " if kernel == \"multiscale\":\n", + "\n", + " bandwidth_range = [0.2, 0.5, 0.9, 1.3]\n", + " for a in bandwidth_range:\n", + " XX += a**2 * (a**2 + dxx)**-1\n", + " YY += a**2 * (a**2 + dyy)**-1\n", + " XY += a**2 * (a**2 + dxy)**-1\n", + "\n", + " if kernel == \"rbf\":\n", + "\n", + " bandwidth_range = [10, 15, 20, 50]\n", + " for a in bandwidth_range:\n", + " XX += torch.exp(-0.5*dxx/a)\n", + " YY += torch.exp(-0.5*dyy/a)\n", + " XY += torch.exp(-0.5*dxy/a)\n", + "\n", + "\n", + " return torch.mean(XX + YY - 2. * XY)" + ] + }, + { + "cell_type": "code", + "execution_count": 281, + "id": "7c80b247-3722-4f3d-ae88-d4ee8b7b2d55", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + " x = torch.rand(bs, input_dim)\n", + " y = torch.randn(bs, input_dim)" + ] + }, + { + "cell_type": "code", + "execution_count": 283, + "id": "6df76ddf-876c-4eca-96c4-ee5dee57443e", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.0637, device='cuda:0')" + ] + }, + "execution_count": 283, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "MMD(x.cuda(), y.cuda(), kernel='rbf')" + ] + }, + { + "cell_type": "markdown", + "id": "a93e53ef-f6d6-4bb9-919c-2173e2634ac2", + "metadata": {}, + "source": [ + "## WITH A VARIATIONAL AUTOENCODER" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6a6b88d-e9ea-4ee1-ad8e-bfe81442864e", + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "import torch\n", + "class CondVAE(nn.Module):\n", + " def __init__(self, dim_input, latent_dim=10, size=[150,150]):\n", + " super(CondVAE, self).__init__()\n", + " \n", + " self.latent_dim = latent_dim\n", + "\n", + " # Encoder\n", + " self.encoder = nn.Sequential(\n", + " nn.Linear(in_features=dim_input, out_features=size[0]*size[1]),\n", + " nn.Unflatten(1, (1, size[0], size[1], 60)),\n", + " nn.ReLU(),\n", + " nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1), \n", + " nn.ReLU(),\n", + " nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),\n", + " nn.ReLU(),\n", + " nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),\n", + " nn.ReLU(),\n", + " nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),\n", + " nn.ReLU(),\n", + " nn.Flatten()\n", + " )\n", + " \n", + " self.fc_mu = nn.Linear(16384, latent_dim)\n", + " self.fc_logvar = nn.Linear(16384, latent_dim)\n", + "\n", + " # Decoder\n", + " self.decoder = nn.Sequential(\n", + " nn.Linear(latent_dim + 8, 16384), \n", + " nn.Unflatten(1, (256, 8, 8)),\n", + " nn.ReLU(),\n", + " nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1),\n", + " nn.ReLU(),\n", + " nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1),\n", + " nn.ReLU(),\n", + " nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=0),\n", + " nn.ReLU(),\n", + " nn.ConvTranspose2d(32, 1, kernel_size=2, stride=1, padding=0),\n", + " nn.Sigmoid()\n", + " )\n", + "\n", + " def encode(self, x):\n", + " x = self.encoder(x)\n", + " mu = self.fc_mu(x)\n", + " log_var = self.fc_logvar(x)\n", + "\n", + " return mu, log_var\n", + " \n", + " def decode(self, z, properties):\n", + " # Concatenate the sampling (latent distribution) + embedding -> samples conditioned on both the input data and the specified label\n", + " #print(properties.shape, z.shape)\n", + " zcomb = torch.concat((z, properties), 1)\n", + " #print(zcomb.shape)\n", + " \n", + " return self.decoder(zcomb) \n", + " \n", + " def sampling(self, mu, log_var):\n", + " # calculate standard deviation\n", + " std = log_var.mul(0.5).exp_()\n", + " \n", + " # create noise tensor of same size as std to add to the latent vector\n", + " eps = torch.cuda.FloatTensor(std.size()).normal_()\n", + " \n", + " # multiply eps with std to scale the random noise according to the learned distribution + add combined\n", + " return eps.mul(std).add_(mu) # return z sample \n", + "\n", + " def forward(self, x, properties):\n", + " mu, log_var = self.encode(x)\n", + " z = self.sampling(mu, log_var)\n", + " #print(z.shape)\n", + "\n", + " return self.decode(z, properties), mu, log_var\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "DLenv2", + "language": "python", + "name": "dlenv2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}