diff --git "a/notebooks/toy_test.ipynb" "b/notebooks/toy_test.ipynb" new file mode 100644--- /dev/null +++ "b/notebooks/toy_test.ipynb" @@ -0,0 +1,2354 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "b89d9358-ef03-4b85-82c3-7eb6366cc7cc", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.distributions as D\n", + "\n", + "from torch.utils.data import DataLoader, dataset, TensorDataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9e29e664-80af-4102-a465-77b73e84f867", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import sys\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "sys.path.append('../insight')\n", + "from archive import archive \n", + "\n", + "import torch.nn.functional as F" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ac8eedcf-ec15-4b95-8977-c3affc94f4de", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import torch.nn.functional as F\n", + "\n", + "class RandomToNormalNN(nn.Module):\n", + " def __init__(self, input_dim, hidden_dim, output_dim, num_hidden_layers):\n", + " super(RandomToNormalNN, self).__init__()\n", + " self.input_layer = nn.Linear(input_dim, hidden_dim)\n", + " self.hidden_layers = nn.ModuleList([\n", + " nn.Linear(hidden_dim, hidden_dim) for _ in range(num_hidden_layers)\n", + " ])\n", + " self.output_layer = nn.Linear(hidden_dim, output_dim)\n", + " self.activation = nn.ReLU() # You can experiment with other activation functions\n", + " \n", + " def forward(self, x):\n", + " x = self.activation(self.input_layer(x))\n", + " for hidden_layer in self.hidden_layers:\n", + " x = self.activation(hidden_layer(x))\n", + " x = self.output_layer(x)\n", + " return x\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "617194de-d0d0-4d44-86ef-75185d1b5c0f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "input_dim = 1\n", + "hidden_dim = 128\n", + "output_dim = 1\n", + "num_hidden_layers = 3\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "\n", + "model = RandomToNormalNN(input_dim, hidden_dim, output_dim, num_hidden_layers).to(device)\n", + "\n", + "nepochs=100\n", + "bs = 1000\n", + "z_dim=1\n", + "epsilon = 0.5\n", + "\n", + "# Create an instance of the neural network\n", + "base_distribution = D.Normal(torch.zeros(z_dim), torch.ones(z_dim))\n", + "\n", + "optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "706bf267-c870-4fe6-a1de-758aac722678", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "random_input = torch.rand(10000, input_dim)\n", + "dset = TensorDataset(random_input)\n", + "loader = DataLoader(dset, batch_size=100, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "825e88d5-c4b2-493c-8597-3c5859b8bf8a", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 0 0.918938398361206\n", + "epoch 1 0.918938398361206\n", + "epoch 2 0.918938398361206\n", + "epoch 3 0.918938398361206\n", + "epoch 4 0.918938398361206\n", + "epoch 5 0.918938398361206\n", + "epoch 6 0.918938398361206\n", + "epoch 7 0.918938398361206\n", + "epoch 8 0.918938398361206\n", + "epoch 9 0.918938398361206\n", + "epoch 10 0.918938398361206\n", + "epoch 11 0.918938398361206\n", + "epoch 12 0.918938398361206\n", + "epoch 13 0.918938398361206\n", + "epoch 14 0.918938398361206\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[42], line 13\u001b[0m\n\u001b[1;32m 11\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39mloss\u001b[38;5;241m.\u001b[39mmean()\n\u001b[1;32m 12\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[0;32m---> 13\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m \n\u001b[1;32m 15\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mepoch \u001b[39m\u001b[38;5;124m'\u001b[39m,e, loss\u001b[38;5;241m.\u001b[39mitem())\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/DLenv2/lib/python3.9/site-packages/torch/optim/optimizer.py:88\u001b[0m, in \u001b[0;36mOptimizer._hook_for_profile..profile_hook_step..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 86\u001b[0m profile_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOptimizer.step#\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m.step\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(obj\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m)\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mrecord_function(profile_name):\n\u001b[0;32m---> 88\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/DLenv2/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27\u001b[0m, in \u001b[0;36m_DecoratorContextManager.__call__..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclone():\n\u001b[0;32m---> 27\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/DLenv2/lib/python3.9/site-packages/torch/optim/adam.py:141\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;66;03m# record the step after step update\u001b[39;00m\n\u001b[1;32m 139\u001b[0m state_steps\u001b[38;5;241m.\u001b[39mappend(state[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstep\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[0;32m--> 141\u001b[0m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madam\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams_with_grad\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 142\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 143\u001b[0m \u001b[43m \u001b[49m\u001b[43mexp_avgs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 144\u001b[0m \u001b[43m \u001b[49m\u001b[43mexp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_exp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 146\u001b[0m \u001b[43m \u001b[49m\u001b[43mstate_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 147\u001b[0m \u001b[43m \u001b[49m\u001b[43mamsgrad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mamsgrad\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 148\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta1\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 149\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta2\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 150\u001b[0m \u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mlr\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[43m \u001b[49m\u001b[43mweight_decay\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mweight_decay\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 152\u001b[0m \u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43meps\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 153\u001b[0m \u001b[43m \u001b[49m\u001b[43mmaximize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmaximize\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/DLenv2/lib/python3.9/site-packages/torch/optim/_functional.py:94\u001b[0m, in \u001b[0;36madam\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize)\u001b[0m\n\u001b[1;32m 91\u001b[0m bias_correction2 \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m beta2 \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m step\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m weight_decay \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m---> 94\u001b[0m grad \u001b[38;5;241m=\u001b[39m \u001b[43mgrad\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparam\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mweight_decay\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 96\u001b[0m \u001b[38;5;66;03m# Decay the first and second moment running average coefficient\u001b[39;00m\n\u001b[1;32m 97\u001b[0m exp_avg\u001b[38;5;241m.\u001b[39mmul_(beta1)\u001b[38;5;241m.\u001b[39madd_(grad, alpha\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m beta1)\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "base_distribution = D.Normal(torch.zeros(z_dim), torch.ones(z_dim))\n", + "\n", + "for e in range(nepochs):\n", + " for x in loader:\n", + " \n", + " optimizer.zero_grad()\n", + " \n", + " output = model(x[0].unsqueeze(1).to(device))\n", + " \n", + " loss = base_distribution.log_prob(output.cpu()).to(device)\n", + " loss = -loss.mean()\n", + " loss.backward()\n", + " optimizer.step() \n", + " \n", + " print('epoch ',e, loss.item())\n", + "\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "40b99ce9-7bb8-4f8c-b218-339ddafd0747", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.9507, grad_fn=)" + ] + }, + "execution_count": 79, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loss" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "id": "57deecda-8bd5-41d1-a176-359f130df5b7", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[0.9970],\n", + " [0.9916],\n", + " [1.0080],\n", + " [1.0096],\n", + " [0.9892],\n", + " [1.0066],\n", + " [0.9949],\n", + " [1.0210],\n", + " [1.0041],\n", + " [0.9981],\n", + " [1.0033],\n", + " [0.9885],\n", + " [1.0145],\n", + " [1.0075],\n", + " [1.0223],\n", + " [0.9913],\n", + " [1.0128],\n", + " [0.9889],\n", + " [1.0095],\n", + " [0.9927],\n", + " [0.9925],\n", + " [1.0221],\n", + " [1.0108],\n", + " [1.0112],\n", + " [1.0038],\n", + " [0.9891],\n", + " [0.9887],\n", + " [0.9949],\n", + " [0.9885],\n", + " [1.0030],\n", + " [1.0090],\n", + " [1.0156],\n", + " [0.9911],\n", + " [1.0166],\n", + " [0.9888],\n", + " [1.0066],\n", + " [1.0073],\n", + " [1.0067],\n", + " [0.9943],\n", + " [1.0096],\n", + " [0.9915],\n", + " [1.0128],\n", + " [0.9921],\n", + " [1.0151],\n", + " [1.0113],\n", + " [1.0126],\n", + " [1.0006],\n", + " [1.0071],\n", + " [1.0146],\n", + " [0.9968],\n", + " [1.0029],\n", + " [1.0010],\n", + " [1.0244],\n", + " [0.9951],\n", + " [0.9949],\n", + " [1.0102],\n", + " [0.9993],\n", + " [1.0032],\n", + " [0.9935],\n", + " [0.9942],\n", + " [1.0088],\n", + " [1.0008],\n", + " [1.0117],\n", + " [0.9902],\n", + " [1.0163],\n", + " [0.9884],\n", + " [1.0151],\n", + " [0.9895],\n", + " [0.9914],\n", + " [0.9977],\n", + " [0.9885],\n", + " [0.9928],\n", + " [0.9920],\n", + " [1.0260],\n", + " [1.0192],\n", + " [1.0039],\n", + " [0.9887],\n", + " [1.0172],\n", + " [1.0190],\n", + " [1.0116],\n", + " [1.0196],\n", + " [1.0089],\n", + " [0.9887],\n", + " [0.9902],\n", + " [0.9925],\n", + " [1.0005],\n", + " [0.9924],\n", + " [1.0215],\n", + " [1.0221],\n", + " [0.9885],\n", + " [1.0191],\n", + " [1.0151],\n", + " [1.0299],\n", + " [0.9932],\n", + " [0.9931],\n", + " [0.9895],\n", + " [1.0130],\n", + " [1.0238],\n", + " [1.0094],\n", + " [1.0108],\n", + " [0.9907],\n", + " [0.9910],\n", + " [0.9937],\n", + " [0.9900],\n", + " [0.9992],\n", + " [1.0072],\n", + " [0.9916],\n", + " [1.0194],\n", + " [1.0025],\n", + " [1.0133],\n", + " [1.0039],\n", + " [1.0015],\n", + " [1.0068],\n", + " [1.0010],\n", + " [1.0045],\n", + " [0.9930],\n", + " [1.0242],\n", + " [1.0062],\n", + " [1.0292],\n", + " [0.9945],\n", + " [0.9887],\n", + " [1.0163],\n", + " [1.0137],\n", + " [1.0147],\n", + " [1.0181],\n", + " [1.0111],\n", + " [1.0234],\n", + " [1.0001],\n", + " [1.0123],\n", + " [1.0025],\n", + " [1.0149],\n", + " [1.0201],\n", + " [1.0006],\n", + " [1.0049],\n", + " [0.9900],\n", + " [1.0222],\n", + " [1.0141],\n", + " [1.0128],\n", + " [1.0269],\n", + " [0.9897],\n", + " [1.0142],\n", + " [1.0018],\n", + " [1.0134],\n", + " [1.0003],\n", + " [0.9976],\n", + " [1.0122],\n", + " [1.0133],\n", + " [1.0175],\n", + " [1.0090],\n", + " [0.9900],\n", + " [0.9914],\n", + " [0.9954],\n", + " [1.0111],\n", + " [0.9926],\n", + " [1.0107],\n", + " [1.0039],\n", + " [1.0162],\n", + " [1.0009],\n", + " [0.9893],\n", + " [0.9886],\n", + " [1.0123],\n", + " [0.9892],\n", + " [1.0123],\n", + " [0.9942],\n", + " [1.0280],\n", + " [1.0013],\n", + " [1.0162],\n", + " [1.0118],\n", + " [0.9911],\n", + " [1.0116],\n", + " [1.0190],\n", + " [1.0093],\n", + " [1.0160],\n", + " [1.0304],\n", + " [0.9998],\n", + " [1.0076],\n", + " [0.9934],\n", + " [0.9984],\n", + " [1.0266],\n", + " [1.0161],\n", + " [1.0011],\n", + " [0.9920],\n", + " [0.9930],\n", + " [0.9954],\n", + " [1.0053],\n", + " [0.9911],\n", + " [1.0089],\n", + " [1.0089],\n", + " [1.0047],\n", + " [1.0155],\n", + " [1.0171],\n", + " [1.0071],\n", + " [1.0279],\n", + " [1.0241],\n", + " [1.0112],\n", + " [1.0052],\n", + " [0.9922],\n", + " [0.9921],\n", + " [0.9902],\n", + " [0.9894],\n", + " [1.0130],\n", + " [1.0140],\n", + " [1.0005],\n", + " [1.0161],\n", + " [1.0242],\n", + " [0.9938],\n", + " [1.0036],\n", + " [0.9885],\n", + " [1.0012],\n", + " [0.9917],\n", + " [0.9964],\n", + " [1.0080],\n", + " [0.9927],\n", + " [1.0123],\n", + " [0.9885],\n", + " [0.9915],\n", + " [0.9894],\n", + " [1.0004],\n", + " [1.0148],\n", + " [0.9885],\n", + " [0.9894],\n", + " [1.0306],\n", + " [0.9989],\n", + " [1.0092],\n", + " [1.0174],\n", + " [1.0064],\n", + " [1.0056],\n", + " [0.9907],\n", + " [1.0120],\n", + " [1.0086],\n", + " [0.9919],\n", + " [0.9958],\n", + " [0.9935],\n", + " [0.9890],\n", + " [1.0046],\n", + " [1.0200],\n", + " [1.0168],\n", + " [1.0193],\n", + " [1.0082],\n", + " [1.0081],\n", + " [1.0119],\n", + " [0.9885],\n", + " [1.0064],\n", + " [1.0245],\n", + " [1.0090],\n", + " [0.9966],\n", + " [1.0105],\n", + " [1.0213],\n", + " [1.0094],\n", + " [0.9895],\n", + " [1.0258],\n", + " [1.0028],\n", + " [1.0096],\n", + " [1.0165],\n", + " [1.0139],\n", + " [0.9919],\n", + " [1.0203],\n", + " [0.9902],\n", + " [1.0255],\n", + " [1.0126],\n", + " [1.0119],\n", + " [1.0091],\n", + " [1.0109],\n", + " [1.0093],\n", + " [1.0092],\n", + " [1.0114],\n", + " [0.9976],\n", + " [1.0145],\n", + " [1.0036],\n", + " [1.0166],\n", + " [1.0061],\n", + " [0.9901],\n", + " [1.0007],\n", + " [0.9888],\n", + " [0.9919],\n", + " [0.9905],\n", + " [1.0048],\n", + " [1.0268],\n", + " [0.9986],\n", + " [1.0288],\n", + " [1.0049],\n", + " [1.0252],\n", + " [0.9943],\n", + " [1.0019],\n", + " [0.9895],\n", + " [0.9956],\n", + " [1.0120],\n", + " [0.9917],\n", + " [0.9929],\n", + " [0.9885],\n", + " [1.0199],\n", + " [1.0069],\n", + " [1.0103],\n", + " [0.9887],\n", + " [1.0091],\n", + " [1.0149],\n", + " [1.0025],\n", + " [1.0291],\n", + " [0.9962],\n", + " [0.9968],\n", + " [1.0026],\n", + " [1.0129],\n", + " [0.9952],\n", + " [0.9994],\n", + " [0.9955],\n", + " [0.9959],\n", + " [0.9887],\n", + " [1.0178],\n", + " [1.0053],\n", + " [1.0167],\n", + " [1.0225],\n", + " [0.9982],\n", + " [0.9903],\n", + " [1.0007],\n", + " [1.0106],\n", + " [1.0052],\n", + " [1.0106],\n", + " [1.0097],\n", + " [1.0013],\n", + " [1.0160],\n", + " [0.9967],\n", + " [0.9931],\n", + " [1.0100],\n", + " [0.9893],\n", + " [1.0076],\n", + " [0.9893],\n", + " [1.0178],\n", + " [1.0060],\n", + " [0.9914],\n", + " [1.0051],\n", + " [1.0088],\n", + " [1.0092],\n", + " [1.0132],\n", + " [0.9895],\n", + " [1.0106],\n", + " [0.9893],\n", + " [1.0037],\n", + " [1.0116],\n", + " [1.0253],\n", + " [0.9914],\n", + " [1.0140],\n", + " [1.0188],\n", + " [1.0078],\n", + " [1.0077],\n", + " [1.0144],\n", + " [1.0066],\n", + " [0.9885],\n", + " [0.9998],\n", + " [1.0006],\n", + " [1.0017],\n", + " [1.0215],\n", + " [1.0127],\n", + " [1.0079],\n", + " [1.0105],\n", + " [0.9945],\n", + " [1.0002],\n", + " [1.0019],\n", + " [1.0253],\n", + " [0.9968],\n", + " [1.0103],\n", + " [1.0172],\n", + " [0.9893],\n", + " [1.0151],\n", + " [1.0161],\n", + " [0.9886],\n", + " [1.0117],\n", + " [0.9979],\n", + " [1.0238],\n", + " [1.0024],\n", + " [1.0097],\n", + " [0.9991],\n", + " [1.0186],\n", + " [1.0243],\n", + " [0.9965],\n", + " [0.9887],\n", + " [0.9888],\n", + " [1.0137],\n", + " [1.0066],\n", + " [0.9918],\n", + " [0.9967],\n", + " [1.0045],\n", + " [1.0295],\n", + " [1.0081],\n", + " [1.0296],\n", + " [0.9916],\n", + " [0.9885],\n", + " [1.0191],\n", + " [1.0212],\n", + " [1.0208],\n", + " [1.0129],\n", + " [1.0012],\n", + " [1.0065],\n", + " [1.0064],\n", + " [0.9978],\n", + " [0.9956],\n", + " [1.0156],\n", + " [1.0086],\n", + " [1.0067],\n", + " [1.0005],\n", + " [1.0160],\n", + " [1.0033],\n", + " [1.0070],\n", + " [1.0082],\n", + " [0.9993],\n", + " [1.0124],\n", + " [1.0223],\n", + " [1.0082],\n", + " [1.0116],\n", + " [1.0153],\n", + " [1.0167],\n", + " [1.0009],\n", + " [0.9939],\n", + " [0.9998],\n", + " [1.0113],\n", + " [1.0157],\n", + " [0.9887],\n", + " [1.0076],\n", + " [1.0024],\n", + " [0.9892],\n", + " [1.0115],\n", + " [1.0050],\n", + " [0.9885],\n", + " [0.9976],\n", + " [0.9901],\n", + " [0.9986],\n", + " [0.9953],\n", + " [0.9902],\n", + " [0.9895],\n", + " [1.0190],\n", + " [0.9924],\n", + " [0.9891],\n", + " [1.0197],\n", + " [0.9910],\n", + " [1.0096],\n", + " [1.0097],\n", + " [0.9893],\n", + " [1.0137],\n", + " [1.0060],\n", + " [1.0186],\n", + " [1.0189],\n", + " [1.0228],\n", + " [0.9887],\n", + " [1.0130],\n", + " [1.0138],\n", + " [0.9885],\n", + " [0.9901],\n", + " [0.9935],\n", + " [0.9937],\n", + " [0.9965],\n", + " [0.9899],\n", + " [0.9907],\n", + " [0.9891],\n", + " [1.0151],\n", + " [0.9926],\n", + " [0.9908],\n", + " [1.0120],\n", + " [1.0169],\n", + " [1.0109],\n", + " [1.0011],\n", + " [0.9950],\n", + " [1.0035],\n", + " [1.0086],\n", + " [0.9914],\n", + " [1.0056],\n", + " [1.0168],\n", + " [0.9897],\n", + " [0.9944],\n", + " [1.0207],\n", + " [1.0028],\n", + " [1.0000],\n", + " [1.0156],\n", + " [0.9914],\n", + " [1.0191],\n", + " [0.9995],\n", + " [1.0261],\n", + " [0.9915],\n", + " [1.0086],\n", + " [0.9916],\n", + " [0.9941],\n", + " [0.9896],\n", + " [0.9981],\n", + " [0.9934],\n", + " [1.0134],\n", + " [0.9928],\n", + " [0.9889],\n", + " [1.0098],\n", + " [0.9937],\n", + " [1.0140],\n", + " [1.0163],\n", + " [0.9960],\n", + " [0.9885],\n", + " [1.0008],\n", + " [1.0007],\n", + " [1.0204],\n", + " [0.9950],\n", + " [1.0292],\n", + " [1.0106],\n", + " [0.9886],\n", + " [0.9894],\n", + " [1.0085],\n", + " [0.9983],\n", + " [1.0120],\n", + " [1.0156],\n", + " [1.0171],\n", + " [0.9919],\n", + " [1.0134],\n", + " [1.0044],\n", + " [0.9990],\n", + " [1.0187],\n", + " [1.0171],\n", + " [0.9893],\n", + " [1.0013],\n", + " [0.9967],\n", + " [1.0267],\n", + " [0.9908],\n", + " [1.0025],\n", + " [1.0145],\n", + " [0.9886],\n", + " [1.0157],\n", + " [1.0136],\n", + " [1.0071],\n", + " [0.9936],\n", + " [1.0282],\n", + " [1.0228],\n", + " [0.9944],\n", + " [0.9901],\n", + " [1.0005],\n", + " [0.9962],\n", + " [0.9953],\n", + " [1.0040],\n", + " [0.9914],\n", + " [1.0087],\n", + " [1.0017],\n", + " [1.0055],\n", + " [0.9959],\n", + " [1.0153],\n", + " [0.9897],\n", + " [1.0072],\n", + " [0.9927],\n", + " [0.9885],\n", + " [1.0189],\n", + " [1.0147],\n", + " [1.0007],\n", + " [1.0035],\n", + " [1.0130],\n", + " [0.9953],\n", + " [0.9886],\n", + " [1.0103],\n", + " [1.0103],\n", + " [0.9921],\n", + " [0.9976],\n", + " [0.9971],\n", + " [1.0137],\n", + " [0.9908],\n", + " [1.0159],\n", + " [0.9933],\n", + " [0.9981],\n", + " [0.9887],\n", + " [1.0149],\n", + " [1.0147],\n", + " [1.0048],\n", + " [0.9892],\n", + " [1.0303],\n", + " [1.0129],\n", + " [1.0152],\n", + " [1.0134],\n", + " [1.0163],\n", + " [1.0111],\n", + " [0.9886],\n", + " [1.0111],\n", + " [1.0257],\n", + " [0.9974],\n", + " [1.0171],\n", + " [0.9921],\n", + " [1.0065],\n", + " [1.0025],\n", + " [1.0162],\n", + " [0.9992],\n", + " [0.9919],\n", + " [0.9920],\n", + " [0.9899],\n", + " [1.0135],\n", + " [0.9956],\n", + " [0.9951],\n", + " [1.0102],\n", + " [1.0075],\n", + " [1.0032],\n", + " [0.9915],\n", + " [0.9919],\n", + " [1.0170],\n", + " [0.9886],\n", + " [0.9920],\n", + " [0.9980],\n", + " [0.9950],\n", + " [1.0181],\n", + " [1.0089],\n", + " [1.0128],\n", + " [0.9886],\n", + " [1.0158],\n", + " [1.0190],\n", + " [1.0089],\n", + " [0.9886],\n", + " [0.9902],\n", + " [0.9901],\n", + " [1.0018],\n", + " [0.9928],\n", + " [0.9989],\n", + " [1.0079],\n", + " [1.0158],\n", + " [0.9933],\n", + " [1.0089],\n", + " [1.0011],\n", + " [0.9892],\n", + " [0.9938],\n", + " [1.0001],\n", + " [0.9907],\n", + " [0.9970],\n", + " [1.0192],\n", + " [1.0078],\n", + " [1.0137],\n", + " [1.0233],\n", + " [1.0036],\n", + " [0.9887],\n", + " [1.0061],\n", + " [1.0099],\n", + " [1.0042],\n", + " [0.9983],\n", + " [1.0021],\n", + " [1.0110],\n", + " [0.9984],\n", + " [1.0205],\n", + " [1.0062],\n", + " [1.0078],\n", + " [1.0014],\n", + " [0.9927],\n", + " [0.9895],\n", + " [1.0151],\n", + " [1.0113],\n", + " [0.9987],\n", + " [1.0060],\n", + " [1.0258],\n", + " [1.0050],\n", + " [1.0216],\n", + " [1.0195],\n", + " [1.0120],\n", + " [1.0046],\n", + " [1.0130],\n", + " [1.0131],\n", + " [0.9925],\n", + " [1.0105],\n", + " [0.9971],\n", + " [1.0256],\n", + " [0.9927],\n", + " [1.0154],\n", + " [1.0036],\n", + " [1.0176],\n", + " [0.9885],\n", + " [1.0279],\n", + " [0.9911],\n", + " [0.9931],\n", + " [1.0144],\n", + " [1.0088],\n", + " [0.9966],\n", + " [1.0016],\n", + " [1.0166],\n", + " [1.0174],\n", + " [0.9975],\n", + " [0.9918],\n", + " [0.9915],\n", + " [1.0046],\n", + " [0.9886],\n", + " [1.0147],\n", + " [1.0001],\n", + " [1.0019],\n", + " [0.9955],\n", + " [0.9970],\n", + " [0.9918],\n", + " [0.9957],\n", + " [0.9893],\n", + " [1.0132],\n", + " [1.0154],\n", + " [1.0027],\n", + " [1.0146],\n", + " [1.0203],\n", + " [1.0052],\n", + " [1.0010],\n", + " [1.0160],\n", + " [1.0121],\n", + " [0.9893],\n", + " [0.9886],\n", + " [0.9976],\n", + " [1.0003],\n", + " [0.9894],\n", + " [1.0237],\n", + " [1.0106],\n", + " [0.9987],\n", + " [0.9931],\n", + " [1.0118],\n", + " [0.9886],\n", + " [1.0004],\n", + " [0.9902],\n", + " [1.0198],\n", + " [0.9919],\n", + " [1.0073],\n", + " [1.0060],\n", + " [0.9897],\n", + " [0.9890],\n", + " [1.0018],\n", + " [1.0080],\n", + " [1.0166],\n", + " [0.9938],\n", + " [1.0027],\n", + " [1.0034],\n", + " [0.9914],\n", + " [0.9966],\n", + " [1.0029],\n", + " [1.0284],\n", + " [1.0191],\n", + " [1.0194],\n", + " [1.0211],\n", + " [1.0001],\n", + " [1.0191],\n", + " [1.0120],\n", + " [1.0018],\n", + " [1.0142],\n", + " [0.9898],\n", + " [0.9894],\n", + " [1.0251],\n", + " [0.9962],\n", + " [1.0180],\n", + " [1.0248],\n", + " [0.9895],\n", + " [0.9999],\n", + " [1.0104],\n", + " [0.9920],\n", + " [0.9984],\n", + " [1.0186],\n", + " [0.9909],\n", + " [1.0053],\n", + " [1.0025],\n", + " [1.0018],\n", + " [0.9935],\n", + " [0.9902],\n", + " [1.0165],\n", + " [0.9933],\n", + " [1.0121],\n", + " [0.9886],\n", + " [1.0088],\n", + " [1.0143],\n", + " [1.0159],\n", + " [0.9993],\n", + " [1.0196],\n", + " [0.9976],\n", + " [1.0055],\n", + " [0.9919],\n", + " [1.0172],\n", + " [0.9913],\n", + " [0.9911],\n", + " [0.9998],\n", + " [0.9886],\n", + " [0.9913],\n", + " [0.9995],\n", + " [1.0062],\n", + " [1.0155],\n", + " [0.9988],\n", + " [0.9909],\n", + " [0.9985],\n", + " [0.9887],\n", + " [0.9922],\n", + " [0.9896],\n", + " [1.0023],\n", + " [1.0051],\n", + " [0.9894],\n", + " [0.9916],\n", + " [0.9941],\n", + " [0.9898],\n", + " [1.0025],\n", + " [0.9894],\n", + " [1.0275],\n", + " [0.9901],\n", + " [1.0061],\n", + " [1.0181],\n", + " [1.0052],\n", + " [1.0079],\n", + " [1.0069],\n", + " [1.0138],\n", + " [0.9908],\n", + " [0.9918],\n", + " [0.9900],\n", + " [0.9915],\n", + " [1.0200],\n", + " [0.9967],\n", + " [0.9946],\n", + " [0.9977],\n", + " [1.0127],\n", + " [1.0114],\n", + " [0.9960],\n", + " [1.0024],\n", + " [1.0016],\n", + " [1.0221],\n", + " [1.0089],\n", + " [1.0177],\n", + " [0.9926],\n", + " [1.0214],\n", + " [1.0022],\n", + " [1.0048],\n", + " [0.9941],\n", + " [1.0168],\n", + " [1.0156],\n", + " [1.0033],\n", + " [1.0148],\n", + " [0.9950],\n", + " [0.9909],\n", + " [1.0190],\n", + " [1.0082],\n", + " [1.0075],\n", + " [1.0101],\n", + " [1.0111],\n", + " [1.0195],\n", + " [1.0041],\n", + " [1.0064],\n", + " [0.9930],\n", + " [1.0155],\n", + " [1.0103],\n", + " [1.0200],\n", + " [0.9993],\n", + " [0.9887],\n", + " [1.0086],\n", + " [0.9894],\n", + " [1.0022],\n", + " [0.9952],\n", + " [1.0106],\n", + " [0.9940],\n", + " [0.9894],\n", + " [1.0153],\n", + " [1.0005],\n", + " [1.0132],\n", + " [0.9937],\n", + " [1.0160],\n", + " [1.0131],\n", + " [1.0294],\n", + " [1.0165],\n", + " [1.0155],\n", + " [0.9922],\n", + " [0.9895],\n", + " [1.0063],\n", + " [0.9926],\n", + " [0.9944],\n", + " [1.0081],\n", + " [1.0137],\n", + " [1.0087],\n", + " [0.9918],\n", + " [0.9915],\n", + " [1.0092],\n", + " [1.0287],\n", + " [0.9906],\n", + " [0.9902],\n", + " [0.9896],\n", + " [0.9888],\n", + " [1.0291],\n", + " [1.0004],\n", + " [0.9894],\n", + " [0.9950],\n", + " [0.9961],\n", + " [1.0126],\n", + " [1.0084],\n", + " [1.0117],\n", + " [1.0202],\n", + " [1.0149],\n", + " [1.0147],\n", + " [0.9951],\n", + " [0.9934],\n", + " [1.0161],\n", + " [1.0140],\n", + " [1.0168],\n", + " [1.0188],\n", + " [1.0185],\n", + " [1.0100],\n", + " [1.0082],\n", + " [1.0204],\n", + " [0.9895],\n", + " [1.0106],\n", + " [1.0057],\n", + " [0.9895],\n", + " [0.9982],\n", + " [1.0010],\n", + " [1.0037],\n", + " [1.0080],\n", + " [0.9973],\n", + " [1.0180],\n", + " [1.0108],\n", + " [1.0078],\n", + " [0.9961],\n", + " [1.0000],\n", + " [1.0201],\n", + " [0.9893],\n", + " [0.9945],\n", + " [1.0174],\n", + " [1.0153],\n", + " [0.9899],\n", + " [1.0185],\n", + " [1.0305],\n", + " [0.9910],\n", + " [1.0147],\n", + " [1.0129],\n", + " [1.0104],\n", + " [1.0295],\n", + " [1.0130],\n", + " [0.9952],\n", + " [1.0135],\n", + " [0.9940],\n", + " [0.9901],\n", + " [1.0163],\n", + " [0.9893],\n", + " [1.0063],\n", + " [1.0091],\n", + " [1.0176],\n", + " [1.0226],\n", + " [0.9899],\n", + " [1.0023],\n", + " [1.0080],\n", + " [0.9926],\n", + " [1.0151],\n", + " [1.0185],\n", + " [1.0077],\n", + " [1.0113],\n", + " [0.9898],\n", + " [1.0042],\n", + " [0.9888],\n", + " [0.9945],\n", + " [1.0078],\n", + " [1.0082],\n", + " [0.9895],\n", + " [0.9886],\n", + " [1.0094],\n", + " [0.9896],\n", + " [1.0039],\n", + " [0.9896],\n", + " [0.9993],\n", + " [1.0217],\n", + " [0.9913],\n", + " [0.9887],\n", + " [0.9918],\n", + " [0.9922],\n", + " [0.9895],\n", + " [1.0186],\n", + " [0.9972],\n", + " [1.0097],\n", + " [1.0158],\n", + " [0.9996],\n", + " [0.9895],\n", + " [0.9907],\n", + " [1.0044],\n", + " [1.0167],\n", + " [0.9940],\n", + " [1.0066],\n", + " [1.0065],\n", + " [0.9943],\n", + " [0.9903],\n", + " [1.0094],\n", + " [1.0067],\n", + " [0.9895],\n", + " [1.0130],\n", + " [1.0171],\n", + " [0.9914],\n", + " [0.9992],\n", + " [1.0053],\n", + " [0.9964],\n", + " [1.0111],\n", + " [0.9917],\n", + " [0.9888],\n", + " [0.9960],\n", + " [0.9912],\n", + " [1.0021],\n", + " [0.9896],\n", + " [0.9932],\n", + " [0.9915],\n", + " [0.9922],\n", + " [0.9914],\n", + " [1.0132],\n", + " [0.9902],\n", + " [0.9904],\n", + " [1.0292],\n", + " [1.0082],\n", + " [1.0077],\n", + " [1.0159],\n", + " [1.0118],\n", + " [1.0162],\n", + " [1.0042],\n", + " [1.0155],\n", + " [1.0032],\n", + " [0.9900],\n", + " [1.0169],\n", + " [0.9894],\n", + " [1.0118],\n", + " [0.9975],\n", + " [0.9982],\n", + " [0.9926],\n", + " [1.0108],\n", + " [0.9944]], grad_fn=)\n" + ] + } + ], + "source": [ + "\n", + "# Generate random input\n", + "random_input = torch.rand(1000, input_dim) # You can adjust the batch size\n", + "output = model(random_input)\n", + "\n", + "print(output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "6c104591-885b-4a5d-bbf0-114a623e2bfb", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0., 0., 0., 0., 0., 1000., 0., 0., 0.,\n", + " 0.]),\n", + " array([-4.9999651e-01, -3.9999652e-01, -2.9999653e-01, -1.9999652e-01,\n", + " -9.9996522e-02, 3.4766272e-06, 1.0000347e-01, 2.0000347e-01,\n", + " 3.0000347e-01, 4.0000346e-01, 5.0000346e-01], dtype=float32),\n", + " )" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAikAAAGcCAYAAAAcfDBFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZ0UlEQVR4nO3dYWzc913H8c/FdV1Vi69qNbVxlsmxIIGWDiZIIKyZbD+AiVYh2gREbHUZQ2kQDKFodIsK64IiUg3EkkUFKUWi0ngwWrHIhbKmEnJJIu5BpiLsamtYVjtL09SiJJzbtXUT3/GgiqmX2HPC1f45eb2kU3X3+9v+5lcnfuv/P99Vms1mMwAAhVm22AMAAFyKSAEAiiRSAIAiiRQAoEgiBQAokkgBAIokUgCAIl232AP8fzQajbz88stZvnx5KpXKYo8DAMxDs9nMa6+9lq6urixbNvv5kiUdKS+//HJWrVq12GMAAFfg5MmT+cAHPjDr+pKOlOXLlyd55w/Z2dm5yNMAAPMxMTGRVatWTf8cn82SjpQLl3g6OztFCgAsMT/qqRqeOAsAFEmkAABFEikAQJFECgBQJJECABRJpAAARRIpAECRRAoAUCSRAgAUSaQAAEW64kh58803U6/XWzkLAMC0y46URqORxx57LGvWrMm///u/Tz8+Ojqabdu2Zf/+/RkYGMiJEyf+32sAwLXrst9g8NVXX01fX19eeuml6ccajUY2bdqUvXv3pr+/P6tXr86WLVtSq9WueA0AuLZVms1m84o+sFLJ0NBQent7881vfjMf//jHMzExkfb29kxNTaWzszPPPvtsXn311StaW7du3Y+cYWJiItVqNfV63bsgA8ASMd+f35d9JuVSarVaenp60t7eniRpa2tLT09PhoaG8vrrr1/R2qUiZXJyMpOTkzP+kADA1aklkTI+Pn5RCVWr1Zw6dSpvvfXWFa1dyu7du7Nz585WjAwsAd1feGqxR7hsYw/fvdgjwFWjJb+C3N7ePn025IJGo5FGo3HFa5eyY8eO1Ov16dvJkydbMT4AUKCWRMqKFSsuuvRSr9ezcuXKK167lI6OjnR2ds64AQBXp5ZESm9vb0ZHR3PhObjnzp3L2NhY+vr6rngNALi2XVGk/PDlmA0bNqSrqyuHDx9Okhw6dCjd3d1Zv379Fa8BANe2y37i7H/913/l0UcfTZL83d/9XW677bb8xE/8RAYHB7Nr166MjIykVqvlwIEDqVQqqVQqV7QGAFzbrvh1UkrgdVLg6ua3e+DqNN+f395gEAAokkgBAIokUgCAIokUAKBIIgUAKJJIAQCKJFIAgCKJFACgSCIFACiSSAEAiiRSAIAiiRQAoEgiBQAokkgBAIokUgCAIokUAKBIIgUAKJJIAQCKJFIAgCKJFACgSCIFACiSSAEAiiRSAIAiiRQAoEgiBQAokkgBAIokUgCAIokUAKBIIgUAKJJIAQCKJFIAgCKJFACgSCIFACiSSAEAiiRSAIAiiRQAoEgiBQAokkgBAIokUgCAIokUAKBIIgUAKJJIAQCKJFIAgCKJFACgSCIFACiSSAEAiiRSAIAiiRQAoEgiBQAokkgBAIokUgCAIokUAKBIIgUAKJJIAQCKJFIAgCKJFACgSCIFACiSSAEAiiRSAIAitTRSDh06lC9+8Yv56le/mnvvvTcvvPBCkmR0dDTbtm3L/v37MzAwkBMnTkx/zFxrAMC167pWfaKpqal8+tOfzrFjx3LdddflX//1X/P7v//7eeaZZ7Jp06bs3bs3/f39Wb16dbZs2ZJarZZGozHrGgBwbWvZmZQzZ87k5ZdfzhtvvJEkqVarOXv2bA4ePJjjx49n48aNSZL+/v4MDw/n6NGjc64BANe2lkXK+9///vzsz/5sPvWpT+V//ud/8tWvfjU7d+5MrVZLT09P2tvbkyRtbW3p6enJ0NDQnGuXMjk5mYmJiRk3AODq1NLnpDzxxBM5fvx4VqxYkV/+5V/OPffck/Hx8XR2ds44rlqt5tSpU3OuXcru3btTrVanb6tWrWrl+ABAQVr2nJQkGR8fz6/8yq/kxIkTue+++3LTTTelvb19+kzJBY1GI41GY861S9mxY0e2b98+fX9iYkKoAMBVqmWR8sYbb+RTn/pUvvWtb+WGG27IH//xH+czn/lMfvd3fzdHjhyZcWy9Xs/KlSszNTU169qldHR0pKOjo1UjAwAFa9nlnueffz7Lly/PDTfckCT50pe+lNdeey133XVXRkdH02w2kyTnzp3L2NhY+vr60tvbO+saAHBta1mk/NiP/VhOnTqVH/zgB0mSt99+O7fddlvuuuuudHV15fDhw0neeS2V7u7urF+/Phs2bJh1DQC4trXscs/NN9+cRx99NJ/97Gdz55135qWXXsrXvva1tLW1ZXBwMLt27crIyEhqtVoOHDiQSqWSSqUy6xoAcG2rNC9ca1mCJiYmUq1WU6/XL/otIWDp6/7CU4s9wmUbe/juxR4Bijffn9/euwcAKJJIAQCKJFIAgCKJFACgSCIFACiSSAEAiiRSAIAiiRQAoEgiBQAokkgBAIokUgCAIokUAKBIIgUAKJJIAQCKJFIAgCKJFACgSCIFACiSSAEAiiRSAIAiiRQAoEgiBQAokkgBAIokUgCAIokUAKBIIgUAKJJIAQCKJFIAgCKJFACgSCIFACiSSAEAiiRSAIAiiRQAoEgiBQAokkgBAIokUgCAIokUAKBIIgUAKJJIAQCKJFIAgCKJFACgSCIFACiSSAEAiiRSAIAiiRQAoEgiBQAokkgBAIokUgCAIokUAKBIIgUAKJJIAQCKJFIAgCKJFACgSCIFACiSSAEAiiRSAIAiiRQAoEgiBQAokkgBAIokUgCAIl33XnzS73znOzlw4EBWrVqVzZs3Z/ny5e/FlwEArmItP5PyyCOPZOvWrbnvvvty7733Zvny5RkdHc22bduyf//+DAwM5MSJE9PHz7UGAFy7Wnom5Rvf+Eb+9E//NM8//3ze//73J0kajUY2bdqUvXv3pr+/P6tXr86WLVtSq9XmXAMArm0tO5Ny/vz5/OEf/mE+97nPTQdKkhw8eDDHjx/Pxo0bkyT9/f0ZHh7O0aNH51wDAK5tLTuTcujQoZw8eTLHjh3L5s2bc+zYsTz00EP59re/nZ6enrS3tydJ2tra0tPTk6Ghobz++uuzrq1bt+6irzE5OZnJycnp+xMTE60aHwAoTMsiZXh4ODfddFO+/OUv5+abb87TTz+dTZs2pa+vL52dnTOOrVarOXXqVN56661Z1y5l9+7d2blzZ6tGBgAK1rLLPW+++WZ+8id/MjfffHOS5GMf+1huvfXWHDlyZPpMyQWNRiONRiPt7e2zrl3Kjh07Uq/Xp28nT55s1fgAQGFaFim33XZbfvCDH8x47AMf+EAeeOCBiy7L1Ov1rFy5MitWrJh17VI6OjrS2dk54wYAXJ1aFikf+chHMjY2lvPnz08/9tZbbyV559eMm81mkuTcuXMZGxtLX19fent7Z10DAK5tLYuUNWvW5Gd+5mfyzDPPJEnOnDmTV199NX/0R3+Urq6uHD58OMk7T7Dt7u7O+vXrs2HDhlnXAIBrW0tfJ+VrX/taPve5z2V4eDijo6N54okncuONN2ZwcDC7du3KyMhIarVaDhw4kEqlkkqlMusaAHBtqzQvXGtZgiYmJlKtVlOv1z0/Ba5C3V94arFHuGxjD9+92CNA8eb789sbDAIARRIpAECRRAoAUCSRAgAUSaQAAEUSKQBAkUQKAFAkkQIAFEmkAABFEikAQJFECgBQJJECABRJpAAARRIpAECRRAoAUCSRAgAUSaQAAEUSKQBAkUQKAFAkkQIAFEmkAABFEikAQJFECgBQJJECABRJpAAARRIpAECRRAoAUCSRAgAUSaQAAEUSKQBAkUQKAFAkkQIAFEmkAABFEikAQJFECgBQJJECABRJpAAARRIpAECRRAoAUCSRAgAUSaQAAEUSKQBAkUQKAFAkkQIAFEmkAABFEikAQJFECgBQJJECABRJpAAARRIpAECRRAoAUCSRAgAUSaQAAEUSKQBAkUQKAFAkkQIAFEmkAABFEikAQJFECgBQpJZHyhtvvJHbb789Y2NjSZLR0dFs27Yt+/fvz8DAQE6cODF97FxrAMC17bpWf8J9+/blO9/5TpKk0Whk06ZN2bt3b/r7+7N69eps2bIltVptzjUAgJaeSRkcHExfX9/0/YMHD+b48ePZuHFjkqS/vz/Dw8M5evTonGsAAC07k/L9738/p0+fzq/+6q9OP1ar1dLT05P29vYkSVtbW3p6ejI0NJTXX3991rV169Zd8mtMTk5mcnJy+v7ExESrxgcACtOSMylTU1N59NFHs3Xr1hmPj4+Pp7Ozc8Zj1Wo1p06dmnNtNrt37061Wp2+rVq1qhXjAwAFakmkPPLII7n//vuzbNnMT9fe3j59puSCRqORRqMx59psduzYkXq9Pn07efJkK8YHAArUkss9+/btywMPPDDjsbVr16bRaOSOO+6Y8Xi9Xs/KlSszNTWVI0eOXHJtNh0dHeno6GjFyABA4VpyJuW73/1u3nrrrelbkhw7dizPPvtsRkdH02w2kyTnzp3L2NhY+vr60tvbO+saAMB7+mJuGzZsSFdXVw4fPpwkOXToULq7u7N+/fo51wAAWv46Ke+2bNmyDA4OZteuXRkZGUmtVsuBAwdSqVRSqVRmXQMAqDQvXG9ZgiYmJlKtVlOv1y/6TSFg6ev+wlOLPcJlG3v47sUeAYo335/f3rsHACiSSAEAiiRSAIAiiRQAoEgiBQAokkgBAIokUgCAIokUAKBIIgUAKJJIAQCKJFIAgCKJFACgSCIFACiSSAEAiiRSAIAiiRQAoEgiBQAokkgBAIokUgCAIokUAKBIIgUAKJJIAQCKJFIAgCKJFACgSCIFACiSSAEAiiRSAIAiiRQAoEgiBQAokkgBAIokUgCAIokUAKBIIgUAKJJIAQCKJFIAgCKJFACgSCIFACiSSAEAiiRSAIAiiRQAoEgiBQAokkgBAIokUgCAIokUAKBIIgUAKJJIAQCKJFIAgCKJFACgSCIFACiSSAEAiiRSAIAiiRQAoEgiBQAokkgBAIokUgCAIokUAKBIIgUAKJJIAQCKJFIAgCK1LFKefPLJrF27Np2dnfnEJz6RM2fOJElGR0ezbdu27N+/PwMDAzlx4sT0x8y1BgBc21oSKS+++GKeeuqpfOMb38hjjz2WZ599Np///OfTaDSyadOm/Pqv/3q2bt2ae++9N1u2bEmSOdcAAK5rxSc5cuRI9u3bl+uvvz533HFHhoeH88QTT+TgwYM5fvx4Nm7cmCTp7+/P5s2bc/To0bz66quzrq1bt64VYwEAS1hLzqQMDAzk+uuvn75/66235oMf/GBqtVp6enrS3t6eJGlra0tPT0+GhobmXAMAeE+eOPvcc8/l/vvvz/j4eDo7O2esVavVnDp1as612UxOTmZiYmLGDQC4OrU8Uk6fPp3z589n8+bNaW9vnz5TckGj0Uij0ZhzbTa7d+9OtVqdvq1atarV4wMAhWhppExNTWXPnj3Zt29fkmTFihUXne2o1+tZuXLlnGuz2bFjR+r1+vTt5MmTrRwfAChISyPlK1/5SrZv3573ve99SZK77roro6OjaTabSZJz585lbGwsfX196e3tnXVtNh0dHens7JxxAwCuTi357Z4k2bNnT9asWZOzZ8/m7NmzefHFF3P+/Pl0dXXl8OHD+ehHP5pDhw6lu7s769evT7PZnHUNAKAlkfL4449n+/bt02dFkuTGG2/MK6+8ksHBwezatSsjIyOp1Wo5cOBAKpVKKpXKrGsAAJXmu8tiiZmYmEi1Wk29XnfpB65C3V94arFHuGxjD9+92CNA8eb789t79wAARRIpAECRRAoAUCSRAgAUSaQAAEUSKQBAkUQKAFAkkQIAFEmkAABFEikAQJFECgBQJJECABRJpAAARRIpAECRRAoAUCSRAgAUSaQAAEUSKQBAkUQKAFAkkQIAFEmkAABFEikAQJFECgBQJJECABRJpAAARRIpAECRRAoAUCSRAgAUSaQAAEUSKQBAkUQKAFAkkQIAFEmkAABFEikAQJFECgBQJJECABRJpAAARRIpAECRRAoAUCSRAgAUSaQAAEUSKQBAkUQKAFAkkQIAFEmkAABFEikAQJFECgBQJJECABRJpAAARRIpAECRRAoAUCSRAgAUSaQAAEUSKQBAkUQKAFAkkQIAFEmkAABFEikAQJFECgBQJJECABRp0SNldHQ027Zty/79+zMwMJATJ04s9kgAQAGuW8wv3mg0smnTpuzduzf9/f1ZvXp1tmzZklqttphjAQAFWNQzKQcPHszx48ezcePGJEl/f3+Gh4dz9OjRxRwLACjAop5JqdVq6enpSXt7e5Kkra0tPT09GRoayrp16y46fnJyMpOTk9P36/V6kmRiYmJhBgYWVGPyjcUe4bL59wh+tAt/T5rN5pzHLWqkjI+Pp7Ozc8Zj1Wo1p06duuTxu3fvzs6dOy96fNWqVe/JfACXq7pnsSeApeO1115LtVqddX1RI6W9vX36LMoFjUYjjUbjksfv2LEj27dvn3HsmTNncsstt6RSqbyns5ZuYmIiq1atysmTJy8KP1rLXi8M+7ww7PPCsM8zNZvNvPbaa+nq6przuEWNlBUrVuTIkSMzHqvX61m5cuUlj+/o6EhHR8eMx2666ab3arwlqbOz01+ABWKvF4Z9Xhj2eWHY5/8z1xmUCxb1ibO9vb0ZHR2dviZ17ty5jI2Npa+vbzHHAgAKsKiRsmHDhnR1deXw4cNJkkOHDqW7uzvr169fzLEAgAIs6uWeZcuWZXBwMLt27crIyEhqtVoOHDhwzT+/5Ep0dHTkoYceuuhyGK1nrxeGfV4Y9nlh2OcrU2n+qN//AQBYBIv+svgAAJciUgCAIokUAKBIIgUAKJJIuQrs2bMnf/Znf5YHH3wwf/mXf/kjj3/jjTdy++23Z2xs7L0f7ioy331+8skns3bt2nR2duYTn/hEzpw5s4BTLi2jo6PZtm1b9u/fn4GBgZw4ceKSx13u9zgXm89eT0xM5JOf/GRuuumm9PT05O///u8XYdKlbb7f0xf87d/+bX7rt35rYYZbiposaY8//nhz48aN0/d/8Rd/sfnNb35zzo95+OGHm0mao6Oj7/F0V4/57vP3vve95tatW5vPP/988x/+4R+aN998c/N3fud3FnLUJWNqaqr5Uz/1U81/+Zd/aTabzeYzzzzT/IVf+IWLjruS73Fmmu9ef/7zn2/+4z/+Y/M//uM/mp/85Ceb7e3tzRdffHGhx12y5rvPFxw7dqz50z/908377rtvgSZcepxJWeL+4i/+Ir/0S780ff+ee+7Jvn37Zj1+cHDQK/pegfnu85EjR7Jv377ccccd+fjHP57Pfvaz+bd/+7eFHHXJOHjwYI4fP56NGzcmSfr7+zM8PJyjR4/OOO5yv8e52Hz2+ty5c7n99ttzzz335EMf+lD+5m/+JsuWLcu3vvWtxRp7yZnv93SSvP322/n617+ezZs3L/CUS4tIWcLefvvtPPfcc1m7du30Y2vWrMnQ0NAlj//+97+f06dPe0Xfy3Q5+zwwMJDrr79++v6tt96aD37wgwsy51JTq9XS09Mz/SajbW1t6enpmbGvl/s9zqXNZ6/b29szMDAwff+GG25ItVr1/XsZ5rPPF+zbty+/93u/t9AjLjkiZQn77//+75w/f37Gm1VVq9W8+eabOXv27Ixjp6am8uijj2br1q0LPeaSdzn7/MOee+653H///e/1iEvS+Pj4RW+0Vq1Wc+rUqen7/5+95//MZ69/2EsvvZSVK1fm53/+59/r8a4a893nZ555Jh/+8Idzyy23LOR4S5JIWcIu1PqF/yZJo9GY8d8LHnnkkdx///1Ztsz/8st1Ofv8bqdPn8758+edzp1Fe3v7jD1N3tnPd+/ple49M81nr3/YX//1X2f//v3v9WhXlfns8/j4eEZGRtLf37/Q4y1Ji/rePczt9OnT+fCHPzzr+m/8xm/k+uuvz8TExPRj9Xo9N9xww0WFvm/fvjzwwAMzHlu7dm3+4A/+IH/+53/e2sGXmFbu8wVTU1PZs2eP507MYcWKFTly5MiMx+r1elauXDl9/5Zbbrnsvedi89nrdxsaGsqHPvSh/NzP/dxCjHfVmM8+Hzx4MA8++GAefPDBJMn58+fTbDbz9a9/PePj46lWqws6c+lESsFWrFiRV155Zc5jvv3tb+d73/ve9P3//M//TG9v70XHffe7351xv1Kp5NixY+nu7m7FqEtaK/f5gq985SvZvn173ve+9yV557kV736uCklvb2++/OUvp9lsplKp5Ny5cxkbG5vxxO5KpZKPfvSjl7X3XGw+e33BCy+8kBdffDGf+cxnkrzzQ7Strc0bv87DfPZ5YGBgxnN/vvSlL2VsbCyPPfbYIkxcPuf+l7hPf/rTeeqpp6bvP/300/nt3/7tJMnIyEh27ty5WKNdVS5nn/fs2ZM1a9bk7NmzeeGFF/LP//zPefrppxd85tJt2LAhXV1dOXz4cJLk0KFD6e7uzvr167Nz586MjIwkmXvvmZ/57vUrr7ySv/qrv8pHPvKRvPDCCxkeHs7u3bsXc/QlZb77zPw5k7LE/eZv/mbGxsbyxS9+MVNTU/nYxz6WX/u1X0uSHDt2LP/0T/+Uhx56aJGnXPrmu8+PP/54tm/fnua73lz8xhtv/JFnaq5Fy5Yty+DgYHbt2pWRkZHUarUcOHAglUolTz75ZO68887ceeedc+498zOfvf7xH//x3H333XnuuedmXKb8kz/5E2dR5mm+39PMX6X57n9NAQAK4XIPAFAkkQIAFEmkAABFEikAQJFECgBQJJECABRJpAAARRIpAECRRAoAUCSRAgAUSaQAAEX6X9Bwo2YT7hXtAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(output.detach().cpu().numpy()[:,0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6bb688e-0fdf-44e5-9312-fd1521920e5d", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7d8ce915-82b3-40da-8037-995c60ec5653", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15a26ff5-fac9-4e9e-b122-ae81fff3ae0f", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40ecf26c-e478-4303-bff7-d83fbffa7bd7", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "fb032c0d-4448-4985-8b96-ddc4dfe96a3e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from torch import nn\n", + "\n", + "\n", + "class RBF(nn.Module):\n", + "\n", + " def __init__(self, n_kernels=5, mul_factor=2.0, bandwidth=None):\n", + " super().__init__()\n", + " self.bandwidth_multipliers = mul_factor ** (torch.arange(n_kernels) - n_kernels // 2)\n", + " self.bandwidth = bandwidth\n", + "\n", + " def get_bandwidth(self, L2_distances):\n", + " if self.bandwidth is None:\n", + " n_samples = L2_distances.shape[0]\n", + " return L2_distances.data.sum() / (n_samples ** 2 - n_samples)\n", + "\n", + " return self.bandwidth\n", + "\n", + " def forward(self, X):\n", + " L2_distances = torch.cdist(X, X) ** 2\n", + " return torch.exp(-L2_distances[None, ...] / (self.get_bandwidth(L2_distances) * self.bandwidth_multipliers)[:, None, None]).sum(dim=0)\n", + "\n", + "\n", + "class MMDLoss(nn.Module):\n", + "\n", + " def __init__(self, kernel=RBF()):\n", + " super().__init__()\n", + " self.kernel = kernel\n", + "\n", + " def forward(self, X, Y):\n", + " K = self.kernel(torch.vstack([X, Y]))\n", + "\n", + " X_size = X.shape[0]\n", + " XX = K[:X_size, :X_size].mean()\n", + " XY = K[:X_size, X_size:].mean()\n", + " YY = K[X_size:, X_size:].mean()\n", + " return XX - 2 * XY + YY" + ] + }, + { + "cell_type": "code", + "execution_count": 119, + "id": "b5fa6136-7df6-4cf0-bc58-1ffc2bd5b3ec", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "input_dim = 1\n", + "hidden_dim = 128\n", + "output_dim = 1\n", + "num_hidden_layers = 3\n", + "\n", + "model = RandomToNormalNN(input_dim, hidden_dim, output_dim, num_hidden_layers)\n", + "\n", + "nepochs=1000\n", + "bs = 1000\n", + "z_dim=1\n", + "epsilon = 0.5\n", + "\n", + "# Create an instance of the neural network\n", + "base_distribution = D.Normal(torch.zeros(z_dim), torch.ones(z_dim))\n", + "mmd = MMDLoss()\n", + "\n", + "optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 120, + "id": "752b4aac-3cd0-4d43-8abf-50a4a45d59a6", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[120], line 10\u001b[0m\n\u001b[1;32m 6\u001b[0m randomN_input \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn(bs, input_dim)\n\u001b[1;32m 8\u001b[0m output \u001b[38;5;241m=\u001b[39m model(random_input)\n\u001b[0;32m---> 10\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1000\u001b[39m\u001b[38;5;241m*\u001b[39m\u001b[43mmmd\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutput\u001b[49m\u001b[43m,\u001b[49m\u001b[43mrandomN_input\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m-\u001b[39mbase_distribution\u001b[38;5;241m.\u001b[39mlog_prob(output)\u001b[38;5;241m.\u001b[39mmean()\n\u001b[1;32m 11\u001b[0m \u001b[38;5;66;03m#loss = loss.mean()\u001b[39;00m\n\u001b[1;32m 12\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/DLenv2/lib/python3.9/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "Cell \u001b[0;32mIn[81], line 33\u001b[0m, in \u001b[0;36mMMDLoss.forward\u001b[0;34m(self, X, Y)\u001b[0m\n\u001b[1;32m 30\u001b[0m K \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mkernel(torch\u001b[38;5;241m.\u001b[39mvstack([X, Y]))\n\u001b[1;32m 32\u001b[0m X_size \u001b[38;5;241m=\u001b[39m X\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 33\u001b[0m XX \u001b[38;5;241m=\u001b[39m \u001b[43mK\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43mX_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43mX_size\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 34\u001b[0m XY \u001b[38;5;241m=\u001b[39m K[:X_size, X_size:]\u001b[38;5;241m.\u001b[39mmean()\n\u001b[1;32m 35\u001b[0m YY \u001b[38;5;241m=\u001b[39m K[X_size:, X_size:]\u001b[38;5;241m.\u001b[39mmean()\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "for e in range(nepochs):\n", + " \n", + " optimizer.zero_grad()\n", + " \n", + " random_input = torch.rand(bs, input_dim)\n", + " randomN_input = torch.randn(bs, input_dim)\n", + "\n", + " output = model(random_input)\n", + " \n", + " loss = mmd(output,randomN_input) #-base_distribution.log_prob(output).mean()\n", + " #loss = loss.mean()\n", + " loss.backward()\n", + " optimizer.step() \n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 116, + "id": "91e6f895-a161-46d1-80ec-c93ceaba765f", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 99., 102., 110., 89., 105., 100., 100., 101., 105., 89.]),\n", + " array([4.2331219e-04, 1.0035120e-01, 2.0027909e-01, 3.0020696e-01,\n", + " 4.0013486e-01, 5.0006270e-01, 5.9999061e-01, 6.9991851e-01,\n", + " 7.9984641e-01, 8.9977425e-01, 9.9970216e-01], dtype=float32),\n", + " )" + ] + }, + "execution_count": 116, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAGcCAYAAADknMuyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaU0lEQVR4nO3df2xdd33/8ZeTuo7Q4lu1AxIHU8caaYEVOkazdWlQbGkD0SrKNMEyfqRjaBDUTUMZo81atgYyUtjUNYtgUjqNioLU0UEUpgpSgbwlke4froqwK0GHqZ2lLliUdNddmzhO7v3+wTcWpklw4dofX/fxkK7Kvff4nDcfWb3PnnN9b1uj0WgEAKCAZaUHAABeuoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoJhLSg/w89Tr9Tz11FNZuXJl2traSo8DAMxBo9HIs88+m66urixbduHzHos+RJ566ql0d3eXHgMA+AUcP348r3rVqy74/KIPkZUrVyb5yf+Rzs7OwtMAAHMxOTmZ7u7umdfxC1n0IXLuckxnZ6cQAYAW8/PeVuHNqgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFDMJaUHYOnrue2h0iP8QsbuurH0CABLnjMiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUMwlpQcAmqfntodKj/Cijd11Y+kRWKT8Pr80OCMCABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAU45NVW0wrftIgAFzIL3xG5OTJk6nVas2cBQB4iXnRIVKv13Pfffdl3bp1+da3vjXz+OjoaLZv3579+/dn27ZtOXbs2JyeAwBeul70pZmnn346fX19efLJJ2ceq9fr2bx5c/bu3Zv+/v6sXbs2W7duTbVavehzACwMl3VZrF70GZFXvOIVufLKK2c9dujQoYyMjGTjxo1Jkv7+/gwNDWVwcPCizwEAL21NebNqtVpNb29v2tvbkyTLly9Pb29vBgYG8n//938XfO666657wb6mpqYyNTU1c39ycrIZIwIAi1BTQmRiYiKdnZ2zHqtUKhkfH8+pU6cu+Nz57NmzJ7t27WrGWEALcMkAXtqa8jki7e3tM2c8zqnX66nX6xd97nx27tyZWq02czt+/HgzRgQAFqGmhMjq1atfcAmlVqtlzZo1F33ufDo6OtLZ2TnrBgAsTU25NLNp06Z8+tOfTqPRSFtbW6anpzM2Npa+vr6cOXPmgs+V5pQwAJT1C50R+dnLKtdff326urpy5MiRJMnhw4fT09OT9evXX/Q5AOCl7UWfEfnRj36Ue++9N0nyhS98IatWrcrVV1+dgwcPZvfu3RkeHk61Ws2BAwfS1taWtra2Cz4HALy0tTUajUbpIS5mcnIylUoltVqt6e8XcWmGixm768bSI7xofqehrFb898Z8mevrt2/fBQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMVcUnoAWKx6bnuo9AhAi2nFf2+M3XVj0eM7IwIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKCYS5q5s8OHD+cb3/hGfvVXfzWDg4O5/fbbc/XVV2d0dDSf+tSn8qY3vSlHjx7NJz7xiVx55ZXNPDQA0IKaFiJnz57N+973vjz++OO55JJL8l//9V/5sz/7szz88MPZvHlz9u7dm/7+/qxduzZbt25NtVpt1qEBgBbVtEszJ06cyFNPPZXnn38+SVKpVPLMM8/k0KFDGRkZycaNG5Mk/f39GRoayuDgYLMODQC0qKaFyMtf/vL85m/+Zt7znvfkf//3f/NP//RP2bVrV6rVanp7e9Pe3p4kWb58eXp7ezMwMHDe/UxNTWVycnLWDQBYmpr6ZtUHH3wwIyMjWb16dd761rfmpptuysTERDo7O2dtV6lUMj4+ft597NmzJ5VKZebW3d3dzBEBgEWkqW9WnZiYyNvf/vYcO3YsN998cy677LK0t7fPnA05p16vp16vn3cfO3fuzI4dO2buT05OihEAWKKaFiLPP/983vOe9+SRRx7JihUrcscdd+T9739/PvShD+Xo0aOztq3ValmzZs1599PR0ZGOjo5mjQUALGJNuzTz2GOPZeXKlVmxYkWS5M4778yzzz6bG264IaOjo2k0GkmS6enpjI2Npa+vr1mHBgBaVNNC5Nd+7dcyPj6e5557Lkly+vTprFq1KjfccEO6urpy5MiRJD/5rJGenp6sX7++WYcGAFpU0y7NXH755bn33nvz53/+57nmmmvy5JNP5v7778/y5ctz8ODB7N69O8PDw6lWqzlw4EDa2tqadWgAoEW1Nc5dM1mkJicnU6lUUqvVXvDXN7+sntseaur+AKDVjN1147zsd66v375rBgAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABRzyXzs9Dvf+U4OHDiQ7u7ubNmyJStXrpyPwwAALa7pIfKZz3wmDzzwQB544IGsWbMmSTI6OppPfepTedOb3pSjR4/mE5/4RK688spmHxoAaDFNDZGvfOUr+fjHP57HHnssL3/5y5Mk9Xo9mzdvzt69e9Pf35+1a9dm69atqVarzTw0ANCCmvYekTNnzuTDH/5wPvKRj8xESJIcOnQoIyMj2bhxY5Kkv78/Q0NDGRwcbNahAYAW1bQQOXz4cI4fP57HH388W7ZsyWtf+9o88MADqVar6e3tTXt7e5Jk+fLl6e3tzcDAwHn3MzU1lcnJyVk3AGBpatqlmaGhoVx22WX59Kc/ncsvvzxf//rXs3nz5vT19aWzs3PWtpVKJePj4+fdz549e7Jr165mjQUALGJNOyNy8uTJvPa1r83ll1+eJHnb296WV77ylTl69OjM2ZBz6vV66vX6efezc+fO1Gq1mdvx48ebNSIAsMg0LURWrVqV5557btZjr3rVq/LRj370BZdXarXazF/U/KyOjo50dnbOugEAS1PTQmTDhg0ZGxvLmTNnZh47depUkp/8+W6j0UiSTE9PZ2xsLH19fc06NADQopoWIuvWrcu1116bhx9+OEly4sSJPP300/mrv/qrdHV15ciRI0l+8qbWnp6erF+/vlmHBgBaVFM/R+T+++/PRz7ykQwNDWV0dDQPPvhgXvayl+XgwYPZvXt3hoeHU61Wc+DAgbS1tTXz0ABAC2prnLtmskhNTk6mUqmkVqs1/f0iPbc91NT9AUCrGbvrxnnZ71xfv33pHQBQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxTQ9RJ5//vm87nWvy9jYWJJkdHQ027dvz/79+7Nt27YcO3as2YcEAFrUJc3e4b59+/Kd73wnSVKv17N58+bs3bs3/f39Wbt2bbZu3ZpqtdrswwIALaipZ0QOHjyYvr6+mfuHDh3KyMhINm7cmCTp7+/P0NBQBgcHm3lYAKBFNS1E/ud//ic/+MEPsn79+pnHqtVqent7097eniRZvnx5ent7MzAwcMH9TE1NZXJyctYNAFiamhIiZ8+ezb333psPfOADsx6fmJhIZ2fnrMcqlUrGx8cvuK89e/akUqnM3Lq7u5sxIgCwCDUlRD7zmc/kgx/8YJYtm7279vb2mbMh59Tr9dTr9Qvua+fOnanVajO348ePN2NEAGARasqbVfft25ePfvSjsx676qqrUq/X8/rXv37W47VaLWvWrLngvjo6OtLR0dGMsQCARa4pIfK9731v1v22trY8/vjjGR8fz9vf/vY0Go20tbVleno6Y2Njs97QCgC8dM3rB5pdf/316erqypEjR5Ikhw8fTk9Pz6w3tAIAL11N/xyRn7Zs2bIcPHgwu3fvzvDwcKrVag4cOJC2trb5PCwA0CLmJUQajcbM/163bl0+//nPJ0luueWW+TgcANCifNcMAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFNC1EvvrVr+aqq65KZ2dn/uAP/iAnTpxIkoyOjmb79u3Zv39/tm3blmPHjjXrkABAi2tKiDzxxBN56KGH8pWvfCX33Xdf/vM//zO33npr6vV6Nm/enHe+8535wAc+kPe+973ZunVrMw4JACwBlzRjJ0ePHs2+ffty6aWX5vWvf32Ghoby4IMP5tChQxkZGcnGjRuTJP39/dmyZUsGBwdz3XXXNePQAEALa8oZkW3btuXSSy+duf/KV74yr371q1OtVtPb25v29vYkyfLly9Pb25uBgYEL7mtqaiqTk5OzbgDA0jQvb1Z99NFH88EPfjATExPp7Oyc9VylUsn4+PgFf3bPnj2pVCozt+7u7vkYEQBYBJoeIj/4wQ9y5syZbNmyJe3t7TNnQ86p1+up1+sX/PmdO3emVqvN3I4fP97sEQGARaIp7xE55+zZs7nnnnuyb9++JMnq1atz9OjRWdvUarWsWbPmgvvo6OhIR0dHM8cCABappp4R+cd//Mfs2LEjv/Irv5IkueGGGzI6OppGo5EkmZ6eztjYWPr6+pp5WACgRTXtjMg999yTdevW5ZlnnskzzzyTJ554ImfOnElXV1eOHDmSt7zlLTl8+HB6enqyfv36Zh0WAGhhTQmRL33pS9mxY8fMmY8kednLXpYf/vCHOXjwYHbv3p3h4eFUq9UcOHAgbW1tzTgsANDi2ho/XQ+L0OTkZCqVSmq12gv+AueX1XPbQ03dHwC0mrG7bpyX/c719dt3zQAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUMyChMjo6Gi2b9+e/fv3Z9u2bTl27NhCHBYAWOQume8D1Ov1bN68OXv37k1/f3/Wrl2brVu3plqtzvehAYBFbt7PiBw6dCgjIyPZuHFjkqS/vz9DQ0MZHByc70MDAIvcvJ8RqVar6e3tTXt7e5Jk+fLl6e3tzcDAQK677roXbD81NZWpqamZ+7VaLUkyOTnZ9NnqU883fZ8A0Erm4/X1p/fbaDQuut28h8jExEQ6OztnPVapVDI+Pn7e7ffs2ZNdu3a94PHu7u55mQ8AXsoq98zv/p999tlUKpULPj/vIdLe3j5zNuScer2eer1+3u137tyZHTt2zNr2xIkTueKKK9LW1ta0uSYnJ9Pd3Z3jx4+/IJRoHuu8cKz1wrDOC8M6L4z5XOdGo5Fnn302XV1dF91u3kNk9erVOXr06KzHarVa1qxZc97tOzo60tHRMeuxyy67bL7GS2dnp1/yBWCdF461XhjWeWFY54UxX+t8sTMh58z7m1U3bdqU0dHRmWtE09PTGRsbS19f33wfGgBY5OY9RK6//vp0dXXlyJEjSZLDhw+np6cn69evn+9DAwCL3Lxfmlm2bFkOHjyY3bt3Z3h4ONVqNQcOHGjq+z1+ER0dHfnbv/3bF1wGorms88Kx1gvDOi8M67wwFsM6tzV+3t/VAADME981AwAUI0QAgGKECABQzLy/WRVgKZuens4Xv/jFnDhxIhs2bMhv/dZvlR4JLujkyZM5ffr0nD7fY6Es+TMio6Oj2b59e/bv359t27bl2LFj593unnvuySc/+cncfvvtufvuuxd4ytY3l3WenJzMu9/97lx22WXp7e3Nv/3bvxWYtLXN9ff5nM997nP54z/+44UZbgmZ6zpPTExkw4YNOXXqVHbs2CFCfgFzWevp6enccccd+exnP5udO3fm4x//eIFJW1u9Xs99992XdevW5Vvf+tYFtyvyWthYws6ePdv49V//9cY3v/nNRqPRaDz88MON3/7t337Bdl/60pcaGzdunLn/O7/zO42vfe1rCzZnq5vrOt96662N//iP/2h8+9vfbrz73e9utLe3N5544omFHrdlzXWdz3n88ccbb3zjGxs333zzAk24NMx1nU+fPt1485vf3LjzzjsXesQlY65rfffddzf+/u//fub+pk2bGkeOHFmwOZeCiYmJxtjYWCNJY2Bg4LzblHotXNJnRA4dOpSRkZFs3LgxSdLf35+hoaEMDg7O2u4f/uEf8nu/93sz92+66abs27dvQWdtZXNZ5+np6bzuda/LTTfdlDe84Q35l3/5lyxbtiyPPPJIqbFbzlx/n5Pk9OnTeeCBB7Jly5YFnrL1zXWd//Vf/zXf+9738td//dclxlwS5rrWIyMjOXHixMz9SqWSZ555ZkFnbXWveMUrcuWVV150m1KvhUs6RKrVanp7e2e+dG/58uXp7e3NwMDAzDanT5/Oo48+mquuumrmsXXr1s3ahoubyzq3t7dn27ZtM/dXrFiRSqWSV7/61Qs+b6uayzqfs2/fvtxyyy0LPeKSMNd1/uIXv5jVq1dnx44dufbaa/P7v//7s14s+fnmutZbtmzJ3r1787WvfS2PPPJIzpw5k7e+9a0lRl6ySr4WLukQmZiYeMGX+FQqlYyPj8/c//GPf5wzZ87M2q5SqeTkyZOKe47mss4/68knn8yaNWtcU38R5rrODz/8cH7jN34jV1xxxUKOt2TMdZ2Hhobyjne8I/v27cvg4GCefvrp3HbbbQs5asub61r/7u/+bv7u7/4uN910U2655ZZ8+ctfzqWXXrqQoy55JV8Ll3SItLe3z5T2OfV6PfV6fdY2P/3Pc9v89D+5uLms88/653/+5+zfv3++R1tS5rLOExMTGR4eTn9//0KPt2TM9ff55MmTueGGG2Z+5uabb85DDz20YHMuBXNd60ajkR//+Mf55Cc/me9///vZvHlzpqenF3LUJa/ka+GSDpHVq1dncnJy1mO1Wi1r1qyZuX/FFVfk0ksvnbVdrVbLihUr/BflHM1lnX/awMBA3vCGN+TNb37zQoy3ZMxlnQ8dOpTbb789K1asyIoVK7J79+7cf//9WbFiRWq12kKP3JLm+vu8atWqPPfcczP3u7u7nUV9kea61nfffXcqlUpuvfXWPPLII3nsscfyuc99biFHXfJKvhYu6RDZtGlTRkdH0/j/X6czPT2dsbGx9PX1zWzT1taWt7zlLfn+978/89h///d/Z9OmTQs9bsuayzqf893vfjdPPPFE/vAP/zBJcubMmZmf4+Lmss7btm3LqVOnZm533HFH3vve9+bUqVOL6nMDFrO5/j5v2LAhIyMjM/dPnTqVnp6ehRy15c11rb/5zW/mmmuuSZL09PTkL/7iL/Ltb397weddykq+Fi7pELn++uvT1dWVI0eOJEkOHz6cnp6erF+/Prt27crw8HCS5H3ve9+sU6pf//rX8yd/8idFZm5Fc13nH/7wh/nsZz+bDRs25Lvf/W6GhoayZ8+ekqO3lLmuM7+cua7zn/7pn+bf//3fZ37u6NGjef/7319k5lY117V+4xvfmEcffXTm506ePJlrr722xMgt7XyXWBbDa+GS/mTVZcuW5eDBg9m9e3eGh4dTrVZz4MCBtLW15atf/WquueaaXHPNNXnXu96VsbGx/M3f/E3Onj2bt73tbXnHO95RevyWMZd1fs1rXpMbb7wxjz766Kw/B/vYxz6Wtra2gtO3jrn+PvPLmes69/X15Y/+6I/yoQ99KN3d3anX6/nwhz9cevyWMte1/tjHPpa//Mu/zJ133pmVK1fm7Nmzou9F+tGPfpR77703SfKFL3whq1atytVXX70oXgvbGs6LAwCFLOlLMwDA4iZEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFDM/wM7pWpKJrfwNwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(random_input.detach().cpu().numpy()[:,0])" + ] + }, + { + "cell_type": "code", + "execution_count": 121, + "id": "a9647567-57f9-4bde-8db7-2565939da253", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([20., 17., 13., 8., 10., 12., 9., 11., 15., 10., 15., 15., 18.,\n", + " 16., 24., 35., 30., 29., 32., 25., 36., 27., 38., 27., 32., 32.,\n", + " 24., 31., 33., 28., 28., 26., 27., 27., 24., 27., 17., 18., 10.,\n", + " 9., 12., 7., 11., 8., 15., 5., 14., 18., 13., 12.]),\n", + " array([-1.945116 , -1.8679345 , -1.7907529 , -1.7135713 , -1.6363897 ,\n", + " -1.5592082 , -1.4820266 , -1.404845 , -1.3276634 , -1.2504818 ,\n", + " -1.1733003 , -1.0961187 , -1.0189371 , -0.9417555 , -0.8645739 ,\n", + " -0.7873923 , -0.71021074, -0.63302916, -0.5558476 , -0.47866598,\n", + " -0.4014844 , -0.32430282, -0.24712123, -0.16993965, -0.09275807,\n", + " -0.01557648, 0.0616051 , 0.13878669, 0.21596827, 0.29314986,\n", + " 0.37033144, 0.447513 , 0.5246946 , 0.6018762 , 0.6790578 ,\n", + " 0.75623935, 0.83342093, 0.9106025 , 0.9877841 , 1.0649657 ,\n", + " 1.1421473 , 1.2193289 , 1.2965105 , 1.373692 , 1.4508736 ,\n", + " 1.5280552 , 1.6052368 , 1.6824183 , 1.7595999 , 1.8367815 ,\n", + " 1.9139631 ], dtype=float32),\n", + " )" + ] + }, + "execution_count": 121, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhsAAAGcCAYAAABwemJAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAeXUlEQVR4nO3df2xV9f3H8del3l5TaeusFigD247cmrmYzFRWXOsoMTHTDUg0hBTosumUhZRfCZGyqukEa/xD6yBKamLQdBBBJGAAISFVStZkTWDSOXVBWlIpVCrktovl0vae7x8L9+uVFu69nHfbc/t8JOePe+7nfu77w6fQF597z+f4HMdxBAAAYGTSWBcAAABSG2EDAACYImwAAABThA0AAGCKsAEAAEwRNgAAgCnCBgAAMHXLWBcgSZFIRF1dXcrMzJTP5xvrcgAAQBwcx1FfX5/y8vI0adLI6xfjImx0dXVpxowZY10GAABIQmdnp3784x+P+Py4CBuZmZmS/ldsVlbWGFcDAADi0dvbqxkzZkR/j49kXISNqx+dZGVlETYAAPCYG30Fgi+IAgAAU4QNAABgirABAABMETYAAIApwgYAADBF2AAAAKYIGwAAwBRhAwAAmCJsAAAAU4QNAABgirABAABMETYAAIApwgYAADBF2AAAAKYIGwAAwNQtY10AgLGRv37/Ddt0vPzYKFQCINWxsgEAAEwRNgAAgCnCBgAAMEXYAAAApggbAADAFGEDAACYImwAAABThA0AAGCKsAEAAEwRNgAAgCnCBgAAMEXYAAAApggbAADAFGEDAACYImwAAABThA0AAGCKsAEAAEzdMtYFAEhM/vr9N2zT8fJjo1AJAMSHlQ0AAGCKsAEAAEwRNgAAgCnCBgAAMGUWNrq6uqy6BgAAHpJw2Dhx4oRKS0t1xx136OGHH1ZPT48kyXEcBYNB+Xw++Xw+LV261PViAQCA9yQUNi5fvqzdu3fr8OHD6uzs1HfffadXX31VknTw4EGtXLlSra2tam1t1a5du0wKBgAA3pJQ2AiFQnr++eeVkZGh2267TWVlZZo06X9dbNmyRWlpacrNzVVxcbFycnJMCgYAAN6SUNiYMmWK0tPTJUlXrlzR+fPntWbNGvX19SkcDqumpkYFBQWqqqqS4zgj9hMOh9Xb2xtzAACA1JTUDqIHDhxQTU2Nenp69Nlnn+mhhx7SkSNHNDAwoK1bt2rNmjWaNWuWVq1aNezr6+rqVFtbe1OFA25gN04AsJfU1SiPPPKI3n//fT344IMxXwT1+/2qqqrS+vXrtX379hFfX11drVAoFD06OzuTKQMAAHhAUmEjLS1NhYWFevvtt3XhwgVduHAh5vkFCxYoFAqN+PpAIKCsrKyYAwAApKab2mcjIyNDd955p370ox/FnB8aGlJRUdFNFQYAAFJDQmHj22+/1Ycffhj98ucnn3yiZcuWqaWlRY2NjdHzDQ0NWrdunfvVAgAAz0noC6Lt7e166qmnVFRUpCeeeEKTJ0/Wpk2btGPHDq1evVo7duxQSUmJKisrVVpaalUzAADwkITCRnFxsbq7u685X1FRoYqKCteKAgAAqYMbsQEAAFOEDQAAYIqwAQAATBE2AACAKcIGAAAwRdgAAACmCBsAAMAUYQMAAJgibAAAAFOEDQAAYIqwAQAATCV0bxQAGO/y1++/YZuOlx8bhUoAXMXKBgAAMEXYAAAApggbAADAFGEDAACYImwAAABThA0AAGCKsAEAAEwRNgAAgCk29QJGSapuNhXPuOLhxbEDiA8rGwAAwBRhAwAAmCJsAAAAU4QNAABgirABAABMETYAAIApwgYAADBF2AAAAKYIGwAAwBQ7iALwDLd2KwUwuljZAAAApggbAADAFGEDAACYImwAAABTCYeNEydOqLS0VHfccYcefvhh9fT0SJLa29u1fPlyNTQ0qLKyUmfOnHG9WAAA4D0JhY3Lly9r9+7dOnz4sDo7O/Xdd9/p1VdfVSQS0fz587Vo0SI9/fTTWrZsmRYvXmxVMwAA8JCEwkYoFNLzzz+vjIwM3XbbbSorK9OkSZN06NAhnTp1SmVlZZKkefPm6eTJk2ptbTUpGgAAeEdCYWPKlClKT0+XJF25ckXnz5/XmjVr1NLSosLCQvn9fklSWlqaCgsL1dTUNGw/4XBYvb29MQcAAEhNSW3qdeDAAdXU1Kinp0efffaZuru7lZWVFdMmOztbZ8+eHfb1dXV1qq2tTeatgVEXz0ZSHS8/NgqVAIA3JXU1yiOPPKL3339fDz74oJYuXSq/3x9d1bgqEokoEokM+/rq6mqFQqHo0dnZmUwZAADAA5Ja2bj6Mcnbb7+tnJwc3XXXXdd8FBIKhTR9+vRhXx8IBBQIBJJ5awAA4DE3tc9GRkaG7rzzTs2dO1ft7e1yHEeSNDAwoI6ODpWXl7tSJAAA8K6Ewsa3336rDz/8MBoqPvnkEy1btkxlZWXKy8tTc3OzJOno0aPKz8/X7Nmz3a8YAAB4SkIfo7S3t+upp55SUVGRnnjiCU2ePFmbNm2Sz+fT3r17tXHjRrW1tamlpUV79uyRz+ezqhsAAHhEQmGjuLhY3d3dwz4XDAb17rvvSpJWrFhx85UBAICUwL1RAACAKcIGAAAwldSlrwBixbPxFwBMVKxsAAAAU4QNAABgirABAABMETYAAIApwgYAADBF2AAAAKYIGwAAwBRhAwAAmCJsAAAAU+wgCmDCcWvH146XH3OlHyDVsbIBAABMETYAAIApwgYAADBF2AAAAKYIGwAAwBRhAwAAmCJsAAAAU4QNAABgik29gBTk1qZVbvUzkY3mBmLxvBcbkWEssLIBAABMETYAAIApwgYAADBF2AAAAKYIGwAAwBRhAwAAmCJsAAAAU4QNAABgik29ACBJbHoGxIeVDQAAYIqwAQAATBE2AACAKZOw0dXVZdEtAADwoITCxr59+1RUVKSsrCw9/vjjunjxoiTJcRwFg0H5fD75fD4tXbrUpFgAAOA9cV+Ncvr0ae3fv18ffPCBvvzyS/3xj3/Us88+q7feeksHDx7UypUrVVJSIkkqKCgwKxgAAHhL3GHj2LFj2rx5s9LT03Xvvffq5MmT2rVrlyRpy5Yt+u1vf6vc3FzNnDnTrFgAAOA9cX+MUllZqfT09OjjKVOmaObMmerr61M4HFZNTY0KCgpUVVUlx3Gu21c4HFZvb2/MAQAAUlPSm3odP35czzzzjDIzM3XkyBENDAxo69atWrNmjWbNmqVVq1aN+Nq6ujrV1tYm+9ZAymKTKACpKKmrUc6dO6fBwUEtXLgwes7v96uqqkrr16/X9u3br/v66upqhUKh6NHZ2ZlMGQAAwAMSDhtDQ0Oqr6/X5s2bh31+wYIFCoVC1+0jEAgoKysr5gAAAKkp4bDx2muvae3atZo8ebIk6cqVKzHPDw0NqaioyJ3qAACA5yX0nY36+noFg0FdunRJly5d0unTp9Xa2qqf/OQnWrJkiXw+nxoaGrRu3TqregEAgMfEHTZ27typtWvXxlxpkpGRoTfeeEOrV6/Wjh07VFJSosrKSpWWlpoUCwAAvCfusLFo0SItWrRo2Od+97vfuVYQAABILdyIDQAAmCJsAAAAU4QNAABgKukdRIHxjt04vYX5uj7+fOBlrGwAAABThA0AAGCKsAEAAEwRNgAAgCnCBgAAMEXYAAAApggbAADAFGEDAACYImwAAABThA0AAGCKsAEAAEwRNgAAgCnCBgAAMEXYAAAApggbAADAFGEDAACYImwAAABThA0AAGCKsAEAAEwRNgAAgCnCBgAAMEXYAAAApggbAADAFGEDAACYImwAAABThA0AAGCKsAEAAEwRNgAAgCnCBgAAMEXYAAAApggbAADAFGEDAACYSihs7Nu3T0VFRcrKytLjjz+uixcvSpLa29u1fPlyNTQ0qLKyUmfOnDEpFgAAeE/cYeP06dPav3+/PvjgA23btk0ff/yxnn32WUUiEc2fP1+LFi3S008/rWXLlmnx4sWWNQMAAA+5Jd6Gx44d0+bNm5Wenq57771XJ0+e1K5du3To0CGdOnVKZWVlkqR58+Zp4cKFam1t1QMPPGBWOAAA8Ia4VzYqKyuVnp4efTxlyhTNnDlTLS0tKiwslN/vlySlpaWpsLBQTU1NI/YVDofV29sbcwAAgNQU98rGDx0/flzPPPOMDh48qKysrJjnsrOzdfbs2RFfW1dXp9ra2mTfOiH56/ffsE3Hy4+NQiUAAExMSV2Ncu7cOQ0ODmrhwoXy+/3RVY2rIpGIIpHIiK+vrq5WKBSKHp2dncmUAQAAPCDhlY2hoSHV19dr8+bNkqRp06bp2LFjMW1CoZCmT58+Yh+BQECBQCDRtwYAAB6U8MrGa6+9prVr12ry5MmSpNLSUrW3t8txHEnSwMCAOjo6VF5e7m6lAADAkxJa2aivr1cwGNSlS5d06dIlnT59WoODg8rLy1Nzc7MeeughHT16VPn5+Zo9e7ZVzQAAwEPiDhs7d+7U2rVroysYkpSRkaHz589r79692rhxo9ra2tTS0qI9e/bI5/OZFAwAALwl7rCxaNEiLVq0aNjnMjMz9e6770qSVqxY4U5lAAAgJXBvFAAAYIqwAQAATBE2AACAKcIGAAAwRdgAAACmCBsAAMAUYQMAAJgibAAAAFOEDQAAYIqwAQAATBE2AACAqYTu+gpYy1+/P652HS8/ZlwJMHHF8/eQv4M3byL9ObOyAQAATBE2AACAKcIGAAAwRdgAAACmCBsAAMAUYQMAAJgibAAAAFOEDQAAYIpNvTSxNlYBMLHFu3Ee4CZWNgAAgCnCBgAAMEXYAAAApggbAADAFGEDAACYImwAAABThA0AAGCKsAEAAEwRNgAAgCl2EIUnsQsiAHgHKxsAAMAUYQMAAJgibAAAAFNJh43+/n6FQqERn+/q6kq2awAAkEISDhuRSETbtm1TMBjUiRMnoucdx1EwGJTP55PP59PSpUtdLRQAAHhTwlej9PT0qLy8XF9//XXM+YMHD2rlypUqKSmRJBUUFLhTIQAA8LSEVzZyc3N19913X3N+y5YtSktLU25uroqLi5WTk+NKgQAAwNtc+YJoX1+fwuGwampqVFBQoKqqKjmOM2L7cDis3t7emAMAAKQmVzb1yszM1JEjRzQwMKCtW7dqzZo1mjVrllatWjVs+7q6OtXW1rrx1rhJ8WyO1fHyY670AwDjnVv/JiKWq5e++v1+VVVVaf369dq+ffuI7aqrqxUKhaJHZ2enm2UAAIBxxGSfjQULFlz3sthAIKCsrKyYAwAApCaTsDE0NKSioiKLrgEAgMckFTYikUjM4+bmZjU2Nka/FNrQ0KB169bdfHUAAMDzEv6C6IULF/TWW29JkhobGzV16lR1dnZq9erV2rFjh0pKSlRZWanS0lLXiwUAAN6TcNi46667tGHDBm3YsCF67p577lFFRYWrhQEAgNTAjdgAAIApwgYAADDlyqZeAADg/7HRYSxWNgAAgCnCBgAAMEXYAAAApggbAADAFGEDAACYImwAAABThA0AAGCKsAEAAEyxqVec4tmgpePlx0ahEgDARJEqv3tY2QAAAKYIGwAAwBRhAwAAmCJsAAAAU4QNAABgirABAABMETYAAIApwgYAADBF2AAAAKbYQXQcSpUd4wDgRvj37uZ54c+QlQ0AAGCKsAEAAEwRNgAAgCnCBgAAMEXYAAAApggbAADAFGEDAACYImwAAABTbOqVwuLZ6AUAUoEXNraayFjZAAAApggbAADAFGEDAACYSjps9Pf3KxQKuVkLAABIQQmHjUgkom3btikYDOrEiRPR8+3t7Vq+fLkaGhpUWVmpM2fOuFooAADwpoSvRunp6VF5ebm+/vrr6LlIJKL58+fr9ddf17x581RQUKDFixerpaXF1WIBAID3JLyykZubq7vvvjvm3KFDh3Tq1CmVlZVJkubNm6eTJ0+qtbXVnSoBAIBnufIF0ZaWFhUWFsrv90uS0tLSVFhYqKampmHbh8Nh9fb2xhwAACA1ubKpV3d3t7KysmLOZWdn6+zZs8O2r6urU21trRtv7Tle3GjLizUDsMUmWkiEKysbfr8/uqpxVSQSUSQSGbZ9dXW1QqFQ9Ojs7HSjDAAAMA65srIxbdo0HTt2LOZcKBTS9OnTh20fCAQUCATceGsAADDOubKyMXfuXLW3t8txHEnSwMCAOjo6VF5e7kb3AADAw5IKGz/8eGTOnDnKy8tTc3OzJOno0aPKz8/X7Nmzb75CAADgaQl/jHLhwgW99dZbkqTGxkZNnTpV99xzj/bu3auNGzeqra1NLS0t2rNnj3w+n+sFAwAAb0k4bNx1113asGGDNmzYEHM+GAzq3XfflSStWLHCneoAAIDncSM2AABgirABAABMETYAAIApV/bZwOhjV08A4x3/TuEqVjYAAIApwgYAADBF2AAAAKYIGwAAwBRhAwAAmCJsAAAAU4QNAABgirABAABMsamXi9jABgDGL/6NHjusbAAAAFOEDQAAYIqwAQAATBE2AACAKcIGAAAwRdgAAACmCBsAAMAUYQMAAJhiUy8AwLjGZlzex8oGAAAwRdgAAACmCBsAAMAUYQMAAJgibAAAAFOEDQAAYIqwAQAATBE2AACAKcIGAAAwRdgAAACmCBsAAMAUYQMAAJgyCxtdXV1WXQMAAA9xLWw4jqNgMCifzyefz6elS5e61TUAAPAw124xf/DgQa1cuVIlJSWSpIKCAre6BgAAHubaysaWLVuUlpam3NxcFRcXKycnx62uAQCAh7kSNvr6+hQOh1VTU6OCggJVVVXJcRw3ugYAAB7nyscomZmZOnLkiAYGBrR161atWbNGs2bN0qpVq4ZtHw6HFQ6Ho497e3vdKAMAAIxDrn1nQ5L8fr+qqqrU3d2t7du3jxg26urqVFtb6+ZbAwAwKvLX7x/rEjzH5NLXBQsWKBQKjfh8dXW1QqFQ9Ojs7LQoAwAAjAOurmxcNTQ0pKKiohGfDwQCCgQCFm8NAADGGVdWNpqbm9XY2Bj9UmhDQ4PWrVvnRtcAAMDjXFnZ6Ozs1OrVq7Vjxw6VlJSosrJSpaWlbnQNAAA8zpWwUVFRoYqKCje6AgAAKYYbsQEAAFOEDQAAYIqwAQAATBE2AACAKcIGAAAwRdgAAACmCBsAAMAUYQMAAJgibAAAAFOEDQAAYIqwAQAATBE2AACAKcIGAAAwRdgAAACmCBsAAMAUYQMAAJgibAAAAFOEDQAAYIqwAQAATBE2AACAKcIGAAAwRdgAAACmCBsAAMAUYQMAAJgibAAAAFOEDQAAYIqwAQAATBE2AACAKcIGAAAwRdgAAACmCBsAAMAUYQMAAJgibAAAAFOEDQAAYIqwAQAATBE2AACAKdfCRnt7u5YvX66GhgZVVlbqzJkzbnUNAAA87BY3OolEIpo/f75ef/11zZs3TwUFBVq8eLFaWlrc6B4AAHiYKysbhw4d0qlTp1RWViZJmjdvnk6ePKnW1lY3ugcAAB7myspGS0uLCgsL5ff7JUlpaWkqLCxUU1OTHnjggWvah8NhhcPh6ONQKCRJ6u3tdaOcGJHwd673CQCAl1j8fv1+v47jXLedK2Gju7tbWVlZMeeys7N19uzZYdvX1dWptrb2mvMzZsxwoxwAAPA92fW2/ff19Sk7O3vE510JG36/P7qqcVUkElEkEhm2fXV1tdauXRvT9uLFi8rJyZHP50v4/Xt7ezVjxgx1dnZeE3pSyUQZpzRxxjpRxilNnLEyztQzUcaazDgdx1FfX5/y8vKu286VsDFt2jQdO3Ys5lwoFNL06dOHbR8IBBQIBGLO3X777TddR1ZWVkr/IFw1UcYpTZyxTpRxShNnrIwz9UyUsSY6zuutaFzlyhdE586dq/b29uhnNgMDA+ro6FB5ebkb3QMAAA9zJWzMmTNHeXl5am5uliQdPXpU+fn5mj17thvdAwAAD3PlY5RJkyZp79692rhxo9ra2tTS0qI9e/Yk9f2LZAQCAb3wwgvXfDSTaibKOKWJM9aJMk5p4oyVcaaeiTJWy3H6nBtdrwIAAHATuDcKAAAwRdgAAACmCBsAAMAUYcNjzp8/n1D7rq4uo0psJTrOicSrc4r/bYDk1fnr7++P3loi1SUyVi/P6WjybNjo7e3VkiVLdPvtt6uwsFDvvffeiG3b29u1fPlyNTQ0qLKyUmfOnBnFSt3R3t6uiooKVVRUXLed4zgKBoPy+Xzy+XxaunTpKFXojnjHmQpzWl9fr5deekl//vOf9eqrr47YzktzGu+8xDv28SzesTY2NkbnbtKkSTp16tQoV3pzIpGItm3bpmAwqBMnTozYLhXmNN6xen1O9+3bp6KiImVlZenxxx/XxYsXh23n6pw6HvXss886H374ofPpp586S5Yscfx+v3P69Olr2g0NDTk/+9nPnCNHjjiO4ziHDx92SkpKRrvcm/bVV185K1ascH71q19dt93+/fudzZs3O62trU5ra6vT09MzOgW6JJ5xpsKc7ty50ykrK4s+fvDBB52DBw8O29YrcxrvvCQy9vEqkZ/BJ598Mjp3bW1to1mmK7q7u52Ojg5HktPU1DRsm1SYU8eJb6yO4+05/eqrr5ynn37a+de//uXs3r3bueOOO5ynnnrqmnZuz6knw8aVK1ecd955J/q4v7/fCQQCzs6dO69pe+DAAefWW291rly54jiO4wwODjoZGRnOP/7xj1Gr1y0vvPDCDcPGr3/9a+eNN95wzpw5MzpFGbjROFNhTmfPnu28+OKL0ccvvfSS8+ijjw7b1itzGu+8JDL28Sresf797393Hn74Yefw4cPOwMDAWJTqmuv9Ak6FOf2+643V63P6zjvvOOFwOPr4hRdecH76059e087tOfXkxyh+v1+VlZXRx7feequys7M1c+bMa9q2tLSosLAweqO4tLQ0FRYWqqmpadTqHS19fX0Kh8OqqalRQUGBqqqqbnjbXy/y+pxeuXJFx48fV1FRUfRcMBgctn4vzWk885LI2MezeH8G//nPf+qbb77RI488olmzZun48eNjUa6pVJnTeHl9TisrK5Wenh59PGXKlGt+d1rMqSfDxg99/fXXmj59un7xi19c81x3d/c1N5TJzs7W2bNnR6u8UZOZmakjR47o/Pnzqq+v15tvvqm//vWvY12W67w+p99++60GBwdjxpCdna3+/n5dunQppq2X5jSeeUlk7ONZvD+Df/rTn/Tpp5/q888/19SpU/Wb3/xG/f39o1mquVSZ03il2pweP35czzzzTMw5izlNibDx5ptvqqGhYdjn/H5/9H8fV0UiEUUikdEobUz4/X5VVVVp/fr12r59+1iX4zqvz+nV2r8/hqu1jzQGL8xpPPOSzNjHo0R/BouKirRv3z5dvnxZH3/88ShUOHpSZU4TlQpzeu7cOQ0ODmrhwoUx5y3mdFyGjXPnzmnq1KkjHqtWrYq2bWpq0n333afi4uJh+5o2bZp6e3tjzoVCIU2fPt10DPFKZKyJWrBgwbi5VM3NcXp9Tl988UWlp6fHjCEUCunWW29VTk7OdfseT3P6Q/HMS05OTtJjH0+S+RnMzc3VnDlzxu38JStV5jQZXp7ToaEh1dfXa/Pmzdc8ZzGnrtyIzW3Tpk2La5+FL774QqdPn9aTTz4pSRocHFRaWlrMDeDmzp2rV155RY7jyOfzaWBgQB0dHSovLzerPxHxjjUZQ0NDMZ+5jSU3x5kKc/rvf/9bX331VfTxf/7zH82dO/eGfY+nOf2heObF5/PpoYceSmrs40myP4Pjef6SlSpzmiyvzulrr72mtWvXavLkyZL+9z2Nq9/lsJjTcbmyEY/z58/rjTfe0C9/+Ut98cUXOnnypOrq6iRJbW1tqq2tlSTNmTNHeXl5am5uliQdPXpU+fn5mj179pjVnqzhlq++P9bm5mY1NjZGv0DY0NCgdevWjWqNbrjROFNhTn//+99r//790ccfffSR/vCHP0jy7pxeb15qa2vV1tYm6fpj94p4xjo4OKhXXnlF7e3tkqTPP/9cgUBAP//5z8ey9KQM93cy1eb0quuNNVXmtL6+XsFgUJcuXdIXX3yhAwcO6KOPPrKd06SvYxlD/f39zv333+9Iijmee+45x3EcZ9euXU5xcXG0/ZdffuksW7bM2bJli7NkyRLnyy+/HKvSk/bJJ5849913n5OTk+Ps3r07esnd98f6t7/9zcnJyXEeffRR5y9/+ct1rxMfr+IZp+Okxpxu2rTJee6555wNGzY4GzdujJ738pyONC/333+/s3v37mi7kcbuJTca6+XLl53S0lLnzjvvdGpqapxXXnnF+e9//zvGVSfum2++cTZt2uRIcp588knn888/dxwnNef0RmNNhTl97733HJ/PF/O7MyMjw+nt7TWdU24xDwAATHn2YxQAAOANhA0AAGCKsAEAAEwRNgAAgCnCBgAAMEXYAAAApggbAADAFGEDAACYImwAAABThA0AAGCKsAEAAEz9H3OFpYhcRSUaAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(output.detach().cpu().numpy()[:,0], bins = 50)" + ] + }, + { + "cell_type": "markdown", + "id": "5319a1e6-2724-45d7-ae33-e2c6798948a7", + "metadata": {}, + "source": [ + "## same with simple normalizing flow" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "5b5b786d-e36f-44c4-87cf-803cbad78805", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.distributions as D\n", + "import matplotlib.pyplot as plt\n", + "\n", + "class NormalizingFlow(nn.Module):\n", + " def __init__(self, input_dim, num_flows):\n", + " super(NormalizingFlow, self).__init__()\n", + " self.flows = nn.ModuleList([PlanarFlow(input_dim) for _ in range(num_flows)])\n", + "\n", + " def forward(self, z):\n", + " det_jacobian = torch.ones(z.size(0), 1)\n", + " for flow in self.flows:\n", + " z, jacobian = flow(z)\n", + " det_jacobian *= jacobian\n", + " return z, det_jacobian\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9deae1f2-d9bd-43f2-b79c-b119eabb57f2", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "class PlanarFlow(nn.Module):\n", + " def __init__(self, input_dim):\n", + " super(PlanarFlow, self).__init__()\n", + " self.input_dim = input_dim\n", + " self.weight = nn.Parameter(torch.randn(1, input_dim))\n", + " self.bias = nn.Parameter(torch.randn(1))\n", + " self.scale = nn.Parameter(torch.randn(1))\n", + "\n", + " def forward(self, z):\n", + " # Transformation function\n", + " z_flow = z + self.scale * torch.tanh(torch.mm(z, self.weight.t()) + self.bias)\n", + " # Absolute determinant of the Jacobian\n", + " psi = (1 - torch.tanh(torch.mm(z, self.weight.t()) + self.bias) ** 2) * self.weight\n", + " det_jacobian = torch.abs(1 + torch.mm(psi, self.weight.t()))\n", + " return z_flow, det_jacobian" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "318a11de-b859-41c8-8a5b-675d44c0795e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Define dimensions\n", + "input_dim = 1\n", + "\n", + "z_dim = input_dim\n", + "nepochs=100\n", + "bs=1000\n", + "num_flows=20\n", + "# Instantiate the PlanarFlow transformation\n", + "flow = NormalizingFlow(input_dim, num_flows)\n", + "base_distribution = D.Normal(torch.zeros(z_dim), torch.ones(z_dim))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "id": "f8d0d1d9-8693-4348-a089-87d3f84379e6", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 0 0.9581038951873779\n", + "epoch: 1 0.9572756886482239\n", + "epoch: 2 0.9593279361724854\n", + "epoch: 3 0.9568343758583069\n", + "epoch: 4 0.9572469592094421\n", + "epoch: 5 0.9553364515304565\n", + "epoch: 6 0.9575842022895813\n", + "epoch: 7 0.9560788869857788\n", + "epoch: 8 0.9587973356246948\n", + "epoch: 9 0.9580194354057312\n", + "epoch: 10 0.9589003920555115\n", + "epoch: 11 0.9569817781448364\n", + "epoch: 12 0.955956220626831\n", + "epoch: 13 0.958595871925354\n", + "epoch: 14 0.9574549198150635\n", + "epoch: 15 0.958065390586853\n", + "epoch: 16 0.9580419063568115\n", + "epoch: 17 0.9566380381584167\n", + "epoch: 18 0.957197904586792\n", + "epoch: 19 0.9565077424049377\n", + "epoch: 20 0.9570044875144958\n", + "epoch: 21 0.9565112590789795\n", + "epoch: 22 0.9574409127235413\n", + "epoch: 23 0.957895040512085\n", + "epoch: 24 0.9586047530174255\n", + "epoch: 25 0.9566912055015564\n", + "epoch: 26 0.958039402961731\n", + "epoch: 27 0.9593290686607361\n", + "epoch: 28 0.9575077891349792\n", + "epoch: 29 0.9567246437072754\n", + "epoch: 30 0.9582246541976929\n", + "epoch: 31 0.9556811451911926\n", + "epoch: 32 0.9571450352668762\n", + "epoch: 33 0.9564812183380127\n", + "epoch: 34 0.9580637216567993\n", + "epoch: 35 0.9575179815292358\n", + "epoch: 36 0.9590878486633301\n", + "epoch: 37 0.957419753074646\n", + "epoch: 38 0.9576168060302734\n", + "epoch: 39 0.9573911428451538\n", + "epoch: 40 0.9573774337768555\n", + "epoch: 41 0.9567323327064514\n", + "epoch: 42 0.959529459476471\n", + "epoch: 43 0.9581653475761414\n", + "epoch: 44 0.9583896398544312\n", + "epoch: 45 0.956044614315033\n", + "epoch: 46 0.9577047824859619\n", + "epoch: 47 0.958259642124176\n", + "epoch: 48 0.9573068022727966\n", + "epoch: 49 0.958161473274231\n", + "epoch: 50 0.957586407661438\n", + "epoch: 51 0.960275411605835\n", + "epoch: 52 0.9580078721046448\n", + "epoch: 53 0.9568662047386169\n", + "epoch: 54 0.9577730298042297\n", + "epoch: 55 0.9574154615402222\n", + "epoch: 56 0.9576438069343567\n", + "epoch: 57 0.9593215584754944\n", + "epoch: 58 0.9583775997161865\n", + "epoch: 59 0.9577406048774719\n", + "epoch: 60 0.9570046663284302\n", + "epoch: 61 0.9578263163566589\n", + "epoch: 62 0.9566665887832642\n", + "epoch: 63 0.9562719464302063\n", + "epoch: 64 0.9590163826942444\n", + "epoch: 65 0.9577448964118958\n", + "epoch: 66 0.9592074155807495\n", + "epoch: 67 0.9566142559051514\n", + "epoch: 68 0.9577342867851257\n", + "epoch: 69 0.9586719870567322\n", + "epoch: 70 0.9577504992485046\n", + "epoch: 71 0.9564600586891174\n", + "epoch: 72 0.9592216610908508\n", + "epoch: 73 0.9561935067176819\n", + "epoch: 74 0.9580640196800232\n", + "epoch: 75 0.9567023515701294\n", + "epoch: 76 0.9570690393447876\n", + "epoch: 77 0.9568561315536499\n", + "epoch: 78 0.9567307233810425\n", + "epoch: 79 0.9573178887367249\n", + "epoch: 80 0.9571306705474854\n", + "epoch: 81 0.9589489698410034\n", + "epoch: 82 0.957552433013916\n", + "epoch: 83 0.9554185271263123\n", + "epoch: 84 0.9571435451507568\n", + "epoch: 85 0.9579405784606934\n", + "epoch: 86 0.9575528502464294\n", + "epoch: 87 0.9572521448135376\n", + "epoch: 88 0.955950140953064\n", + "epoch: 89 0.9554963111877441\n", + "epoch: 90 0.9585280418395996\n", + "epoch: 91 0.9571262001991272\n", + "epoch: 92 0.9583413004875183\n", + "epoch: 93 0.9568629264831543\n", + "epoch: 94 0.9590401649475098\n", + "epoch: 95 0.9552928805351257\n", + "epoch: 96 0.9579176306724548\n", + "epoch: 97 0.9571013450622559\n", + "epoch: 98 0.9567534327507019\n", + "epoch: 99 0.9590332508087158\n" + ] + } + ], + "source": [ + "optimizer = optim.Adam(flow.parameters(), lr=1e-4, weight_decay=1e-4)\n", + "\n", + "for e in range(nepochs):\n", + " \n", + " \n", + " optimizer.zero_grad()\n", + " \n", + " random_input = torch.rand(bs, input_dim)\n", + " randomN_input = torch.randn(bs, input_dim)\n", + "\n", + " output, det_jacb = flow(random_input)\n", + " \n", + " loss = -base_distribution.log_prob(output) + torch.log(det_jacb) #+mmd(output,randomN_input)\n", + " loss = loss.mean()\n", + " loss.backward()\n", + " optimizer.step() \n", + " \n", + " print('epoch:',e, loss.item())\n", + " \n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "id": "3c8313d8-80f3-4f80-af2e-adee65f4b775", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([117., 90., 96., 93., 113., 87., 95., 104., 102., 103.]),\n", + " array([0.00128156, 0.10107996, 0.20087835, 0.30067676, 0.40047514,\n", + " 0.5002736 , 0.60007197, 0.69987035, 0.7996687 , 0.8994672 ,\n", + " 0.99926555], dtype=float32),\n", + " )" + ] + }, + "execution_count": 78, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAGcCAYAAADknMuyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAcqElEQVR4nO3df2zc9WH/8ZcTHEeo8SFYIT9qMFYXaDta1pVsDIJiS2urgqJMVbtsLWFdtZKKTauybpBBN2izpt0mmixqp4VpjQaVWNmIwoTaoFXpkkj3R1BQHaSW1Y2dhQAWNNmZQeI4ufv+0W8sXBLqtGe/febxkE70Pvfx5/PmrUP37PtzP9oajUYjAAAFzCk9AADgzUuIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIq5oPQAfpZ6vZ7nnnsuCxYsSFtbW+nhAACT0Gg08vLLL2fx4sWZM+fc6x4zPkSee+65dHV1lR4GAPBzOHz4cN72tred8/EZHyILFixI8pN/kc7OzsKjAQAmY2RkJF1dXeOv4+cy40PkzOWYzs5OIQIALeZnva3Cm1UBgGJ+7hA5fvx4arVaM8cCALzJnHeI1Ov1bNu2LUuXLs1TTz01vv2xxx7LVVddlc7Oznz4wx/O0aNHxx8bHBzM2rVrs3Xr1qxZsyaHDh1qzugBgJZ23iHy0ksvpbe3N88+++z4toMHD+bxxx/Po48+mm3btuW73/1u7rzzziQ/CZeVK1fmox/9aD71qU/l1ltvzerVq5v3bwAAtKzzfrPqpZde+rpte/fuzZYtWzJv3ry8613vSn9/fx555JEkyc6dOzMwMJDly5cnSfr6+rJq1ars27cv11133S84fACglTXlzapr1qzJvHnzxu9fdtllufzyy5Mk1Wo1PT09aW9vT5LMnTs3PT092bVr11mPNTo6mpGRkQk3AGB2mpJPzezfvz+33357kmR4ePh1H7utVCo5cuTIWf9248aNqVQq4zdfZgYAs1fTQ+T555/PqVOnsmrVqiRJe3v7+GrIGfV6PfV6/ax/v379+tRqtfHb4cOHmz1EAGCGaOoXmp0+fTqbNm3Kli1bxrctWrQoe/funbBfrVbLkiVLznqMjo6OdHR0NHNYAMAM1dQVka985StZt25d3vKWtyRJTp48mRUrVmRwcDCNRiNJMjY2lqGhofT29jbz1ABAC/q5VkTOdlll06ZNWbp0aY4dO5Zjx47l4MGDOXXqVG655ZYsXrw4e/bsyU033ZTdu3enu7s7y5Yt+4UHDwC0tvMOkRdffDEPPPBAkuShhx7KwoUL09/fn3Xr1o2veiTJhRdemBdeeCFz5szJjh07smHDhhw4cCDVajXbt2//md89DwDMfm2N19bDDDQyMpJKpZJareZH7wCgRUz29duP3gEAxQgRAKCYpn58t9V03/V46SGct6Ev3Vx6CADQNFZEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAinlT/9YMzDZ+PwloNVZEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUMwFpQcAwNTrvuvx0kM4b0Nfurn0EJgGVkQAgGKECABQjEszANAkLoGdPysiAEAxQgQAKMalGQBmpFa8zMH5+7lXRI4fP55ardbMsQAAbzLnHSL1ej3btm3L0qVL89RTT41vHxwczNq1a7N169asWbMmhw4dmtRjAMCb13lfmnnppZfS29ubZ599dnxbvV7PypUrs3nz5vT19eXKK6/M6tWrU61W3/AxAODN7bxXRC699NJcccUVE7bt3LkzAwMDWb58eZKkr68v/f392bdv3xs+BgC8uTXlzarVajU9PT1pb29PksydOzc9PT3ZtWtX/u///u+cj1133XWvO9bo6GhGR0fH74+MjDRjiADADNSUj+8ODw+ns7NzwrZKpZIjR4684WNns3HjxlQqlfFbV1dXM4YIAMxATQmR9vb28RWPM+r1eur1+hs+djbr169PrVYbvx0+fLgZQwQAZqCmhMiiRYtedwmlVqtlyZIlb/jY2XR0dKSzs3PCDQCYnZoSIitWrMjg4GAajUaSZGxsLENDQ+nt7X3DxwCAN7efK0R++rLK9ddfn8WLF2fPnj1Jkt27d6e7uzvLli17w8cAgDe38/7UzIsvvpgHHnggSfLQQw9l4cKFufrqq7Njx45s2LAhBw4cSLVazfbt29PW1pa2trZzPgYAvLm1Nc5cM5mhRkZGUqlUUqvVmv5+kVb8HYPSP9fMzOY5zbm04nOD6TFV/w1O9vXbr+8CAMUIEQCgGCECABQjRACAYoQIAFBMU370Dt5Iq75b36c5AKaeFREAoBghAgAU49IMwHlq1cuNMBNZEQEAihEiAEAxQgQAKEaIAADFCBEAoBifmoFz8MkIgKlnRQQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMb7QrMX4ki0AZhMrIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCK8YVmQFG+pA/e3KyIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAo5oJmHmz37t35z//8z/zSL/1S9u3bl7vvvjtXX311BgcH8+Uvfznvfe97s3fv3nzhC1/IFVdc0cxTAwAtqGkhcvr06XziE5/IM888kwsuuCD/9V//lT/6oz/KE088kZUrV2bz5s3p6+vLlVdemdWrV6darTbr1ABAi2rapZmjR4/mueeey6uvvpokqVQqOXbsWHbu3JmBgYEsX748SdLX15f+/v7s27evWacGAFpU00LkrW99a37t134tH//4x/O///u/+fu///vcd999qVar6enpSXt7e5Jk7ty56enpya5du856nNHR0YyMjEy4AQCzU1PfrPrII49kYGAgixYtygc+8IHccsstGR4eTmdn54T9KpVKjhw5ctZjbNy4MZVKZfzW1dXVzCECADNIU9+sOjw8nA996EM5dOhQbrvttlx00UVpb28fXw05o16vp16vn/UY69evz7p168bvj4yMiBEAmKWaFiKvvvpqPv7xj+fJJ5/M/Pnzc8899+STn/xkPv3pT2fv3r0T9q3ValmyZMlZj9PR0ZGOjo5mDQsAmMGadmnm6aefzoIFCzJ//vwkyb333puXX345N954YwYHB9NoNJIkY2NjGRoaSm9vb7NODQC0qKaFyNvf/vYcOXIkr7zySpLk5MmTWbhwYW688cYsXrw4e/bsSfKT7xrp7u7OsmXLmnVqAKBFNe3SzMUXX5wHHnggf/zHf5xrrrkmzz77bB588MHMnTs3O3bsyIYNG3LgwIFUq9Vs3749bW1tzTo1ANCi2hpnrpnMUCMjI6lUKqnVaq/79M0vqvuux5t6PABoNUNfunlKjjvZ12+/NQMAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKCYC6bioN///vezffv2dHV1ZdWqVVmwYMFUnAYAaHFND5GvfvWrefjhh/Pwww9nyZIlSZLBwcF8+ctfznvf+97s3bs3X/jCF3LFFVc0+9QAQItpaog8+uij+fznP5+nn346b33rW5Mk9Xo9K1euzObNm9PX15crr7wyq1evTrVabeapAYAW1LT3iJw6dSqf+cxn8tnPfnY8QpJk586dGRgYyPLly5MkfX196e/vz759+5p1agCgRTUtRHbv3p3Dhw/nmWeeyapVq/KOd7wjDz/8cKrVanp6etLe3p4kmTt3bnp6erJr166zHmd0dDQjIyMTbgDA7NS0SzP9/f256KKL8jd/8ze5+OKL8+1vfzsrV65Mb29vOjs7J+xbqVRy5MiRsx5n48aNue+++5o1LABgBmvaisjx48fzjne8IxdffHGS5IMf/GAuu+yy7N27d3w15Ix6vZ56vX7W46xfvz61Wm38dvjw4WYNEQCYYZoWIgsXLswrr7wyYdvb3va2/Pmf//nrLq/UarXxT9T8tI6OjnR2dk64AQCzU9NC5IYbbsjQ0FBOnTo1vu3EiRNJfvLx3UajkSQZGxvL0NBQent7m3VqAKBFNS1Eli5dmmuvvTZPPPFEkuTo0aN56aWX8md/9mdZvHhx9uzZk+Qnb2rt7u7OsmXLmnVqAKBFNfV7RB588MF89rOfTX9/fwYHB/PII4/kwgsvzI4dO7Jhw4YcOHAg1Wo127dvT1tbWzNPDQC0oLbGmWsmM9TIyEgqlUpqtVrT3y/SfdfjTT0eALSaoS/dPCXHnezrtx+9AwCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKaXqIvPrqq3nnO9+ZoaGhJMng4GDWrl2brVu3Zs2aNTl06FCzTwkAtKgLmn3ALVu25Pvf/36SpF6vZ+XKldm8eXP6+vpy5ZVXZvXq1alWq80+LQDQgpq6IrJjx4709vaO39+5c2cGBgayfPnyJElfX1/6+/uzb9++Zp4WAGhRTQuR//mf/8nzzz+fZcuWjW+rVqvp6elJe3t7kmTu3Lnp6enJrl27mnVaAKCFNeXSzOnTp/PAAw/kvvvum7B9eHg4nZ2dE7ZVKpUcOXLknMcaHR3N6Ojo+P2RkZFmDBEAmIGasiLy1a9+NbfffnvmzJl4uPb29vHVkDPq9Xrq9fo5j7Vx48ZUKpXxW1dXVzOGCADMQE0JkS1btuTtb3975s+fn/nz5ydJrrrqqvzjP/7j61Y0arValixZcs5jrV+/PrVabfx2+PDhZgwRAJiBmnJp5oc//OGE+21tbXnmmWdy5MiRfOhDH0qj0UhbW1vGxsYyNDQ04Q2tP62joyMdHR3NGBYAMMNN6ReaXX/99Vm8eHH27NmTJNm9e3e6u7snvKEVAHjzavr3iLzWnDlzsmPHjmzYsCEHDhxItVrN9u3b09bWNpWnBQBaxJSESKPRGP/fS5cuzb/8y78kSe64446pOB0A0KL81gwAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKaVqIPPbYY7nqqqvS2dmZD3/4wzl69GiSZHBwMGvXrs3WrVuzZs2aHDp0qFmnBABaXFNC5ODBg3n88cfz6KOPZtu2bfnud7+bO++8M/V6PStXrsxHP/rRfOpTn8qtt96a1atXN+OUAMAscEEzDrJ3795s2bIl8+bNy7ve9a709/fnkUceyc6dOzMwMJDly5cnSfr6+rJq1ars27cv1113XTNODQC0sKasiKxZsybz5s0bv3/ZZZfl8ssvT7VaTU9PT9rb25Mkc+fOTU9PT3bt2nXOY42OjmZkZGTCDQCYnabkzar79+/P7bffnuHh4XR2dk54rFKp5MiRI+f8240bN6ZSqYzfurq6pmKIAMAM0PQQef7553Pq1KmsWrUq7e3t46shZ9Tr9dTr9XP+/fr161Or1cZvhw8fbvYQAYAZoinvETnj9OnT2bRpU7Zs2ZIkWbRoUfbu3Tthn1qtliVLlpzzGB0dHeno6GjmsACAGaqpKyJf+cpXsm7durzlLW9Jktx4440ZHBxMo9FIkoyNjWVoaCi9vb3NPC0A0KKatiKyadOmLF26NMeOHcuxY8dy8ODBnDp1KosXL86ePXty0003Zffu3enu7s6yZcuadVoAoIU1JUS++c1vZt26deMrH0ly4YUX5oUXXsiOHTuyYcOGHDhwINVqNdu3b09bW1szTgsAtLi2xmvrYQYaGRlJpVJJrVZ73SdwflHddz3e1OMBQKsZ+tLNU3Lcyb5++60ZAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABQjRACAYoQIAFCMEAEAihEiAEAxQgQAKEaIAADFCBEAoBghAgAUI0QAgGKECABQjBABAIoRIgBAMUIEAChGiAAAxQgRAKCYaQmRwcHBrF27Nlu3bs2aNWty6NCh6TgtADDDXTDVJ6jX61m5cmU2b96cvr6+XHnllVm9enWq1epUnxoAmOGmfEVk586dGRgYyPLly5MkfX196e/vz759+6b61ADADDflKyLVajU9PT1pb29PksydOzc9PT3ZtWtXrrvuutftPzo6mtHR0fH7tVotSTIyMtL0sdVHX236MQGglUzF6+trj9toNN5wvykPkeHh4XR2dk7YVqlUcuTIkbPuv3Hjxtx3332v297V1TUl4wOAN7PKpqk9/ssvv5xKpXLOx6c8RNrb28dXQ86o1+up1+tn3X/9+vVZt27dhH2PHj2aSy65JG1tbU0b18jISLq6unL48OHXhRLNY56nh3mePuZ6epjn6TNVc91oNPLyyy9n8eLFb7jflIfIokWLsnfv3gnbarValixZctb9Ozo60tHRMWHbRRddNFXDS2dnpyf5NDDP08M8Tx9zPT3M8/SZirl+o5WQM6b8zaorVqzI4ODg+DWisbGxDA0Npbe3d6pPDQDMcFMeItdff30WL16cPXv2JEl2796d7u7uLFu2bKpPDQDMcFN+aWbOnDnZsWNHNmzYkAMHDqRarWb79u1Nfb/Hz6OjoyN/9Vd/9brLQDSXeZ4e5nn6mOvpYZ6nT+m5bmv8rM/VAABMEb81AwAUI0QAgGKECABQzJS/WRVgNhsbG8s3vvGNHD16NDfccEN+/dd/vfSQ4JyOHz+ekydPTur7PabLrF4RGRwczNq1a7N169asWbMmhw4dOut+mzZtyhe/+MXcfffduf/++6d5lLPDZOZ6ZGQkH/vYx3LRRRelp6cn//qv/1pgpK1tss/pM77+9a/n93//96dncLPIZOd5eHg4N9xwQ06cOJF169aJkJ/DZOZ6bGws99xzT772ta9l/fr1+fznP19gpK2tXq9n27ZtWbp0aZ566qlz7lfk9bAxS50+fbrxK7/yK43vfOc7jUaj0XjiiScav/Ebv/G6/b75zW82li9fPn7/N3/zNxvf+ta3pm2cs8Fk5/rOO+9s/Md//Efje9/7XuNjH/tYo729vXHw4MHpHm7Lmuw8n/HMM8803vOe9zRuu+22aRrh7DDZeT558mTjfe97X+Pee++d7iHOGpOd6/vvv7/xt3/7t+P3V6xY0dizZ8+0jXM2GB4ebgwNDTWSNHbt2nXWfUq9Hs7aFZGdO3dmYGAgy5cvT5L09fWlv78/+/btm7Df3/3d3+X973//+P1bbrklW7ZsmdaxtrrJzPXY2Fje+c535pZbbsm73/3u/NM//VPmzJmTJ598stSwW85kn9NJcvLkyTz88MNZtWrVNI+y9U12nv/5n/85P/zhD/MXf/EXJYY5K0x2rgcGBnL06NHx+5VKJceOHZvWsba6Sy+9NFdcccUb7lPq9XDWhki1Wk1PT8/4D+7NnTs3PT092bVr1/g+J0+ezP79+3PVVVeNb1u6dOmEffjZJjPX7e3tWbNmzfj9+fPnp1Kp5PLLL5/28baqyczzGVu2bMkdd9wx3UOcFSY7z9/4xjeyaNGirFu3Ltdee21++7d/e8KLJT/bZOd61apV2bx5c771rW/lySefzKlTp/KBD3ygxJBnrZKvh7M2RIaHh1/34z2VSiVHjhwZv//jH/84p06dmrBfpVLJ8ePH1fZ5mMxc/7Rnn302S5YscU39PEx2np944on86q/+ai655JLpHN6sMdl57u/vz0c+8pFs2bIl+/bty0svvZS77rprOofa8iY717/1W7+Vv/7rv84tt9ySO+64I//+7/+eefPmTedQZ72Sr4ezNkTa29vHK/uMer2eer0+YZ/X/vPMPq/9Jz/bZOb6p/3DP/xDtm7dOtVDm1UmM8/Dw8M5cOBA+vr6pnt4s8Zkn8/Hjx/PjTfeOP43t912Wx5//PFpG+dsMNm5bjQa+fGPf5wvfvGL+dGPfpSVK1dmbGxsOoc665V8PZy1IbJo0aKMjIxM2Far1bJkyZLx+5dccknmzZs3Yb9arZb58+f7f5PnYTJz/Vq7du3Ku9/97rzvfe+bjuHNGpOZ5507d+buu+/O/PnzM3/+/GzYsCEPPvhg5s+fn1qtNt1DbkmTfT4vXLgwr7zyyvj9rq4uK6nnabJzff/996dSqeTOO+/Mk08+maeffjpf//rXp3Oos17J18NZGyIrVqzI4OBgGv//p3TGxsYyNDSU3t7e8X3a2tpy00035Uc/+tH4tv/+7//OihUrpnu4LW0yc33GD37wgxw8eDC/8zu/kyQ5derU+N/xxiYzz2vWrMmJEyfGb/fcc09uvfXWnDhxYkZ9b8BMNtnn8w033JCBgYHx+ydOnEh3d/d0DrXlTXauv/Od7+Saa65JknR3d+dP/uRP8r3vfW/axzublXw9nLUhcv3112fx4sXZs2dPkmT37t3p7u7OsmXLct999+XAgQNJkk984hMTllO//e1v5w/+4A+KjLlVTXauX3jhhXzta1/LDTfckB/84Afp7+/Pxo0bSw69pUx2nvnFTHae//AP/zD/9m//Nv53e/fuzSc/+ckiY25Vk53r97znPdm/f//43x0/fjzXXnttiSG3tLNdYpkJr4ez9ptV58yZkx07dmTDhg05cOBAqtVqtm/fnra2tjz22GO55pprcs011+T3fu/3MjQ0lL/8y7/M6dOn88EPfjAf+chHSg+/pUxmrn/5l385N998c/bv3z/h42Cf+9zn0tbWVnD0rWOyz2l+MZOd597e3vzu7/5uPv3pT6erqyv1ej2f+cxnSg+/pUx2rj/3uc/lT//0T3PvvfdmwYIFOX36tOg7Ty+++GIeeOCBJMlDDz2UhQsX5uqrr54Rr4dtDeviAEAhs/bSDAAw8wkRAKAYIQIAFCNEAIBihAgAUIwQAQCKESIAQDFCBAAoRogAAMUIEQCgGCECABTz/wDZWVIoWUj2pwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(random_input.detach().cpu().numpy()[:,0])" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "f4bff66f-2790-45f3-bc22-5b5539c04403", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([25., 17., 29., 24., 22., 22., 21., 13., 13., 21., 24., 16., 22.,\n", + " 26., 8., 23., 16., 23., 16., 15., 18., 24., 24., 26., 21., 17.,\n", + " 17., 19., 17., 17., 26., 22., 14., 19., 14., 14., 22., 29., 16.,\n", + " 23., 22., 18., 20., 27., 15., 18., 16., 24., 21., 24.]),\n", + " array([-0.4638622 , -0.4453545 , -0.42684677, -0.40833908, -0.38983136,\n", + " -0.37132365, -0.35281593, -0.3343082 , -0.31580052, -0.2972928 ,\n", + " -0.27878508, -0.26027736, -0.24176966, -0.22326194, -0.20475423,\n", + " -0.18624651, -0.1677388 , -0.14923109, -0.13072337, -0.11221566,\n", + " -0.09370795, -0.07520024, -0.05669252, -0.03818481, -0.0196771 ,\n", + " -0.00116938, 0.01733833, 0.03584604, 0.05435375, 0.07286147,\n", + " 0.09136918, 0.10987689, 0.1283846 , 0.14689232, 0.16540003,\n", + " 0.18390775, 0.20241547, 0.22092317, 0.23943089, 0.2579386 ,\n", + " 0.2764463 , 0.29495403, 0.31346175, 0.33196944, 0.35047716,\n", + " 0.36898488, 0.3874926 , 0.40600032, 0.424508 , 0.44301572,\n", + " 0.46152344], dtype=float32),\n", + " )" + ] + }, + "execution_count": 79, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhsAAAGcCAYAAABwemJAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAabElEQVR4nO3df2zU933H8ddBjSsW7Gq0pMBIwO3sqlk7tUoZ6UIF/LFVSpdEy8ai/HD3o0siVaQJUpbQZO3Q2Mh+qKEjaysyrdEUJVm6FpGOUiJFbIDmP5iIBluVbBQ7c/iltkFnpGQO+G5/VFhxgeTs3MfmzOMhnaI7f/31O/5w8NT3zv5U6vV6PQAAhcyY6gEAgOlNbAAARYkNAKAosQEAFCU2AICixAYAUJTYAACKetdUD5AktVotR48ezZw5c1KpVKZ6HACgAfV6PadOncqCBQsyY8aFr19cFLFx9OjRLFq0aKrHAAAmYHBwML/wC79wwY9fFLExZ86cJD8dtqOjY4qnAQAaMTQ0lEWLFo3+O34h446NF154IWvWrMkPfvCDfPzjH8/TTz+d9773venv789f/MVf5OMf/3j27t2bP/3TP82VV17Z0DnPvnTS0dEhNgCgxbzdWyDG9QbR//u//8u3v/3tPPfccxkcHMxrr72Wr3zlK6nVarn++uuzevXq3HHHHbn99ttz8803v6PBAYDpYVyxUa1W86UvfSmzZ8/Oz/3cz2X58uWZMWNGdu7cmUOHDmX58uVJklWrVuXAgQPZt29fkaEBgNYxrti4/PLLM2vWrCTJG2+8kePHj+fee+9NX19furq60tbWliSZOXNmurq6smvXrvOeZ3h4OENDQ2NuAMD0NKHfs/G9730vy5Yty65du/Jf//VfOXHixDnvtejs7MyRI0fO+/kbN25MZ2fn6M1PogDA9DWh2Pj1X//1/NM//VM++clP5rbbbktbW9voVY2zarVaarXaeT9/3bp1qVaro7fBwcGJjAEAtIAJ/ejr2ZdJ/v7v/z5z587N+973vnNeCqlWq1m4cOF5P7+9vT3t7e0T+dIAQIt5R7+ufPbs2Xnve9+bFStWpL+/P/V6PUly+vTpDAwMZOXKlU0ZEgBoXeOKjZ/85Cf57ne/OxoV//qv/5rbb789y5cvz4IFC7Jnz54kye7du7N48eIsXbq0+RMDAC1lXC+j9Pf353Of+1x6enryW7/1W7nsssvyZ3/2Z6lUKtm2bVs2bNiQgwcPpq+vL1u3brXPCQCQSv3sZYopNDQ0lM7OzlSrVb9BFABaRKP/fttiHgAoSmwAAEWJDQCgKLEBABQlNgCAoib0G0QvRYsf2P62xww8fN0kTAJQlr/vaDZXNgCAosQGAFCU2AAAihIbAEBRYgMAKEpsAABFiQ0AoCixAQAUJTYAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICixAYAUJTYAACKEhsAQFFiAwAoSmwAAEWJDQCgqHdN9QAAMBkWP7D9bY8ZePi6SZjk0uPKBgBQlNgAAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICixAYAUJTYAACKEhsAQFH2RqEl2eNgclzK3+dL+f8dms2VDQCgKLEBABQlNgCAosQGAFCU2AAAihIbAEBRRWLj6NGjJU4LALSgccXGs88+m56ennR0dOSmm27Kq6++miSp1+vp7u5OpVJJpVLJbbfdVmRYAKD1NPxLvQ4fPpzt27fnO9/5Tl566aX84R/+Ye6///489thj2bFjR+6+++4sW7YsSbJkyZJiAwMAraXh2Ni7d282b96cWbNm5aqrrsqBAwfyrW99K0ny6KOP5jd+4zcyb968XHHFFcWGBQBaT8Mvo/T29mbWrFmj9y+//PJcccUVOXXqVIaHh/PQQw9lyZIlWbNmTer1+luea3h4OENDQ2NuAMD0NOG9Ufbv358777wzc+bMyfPPP5/Tp0/nG9/4Ru6999588IMfzBe+8IULfu7GjRuzfv36iX7pcbG/AQBMrQn9NMqxY8dy5syZ3HjjjaOPtbW1Zc2aNXnggQfy5JNPvuXnr1u3LtVqdfQ2ODg4kTEAgBYw7tgYGRnJpk2bsnnz5vN+/IYbbki1Wn3Lc7S3t6ejo2PMDQCYnsYdG4888kjWrl2byy67LEnyxhtvjPn4yMhIenp6mjMdANDyxvWejU2bNqW7uzsnT57MyZMnc/jw4ezbty8f+MAHcuutt6ZSqWTLli257777Ss0LALSYhmPjmWeeydq1a8f8pMns2bPzta99Lffcc0+eeuqpLFu2LL29vbn22muLDAsAtJ6GY2P16tVZvXr1eT/22c9+tmkDAQDTi43YAICixAYAUJTYAACKEhsAQFFiAwAoSmwAAEVNeCM2znUpb/p2Kf+/X4waWQ9gajXr781W+PvXlQ0AoCixAQAUJTYAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICi7I3CJe1S2psAJpvnBWe5sgEAFCU2AICixAYAUJTYAACKEhsAQFFiAwAoSmwAAEWJDQCgKLEBABQlNgCAosQGAFCUvVEmWSN7BTRiuu4n0KzvD0yGVtz742J7jrXi95Dxc2UDAChKbAAARYkNAKAosQEAFCU2AICixAYAUJTYAACKEhsAQFFiAwAoSmwAAEWJDQCgKHujtKjJ3E/gYttLgYuLvS2gnOny968rGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICixAYAUNS4YuPZZ59NT09POjo6ctNNN+XVV19NkvT39+euu+7Kli1b0tvbm5dffrnIsABA62k4Ng4fPpzt27fnO9/5Th5//PH8y7/8S+6///7UarVcf/31Wb16de64447cfvvtufnmm0vODAC0kIZ/g+jevXuzefPmzJo1K1dddVUOHDiQb33rW9m5c2cOHTqU5cuXJ0lWrVqVG2+8Mfv27csnPvGJYoMDAK2h4Ssbvb29mTVr1uj9yy+/PFdccUX6+vrS1dWVtra2JMnMmTPT1dWVXbt2XfBcw8PDGRoaGnMDAKanCe+Nsn///tx5553ZsWNHOjo6xnyss7MzR44cueDnbty4MevXr5/olwaAMabLHiLT1YR+GuXYsWM5c+ZMbrzxxrS1tY1e1TirVqulVqtd8PPXrVuXarU6ehscHJzIGABACxj3lY2RkZFs2rQpmzdvTpLMnz8/e/fuHXNMtVrNwoULL3iO9vb2tLe3j/dLAwAtaNxXNh555JGsXbs2l112WZLk2muvTX9/f+r1epLk9OnTGRgYyMqVK5s7KQDQksZ1ZWPTpk3p7u7OyZMnc/LkyRw+fDhnzpzJggULsmfPnnzqU5/K7t27s3jx4ixdurTUzABAC2k4Np555pmsXbt29ApGksyePTvHjx/Ptm3bsmHDhhw8eDB9fX3ZunVrKpVKkYEBgNbScGysXr06q1evPu/H5syZk3/4h39Iknz+859vzmQAwLRgbxQAoCixAQAUJTYAgKLEBgBQlNgAAIqa8N4oAJOtkf0vBh6+bhImaa7p+v8FZ7myAQAUJTYAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICixAYAUJS9UaDFNLKPBkwnF9ufeXvZjJ8rGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICixAYAUJTYAACKEhsAQFFiAwAoSmwAAEWJDQCgKBux5eLb5Aemm8l8jnk+w8XHlQ0AoCixAQAUJTYAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICi7I3CpGnVPSuaNXcj5xl4+LqmfC0orVWfz5PF92csVzYAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARU04Nl5//fVUq9ULfvzo0aMTPTUAMI2MOzZqtVoef/zxdHd354UXXhh9vF6vp7u7O5VKJZVKJbfddltTBwUAWtO4f6nXj3/846xcuTKvvPLKmMd37NiRu+++O8uWLUuSLFmypDkTAgAtbdxXNubNm5crr7zynMcfffTRzJw5M/PmzcvVV1+duXPnNmVAAKC1NeUNoqdOncrw8HAeeuihLFmyJGvWrEm9Xr/g8cPDwxkaGhpzAwCmp6bsjTJnzpw8//zzOX36dL7xjW/k3nvvzQc/+MF84QtfOO/xGzduzPr165vxpWFasZ8CTC3PwTKa+qOvbW1tWbNmTR544IE8+eSTFzxu3bp1qVaro7fBwcFmjgEAXESK/J6NG2644S1/LLa9vT0dHR1jbgDA9FQkNkZGRtLT01Pi1ABAi5lQbNRqtTH39+zZkyeeeGL0TaFbtmzJfffd986nAwBa3rjfIPqjH/0ojz32WJLkiSeeyPvf//4MDg7mnnvuyVNPPZVly5alt7c31157bdOHBQBaz7hj433ve1+++MUv5otf/OLoYx/60Idyyy23NHUwAGB6sBEbAFCU2AAAihIbAEBRYgMAKEpsAABFNWVvFADOr1l7bdizg1bmygYAUJTYAACKEhsAQFFiAwAoSmwAAEWJDQCgKLEBABQlNgCAosQGAFCU2AAAihIbAEBR9kaZxuylAMDFwJUNAKAosQEAFCU2AICixAYAUJTYAACKEhsAQFFiAwAoSmwAAEWJDQCgKLEBABQlNgCAosQGAFCU2AAAihIbAEBRYgMAKEpsAABFiQ0AoCixAQAUJTYAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICixAYAUJTYAACKEhsAQFETjo3XX3891Wq1mbMAANPQuGOjVqvl8ccfT3d3d1544YXRx/v7+3PXXXdly5Yt6e3tzcsvv9zUQQGA1vSu8X7Cj3/846xcuTKvvPLK6GO1Wi3XX399vvrVr2bVqlVZsmRJbr755vT19TV1WACg9Yz7ysa8efNy5ZVXjnls586dOXToUJYvX54kWbVqVQ4cOJB9+/Y1Z0oAoGU15Q2ifX196erqSltbW5Jk5syZ6erqyq5du5pxegCghY37ZZTzOXHiRDo6OsY81tnZmSNHjpz3+OHh4QwPD4/eHxoaasYYAMBFqClXNtra2kavapxVq9VSq9XOe/zGjRvT2dk5elu0aFEzxgAALkJNiY358+efc3WiWq1m4cKF5z1+3bp1qVaro7fBwcFmjAEAXISaEhsrVqxIf39/6vV6kuT06dMZGBjIypUrz3t8e3t7Ojo6xtwAgOlpQrHxsy+PXHPNNVmwYEH27NmTJNm9e3cWL16cpUuXvvMJAYCWNu43iP7oRz/KY489liR54okn8v73vz8f+tCHsm3btmzYsCEHDx5MX19ftm7dmkql0vSBAYDWUqmffe1jCg0NDaWzszPVarXpL6ksfmB7U89H6xh4+Lq3PcafD+BS0MjfhxPR6L/fNmIDAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICixAYAUJTYAACKEhsAQFFiAwAoSmwAAEWJDQCgKLEBABQlNgCAosQGAFCU2AAAihIbAEBR75rqAaCUxQ9sn+oRAIgrGwBAYWIDAChKbAAARYkNAKAosQEAFCU2AICixAYAUJTYAACKEhsAQFFiAwAoSmwAAEWJDQCgKLEBABQlNgCAosQGAFCU2AAAihIbAEBRYgMAKEpsAABFiQ0AoCixAQAUJTYAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICiisXG0aNHS50aAGghTYuNer2e7u7uVCqVVCqV3Hbbbc06NQDQwt7VrBPt2LEjd999d5YtW5YkWbJkSbNODQC0sKZd2Xj00Uczc+bMzJs3L1dffXXmzp3brFMDAC2sKbFx6tSpDA8P56GHHsqSJUuyZs2a1Ov1Cx4/PDycoaGhMTcAYHpqSmzMmTMnzz//fI4fP55Nmzbl61//ev7mb/7mgsdv3LgxnZ2do7dFixY1YwwA4CLU1J9GaWtry5o1a/LAAw/kySefvOBx69atS7VaHb0NDg42cwwA4CJS5Edfb7jhhlSr1Qt+vL29PR0dHWNuAMD0VCQ2RkZG0tPTU+LUAECLaUps7NmzJ0888cTom0K3bNmS++67rxmnBgBaXFN+z8bg4GDuueeePPXUU1m2bFl6e3tz7bXXNuPUAECLa0ps3HLLLbnllluacSoAYJqxERsAUJTYAACKEhsAQFFiAwAoSmwAAEWJDQCgKLEBABQlNgCAosQGAFCU2AAAihIbAEBRYgMAKEpsAABFiQ0AoCixAQAUJTYAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICixAYAUJTYAACKEhsAQFFiAwAoSmwAAEWJDQCgKLEBABQlNgCAosQGAFCU2AAAihIbAEBRYgMAKEpsAABFiQ0AoCixAQAUJTYAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARYkNAKAosQEAFCU2AICixAYAUJTYAACKalps9Pf356677sqWLVvS29ubl19+uVmnBgBa2LuacZJarZbrr78+X/3qV7Nq1aosWbIkN998c/r6+ppxegCghTXlysbOnTtz6NChLF++PEmyatWqHDhwIPv27WvG6QGAFtaUKxt9fX3p6upKW1tbkmTmzJnp6urKrl278olPfOKc44eHhzM8PDx6v1qtJkmGhoaaMc4YteHXmn5OAGglJf59ffN56/X6Wx7XlNg4ceJEOjo6xjzW2dmZI0eOnPf4jRs3Zv369ec8vmjRomaMAwC8Seemsuc/depUOjs7L/jxpsRGW1vb6FWNs2q1Wmq12nmPX7duXdauXTvm2FdffTVz585NpVJpxkhTbmhoKIsWLcrg4OA5IcbksQ5TzxpcHKzD1JuOa1Cv13Pq1KksWLDgLY9rSmzMnz8/e/fuHfNYtVrNwoULz3t8e3t72tvbxzz2nve8pxmjXHQ6OjqmzR+qVmYdpp41uDhYh6k33dbgra5onNWUN4iuWLEi/f39o6/ZnD59OgMDA1m5cmUzTg8AtLCmxMY111yTBQsWZM+ePUmS3bt3Z/HixVm6dGkzTg8AtLCmvIwyY8aMbNu2LRs2bMjBgwfT19eXrVu3Tpv3X0xEe3t7vvzlL5/zchGTyzpMPWtwcbAOU+9SXoNK/e1+XgUA4B2wNwoAUJTYAACKEhsAQFFiAwAoSmwUsGnTpvz5n/95HnzwwXzlK1952+Nfe+21fPjDH87AwED54S4Rja7Bs88+m56ennR0dOSmm27Kq6++OolTTi/9/f256667smXLlvT29ubll18+73HjfX7QuEbWYGhoKLfeemve8573pKurK//4j/84BZNOb40+F8765je/md/93d+dnOGmSp2meuaZZ+rLly8fvf/JT36yvmPHjrf8nIcffriepN7f3194uktDo2vwwx/+sH7HHXfU//M//7P+7W9/u/7zP//z9c997nOTOeq0MTIyUv+lX/ql+vPPP1+v1+v15557rr5s2bJzjpvI84PGNLoG999/f/273/1u/T/+4z/qt956a72tra1++PDhyR532mp0Hc566aWX6r/8y79c/+xnPztJE04NVzaa7K//+q/za7/2a6P3P/OZz2Tz5s0XPH7btm1+02qTNboGe/fuzebNm3PVVVflN3/zN7NmzZr827/922SOOm3s3Lkzhw4dyvLly5Mkq1atyoEDB7Jv374xx433+UHjGlmD06dP58Mf/nA+85nP5KMf/Wj+7u/+LjNmzMi///u/T9XY006jz4UkeeONN/L000/nxhtvnOQpJ5/YaKI33ngj+/fvT09Pz+hj3d3d2bVr13mP/9///d8cO3bMb1ptovGsQW9vb2bNmjV6//LLL88VV1wxKXNON319fenq6hrdkHHmzJnp6uoa830f7/OD8WlkDdra2tLb2zt6/93vfnc6Ozv9uW+iRtbhrM2bN+fzn//8ZI84JcRGE/3kJz/JmTNnxmyw09nZmddffz0nT54cc+zIyEgee+yx3HHHHZM95rQ2njX4Wfv378+dd95ZesRp6cSJE+dsLNXZ2ZkjR46M3n8na8Pba2QNftYrr7yShQsX5ld+5VdKj3fJaHQdnnvuuXzsYx/L3LlzJ3O8KSM2muhsyZ79b5LUarUx/z3rb//2b3PnnXdmxgxL0EzjWYM3O3bsWM6cOXNJXM4soa2tbcz3PPnp9/vN3/OJrg2NaWQNftbXv/71bNmypfRol5RG1uHEiRM5ePBgVq1aNdnjTZmm7I1yqTh27Fg+9rGPXfDjv/M7v5NZs2ZlaGho9LFqtZp3v/vd59Tr5s2b80d/9EdjHuvp6cndd9+dv/qrv2ru4NNIM9fgrJGRkWzatMl7B96B+fPnZ+/evWMeq1arWbhw4ej9uXPnjnttaFwja/Bmu3btykc/+tFcffXVkzHeJaORddi5c2cefPDBPPjgg0mSM2fOpF6v5+mnn86JEyca2rK91YiNcZg/f36OHz/+lsf84Ac/yA9/+MPR+//93/+dFStWnHPc//zP/4y5X6lU8tJLL2Xx4sXNGHXaauYanPXII49k7dq1ueyyy5L89L0Fb34vB29vxYoV+cu//MvU6/VUKpWcPn06AwMDY978XKlU8qlPfWpca0PjGlmDs1588cUcPnw4f/AHf5Dkp//YzZw585LePLNZGlmH3t7eMe+d+ZM/+ZMMDAzk8ccfn4KJJ4dr+E32e7/3e9m+ffvo/e9///v5/d///STJwYMHs379+qka7ZIxnjXYtGlTuru7c/Lkybz44ov53ve+l+9///uTPnOru+aaa7JgwYLs2bMnSbJ79+4sXrw4S5cuzfr163Pw4MEkb702vDONrsHx48fzta99Lb/6q7+aF198MQcOHMjGjRuncvRppdF1uNS4stFkt9xySwYGBvKlL30pIyMj+fSnP53f/u3fTpK89NJL+ed//ud8+ctfnuIpp7dG1+CZZ57J2rVrU3/TxsezZ89+2ysnnGvGjBnZtm1bNmzYkIMHD6avry9bt25NpVLJs88+m4985CP5yEc+8pZrwzvTyBr84i/+Yq677rrs379/zMuGf/zHf+yqRpM0+ly41NhiHgAoyssoAEBRYgMAKEpsAABFiQ0AoCixAQAUJTYAgKLEBgBQlNgAAIoSGwBAUWIDAChKbAAARf0/byenrUqnBfYAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(output.detach().cpu().numpy()[:,0], bins = 50)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "826f3992-b2bd-4959-8282-a232a744bcaa", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "softplus = nn.Softplus()" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "32bcbe63-7665-4bad-9b12-8389cd3e773a", + "metadata": {}, + "outputs": [], + "source": [ + "class RadialFlow(nn.Module):\n", + " def __init__(self, input_dim):\n", + " super(RadialFlow, self).__init__()\n", + " self.input_dim = input_dim\n", + " self.alpha = nn.Parameter(torch.randn(input_dim))\n", + " self.beta = nn.Parameter(torch.randn(input_dim))\n", + " self.gamma = nn.Parameter(torch.randn(input_dim))\n", + "\n", + " def forward(self, z):\n", + " \n", + " \n", + " num = softplus(self.alpha) * (torch.exp(self.beta)-1) * (z - self.gamma) \n", + " den = softplus(self.alpha) + (z - self.gamma) \n", + " \n", + " z_flow = z + num / den\n", + " \n", + " r = torch.abs(z-self.gamma)\n", + " h = 1 / (softplus(self.alpha) + r)\n", + " \n", + " term1 = (1 + softplus(self.alpha) * (torch.exp(self.beta)-1) * h)**(self.input_dim -1)\n", + " term2 = 1 + softplus(self.alpha) * (torch.exp(self.beta)-1) * h - softplus(self.alpha) * (torch.exp(self.beta)-1) * r *h**2\n", + " \n", + " det_jacobian = term1*term2\n", + " \n", + " return z_flow, det_jacobian\n", + "\n", + "class NormalizingFlow(nn.Module):\n", + " def __init__(self, input_dim, num_flows):\n", + " super(NormalizingFlow, self).__init__()\n", + " self.flows = nn.ModuleList([RadialFlow(input_dim) for _ in range(num_flows)])\n", + "\n", + " def forward(self, z):\n", + " det_jacobian = torch.ones(z.size(0), 1).to(device)\n", + " for flow in self.flows:\n", + " z, jacobian = flow(z)\n", + " det_jacobian *= jacobian\n", + " return z, det_jacobian" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "id": "bf6d28e8-544a-4ab9-8174-2d3ec35d0891", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "def MMD(x, y, kernel):\n", + " \"\"\"Emprical maximum mean discrepancy. The lower the result\n", + " the more evidence that distributions are the same.\n", + "\n", + " Args:\n", + " x: first sample, distribution P\n", + " y: second sample, distribution Q\n", + " kernel: kernel type such as \"multiscale\" or \"rbf\"\n", + " \"\"\"\n", + " xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())\n", + " rx = (xx.diag().unsqueeze(0).expand_as(xx))\n", + " ry = (yy.diag().unsqueeze(0).expand_as(yy))\n", + "\n", + " dxx = rx.t() + rx - 2. * xx # Used for A in (1)\n", + " dyy = ry.t() + ry - 2. * yy # Used for B in (1)\n", + " dxy = rx.t() + ry - 2. * zz # Used for C in (1)\n", + "\n", + " XX, YY, XY = (torch.zeros(xx.shape).to(device),\n", + " torch.zeros(xx.shape).to(device),\n", + " torch.zeros(xx.shape).to(device))\n", + "\n", + " if kernel == \"multiscale\":\n", + "\n", + " bandwidth_range = [0.2, 0.5, 0.9, 1.3]\n", + " for a in bandwidth_range:\n", + " XX += a**2 * (a**2 + dxx)**-1\n", + " YY += a**2 * (a**2 + dyy)**-1\n", + " XY += a**2 * (a**2 + dxy)**-1\n", + "\n", + " if kernel == \"rbf\":\n", + "\n", + " bandwidth_range = [1, 1.5, 2.0, 5.0]\n", + " for a in bandwidth_range:\n", + " XX += torch.exp(-0.5*dxx/a)\n", + " YY += torch.exp(-0.5*dyy/a)\n", + " XY += torch.exp(-0.5*dxy/a)\n", + "\n", + "\n", + " return torch.mean(XX + YY - 2. * XY)" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "id": "478d8261-7e95-48fa-8b5c-6194ebaf4524", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "RadialFlow()" + ] + }, + "execution_count": 98, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "RadialFlow(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "c2f42513-ed84-4bb5-9097-f7b1b7862da9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Define dimensions\n", + "input_dim = 1\n", + "\n", + "z_dim = input_dim\n", + "nepochs=10000\n", + "\n", + "bs=1000\n", + "num_flows=100\n", + "# Instantiate the PlanarFlow transformation\n", + "flow = NormalizingFlow(input_dim, num_flows).to(device)\n", + "\n", + "base_distribution = D.Normal(torch.zeros(z_dim), torch.ones(z_dim))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "94928333-6043-4cbf-b247-b2904f0a9719", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "random_input = torch.rand(10000, input_dim)\n", + "dset = TensorDataset(random_input)\n", + "loader = DataLoader(dset, batch_size=1000, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "id": "18b1dcd2-7445-48af-bd62-46669d372ceb", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 0 28200.544921875\n", + "epoch: 1 1129.369384765625\n", + "epoch: 2 21.443578720092773\n", + "epoch: 3 112.31295776367188\n", + "epoch: 4 149.22096252441406\n", + "epoch: 5 33502.9375\n", + "epoch: 6 110.99540710449219\n", + "epoch: 7 271.1609191894531\n", + "epoch: 8 57.25232696533203\n", + "epoch: 9 73.51012420654297\n", + "epoch: 10 97.47554779052734\n", + "epoch: 11 123.52039337158203\n", + "epoch: 12 221.76556396484375\n", + "epoch: 13 225.6431427001953\n", + "epoch: 14 36.62788772583008\n", + "epoch: 15 788.5504150390625\n", + "epoch: 16 905.950927734375\n", + "epoch: 17 359.9490661621094\n", + "epoch: 18 94.8775863647461\n", + "epoch: 19 1562.370849609375\n", + "epoch: 20 68.43376159667969\n", + "epoch: 21 120.20706939697266\n", + "epoch: 22 117.14567565917969\n", + "epoch: 23 156.90737915039062\n", + "epoch: 24 64.27742767333984\n", + "epoch: 25 72.80873107910156\n", + "epoch: 26 299.83074951171875\n", + "epoch: 27 198.3571014404297\n", + "epoch: 28 67.56690979003906\n", + "epoch: 29 59.80242156982422\n", + "epoch: 30 152.69972229003906\n", + "epoch: 31 85.3028335571289\n", + "epoch: 32 2329.72314453125\n", + "epoch: 33 50.84280776977539\n", + "epoch: 34 25396.7109375\n", + "epoch: 35 1737.4556884765625\n", + "epoch: 36 101.53565216064453\n", + "epoch: 37 505.9686584472656\n", + "epoch: 38 972.81689453125\n", + "epoch: 39 48.345829010009766\n", + "epoch: 40 71052.6328125\n", + "epoch: 41 4206.85205078125\n", + "epoch: 42 168.01895141601562\n", + "epoch: 43 9705.7646484375\n", + "epoch: 44 337.1620788574219\n", + "epoch: 45 2976.151611328125\n", + "epoch: 46 2142248192.0\n", + "epoch: 47 18152.1484375\n", + "epoch: 48 26.833215713500977\n", + "epoch: 49 3471.46044921875\n", + "epoch: 50 583.5856323242188\n", + "epoch: 51 96.70631408691406\n", + "epoch: 52 170.61947631835938\n", + "epoch: 53 64.85731506347656\n", + "epoch: 54 914.636474609375\n", + "epoch: 55 153.5564727783203\n", + "epoch: 56 184.95700073242188\n", + "epoch: 57 45.27014923095703\n", + "epoch: 58 1541.3345947265625\n", + "epoch: 59 227.9755859375\n", + "epoch: 60 144.11129760742188\n", + "epoch: 61 55.70856475830078\n", + "epoch: 62 2118.783447265625\n", + "epoch: 63 161.55926513671875\n", + "epoch: 64 31.37447738647461\n", + "epoch: 65 1794.9423828125\n", + "epoch: 66 52.8399543762207\n", + "epoch: 67 51.3333854675293\n", + "epoch: 68 40.058067321777344\n", + "epoch: 69 51.833065032958984\n", + "epoch: 70 1669.7928466796875\n", + "epoch: 71 52.97251892089844\n", + "epoch: 72 56.90239715576172\n", + "epoch: 73 27612.794921875\n", + "epoch: 74 184.5417938232422\n", + "epoch: 75 70.6549072265625\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[95], line 14\u001b[0m\n\u001b[1;32m 12\u001b[0m loss \u001b[38;5;241m=\u001b[39m MMD(output,randomN_input, kernel\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mrbf\u001b[39m\u001b[38;5;124m'\u001b[39m)\u001b[38;5;241m-\u001b[39mbase_distribution\u001b[38;5;241m.\u001b[39mlog_prob(output\u001b[38;5;241m.\u001b[39mcpu())\u001b[38;5;241m.\u001b[39mto(device) \u001b[38;5;241m+\u001b[39m torch\u001b[38;5;241m.\u001b[39mlog(det_jacb) \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[1;32m 13\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mmean()\n\u001b[0;32m---> 14\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 15\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep() \n\u001b[1;32m 17\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mepoch:\u001b[39m\u001b[38;5;124m'\u001b[39m,e, loss\u001b[38;5;241m.\u001b[39mitem())\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/DLenv2/lib/python3.9/site-packages/torch/_tensor.py:363\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 355\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 356\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 357\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 361\u001b[0m create_graph\u001b[38;5;241m=\u001b[39mcreate_graph,\n\u001b[1;32m 362\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs)\n\u001b[0;32m--> 363\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/astro/scratch/lcabayol/anaconda3/envs/DLenv2/lib/python3.9/site-packages/torch/autograd/__init__.py:173\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 168\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 170\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 173\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 174\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "optimizer = optim.Adam(flow.parameters(), lr=1e-3, weight_decay=1e-4)\n", + "\n", + "for e in range(nepochs):\n", + " for x in loader:\n", + " #assert False\n", + " \n", + " optimizer.zero_grad()\n", + " randomN_input = torch.randn(len(x[0]), input_dim).to(device)\n", + "\n", + " output, det_jacb = flow(x[0].to(device))\n", + " \n", + " loss = MMD(output,randomN_input, kernel='rbf')-base_distribution.log_prob(output.cpu()).to(device) + torch.log(det_jacb) #\n", + " loss = loss.mean()\n", + " loss.backward()\n", + " optimizer.step() \n", + " \n", + " print('epoch:',e, loss.item())\n", + " \n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "id": "75356e3f-a4a8-4da3-98a4-1574c7bc3f68", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0., 0., 0., 0., 0., 1., 0., 0., 3., 1., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 2., 4., 2.,\n", + " 0., 0., 0., 112., 59., 70., 72., 74., 71., 99., 111.,\n", + " 127., 71., 10., 9., 11., 17., 6., 6., 4., 5., 1.,\n", + " 1., 4., 4., 7., 2., 3.]),\n", + " array([-10. , -9.6, -9.2, -8.8, -8.4, -8. , -7.6, -7.2, -6.8,\n", + " -6.4, -6. , -5.6, -5.2, -4.8, -4.4, -4. , -3.6, -3.2,\n", + " -2.8, -2.4, -2. , -1.6, -1.2, -0.8, -0.4, 0. , 0.4,\n", + " 0.8, 1.2, 1.6, 2. , 2.4, 2.8, 3.2, 3.6, 4. ,\n", + " 4.4, 4.8, 5.2, 5.6, 6. , 6.4, 6.8, 7.2, 7.6,\n", + " 8. , 8.4, 8.8, 9.2, 9.6, 10. ], dtype=float32),\n", + " )" + ] + }, + "execution_count": 97, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAGcCAYAAADknMuyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAh50lEQVR4nO3dcWzU9f3H8de1HNcwuCOgQFsr5WKKwmY2Jt2YlNHGhA1d18TENKhdHAYwSmaITjo1glY7+WMUOzVp/0FlRiVay8RZoim23S5ZN5itGaCVu1oKFBjsDqWUtvf9/eGPi0ev5dp928/d9flIvtHv9/vp9/v++vne3cvP93v3dViWZQkAAMCANNMFAACAyYsgAgAAjCGIAAAAYwgiAADAGIIIAAAwhiACAACMIYgAAABjppgu4GrC4bCOHz+uGTNmyOFwmC4HAADEwbIsnT9/XllZWUpLG37cI+GDyPHjx5WTk2O6DAAAMAZdXV267rrrhl2f8EFkxowZkr45ELfbbbgaAAAQj1AopJycnMjn+HASPohcvhzjdrsJIgAAJJmr3VbBzaoAAMAYgggAADCGIAIAAIwhiAAAAGMIIgAAwBiCCAAAMIYgAgAAjCGIAAAAYwgiAADAGIIIAAAwhiACAACMIYgAAABjCCIAAMAYgggAADCGIAIAAIyZYroAAKkrd/Peq7YJ/P72CagEQKJiRAQAABhDEAEAAMYQRAAAgDEEEQAAYAxBBAAAGEMQAQAAxhBEAACAMQQRAABgDEEEAAAYQxABAADGEEQAAIAxBBEAAGAMQQQAABhDEAEAAMYQRAAAgDEEEQAAYAxBBAAAGEMQAQAAxhBEAACAMQQRAABgzBTTBQCY3HI3771qm8Dvb5+ASgCYwIgIAAAwZsxBpLe3V8Fg0M5aAADAJDPqIBIOh7Vz507l5eXp4MGDQ9ZfuHBBixYtUiAQiCzz+/3asGGDampqVFZWps7Ozv+paAAAkBpGfY/ImTNnVFhYqGPHjsVcX11drUOHDkXmw+GwiouLtWPHDhUVFWnBggUqLS2Vz+cbe9UAACAljHpEZM6cOZo/f37MdfX19SosLIxa1tDQoI6ODhUUFEiSioqK1NbWptbW1jGUCwAAUoltN6t++eWXOnHihPLz86OW+3w+eb1eOZ1OSVJ6erq8Xq8aGxtjbqevr0+hUChqAgAAqcmWIDI4OKja2lqtW7duyLqenh653e6oZR6PR93d3TG3VVlZKY/HE5lycnLsKBEAACQgW4LIiy++qPXr1ystbejmnE5nZDTksnA4rHA4HHNb5eXlCgaDkamrq8uOEgEAQAKyJYhUV1frhhtuUEZGhjIyMiRJCxcu1KOPPqrMzMwhl1eCwaCys7NjbsvlcsntdkdNAAAgNdnyy6qff/551LzD4dCRI0eUm5urv/71r9q2bZssy5LD4VB/f78CgcCQm1oBAMDkM6YRkeEuq8SybNkyZWVlqbm5WZLU1NSk3NzcITe1AgCAyWfUIyKnT59WbW2tJGnXrl2aN2+ebrzxxmHbp6Wlqb6+XhUVFWpvb5fP51NdXZ0cDsfYqwYAACnBYVmWZbqIkYRCIXk8HgWDQe4XAZJMPA+0iwcPvQOST7yf3zz0DgAAGEMQAQAAxhBEAACAMQQRAABgDEEEAAAYQxABAADGEEQAAIAxBBEAAGAMQQQAABhDEAEAAMYQRAAAgDGjfugdAEj2PUcGwOTGiAgAADCGIAIAAIwhiAAAAGMIIgAAwBiCCAAAMIYgAgAAjCGIAAAAYwgiAADAGIIIAAAwhiACAACMIYgAAABjeNYMACSyLZ442gTHvw5gnDAiAgAAjCGIAAAAYwgiAADAGIIIAAAwhiACAACMIYgAAABjCCIAAMAYgggAADCGIAIAAIwhiAAAAGPGHER6e3sVDPKzwgAAYOxGHUTC4bB27typvLw8HTx4MLJ8z549Wrhwodxut+68806dPXs2ss7v92vDhg2qqalRWVmZOjs77akeAAAktVEHkTNnzqiwsFDHjh2LLDt69Kj27t2rd955Rzt37tT+/fv12GOPSfomuBQXF+uuu+7SunXrdO+996q0tNS+IwAAAElr1E/fnTNnzpBlLS0tqq6u1tSpU7V48WK1tbVp9+7dkqSGhgZ1dHSooKBAklRUVKSSkhK1trZq6dKl/2P5AAAgmdlys2pZWZmmTp0amZ87d66uv/56SZLP55PX65XT6ZQkpaeny+v1qrGxMea2+vr6FAqFoiYAAJCaxuVbMwcOHND69eslST09PXK73VHrPR6Puru7Y/5tZWWlPB5PZMrJyRmPEgEAQAKwPYicOHFCAwMDKikpkSQ5nc7IaMhl4XBY4XA45t+Xl5crGAxGpq6uLrtLBAAACWLU94iMZHBwUFVVVaquro4sy8zMVEtLS1S7YDCo7OzsmNtwuVxyuVx2lgUAABKUrSMi27dv16ZNmzR9+nRJ0qVLl7Ry5Ur5/X5ZliVJ6u/vVyAQUGFhoZ27BgAASWhMIyKxLqtUVVUpLy9P586d07lz53T06FENDAzojjvuUFZWlpqbm7VixQo1NTUpNzdX+fn5/3PxAAAguY06iJw+fVq1tbWSpF27dmnevHlqa2vTpk2bIqMekjRt2jSdPHlSaWlpqq+vV0VFhdrb2+Xz+VRXVyeHw2HfUQBIabmb9161TeD3t09AJQDs5rC+nR4SUCgUksfjUTAYHPLtGwDmxBMOJlLKBpEtnjja8LgNJJ54P7956B0AADCGIAIAAIwhiAAAAGMIIgAAwBiCCAAAMIYgAgAAjCGIAAAAYwgiAADAGIIIAAAwhiACAACMIYgAAABjCCIAAMCYUT99FwDsFMhYc9U2uRdfn4BKAJjAiAgAADCGEREAMGWLx3QFgHGMiAAAAGMIIgAAwBiCCAAAMIYgAgAAjCGIAAAAYwgiAADAGIIIAAAwhiACAACMIYgAAABjCCIAAMAYgggAADCGIAIAAIwhiAAAAGMIIgAAwBiCCAAAMGaK6QIAJJ7czXtNlwBgkmBEBAAAGEMQAQAAxhBEAACAMWMOIr29vQoGg3bWAgAAJplRB5FwOKydO3cqLy9PBw8ejCz3+/3asGGDampqVFZWps7OzrjWAQCAyWvU35o5c+aMCgsLdezYsciycDis4uJi7dixQ0VFRVqwYIFKS0vl8/lGXAcAACa3UY+IzJkzR/Pnz49a1tDQoI6ODhUUFEiSioqK1NbWptbW1hHXAQCAyc2W3xHx+Xzyer1yOp2SpPT0dHm9XjU2Nuqrr74adt3SpUuHbKuvr099fX2R+VAoZEeJAAAgAdnyrZmenh653e6oZR6PR93d3SOui6WyslIejycy5eTk2FEiAABIQLYEEafTGRnxuCwcDiscDo+4Lpby8nIFg8HI1NXVZUeJAAAgAdkSRDIzM4dcQgkGg8rOzh5xXSwul0tutztqAgAAqcmWe0RWrlypbdu2ybIsORwO9ff3KxAIqLCwUAMDA8OuAwC7xPN8nMDvb5+ASv7fFs/E7QtIYmMaEbnyssqyZcuUlZWl5uZmSVJTU5Nyc3OVn58/4joAADC5jXpE5PTp06qtrZUk7dq1S/PmzdONN96o+vp6VVRUqL29XT6fT3V1dXI4HHI4HMOuAwAAk5vDsizLdBEjCYVC8ng8CgaD3C8CTJB4LnPYJZCx5qptci++bs++UvXSzBYet4HEE+/nNw+9AwAAxhBEAACAMQQRAABgDEEEAAAYQxABAADGEEQAAIAxBBEAAGAMQQQAABhDEAEAAMYQRAAAgDEEEQAAYAxBBAAAGEMQAQAAxhBEAACAMQQRAABgDEEEAAAYQxABAADGEEQAAIAxBBEAAGAMQQQAABhDEAEAAMYQRAAAgDEEEQAAYAxBBAAAGEMQAQAAxhBEAACAMQQRAABgDEEEAAAYQxABAADGEEQAAIAxBBEAAGAMQQQAABhDEAEAAMYQRAAAgDEEEQAAYAxBBAAAGDPFzo01NTXpww8/1DXXXKPW1lY9/vjjuvHGG+X3+/X8889ryZIlamlp0TPPPKP58+fbuWsAAJCEbAsig4ODuu+++3TkyBFNmTJFH3/8sR566CHt27dPxcXF2rFjh4qKirRgwQKVlpbK5/PZtWsAAJCkbLs0c/bsWR0/flwXLlyQJHk8Hp07d04NDQ3q6OhQQUGBJKmoqEhtbW1qbW21a9cAACBJ2RZErr32Wv3whz/UPffco//+97964YUXtHXrVvl8Pnm9XjmdTklSenq6vF6vGhsbY26nr69PoVAoagIAAKnJ1ptVd+/erY6ODmVmZmrVqlW644471NPTI7fbHdXO4/Gou7s75jYqKyvl8XgiU05Ojp0lAgCABGLrzao9PT1avXq1Ojs79atf/UozZ86U0+mMjIZcFg6HFQ6HY26jvLxcmzZtisyHQiHCCAAAKcq2IHLhwgXdc889+sc//qGMjAw98cQTWrt2rR544AG1tLREtQ0Gg8rOzo65HZfLJZfLZVdZAAAggdl2aebTTz/VjBkzlJGRIUnasmWLzp8/r+XLl8vv98uyLElSf3+/AoGACgsL7do1AABIUrYFkRtuuEHd3d36+uuvJUmXLl3SvHnztHz5cmVlZam5uVnSN781kpubq/z8fLt2DQAAkpRtl2ZmzZql2tpabdy4Ud/73vd07Ngxvfbaa0pPT1d9fb0qKirU3t4un8+nuro6ORwOu3YNAACSlK03q65atUqrVq0asjwvL0+vvvqqJOnBBx+0c5cAACCJ8awZAABgDEEEAAAYQxABAADGEEQAAIAxBBEAAGAMQQQAABhDEAEAAMbY+jsiAJD0tnjiaBMc/zqASYIREQAAYAwjIsAkk7t5r+kSACCCEREAAGAMIyIAJo947v8AMKEYEQEAAMYwIgIAo5VoIyt80wdJjBERAABgDEEEAAAYw6UZACkhkLHGdAkAxoAREQAAYAxBBAAAGEMQAQAAxhBEAACAMQQRAABgDEEEAAAYQxABAADGEEQAAIAxBBEAAGAMQQQAABhDEAEAAMYQRAAAgDEEEQAAYAxBBAAAGEMQAQAAxhBEAACAMQQRAABgzJTx2OihQ4dUV1ennJwclZSUaMaMGeOxGwAAkORsDyIvvvii3njjDb3xxhvKzs6WJPn9fj3//PNasmSJWlpa9Mwzz2j+/Pl27xoAACQZW4PIO++8o6efflqffvqprr32WklSOBxWcXGxduzYoaKiIi1YsEClpaXy+Xx27hoAACQh2+4RGRgY0MMPP6xHHnkkEkIkqaGhQR0dHSooKJAkFRUVqa2tTa2trXbtGgAAJCnbgkhTU5O6urp05MgRlZSU6KabbtIbb7whn88nr9crp9MpSUpPT5fX61VjY2PM7fT19SkUCkVNAAAgNdl2aaatrU0zZ87Utm3bNGvWLH3wwQcqLi5WYWGh3G53VFuPx6Pu7u6Y26msrNTWrVvtKgsAACQw20ZEent7ddNNN2nWrFmSpJ/97GeaO3euWlpaIqMhl4XDYYXD4ZjbKS8vVzAYjExdXV12lQgAABKMbSMi8+bN09dffx217LrrrtP999+vurq6qOXBYDDyjZoruVwuuVwuu8oCkAICGWtMlwBgnNg2InLrrbcqEAhoYGAgsuzixYuSvvn6rmVZkqT+/n4FAgEVFhbatWsAAJCkbAsieXl5+v73v699+/ZJks6ePaszZ87o0UcfVVZWlpqbmyV9c1Nrbm6u8vPz7do1AABIUrb+jshrr72mRx55RG1tbfL7/dq9e7emTZum+vp6VVRUqL29XT6fT3V1dXI4HHbuGgAAJCGHdfmaSYIKhULyeDwKBoNDvn0DYPRyN+81XUIU7v+YIFuCpivAJBPv5zcPvQMAAMYQRAAAgDEEEQAAYAxBBAAAGEMQAQAAxhBEAACAMQQRAABgDEEEAAAYQxABAADGEEQAAIAxBBEAAGAMQQQAABhDEAEAAMYQRAAAgDEEEQAAYAxBBAAAGEMQAQAAxhBEAACAMQQRAABgDEEEAAAYQxABAADGEEQAAIAxBBEAAGAMQQQAABhDEAEAAMYQRAAAgDEEEQAAYAxBBAAAGEMQAQAAxhBEAACAMQQRAABgDEEEAAAYQxABAADGEEQAAIAxBBEAAGCM7UHkwoULWrRokQKBgCTJ7/drw4YNqqmpUVlZmTo7O+3eJQAASFJT7N5gdXW1Dh06JEkKh8MqLi7Wjh07VFRUpAULFqi0tFQ+n8/u3QIAgCRk64hIfX29CgsLI/MNDQ3q6OhQQUGBJKmoqEhtbW1qbW21c7cAACBJ2RZEvvzyS504cUL5+fmRZT6fT16vV06nU5KUnp4ur9erxsbGYbfT19enUCgUNQEAgNRkSxAZHBxUbW2t1q1bF7W8p6dHbrc7apnH41F3d/ew26qsrJTH44lMOTk5dpQIAAASkC1B5MUXX9T69euVlha9OafTGRkNuSwcDiscDg+7rfLycgWDwcjU1dVlR4kAACAB2XKzanV1tX77299GLVu4cKHC4bAWL14ctTwYDCo7O3vYbblcLrlcLjvKAgAACc6WIPL5559HzTscDh05ckTd3d1avXq1LMuSw+FQf3+/AoFA1A2tAABg8hrXHzRbtmyZsrKy1NzcLElqampSbm5u1A2tAABg8rL9d0S+LS0tTfX19aqoqFB7e7t8Pp/q6urkcDjGc7cAACBJjEsQsSwr8u95eXl69dVXJUkPPvjgeOwOAAAkKZ41AwAAjCGIAAAAYwgiAADAGIIIAAAwhiACAACMIYgAAABjCCIAAMAYgggAADCGIAIAAIwhiAAAAGMIIgAAwBiCCAAAMIYgAgAAjCGIAAAAY6aYLgCAfXI37zVdAgCMCiMiAADAGIIIAAAwhiACAACMIYgAAABjCCIAAMAYgggAADCGIAIAAIwhiAAAAGMIIgAAwBiCCAAAMIYgAgAAjCGIAAAAYwgiAADAGIIIAAAwhiACAACMIYgAAABjCCIAAMAYgggAADCGIAIAAIwhiAAAAGNsCyJ79uzRwoUL5Xa7deedd+rs2bOSJL/frw0bNqimpkZlZWXq7Oy0a5cAACDJ2RJEjh49qr179+qdd97Rzp07tX//fj322GMKh8MqLi7WXXfdpXXr1unee+9VaWmpHbsEAAApYIodG2lpaVF1dbWmTp2qxYsXq62tTbt371ZDQ4M6OjpUUFAgSSoqKlJJSYlaW1u1dOlSO3YNAACSmC0jImVlZZo6dWpkfu7cubr++uvl8/nk9XrldDolSenp6fJ6vWpsbLRjtwAAIMmNy82qBw4c0Pr169XT0yO32x21zuPxqLu7e9i/7evrUygUipoAAEBqsj2InDhxQgMDAyopKZHT6YyMhlwWDocVDoeH/fvKykp5PJ7IlJOTY3eJAAAgQdgaRAYHB1VVVaXq6mpJUmZm5pARjWAwqOzs7GG3UV5ermAwGJm6urrsLBEAACQQW4PI9u3btWnTJk2fPl2StHz5cvn9flmWJUnq7+9XIBBQYWHhsNtwuVxyu91REwAASE22fGtGkqqqqpSXl6dz587p3LlzOnr0qAYGBpSVlaXm5matWLFCTU1Nys3NVX5+vl27BQAAScyWIPLWW29p06ZNkZEPSZo2bZpOnjyp+vp6VVRUqL29XT6fT3V1dXI4HHbsFgAAJDmH9e30kIBCoZA8Ho+CwSCXaYCryN2813QJoxbIWGO6hMlhS9B0BZhk4v385lkzAADAGNvuEQEAQJK0xRNHG0Zo8A1GRAAAgDEEEQAAYAxBBAAAGMM9IkCSSMZvxADA1TAiAgAAjCGIAAAAYwgiAADAGIIIAAAwhiACAACMIYgAAABjCCIAAMAYgggAADCGHzQDAMQvngfaAaPAiAgAADCGIAIAAIzh0gwATAbxXFLZEhz/OoArMCICAACMYUQEAPANbkSFAYyIAAAAYwgiAADAGIIIAAAwhiACAACM4WZVIAHkbt5rugQAiWQSfd2aEREAAGAMIyIAgIln11eFU2RUYDJjRAQAABjDiAgAILVNovstkhEjIgAAwBiCCAAAMIZLM0AKCWSsuWqb3IuvT0AlwARJtOfjJFo9SYAREQAAYAwjIhgeN3gBQHJLgvdxRkQAAIAxjIgASSKe+z/s2g73kQBJIEXuR5mQERG/368NGzaopqZGZWVl6uzsnIjdAgCABDfuIyLhcFjFxcXasWOHioqKtGDBApWWlsrn8433rgEAQIIb9yDS0NCgjo4OFRQUSJKKiopUUlKi1tZWLV26dLx3n3oS7cajRKtngvHUXCBFpMhljmQ07kHE5/PJ6/XK6XRKktLT0+X1etXY2BgziPT19amvry8yHwx+8yEWCoXGu9Tk0GddvY1d/63i2Vc8Urjvwn0XJmxfIYdN/RGHVD0uADGM03v05c9tyxr5NT7uQaSnp0dutztqmcfjUXd3d8z2lZWV2rp165DlOTk541JfSvp9giX7RKsnSU3sf8W7JmxPnB2AYeP8Hn3+/Hl5PMPvY9yDiNPpjIyGXBYOhxUOh2O2Ly8v16ZNm6Lanj17VrNnz5bD4bCtrlAopJycHHV1dQ0JSqki1Y+R40t+qX6MHF/yS/VjHM/jsyxL58+fV1ZW1ojtxj2IZGZmqqWlJWpZMBhUdnZ2zPYul0sulytq2cyZM8erPLnd7pQ8ub4t1Y+R40t+qX6MHF/yS/VjHK/jG2kk5LJx//ruypUr5ff7I9eI+vv7FQgEVFhYON67BgAACW7cg8iyZcuUlZWl5uZmSVJTU5Nyc3OVn58/3rsGAAAJbtwvzaSlpam+vl4VFRVqb2+Xz+dTXV2drfd7jIXL5dJTTz015DJQKkn1Y+T4kl+qHyPHl/xS/RgT4fgc1tW+VwMAADBOeOgdAAAwhiACAACMIYgAAABjJlUQOXnypOkSEtrx48dNlzBqlmUlZd3jhf8WySOZz91krdtOydR/vb29kcelJKJJEUT8fr/WrFmjNWvWDFlXVVWl5557To8//rj+8Ic/jLiNDRs2qKamRmVlZers7BzPksfs4YcflsPhiJruuiv2z3VblqW8vLxIu3vuuWeCqx2bXbt2RWpOS0tTR0dHzHbJ0mdXCoVCuvvuuzVz5kx5vV69+eabw7ZNhj6Mtx/ifS0moj179mjhwoVyu9268847dfbs2Zjt4j13E02851ky9uHp06c1derUIe+b//73v4e0Tbb+C4fD2rlzp/Ly8nTw4MHI8tG8N05In1qTwBdffGE9+OCD1k9/+tOo5W+99ZZVUFAQmf/JT35i/eUvfxny94ODg9Z3v/td66OPPrIsy7L27dtn/fjHPx7Xmsfi0qVL1kMPPWQdPnzY8vv9lt/vtx555BFr586dMdvv3bvXqq6utlpbW63W1lbrzJkzE1zx2KxduzZSc3t7e8w2ydJnsTz22GPWn//8Z+uTTz6x7r77bsvpdFpHjx6N2TbR+zDefoj3tZiIvvjiC2vdunXWp59+ar399tvWrFmzrPvvvz9m23jO3UQUz3mWrH1YW1trvffee5H3zM8++8xatGhRzLbJ1n89PT1WIBCwJFmNjY2WZY3uvXGi+nRSBBHLsqynnnpqSBDJz8+3nnnmmcj8c889Z61evXrI377//vtWRkaGdenSJcuyLGtgYMCaNm2a9fe//31cax6tr776ygqFQlHLCgoKrFOnTsVs//Of/9x66aWXrM7OzokozxZ/+9vfrNtuu83at2+f1d/fP2y7ZOmzK126dMl65ZVXIvO9vb2Wy+Wy3nrrrZjtE70P4+2HeF+LieiVV16x+vr6IvNPPfVUzA+yeM/dRBTPeZasfXj8+PGo+X379lkbN24c0i6Z++/bQWQ0740T1aeT4tJMLJcuXdKBAwe0cOHCyLK8vDw1NjYOaevz+eT1eiMP70tPT5fX643Z1qTvfOc7mjFjRmT+1KlTkqRrr712SNvz58+rr69PTzzxhBYsWKCNGzde9VHNieBf//qXTp06pVWrVumGG27QgQMHYrZLlj67ktPpVFlZWWQ+IyNDHo9H119//ZC2ydCH8fTDaF6LiaisrExTp06NzM+dOzdmf8V77iaaeM6zZO7DzMzMqPl3331XxcXFQ9ola/9dKd73xons00kbRP7zn/9oYGAg6iE/Ho9Hvb29OnfuXFTbnp6eIQ8D8ng86u7unpBax2rPnj26/fbbY66bMWOGPvroI508eVJVVVV6+eWX9cILL0xwhaP3wAMP6JNPPtGhQ4c0b9483XHHHert7R3SLln77ErHjh1Tdna2fvSjHw1Zlwx9GE8/jOa1mAwOHDig9evXD1ke77mbaOI5z1KpD/fv368VK1YMWZ6s/XeleN8bJ7JPJ20QuZwGL/9T+ubGnm//89ttv93ucpsr2yWa+vp6/eIXvxixjdPp1MaNG7V582a9/vrrE1TZ/27hwoXas2ePLl68qP379w9Zn6x9dqWXX35ZNTU1I7ZJ5D6Mpx9G81pMdCdOnNDAwIBKSkqGbXO1czdRjXSepUof/vOf/9TixYujRriulKz9d1m8740T2adJHUROnDihefPmDTv95je/GfZvZ8+eralTpyoUCkWWBYNBZWRkaPbs2VFtMzMzo9pdbpudnW3vAV3FaI7366+/lt/v16JFi+La9i9/+UvjX+8abX/OmTNHy5Yti1l3ovTZlUZzjI2Njbr55pt1yy23xLXtROjDK8XTD6N5LSaywcFBVVVVqbq6+qptRzp3E12s8yxV+vDdd9+96v+8Scndf/G+N05kn477Q+/GU2Zm5ph/G8ThcGjFihX64osvIss+++wzrVy5ckjblStXatu2bbIsSw6HQ/39/QoEAiosLBxr6WMymuNtaGjQbbfdFve2BwcHo64FmjCW/hyu7kTpsyvFe4yHDx/W0aNHtXbtWknSwMCA0tPTR3xYZCL04ZXi6YfRvBYT2fbt27Vp0yZNnz5d0jfX2Ef6P+tE7K94xKo7Vfrwvffe04cffhhX22Ttv3jfGyeyT5N6RGQ0Yg0l3Xfffdq7d29k/oMPPtCvf/1rSVJ7e7u2bt0qSVq2bJmysrLU3NwsSWpqalJubq7y8/MnoPKxiXVZ5tvH1NzcrF27dkVuOqupqdGjjz464XWOxsDAgLZt2ya/3y9JOnTokFwul37wgx9ISv4+u+zkyZN66aWXdOutt+rw4cNqa2tTZWWlpOTrw5H6YevWrWpvb5c08msxGVRVVSkvL0/nzp3T4cOH9f777+uDDz6IHOPVzt1ENtJ5lkp96Pf7NX369Kj/20+F/rvys+9q741G+tT27+EkoI8//ti6+eabrdmzZ1tvv/125GtLlmVZzz77rPXkk09av/vd76yKiorI8t27d1u33HJLZP7IkSPWvffea/3xj3+07r77buvIkSMTegyjMTAwYHm93qjjtKzoY/rTn/5kzZ4921q9erX19NNPR77alcguXrxoLV++3LrmmmusJ554wtq2bZv11VdfRdYnc59d1tvbay1ZssSSFDU9+eSTlmUlZx8O1w9Lliyx3n777Ui74V6Lie7NN9+0HA5HVH9NmzbNCoVCkWO82rmbyEY6z1KlDy3LsrZv3249//zzUcuSvf9OnTplPfvss5Yka+3atdahQ4csyxr5vdFEnzosK8G+7wcAACaNSXNpBgAAJB6CCAAAMIYgAgAAjCGIAAAAYwgiAADAGIIIAAAwhiACAACMIYgAAABjCCIAAMAYgggAADCGIAIAAIz5P4fZqcDMAXuYAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(randomN_input.detach().cpu().numpy()[:,0], bins = 50, range=(-10,10))\n", + "plt.hist(output.detach().cpu().numpy()[:,0], bins = 50, range=(-10,10))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 280, + "id": "72b15ab5-9031-444e-886d-b9f47859d01e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "def MMD(x, y, kernel):\n", + " \"\"\"Emprical maximum mean discrepancy. The lower the result\n", + " the more evidence that distributions are the same.\n", + "\n", + " Args:\n", + " x: first sample, distribution P\n", + " y: second sample, distribution Q\n", + " kernel: kernel type such as \"multiscale\" or \"rbf\"\n", + " \"\"\"\n", + " xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())\n", + " rx = (xx.diag().unsqueeze(0).expand_as(xx))\n", + " ry = (yy.diag().unsqueeze(0).expand_as(yy))\n", + "\n", + " dxx = rx.t() + rx - 2. * xx # Used for A in (1)\n", + " dyy = ry.t() + ry - 2. * yy # Used for B in (1)\n", + " dxy = rx.t() + ry - 2. * zz # Used for C in (1)\n", + "\n", + " XX, YY, XY = (torch.zeros(xx.shape).to(device),\n", + " torch.zeros(xx.shape).to(device),\n", + " torch.zeros(xx.shape).to(device))\n", + "\n", + " if kernel == \"multiscale\":\n", + "\n", + " bandwidth_range = [0.2, 0.5, 0.9, 1.3]\n", + " for a in bandwidth_range:\n", + " XX += a**2 * (a**2 + dxx)**-1\n", + " YY += a**2 * (a**2 + dyy)**-1\n", + " XY += a**2 * (a**2 + dxy)**-1\n", + "\n", + " if kernel == \"rbf\":\n", + "\n", + " bandwidth_range = [10, 15, 20, 50]\n", + " for a in bandwidth_range:\n", + " XX += torch.exp(-0.5*dxx/a)\n", + " YY += torch.exp(-0.5*dyy/a)\n", + " XY += torch.exp(-0.5*dxy/a)\n", + "\n", + "\n", + " return torch.mean(XX + YY - 2. * XY)" + ] + }, + { + "cell_type": "code", + "execution_count": 281, + "id": "7c80b247-3722-4f3d-ae88-d4ee8b7b2d55", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + " x = torch.rand(bs, input_dim)\n", + " y = torch.randn(bs, input_dim)" + ] + }, + { + "cell_type": "code", + "execution_count": 283, + "id": "6df76ddf-876c-4eca-96c4-ee5dee57443e", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.0637, device='cuda:0')" + ] + }, + "execution_count": 283, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "MMD(x.cuda(), y.cuda(), kernel='rbf')" + ] + }, + { + "cell_type": "markdown", + "id": "a93e53ef-f6d6-4bb9-919c-2173e2634ac2", + "metadata": {}, + "source": [ + "## WITH A VARIATIONAL AUTOENCODER" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6a6b88d-e9ea-4ee1-ad8e-bfe81442864e", + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "import torch\n", + "class CondVAE(nn.Module):\n", + " def __init__(self, dim_input, latent_dim=10, size=[150,150]):\n", + " super(CondVAE, self).__init__()\n", + " \n", + " self.latent_dim = latent_dim\n", + "\n", + " # Encoder\n", + " self.encoder = nn.Sequential(\n", + " nn.Linear(in_features=dim_input, out_features=size[0]*size[1]),\n", + " nn.Unflatten(1, (1, size[0], size[1], 60)),\n", + " nn.ReLU(),\n", + " nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1), \n", + " nn.ReLU(),\n", + " nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),\n", + " nn.ReLU(),\n", + " nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),\n", + " nn.ReLU(),\n", + " nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),\n", + " nn.ReLU(),\n", + " nn.Flatten()\n", + " )\n", + " \n", + " self.fc_mu = nn.Linear(16384, latent_dim)\n", + " self.fc_logvar = nn.Linear(16384, latent_dim)\n", + "\n", + " # Decoder\n", + " self.decoder = nn.Sequential(\n", + " nn.Linear(latent_dim + 8, 16384), \n", + " nn.Unflatten(1, (256, 8, 8)),\n", + " nn.ReLU(),\n", + " nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1),\n", + " nn.ReLU(),\n", + " nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1),\n", + " nn.ReLU(),\n", + " nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=0),\n", + " nn.ReLU(),\n", + " nn.ConvTranspose2d(32, 1, kernel_size=2, stride=1, padding=0),\n", + " nn.Sigmoid()\n", + " )\n", + "\n", + " def encode(self, x):\n", + " x = self.encoder(x)\n", + " mu = self.fc_mu(x)\n", + " log_var = self.fc_logvar(x)\n", + "\n", + " return mu, log_var\n", + " \n", + " def decode(self, z, properties):\n", + " # Concatenate the sampling (latent distribution) + embedding -> samples conditioned on both the input data and the specified label\n", + " #print(properties.shape, z.shape)\n", + " zcomb = torch.concat((z, properties), 1)\n", + " #print(zcomb.shape)\n", + " \n", + " return self.decoder(zcomb) \n", + " \n", + " def sampling(self, mu, log_var):\n", + " # calculate standard deviation\n", + " std = log_var.mul(0.5).exp_()\n", + " \n", + " # create noise tensor of same size as std to add to the latent vector\n", + " eps = torch.cuda.FloatTensor(std.size()).normal_()\n", + " \n", + " # multiply eps with std to scale the random noise according to the learned distribution + add combined\n", + " return eps.mul(std).add_(mu) # return z sample \n", + "\n", + " def forward(self, x, properties):\n", + " mu, log_var = self.encode(x)\n", + " z = self.sampling(mu, log_var)\n", + " #print(z.shape)\n", + "\n", + " return self.decode(z, properties), mu, log_var\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "DLenv2", + "language": "python", + "name": "dlenv2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}