File size: 216,093 Bytes
83586b8 |
|
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "982e76f5",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T16:54:38.420949Z",
"iopub.status.busy": "2024-03-22T16:54:38.420627Z",
"iopub.status.idle": "2024-03-22T16:54:38.453783Z",
"shell.execute_reply": "2024-03-22T16:54:38.452888Z"
},
"papermill": {
"duration": 0.04789,
"end_time": "2024-03-22T16:54:38.455806",
"exception": false,
"start_time": "2024-03-22T16:54:38.407916",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"import joblib\n",
"\n",
"#joblib.parallel_backend(\"threading\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "675f0b41",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T16:54:38.481275Z",
"iopub.status.busy": "2024-03-22T16:54:38.480905Z",
"iopub.status.idle": "2024-03-22T16:54:38.487727Z",
"shell.execute_reply": "2024-03-22T16:54:38.486854Z"
},
"papermill": {
"duration": 0.021698,
"end_time": "2024-03-22T16:54:38.489673",
"exception": false,
"start_time": "2024-03-22T16:54:38.467975",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'\\n%cd /kaggle/working\\n#!git clone https://github.com/R-N/ml-utility-loss --depth=1 --single-branch --branch=main\\n%cd ml-utility-loss\\n!git pull\\n#!pip install .\\n!pip install . --no-deps --force-reinstall --upgrade\\n#'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\"\"\"\n",
"%cd /kaggle/working\n",
"#!git clone https://github.com/R-N/ml-utility-loss --depth=1 --single-branch --branch=main\n",
"%cd ml-utility-loss\n",
"!git pull\n",
"#!pip install .\n",
"!pip install . --no-deps --force-reinstall --upgrade\n",
"#\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5ae30f5c",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T16:54:38.512890Z",
"iopub.status.busy": "2024-03-22T16:54:38.512651Z",
"iopub.status.idle": "2024-03-22T16:54:38.516571Z",
"shell.execute_reply": "2024-03-22T16:54:38.515766Z"
},
"papermill": {
"duration": 0.017861,
"end_time": "2024-03-22T16:54:38.518464",
"exception": false,
"start_time": "2024-03-22T16:54:38.500603",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"plt.rcParams['figure.figsize'] = [3,3]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9f42c810",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T16:54:38.541803Z",
"iopub.status.busy": "2024-03-22T16:54:38.541370Z",
"iopub.status.idle": "2024-03-22T16:54:38.545217Z",
"shell.execute_reply": "2024-03-22T16:54:38.544371Z"
},
"executionInfo": {
"elapsed": 678,
"status": "ok",
"timestamp": 1696841022168,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "ns5hFcVL2yvs",
"papermill": {
"duration": 0.017736,
"end_time": "2024-03-22T16:54:38.547270",
"exception": false,
"start_time": "2024-03-22T16:54:38.529534",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"datasets = [\n",
" \"insurance\",\n",
" \"treatment\",\n",
" \"contraceptive\"\n",
"]\n",
"\n",
"study_dir = \"./\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "85d0c8ce",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T16:54:38.571177Z",
"iopub.status.busy": "2024-03-22T16:54:38.570870Z",
"iopub.status.idle": "2024-03-22T16:54:38.576291Z",
"shell.execute_reply": "2024-03-22T16:54:38.575410Z"
},
"papermill": {
"duration": 0.019685,
"end_time": "2024-03-22T16:54:38.578251",
"exception": false,
"start_time": "2024-03-22T16:54:38.558566",
"status": "completed"
},
"tags": [
"parameters"
]
},
"outputs": [],
"source": [
"#Parameters\n",
"import os\n",
"\n",
"path_prefix = \"../../../../\"\n",
"\n",
"dataset_dir = os.path.join(path_prefix, \"ml-utility-loss/datasets\")\n",
"dataset_name = \"treatment\"\n",
"model_name=\"ml_utility_2\"\n",
"models = [\"tvae\", \"realtabformer\", \"lct_gan\", \"tab_ddpm_concat\"]\n",
"single_model = \"lct_gan\"\n",
"random_seed = 42\n",
"gp = True\n",
"gp_multiply = True\n",
"folder = \"eval\"\n",
"debug = False\n",
"path = None\n",
"param_index = 0\n",
"allow_same_prediction = True\n",
"log_wandb = False"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e2d3d897",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T16:54:38.603452Z",
"iopub.status.busy": "2024-03-22T16:54:38.603196Z",
"iopub.status.idle": "2024-03-22T16:54:38.608013Z",
"shell.execute_reply": "2024-03-22T16:54:38.607193Z"
},
"papermill": {
"duration": 0.019844,
"end_time": "2024-03-22T16:54:38.609908",
"exception": false,
"start_time": "2024-03-22T16:54:38.590064",
"status": "completed"
},
"tags": [
"injected-parameters"
]
},
"outputs": [],
"source": [
"# Parameters\n",
"dataset = \"contraceptive\"\n",
"dataset_name = \"contraceptive\"\n",
"single_model = \"tvae\"\n",
"gp = False\n",
"gp_multiply = False\n",
"random_seed = 42\n",
"debug = False\n",
"folder = \"eval\"\n",
"path_prefix = \"../../../../\"\n",
"path = \"eval/contraceptive/tvae/42\"\n",
"param_index = 0\n",
"allow_same_prediction = True\n",
"log_wandb = False\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd7c02d6",
"metadata": {
"papermill": {
"duration": 0.011211,
"end_time": "2024-03-22T16:54:38.632439",
"exception": false,
"start_time": "2024-03-22T16:54:38.621228",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5f45b1d0",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T16:54:38.655861Z",
"iopub.status.busy": "2024-03-22T16:54:38.655594Z",
"iopub.status.idle": "2024-03-22T16:54:38.664480Z",
"shell.execute_reply": "2024-03-22T16:54:38.663684Z"
},
"executionInfo": {
"elapsed": 7,
"status": "ok",
"timestamp": 1696841022169,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "UdvXYv3c3LXy",
"papermill": {
"duration": 0.022742,
"end_time": "2024-03-22T16:54:38.666331",
"exception": false,
"start_time": "2024-03-22T16:54:38.643589",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/kaggle/working\n",
"/kaggle/working/eval/contraceptive/tvae/42\n"
]
}
],
"source": [
"from pathlib import Path\n",
"import os\n",
"\n",
"%cd /kaggle/working/\n",
"\n",
"if path is None:\n",
" path = os.path.join(folder, dataset_name, single_model, random_seed)\n",
"Path(path).mkdir(parents=True, exist_ok=True)\n",
"\n",
"%cd {path}"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f85bf540",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T16:54:38.690470Z",
"iopub.status.busy": "2024-03-22T16:54:38.690185Z",
"iopub.status.idle": "2024-03-22T16:54:40.726590Z",
"shell.execute_reply": "2024-03-22T16:54:40.725650Z"
},
"papermill": {
"duration": 2.050871,
"end_time": "2024-03-22T16:54:40.728604",
"exception": false,
"start_time": "2024-03-22T16:54:38.677733",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Set seed to <function seed at 0x7c41c3d1b370>\n"
]
}
],
"source": [
"from ml_utility_loss.util import seed\n",
"if single_model:\n",
" model_name=f\"{model_name}_{single_model}\"\n",
"if random_seed is not None:\n",
" seed(random_seed)\n",
" print(\"Set seed to\", seed)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "8489feae",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T16:54:40.760710Z",
"iopub.status.busy": "2024-03-22T16:54:40.760306Z",
"iopub.status.idle": "2024-03-22T16:54:40.772850Z",
"shell.execute_reply": "2024-03-22T16:54:40.771812Z"
},
"papermill": {
"duration": 0.03277,
"end_time": "2024-03-22T16:54:40.774761",
"exception": false,
"start_time": "2024-03-22T16:54:40.741991",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import json\n",
"import os\n",
"\n",
"df = pd.read_csv(os.path.join(dataset_dir, f\"{dataset_name}.csv\"))\n",
"with open(os.path.join(dataset_dir, f\"{dataset_name}.json\")) as f:\n",
" info = json.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "debcc684",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T16:54:40.799242Z",
"iopub.status.busy": "2024-03-22T16:54:40.798929Z",
"iopub.status.idle": "2024-03-22T16:54:40.806026Z",
"shell.execute_reply": "2024-03-22T16:54:40.805321Z"
},
"executionInfo": {
"elapsed": 6,
"status": "ok",
"timestamp": 1696841022169,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "Vrl2QkoV3o_8",
"papermill": {
"duration": 0.021718,
"end_time": "2024-03-22T16:54:40.808141",
"exception": false,
"start_time": "2024-03-22T16:54:40.786423",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"task = info[\"task\"]\n",
"target = info[\"target\"]\n",
"cat_features = info[\"cat_features\"]\n",
"mixed_features = info[\"mixed_features\"]\n",
"longtail_features = info[\"longtail_features\"]\n",
"integer_features = info[\"integer_features\"]\n",
"\n",
"test = df.sample(frac=0.2, random_state=42)\n",
"train = df[~df.index.isin(test.index)]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "7538184a",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T16:54:40.833629Z",
"iopub.status.busy": "2024-03-22T16:54:40.833384Z",
"iopub.status.idle": "2024-03-22T16:54:40.928573Z",
"shell.execute_reply": "2024-03-22T16:54:40.927760Z"
},
"executionInfo": {
"elapsed": 6,
"status": "ok",
"timestamp": 1696841022169,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "TilUuFk9vqMb",
"papermill": {
"duration": 0.110635,
"end_time": "2024-03-22T16:54:40.930702",
"exception": false,
"start_time": "2024-03-22T16:54:40.820067",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"import ml_utility_loss.synthesizers.tab_ddpm.params as TAB_DDPM_PARAMS\n",
"import ml_utility_loss.synthesizers.lct_gan.params as LCT_GAN_PARAMS\n",
"import ml_utility_loss.synthesizers.realtabformer.params as RTF_PARAMS\n",
"from ml_utility_loss.synthesizers.realtabformer.params.default import GPT2_PARAMS, REALTABFORMER_PARAMS\n",
"from ml_utility_loss.util import filter_dict_2, filter_dict\n",
"\n",
"tab_ddpm_params = getattr(TAB_DDPM_PARAMS, dataset_name).BEST\n",
"lct_gan_params = getattr(LCT_GAN_PARAMS, dataset_name).BEST\n",
"lct_ae_params = filter_dict_2(lct_gan_params, LCT_GAN_PARAMS.default.AE_PARAMS)\n",
"rtf_params = getattr(RTF_PARAMS, dataset_name).BEST\n",
"rtf_params = filter_dict(rtf_params, REALTABFORMER_PARAMS)\n",
"\n",
"lct_ae_embedding_size=lct_gan_params[\"embedding_size\"]\n",
"tab_ddpm_normalization=\"quantile\"\n",
"tab_ddpm_cat_encoding=tab_ddpm_params[\"cat_encoding\"]\n",
"#tab_ddpm_cat_encoding=\"one-hot\"\n",
"tab_ddpm_y_policy=\"default\"\n",
"tab_ddpm_is_y_cond=True"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "cca61838",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T16:54:40.958708Z",
"iopub.status.busy": "2024-03-22T16:54:40.958059Z",
"iopub.status.idle": "2024-03-22T16:54:45.689430Z",
"shell.execute_reply": "2024-03-22T16:54:45.688570Z"
},
"executionInfo": {
"elapsed": 3113,
"status": "ok",
"timestamp": 1696841025277,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "7Abt8nStvr9Z",
"papermill": {
"duration": 4.747577,
"end_time": "2024-03-22T16:54:45.691899",
"exception": false,
"start_time": "2024-03-22T16:54:40.944322",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-03-22 16:54:43.265606: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-03-22 16:54:43.265666: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-03-22 16:54:43.267291: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
]
}
],
"source": [
"from ml_utility_loss.loss_learning.estimator.pipeline import load_lct_ae\n",
"\n",
"lct_ae = load_lct_ae(\n",
" dataset_name=dataset_name,\n",
" model_dir=os.path.join(path_prefix, \"ml-utility-loss/models\"),\n",
" model_name=\"lct_ae\",\n",
" df_name=\"df\",\n",
")\n",
"lct_ae = None"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "6f83b7b6",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T16:54:45.718090Z",
"iopub.status.busy": "2024-03-22T16:54:45.716978Z",
"iopub.status.idle": "2024-03-22T16:54:45.722995Z",
"shell.execute_reply": "2024-03-22T16:54:45.722101Z"
},
"papermill": {
"duration": 0.020744,
"end_time": "2024-03-22T16:54:45.725071",
"exception": false,
"start_time": "2024-03-22T16:54:45.704327",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"from ml_utility_loss.loss_learning.estimator.pipeline import load_rtf_embed\n",
"\n",
"rtf_embed = load_rtf_embed(\n",
" dataset_name=dataset_name,\n",
" model_dir=os.path.join(path_prefix, \"ml-utility-loss/models\"),\n",
" model_name=\"realtabformer\",\n",
" df_name=\"df\",\n",
" ckpt_type=\"best-disc-model\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "0026de74",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T16:54:45.751786Z",
"iopub.status.busy": "2024-03-22T16:54:45.751472Z",
"iopub.status.idle": "2024-03-22T16:54:54.134274Z",
"shell.execute_reply": "2024-03-22T16:54:54.133246Z"
},
"executionInfo": {
"elapsed": 20137,
"status": "ok",
"timestamp": 1696841045408,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "tbaguWxAvtPi",
"papermill": {
"duration": 8.399093,
"end_time": "2024-03-22T16:54:54.136682",
"exception": false,
"start_time": "2024-03-22T16:54:45.737589",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n",
" warnings.warn(\n",
"/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:119: ConvergenceWarning: Number of distinct clusters (4) found smaller than n_clusters (10). Possibly due to duplicate points in X.\n",
" .fit(X)\n",
"/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:119: ConvergenceWarning: Number of distinct clusters (4) found smaller than n_clusters (10). Possibly due to duplicate points in X.\n",
" .fit(X)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n",
" warnings.warn(\n",
"/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:119: ConvergenceWarning: Number of distinct clusters (4) found smaller than n_clusters (10). Possibly due to duplicate points in X.\n",
" .fit(X)\n"
]
}
],
"source": [
"from ml_utility_loss.loss_learning.estimator.preprocessing import DataPreprocessor\n",
"\n",
"preprocessor = DataPreprocessor(\n",
" task,\n",
" target=target,\n",
" cat_features=cat_features,\n",
" mixed_features=mixed_features,\n",
" longtail_features=longtail_features,\n",
" integer_features=integer_features,\n",
" lct_ae_embedding_size=lct_ae_embedding_size,\n",
" lct_ae_params=lct_ae_params,\n",
" lct_ae=lct_ae,\n",
" tab_ddpm_normalization=tab_ddpm_normalization,\n",
" tab_ddpm_cat_encoding=tab_ddpm_cat_encoding,\n",
" tab_ddpm_y_policy=tab_ddpm_y_policy,\n",
" tab_ddpm_is_y_cond=tab_ddpm_is_y_cond,\n",
" realtabformer_embedding=rtf_embed,\n",
" realtabformer_params=rtf_params,\n",
")\n",
"preprocessor.fit(df)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "a9c9b110",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"execution": {
"iopub.execute_input": "2024-03-22T16:54:54.164061Z",
"iopub.status.busy": "2024-03-22T16:54:54.163711Z",
"iopub.status.idle": "2024-03-22T16:54:54.171531Z",
"shell.execute_reply": "2024-03-22T16:54:54.170687Z"
},
"executionInfo": {
"elapsed": 13,
"status": "ok",
"timestamp": 1696841045411,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "OxUH_GBEv2qK",
"outputId": "76464c90-3baf-4bdc-a955-6f4fddc16b9c",
"papermill": {
"duration": 0.023874,
"end_time": "2024-03-22T16:54:54.173518",
"exception": false,
"start_time": "2024-03-22T16:54:54.149644",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"{'tvae': 46,\n",
" 'realtabformer': (24, 72, Embedding(72, 672), True),\n",
" 'lct_gan': 40,\n",
" 'tab_ddpm_concat': 10}"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessor.adapter_sizes"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "3cb9ed90",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T16:54:54.198629Z",
"iopub.status.busy": "2024-03-22T16:54:54.198356Z",
"iopub.status.idle": "2024-03-22T16:54:54.203955Z",
"shell.execute_reply": "2024-03-22T16:54:54.203117Z"
},
"papermill": {
"duration": 0.020638,
"end_time": "2024-03-22T16:54:54.205880",
"exception": false,
"start_time": "2024-03-22T16:54:54.185242",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"from ml_utility_loss.loss_learning.estimator.pipeline import load_dataset_3_factory\n",
"\n",
"datasetsn = load_dataset_3_factory(\n",
" dataset_dir=os.path.join(path_prefix, \"ml-utility-loss/\"),\n",
" dataset_name=dataset_name,\n",
" preprocessor=preprocessor,\n",
" cache_dir=path_prefix,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "ad1eb833",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T16:54:54.230490Z",
"iopub.status.busy": "2024-03-22T16:54:54.230224Z",
"iopub.status.idle": "2024-03-22T17:02:21.663470Z",
"shell.execute_reply": "2024-03-22T17:02:21.662503Z"
},
"papermill": {
"duration": 447.460222,
"end_time": "2024-03-22T17:02:21.677822",
"exception": false,
"start_time": "2024-03-22T16:54:54.217600",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Caching in ../../../../contraceptive/_cache_aug_test/tvae/all inf False\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"../../../../ml-utility-loss/aug_test/contraceptive [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
"Caching in ../../../../contraceptive/_cache_bs_test/tvae/all inf False\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"../../../../ml-utility-loss/bs_test/contraceptive [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
"Caching in ../../../../contraceptive/_cache_synth_test/tvae/all inf False\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"../../../../ml-utility-loss/synthetics/contraceptive [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
"1050\n"
]
}
],
"source": [
"from ml_utility_loss.loss_learning.estimator.pipeline import load_dataset_4\n",
"\n",
"test_set = load_dataset_4(\n",
" dataset_dir=os.path.join(path_prefix, \"ml-utility-loss/\"),\n",
" dataset_name=dataset_name,\n",
" preprocessor=preprocessor,\n",
" model=single_model,\n",
" cache_dir=path_prefix,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "14ff8b40",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T17:02:21.706111Z",
"iopub.status.busy": "2024-03-22T17:02:21.705309Z",
"iopub.status.idle": "2024-03-22T17:02:22.031100Z",
"shell.execute_reply": "2024-03-22T17:02:22.030076Z"
},
"executionInfo": {
"elapsed": 588,
"status": "ok",
"timestamp": 1696841049215,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "NgahtU1q9uLO",
"papermill": {
"duration": 0.342422,
"end_time": "2024-03-22T17:02:22.033338",
"exception": false,
"start_time": "2024-03-22T17:02:21.690916",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"{'Body': 'twin_encoder',\n",
" 'loss_balancer_meta': True,\n",
" 'loss_balancer_log': False,\n",
" 'loss_balancer_lbtw': False,\n",
" 'pma_skip_small': False,\n",
" 'isab_skip_small': False,\n",
" 'layer_norm': False,\n",
" 'pma_layer_norm': False,\n",
" 'attn_residual': True,\n",
" 'tf_n_layers_dec': False,\n",
" 'tf_isab_rank': 0,\n",
" 'tf_layer_norm': False,\n",
" 'tf_pma_start': -1,\n",
" 'head_n_seeds': 0,\n",
" 'tf_pma_low': 16,\n",
" 'dropout': 0,\n",
" 'combine_mode': 'diff_left',\n",
" 'tf_isab_mode': 'separate',\n",
" 'grad_loss_fn': <function torch.nn.functional.l1_loss(input: torch.Tensor, target: torch.Tensor, size_average: Optional[bool] = None, reduce: Optional[bool] = None, reduction: str = 'mean') -> torch.Tensor>,\n",
" 'single_model': True,\n",
" 'bias': True,\n",
" 'bias_final': True,\n",
" 'pma_ffn_mode': 'none',\n",
" 'patience': 10,\n",
" 'inds_init_mode': 'fixnorm',\n",
" 'grad_clip': 0.73,\n",
" 'gradient_penalty_mode': {'gradient_penalty': False,\n",
" 'calc_grad_m': False,\n",
" 'avg_non_role_model_m': False,\n",
" 'inverse_avg_non_role_model_m': False},\n",
" 'synth_data': 2,\n",
" 'bias_lr_mul': 1.0,\n",
" 'bias_weight_decay': 0.05,\n",
" 'head_activation': torch.nn.modules.activation.Softsign,\n",
" 'loss_balancer_beta': 0.67,\n",
" 'loss_balancer_r': 0.943,\n",
" 'tf_activation': torch.nn.modules.activation.Tanh,\n",
" 'dataset_size': 2048,\n",
" 'batch_size': 4,\n",
" 'epochs': 100,\n",
" 'lr_mul': 0.09,\n",
" 'n_warmup_steps': 100,\n",
" 'Optim': functools.partial(<class 'torch.optim.adamw.AdamW'>, amsgrad=True),\n",
" 'fixed_role_model': 'tvae',\n",
" 'd_model': 256,\n",
" 'attn_activation': torch.nn.modules.activation.PReLU,\n",
" 'tf_d_inner': 512,\n",
" 'tf_n_layers_enc': 3,\n",
" 'tf_n_head': 32,\n",
" 'tf_activation_final': ml_utility_loss.activations.LeakyHardtanh,\n",
" 'ada_d_hid': 1024,\n",
" 'ada_n_layers': 9,\n",
" 'ada_activation': torch.nn.modules.activation.Softsign,\n",
" 'ada_activation_final': ml_utility_loss.activations.LeakyHardsigmoid,\n",
" 'head_d_hid': 256,\n",
" 'head_n_layers': 9,\n",
" 'head_n_head': 32,\n",
" 'head_activation_final': ml_utility_loss.activations.LeakyHardsigmoid,\n",
" 'models': ['tvae'],\n",
" 'max_seconds': 3600,\n",
" 'tf_lora': False,\n",
" 'tf_num_inds': 128,\n",
" 'ada_n_seeds': 0,\n",
" 'gradient_penalty_kwargs': {'mag_loss': True,\n",
" 'mse_mag': False,\n",
" 'mag_corr': False,\n",
" 'seq_mag': False,\n",
" 'cos_loss': False,\n",
" 'mag_corr_kwargs': {'only_sign': False},\n",
" 'cos_loss_kwargs': {'only_sign': True, 'cos_matrix': False},\n",
" 'mse_mag_kwargs': {'target': 0.65, 'multiply': True, 'forgive_over': True}}}"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import ml_utility_loss.loss_learning.estimator.params2 as PARAMS\n",
"from ml_utility_loss.tuning import map_parameters\n",
"from ml_utility_loss.loss_learning.estimator.params.default import update_param_space, update_param_space_2\n",
"import wandb\n",
"\n",
"#\"\"\"\n",
"param_space = {\n",
" **getattr(PARAMS, dataset_name).PARAM_SPACE,\n",
"}\n",
"params = {\n",
" **getattr(PARAMS, dataset_name).BESTS[param_index],\n",
"}\n",
"if gp:\n",
" params[\"gradient_penalty_mode\"] = \"ALL\"\n",
" params[\"mse_mag\"] = True\n",
" if gp_multiply:\n",
" params[\"mse_mag_multiply\"] = True\n",
" #params[\"mse_mag_target\"] = 1.0\n",
" else:\n",
" params[\"mse_mag_multiply\"] = False\n",
" #params[\"mse_mag_target\"] = 0.1\n",
"else:\n",
" params[\"gradient_penalty_mode\"] = \"NONE\"\n",
" params[\"mse_mag\"] = False\n",
"params[\"single_model\"] = False\n",
"if models:\n",
" params[\"models\"] = models\n",
"if single_model:\n",
" params[\"fixed_role_model\"] = single_model\n",
" params[\"single_model\"] = True\n",
" params[\"models\"] = [single_model]\n",
"if params[\"fixed_role_model\"] == \"realtabformer\" and dataset_name == \"treatment\":\n",
" params[\"batch_size\"] = 2\n",
"params[\"max_seconds\"] = 3600\n",
"params[\"patience\"] = 10\n",
"params[\"epochs\"] = 100\n",
"if debug:\n",
" params[\"epochs\"] = 2\n",
"with open(\"params.json\", \"w\") as f:\n",
" json.dump(params, f)\n",
"params = map_parameters(params, param_space=param_space)\n",
"params"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "a48bd9e9",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T17:02:22.061219Z",
"iopub.status.busy": "2024-03-22T17:02:22.060881Z",
"iopub.status.idle": "2024-03-22T17:12:00.661298Z",
"shell.execute_reply": "2024-03-22T17:12:00.660201Z"
},
"papermill": {
"duration": 578.632009,
"end_time": "2024-03-22T17:12:00.678673",
"exception": false,
"start_time": "2024-03-22T17:02:22.046664",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Caching in ../../../../contraceptive/_cache_aug_train/tvae/all inf False\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"split df ratio is 0\n",
"../../../../ml-utility-loss/aug_train/contraceptive [400, 0]\n",
"Caching in ../../../../contraceptive/_cache_aug_val/tvae/all inf False\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"split df ratio is 1\n",
"../../../../ml-utility-loss/aug_val/contraceptive [0, 200]\n",
"Caching in ../../../../contraceptive/_cache_bs_train/tvae/all inf False\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"split df ratio is 0\n",
"../../../../ml-utility-loss/bs_train/contraceptive [100, 0]\n",
"Caching in ../../../../contraceptive/_cache_bs_val/tvae/all inf False\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"split df ratio is 1\n",
"../../../../ml-utility-loss/bs_val/contraceptive [0, 50]\n",
"Caching in ../../../../contraceptive/_cache_synth/tvae/all inf False\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Splitting without random!\n",
"Split with reverse index!\n",
"../../../../ml-utility-loss/synthetics/contraceptive [400, 200]\n",
"[900, 450]\n",
"[900, 450]\n"
]
}
],
"source": [
"train_set, val_set = datasetsn(model=params[\"fixed_role_model\"], synth_data=params[\"synth_data\"])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "2fcb1418",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"execution": {
"iopub.execute_input": "2024-03-22T17:12:00.708438Z",
"iopub.status.busy": "2024-03-22T17:12:00.708106Z",
"iopub.status.idle": "2024-03-22T17:12:01.197733Z",
"shell.execute_reply": "2024-03-22T17:12:01.196772Z"
},
"executionInfo": {
"elapsed": 396850,
"status": "error",
"timestamp": 1696841446059,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "_bt1MQc5kpSk",
"outputId": "01c1d3e5-ac64-461d-835a-b76f4a66e6d6",
"papermill": {
"duration": 0.506346,
"end_time": "2024-03-22T17:12:01.199915",
"exception": false,
"start_time": "2024-03-22T17:12:00.693569",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Creating model of type <class 'ml_utility_loss.loss_learning.estimator.model.models.TwinEncoder'>\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[*] Embedding False True\n",
"['tvae'] 1\n"
]
}
],
"source": [
"from ml_utility_loss.loss_learning.estimator.model.pipeline import remove_non_model_params\n",
"from ml_utility_loss.loss_learning.estimator.pipeline import create_model\n",
"from ml_utility_loss.util import filter_dict, clear_memory\n",
"\n",
"clear_memory()\n",
"\n",
"params2 = remove_non_model_params(params)\n",
"adapters = filter_dict(preprocessor.adapter_sizes, params[\"models\"])\n",
"\n",
"model = create_model(\n",
" adapters=adapters,\n",
" #Body=\"twin_encoder\",\n",
" **params2,\n",
")\n",
"#cf.apply_weight_standardization(model, n_last_layers_ignore=0)\n",
"print(model.models, len(model.adapters))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "938f94fc",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T17:12:01.229980Z",
"iopub.status.busy": "2024-03-22T17:12:01.229573Z",
"iopub.status.idle": "2024-03-22T17:12:01.234682Z",
"shell.execute_reply": "2024-03-22T17:12:01.233844Z"
},
"papermill": {
"duration": 0.022181,
"end_time": "2024-03-22T17:12:01.236594",
"exception": false,
"start_time": "2024-03-22T17:12:01.214413",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"study_name=f\"{model_name}_{dataset_name}\""
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "12fb613e",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T17:12:01.264684Z",
"iopub.status.busy": "2024-03-22T17:12:01.264409Z",
"iopub.status.idle": "2024-03-22T17:12:01.271364Z",
"shell.execute_reply": "2024-03-22T17:12:01.270425Z"
},
"papermill": {
"duration": 0.023683,
"end_time": "2024-03-22T17:12:01.273479",
"exception": false,
"start_time": "2024-03-22T17:12:01.249796",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"11895304"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"count_parameters(model)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "bd386e57",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T17:12:01.303569Z",
"iopub.status.busy": "2024-03-22T17:12:01.302995Z",
"iopub.status.idle": "2024-03-22T17:12:01.383209Z",
"shell.execute_reply": "2024-03-22T17:12:01.382160Z"
},
"papermill": {
"duration": 0.096944,
"end_time": "2024-03-22T17:12:01.385367",
"exception": false,
"start_time": "2024-03-22T17:12:01.288423",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"========================================================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"========================================================================================================================\n",
"MLUtilitySingle [2, 1179, 46] --\n",
"├─Adapter: 1-1 [2, 1179, 46] --\n",
"│ └─Sequential: 2-1 [2, 1179, 256] --\n",
"│ │ └─FeedForward: 3-1 [2, 1179, 1024] --\n",
"│ │ │ └─Linear: 4-1 [2, 1179, 1024] 48,128\n",
"│ │ │ └─Softsign: 4-2 [2, 1179, 1024] --\n",
"│ │ └─FeedForward: 3-2 [2, 1179, 1024] --\n",
"│ │ │ └─Linear: 4-3 [2, 1179, 1024] 1,049,600\n",
"│ │ │ └─Softsign: 4-4 [2, 1179, 1024] --\n",
"│ │ └─FeedForward: 3-3 [2, 1179, 1024] --\n",
"│ │ │ └─Linear: 4-5 [2, 1179, 1024] 1,049,600\n",
"│ │ │ └─Softsign: 4-6 [2, 1179, 1024] --\n",
"│ │ └─FeedForward: 3-4 [2, 1179, 1024] --\n",
"│ │ │ └─Linear: 4-7 [2, 1179, 1024] 1,049,600\n",
"│ │ │ └─Softsign: 4-8 [2, 1179, 1024] --\n",
"│ │ └─FeedForward: 3-5 [2, 1179, 1024] --\n",
"│ │ │ └─Linear: 4-9 [2, 1179, 1024] 1,049,600\n",
"│ │ │ └─Softsign: 4-10 [2, 1179, 1024] --\n",
"│ │ └─FeedForward: 3-6 [2, 1179, 1024] --\n",
"│ │ │ └─Linear: 4-11 [2, 1179, 1024] 1,049,600\n",
"│ │ │ └─Softsign: 4-12 [2, 1179, 1024] --\n",
"│ │ └─FeedForward: 3-7 [2, 1179, 1024] --\n",
"│ │ │ └─Linear: 4-13 [2, 1179, 1024] 1,049,600\n",
"│ │ │ └─Softsign: 4-14 [2, 1179, 1024] --\n",
"│ │ └─FeedForward: 3-8 [2, 1179, 1024] --\n",
"│ │ │ └─Linear: 4-15 [2, 1179, 1024] 1,049,600\n",
"│ │ │ └─Softsign: 4-16 [2, 1179, 1024] --\n",
"│ │ └─FeedForward: 3-9 [2, 1179, 256] --\n",
"│ │ │ └─Linear: 4-17 [2, 1179, 256] 262,400\n",
"│ │ │ └─LeakyHardsigmoid: 4-18 [2, 1179, 256] --\n",
"├─Adapter: 1-2 [2, 294, 46] (recursive)\n",
"│ └─Sequential: 2-2 [2, 294, 256] (recursive)\n",
"│ │ └─FeedForward: 3-10 [2, 294, 1024] (recursive)\n",
"│ │ │ └─Linear: 4-19 [2, 294, 1024] (recursive)\n",
"│ │ │ └─Softsign: 4-20 [2, 294, 1024] --\n",
"│ │ └─FeedForward: 3-11 [2, 294, 1024] (recursive)\n",
"│ │ │ └─Linear: 4-21 [2, 294, 1024] (recursive)\n",
"│ │ │ └─Softsign: 4-22 [2, 294, 1024] --\n",
"│ │ └─FeedForward: 3-12 [2, 294, 1024] (recursive)\n",
"│ │ │ └─Linear: 4-23 [2, 294, 1024] (recursive)\n",
"│ │ │ └─Softsign: 4-24 [2, 294, 1024] --\n",
"│ │ └─FeedForward: 3-13 [2, 294, 1024] (recursive)\n",
"│ │ │ └─Linear: 4-25 [2, 294, 1024] (recursive)\n",
"│ │ │ └─Softsign: 4-26 [2, 294, 1024] --\n",
"│ │ └─FeedForward: 3-14 [2, 294, 1024] (recursive)\n",
"│ │ │ └─Linear: 4-27 [2, 294, 1024] (recursive)\n",
"│ │ │ └─Softsign: 4-28 [2, 294, 1024] --\n",
"│ │ └─FeedForward: 3-15 [2, 294, 1024] (recursive)\n",
"│ │ │ └─Linear: 4-29 [2, 294, 1024] (recursive)\n",
"│ │ │ └─Softsign: 4-30 [2, 294, 1024] --\n",
"│ │ └─FeedForward: 3-16 [2, 294, 1024] (recursive)\n",
"│ │ │ └─Linear: 4-31 [2, 294, 1024] (recursive)\n",
"│ │ │ └─Softsign: 4-32 [2, 294, 1024] --\n",
"│ │ └─FeedForward: 3-17 [2, 294, 1024] (recursive)\n",
"│ │ │ └─Linear: 4-33 [2, 294, 1024] (recursive)\n",
"│ │ │ └─Softsign: 4-34 [2, 294, 1024] --\n",
"│ │ └─FeedForward: 3-18 [2, 294, 256] (recursive)\n",
"│ │ │ └─Linear: 4-35 [2, 294, 256] (recursive)\n",
"│ │ │ └─LeakyHardsigmoid: 4-36 [2, 294, 256] --\n",
"├─TwinEncoder: 1-3 [2, 4096] --\n",
"│ └─Encoder: 2-3 [2, 16, 256] --\n",
"│ │ └─ModuleList: 3-20 -- (recursive)\n",
"│ │ │ └─EncoderLayer: 4-37 [2, 1179, 256] --\n",
"│ │ │ │ └─SimpleInducedSetAttention: 5-1 [2, 1179, 256] --\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-1 [2, 128, 256] 32,768\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-2 [2, 128, 256] --\n",
"│ │ │ │ │ │ └─Linear: 7-1 [2, 128, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-2 [2, 1179, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-3 [2, 1179, 256] 65,536\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-4 [2, 32, 128, 8] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-1 [2, 32, 128, 1179] --\n",
"│ │ │ │ │ │ └─Linear: 7-5 [2, 128, 256] 65,792\n",
"│ │ │ │ │ │ └─PReLU: 7-6 [2, 128, 256] 1\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-3 [2, 1179, 256] --\n",
"│ │ │ │ │ │ └─Linear: 7-7 [2, 1179, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-8 [2, 128, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-9 [2, 128, 256] 65,536\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-10 [2, 32, 1179, 8] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-2 [2, 32, 1179, 128] --\n",
"│ │ │ │ │ │ └─Linear: 7-11 [2, 1179, 256] 65,792\n",
"│ │ │ │ │ │ └─PReLU: 7-12 [2, 1179, 256] 1\n",
"│ │ │ │ └─DoubleFeedForward: 5-2 [2, 1179, 256] --\n",
"│ │ │ │ │ └─Linear: 6-4 [2, 1179, 512] 131,584\n",
"│ │ │ │ │ └─Tanh: 6-5 [2, 1179, 512] --\n",
"│ │ │ │ │ └─Linear: 6-6 [2, 1179, 256] 131,328\n",
"│ │ │ └─EncoderLayer: 4-38 [2, 1179, 256] --\n",
"│ │ │ │ └─SimpleInducedSetAttention: 5-3 [2, 1179, 256] --\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-7 [2, 128, 256] 32,768\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-8 [2, 128, 256] --\n",
"│ │ │ │ │ │ └─Linear: 7-13 [2, 128, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-14 [2, 1179, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-15 [2, 1179, 256] 65,536\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-16 [2, 32, 128, 8] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-3 [2, 32, 128, 1179] --\n",
"│ │ │ │ │ │ └─Linear: 7-17 [2, 128, 256] 65,792\n",
"│ │ │ │ │ │ └─PReLU: 7-18 [2, 128, 256] 1\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-9 [2, 1179, 256] --\n",
"│ │ │ │ │ │ └─Linear: 7-19 [2, 1179, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-20 [2, 128, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-21 [2, 128, 256] 65,536\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-22 [2, 32, 1179, 8] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-4 [2, 32, 1179, 128] --\n",
"│ │ │ │ │ │ └─Linear: 7-23 [2, 1179, 256] 65,792\n",
"│ │ │ │ │ │ └─PReLU: 7-24 [2, 1179, 256] 1\n",
"│ │ │ │ └─DoubleFeedForward: 5-4 [2, 1179, 256] --\n",
"│ │ │ │ │ └─Linear: 6-10 [2, 1179, 512] 131,584\n",
"│ │ │ │ │ └─Tanh: 6-11 [2, 1179, 512] --\n",
"│ │ │ │ │ └─Linear: 6-12 [2, 1179, 256] 131,328\n",
"│ │ │ └─EncoderLayer: 4-39 [2, 16, 256] --\n",
"│ │ │ │ └─SimpleInducedSetAttention: 5-5 [2, 1179, 256] --\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-13 [2, 128, 256] 32,768\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-14 [2, 128, 256] --\n",
"│ │ │ │ │ │ └─Linear: 7-25 [2, 128, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-26 [2, 1179, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-27 [2, 1179, 256] 65,536\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-28 [2, 32, 128, 8] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-5 [2, 32, 128, 1179] --\n",
"│ │ │ │ │ │ └─Linear: 7-29 [2, 128, 256] 65,792\n",
"│ │ │ │ │ │ └─PReLU: 7-30 [2, 128, 256] 1\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-15 [2, 1179, 256] --\n",
"│ │ │ │ │ │ └─Linear: 7-31 [2, 1179, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-32 [2, 128, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-33 [2, 128, 256] 65,536\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-34 [2, 32, 1179, 8] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-6 [2, 32, 1179, 128] --\n",
"│ │ │ │ │ │ └─Linear: 7-35 [2, 1179, 256] 65,792\n",
"│ │ │ │ │ │ └─PReLU: 7-36 [2, 1179, 256] 1\n",
"│ │ │ │ └─DoubleFeedForward: 5-6 [2, 1179, 256] --\n",
"│ │ │ │ │ └─Linear: 6-16 [2, 1179, 512] 131,584\n",
"│ │ │ │ │ └─LeakyHardtanh: 6-17 [2, 1179, 512] --\n",
"│ │ │ │ │ └─Linear: 6-18 [2, 1179, 256] 131,328\n",
"│ │ │ │ └─PoolingByMultiheadAttention: 5-7 [2, 16, 256] --\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-19 [2, 16, 256] 4,096\n",
"│ │ │ │ │ └─SimpleMultiHeadAttention: 6-20 [2, 16, 256] --\n",
"│ │ │ │ │ │ └─Linear: 7-37 [2, 16, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-38 [2, 1179, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-39 [2, 1179, 256] 65,536\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-40 [2, 32, 16, 8] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-7 [2, 32, 16, 1179] --\n",
"│ │ │ │ │ │ └─Linear: 7-41 [2, 16, 256] 65,792\n",
"│ │ │ │ │ │ └─PReLU: 7-42 [2, 16, 256] 1\n",
"│ └─Encoder: 2-4 [2, 16, 256] (recursive)\n",
"│ │ └─ModuleList: 3-20 -- (recursive)\n",
"│ │ │ └─EncoderLayer: 4-40 [2, 294, 256] (recursive)\n",
"│ │ │ │ └─SimpleInducedSetAttention: 5-8 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-21 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-22 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-43 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-44 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-45 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-46 [2, 32, 128, 8] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-8 [2, 32, 128, 294] --\n",
"│ │ │ │ │ │ └─Linear: 7-47 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ │ └─PReLU: 7-48 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-23 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-49 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-50 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-51 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-52 [2, 32, 294, 8] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-9 [2, 32, 294, 128] --\n",
"│ │ │ │ │ │ └─Linear: 7-53 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ │ └─PReLU: 7-54 [2, 294, 256] (recursive)\n",
"│ │ │ │ └─DoubleFeedForward: 5-9 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ └─Linear: 6-24 [2, 294, 512] (recursive)\n",
"│ │ │ │ │ └─Tanh: 6-25 [2, 294, 512] --\n",
"│ │ │ │ │ └─Linear: 6-26 [2, 294, 256] (recursive)\n",
"│ │ │ └─EncoderLayer: 4-41 [2, 294, 256] (recursive)\n",
"│ │ │ │ └─SimpleInducedSetAttention: 5-10 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-27 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-28 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-55 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-56 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-57 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-58 [2, 32, 128, 8] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-10 [2, 32, 128, 294] --\n",
"│ │ │ │ │ │ └─Linear: 7-59 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ │ └─PReLU: 7-60 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-29 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-61 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-62 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-63 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-64 [2, 32, 294, 8] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-11 [2, 32, 294, 128] --\n",
"│ │ │ │ │ │ └─Linear: 7-65 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ │ └─PReLU: 7-66 [2, 294, 256] (recursive)\n",
"│ │ │ │ └─DoubleFeedForward: 5-11 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ └─Linear: 6-30 [2, 294, 512] (recursive)\n",
"│ │ │ │ │ └─Tanh: 6-31 [2, 294, 512] --\n",
"│ │ │ │ │ └─Linear: 6-32 [2, 294, 256] (recursive)\n",
"│ │ │ └─EncoderLayer: 4-42 [2, 16, 256] (recursive)\n",
"│ │ │ │ └─SimpleInducedSetAttention: 5-12 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-33 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-34 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-67 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-68 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-69 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-70 [2, 32, 128, 8] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-12 [2, 32, 128, 294] --\n",
"│ │ │ │ │ │ └─Linear: 7-71 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ │ └─PReLU: 7-72 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-35 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-73 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-74 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-75 [2, 128, 256] (recursive)\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-76 [2, 32, 294, 8] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-13 [2, 32, 294, 128] --\n",
"│ │ │ │ │ │ └─Linear: 7-77 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ │ └─PReLU: 7-78 [2, 294, 256] (recursive)\n",
"│ │ │ │ └─DoubleFeedForward: 5-13 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ └─Linear: 6-36 [2, 294, 512] (recursive)\n",
"│ │ │ │ │ └─LeakyHardtanh: 6-37 [2, 294, 512] --\n",
"│ │ │ │ │ └─Linear: 6-38 [2, 294, 256] (recursive)\n",
"│ │ │ │ └─PoolingByMultiheadAttention: 5-14 [2, 16, 256] (recursive)\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-39 [2, 16, 256] (recursive)\n",
"│ │ │ │ │ └─SimpleMultiHeadAttention: 6-40 [2, 16, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-79 [2, 16, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-80 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-81 [2, 294, 256] (recursive)\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-82 [2, 32, 16, 8] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-14 [2, 32, 16, 294] --\n",
"│ │ │ │ │ │ └─Linear: 7-83 [2, 16, 256] (recursive)\n",
"│ │ │ │ │ │ └─PReLU: 7-84 [2, 16, 256] (recursive)\n",
"├─Head: 1-4 [2] --\n",
"│ └─Sequential: 2-5 [2, 1] --\n",
"│ │ └─FeedForward: 3-21 [2, 256] --\n",
"│ │ │ └─Linear: 4-43 [2, 256] 1,048,832\n",
"│ │ │ └─Softsign: 4-44 [2, 256] --\n",
"│ │ └─FeedForward: 3-22 [2, 256] --\n",
"│ │ │ └─Linear: 4-45 [2, 256] 65,792\n",
"│ │ │ └─Softsign: 4-46 [2, 256] --\n",
"│ │ └─FeedForward: 3-23 [2, 256] --\n",
"│ │ │ └─Linear: 4-47 [2, 256] 65,792\n",
"│ │ │ └─Softsign: 4-48 [2, 256] --\n",
"│ │ └─FeedForward: 3-24 [2, 256] --\n",
"│ │ │ └─Linear: 4-49 [2, 256] 65,792\n",
"│ │ │ └─Softsign: 4-50 [2, 256] --\n",
"│ │ └─FeedForward: 3-25 [2, 256] --\n",
"│ │ │ └─Linear: 4-51 [2, 256] 65,792\n",
"│ │ │ └─Softsign: 4-52 [2, 256] --\n",
"│ │ └─FeedForward: 3-26 [2, 256] --\n",
"│ │ │ └─Linear: 4-53 [2, 256] 65,792\n",
"│ │ │ └─Softsign: 4-54 [2, 256] --\n",
"│ │ └─FeedForward: 3-27 [2, 256] --\n",
"│ │ │ └─Linear: 4-55 [2, 256] 65,792\n",
"│ │ │ └─Softsign: 4-56 [2, 256] --\n",
"│ │ └─FeedForward: 3-28 [2, 256] --\n",
"│ │ │ └─Linear: 4-57 [2, 256] 65,792\n",
"│ │ │ └─Softsign: 4-58 [2, 256] --\n",
"│ │ └─FeedForward: 3-29 [2, 1] --\n",
"│ │ │ └─Linear: 4-59 [2, 1] 257\n",
"│ │ │ └─LeakyHardsigmoid: 4-60 [2, 1] --\n",
"========================================================================================================================\n",
"Total params: 11,895,304\n",
"Trainable params: 11,895,304\n",
"Non-trainable params: 0\n",
"Total mult-adds (M): 44.15\n",
"========================================================================================================================\n",
"Input size (MB): 0.54\n",
"Forward/backward pass size (MB): 375.40\n",
"Params size (MB): 47.58\n",
"Estimated Total Size (MB): 423.53\n",
"========================================================================================================================"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from torchinfo import summary\n",
"\n",
"role_model = params[\"fixed_role_model\"]\n",
"s = train_set[0][role_model]\n",
"summary(model[role_model], input_size=((2, *s[0].shape), (2, *s[1].shape)), depth=9) # 8 max"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "0f42c4d1",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T17:12:01.417191Z",
"iopub.status.busy": "2024-03-22T17:12:01.416837Z",
"iopub.status.idle": "2024-03-22T18:16:21.301787Z",
"shell.execute_reply": "2024-03-22T18:16:21.300655Z"
},
"papermill": {
"duration": 3859.922538,
"end_time": "2024-03-22T18:16:21.323259",
"exception": false,
"start_time": "2024-03-22T17:12:01.400721",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3 datasets [900, 450, 1050]\n",
"Creating model of type <class 'ml_utility_loss.loss_learning.estimator.model.models.TwinEncoder'>\n",
"[*] Embedding False True\n",
"g_loss_mul 0.1\n",
"Epoch 0\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss {'avg_role_model_loss': 0.020646917774962883, 'avg_role_model_std_loss': 0.7029616862738138, 'avg_role_model_mean_pred_loss': 0.0014198488189189598, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.020646917774962883, 'n_size': 900, 'n_batch': 225, 'duration': 211.21505284309387, 'duration_batch': 0.9387335681915283, 'duration_size': 0.23468339204788208, 'avg_pred_std': 0.12668553197549448}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Val loss {'avg_role_model_loss': 0.00800067097414285, 'avg_role_model_std_loss': 0.8323721770502058, 'avg_role_model_mean_pred_loss': 0.00025833536566152146, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00800067097414285, 'n_size': 450, 'n_batch': 113, 'duration': 91.81465721130371, 'duration_batch': 0.8125190903655196, 'duration_size': 0.20403257158067492, 'avg_pred_std': 0.07484065717399384}\n",
"Epoch 1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss {'avg_role_model_loss': 0.007335189075060447, 'avg_role_model_std_loss': 0.5605360431658958, 'avg_role_model_mean_pred_loss': 0.00016509135985034204, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007335189075060447, 'n_size': 900, 'n_batch': 225, 'duration': 211.18131518363953, 'duration_batch': 0.9385836230383979, 'duration_size': 0.23464590575959948, 'avg_pred_std': 0.09665914196934965}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Val loss {'avg_role_model_loss': 0.005770373049502572, 'avg_role_model_std_loss': 1.328719814272311, 'avg_role_model_mean_pred_loss': 0.0001455103024895009, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.005770373049502572, 'n_size': 450, 'n_batch': 113, 'duration': 92.31423711776733, 'duration_batch': 0.8169401514846667, 'duration_size': 0.20514274915059408, 'avg_pred_std': 0.06159803802889269}\n",
"Epoch 2\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss {'avg_role_model_loss': 0.004101434572932905, 'avg_role_model_std_loss': 0.40740934907574267, 'avg_role_model_mean_pred_loss': 4.4437185304309564e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.004101434572932905, 'n_size': 900, 'n_batch': 225, 'duration': 211.43363165855408, 'duration_batch': 0.9397050295935737, 'duration_size': 0.23492625739839343, 'avg_pred_std': 0.10036693361898263}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Val loss {'avg_role_model_loss': 0.005793615489771279, 'avg_role_model_std_loss': 2.93090469723272, 'avg_role_model_mean_pred_loss': 7.996645985781352e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.005793615489771279, 'n_size': 450, 'n_batch': 113, 'duration': 92.21822166442871, 'duration_batch': 0.8160904572073338, 'duration_size': 0.20492938147650824, 'avg_pred_std': 0.0686845871407654}\n",
"Epoch 3\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss {'avg_role_model_loss': 0.0032525908350200753, 'avg_role_model_std_loss': 0.45932424604433697, 'avg_role_model_mean_pred_loss': 1.5064653623929553e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0032525908350200753, 'n_size': 900, 'n_batch': 225, 'duration': 210.6352882385254, 'duration_batch': 0.9361568366156684, 'duration_size': 0.2340392091539171, 'avg_pred_std': 0.09941185830367937}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Val loss {'avg_role_model_loss': 0.0025840073977209006, 'avg_role_model_std_loss': 2.9610207560058215, 'avg_role_model_mean_pred_loss': 8.98458365437745e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0025840073977209006, 'n_size': 450, 'n_batch': 113, 'duration': 91.6226761341095, 'duration_batch': 0.81082014277973, 'duration_size': 0.20360594696468778, 'avg_pred_std': 0.04679510168797147}\n",
"Epoch 4\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss {'avg_role_model_loss': 0.0028994390259807308, 'avg_role_model_std_loss': 0.27792873476118807, 'avg_role_model_mean_pred_loss': 1.2238399814356409e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0028994390259807308, 'n_size': 900, 'n_batch': 225, 'duration': 210.90299940109253, 'duration_batch': 0.9373466640048557, 'duration_size': 0.23433666600121392, 'avg_pred_std': 0.10629830273903078}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Val loss {'avg_role_model_loss': 0.004226378290137897, 'avg_role_model_std_loss': 3.139442252464839, 'avg_role_model_mean_pred_loss': 2.8896792241464515e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.004226378290137897, 'n_size': 450, 'n_batch': 113, 'duration': 92.30420923233032, 'duration_batch': 0.8168514091356666, 'duration_size': 0.20512046496073405, 'avg_pred_std': 0.04041025609363167}\n",
"Epoch 5\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss {'avg_role_model_loss': 0.0029357373788823477, 'avg_role_model_std_loss': 0.35013624048834924, 'avg_role_model_mean_pred_loss': 1.5252040759503315e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0029357373788823477, 'n_size': 900, 'n_batch': 225, 'duration': 210.8098328113556, 'duration_batch': 0.9369325902726915, 'duration_size': 0.23423314756817287, 'avg_pred_std': 0.10065761231092943}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Val loss {'avg_role_model_loss': 0.003366937771077371, 'avg_role_model_std_loss': 2.7330855187934775, 'avg_role_model_mean_pred_loss': 2.1526505657937356e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.003366937771077371, 'n_size': 450, 'n_batch': 113, 'duration': 92.45483732223511, 'duration_batch': 0.8181844010817266, 'duration_size': 0.20545519404941134, 'avg_pred_std': 0.048643345435137604}\n",
"Epoch 6\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss {'avg_role_model_loss': 0.0025804581021010463, 'avg_role_model_std_loss': 0.33193543293685124, 'avg_role_model_mean_pred_loss': 1.1130432750345e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0025804581021010463, 'n_size': 900, 'n_batch': 225, 'duration': 210.82577991485596, 'duration_batch': 0.9370034662882487, 'duration_size': 0.23425086657206218, 'avg_pred_std': 0.10346181529677577}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Val loss {'avg_role_model_loss': 0.0029045950072920986, 'avg_role_model_std_loss': 4.791847205054414, 'avg_role_model_mean_pred_loss': 1.7926797626652248e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0029045950072920986, 'n_size': 450, 'n_batch': 113, 'duration': 92.71966814994812, 'duration_batch': 0.8205280367252046, 'duration_size': 0.20604370699988472, 'avg_pred_std': 0.04583418705378066}\n",
"Epoch 7\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss {'avg_role_model_loss': 0.0024444010488999385, 'avg_role_model_std_loss': 0.27924930442016427, 'avg_role_model_mean_pred_loss': 1.0806890312169106e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0024444010488999385, 'n_size': 900, 'n_batch': 225, 'duration': 211.59207558631897, 'duration_batch': 0.9404092248280843, 'duration_size': 0.23510230620702108, 'avg_pred_std': 0.10225464255238573}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Val loss {'avg_role_model_loss': 0.002720622533104486, 'avg_role_model_std_loss': 3.2791891266000865, 'avg_role_model_mean_pred_loss': 1.838483538128186e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002720622533104486, 'n_size': 450, 'n_batch': 113, 'duration': 96.57148313522339, 'duration_batch': 0.8546148950019768, 'duration_size': 0.21460329585605198, 'avg_pred_std': 0.04418732333965435}\n",
"Epoch 8\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss {'avg_role_model_loss': 0.00242780985414154, 'avg_role_model_std_loss': 0.25769682647126535, 'avg_role_model_mean_pred_loss': 1.037415603605397e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00242780985414154, 'n_size': 900, 'n_batch': 225, 'duration': 213.52410340309143, 'duration_batch': 0.9489960151248508, 'duration_size': 0.2372490037812127, 'avg_pred_std': 0.10494643683855732}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Val loss {'avg_role_model_loss': 0.002887863010659607, 'avg_role_model_std_loss': 1.8791022815315657, 'avg_role_model_mean_pred_loss': 2.555802672021092e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002887863010659607, 'n_size': 450, 'n_batch': 113, 'duration': 92.99426054954529, 'duration_batch': 0.8229580579605777, 'duration_size': 0.20665391233232286, 'avg_pred_std': 0.059518284711418096}\n",
"Epoch 9\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss {'avg_role_model_loss': 0.002208985082106665, 'avg_role_model_std_loss': 0.42546326303130916, 'avg_role_model_mean_pred_loss': 7.684016633753033e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.002208985082106665, 'n_size': 900, 'n_batch': 225, 'duration': 213.6311333179474, 'duration_batch': 0.9494717036353217, 'duration_size': 0.23736792590883043, 'avg_pred_std': 0.10293637524772849}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Val loss {'avg_role_model_loss': 0.0024638745025731624, 'avg_role_model_std_loss': 3.782287786237916, 'avg_role_model_mean_pred_loss': 1.655159604749657e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0024638745025731624, 'n_size': 450, 'n_batch': 113, 'duration': 97.28867506980896, 'duration_batch': 0.8609617262814953, 'duration_size': 0.21619705571068656, 'avg_pred_std': 0.04808065608046965}\n",
"Epoch 10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss {'avg_role_model_loss': 0.0020764596655175813, 'avg_role_model_std_loss': 0.3896139937922279, 'avg_role_model_mean_pred_loss': 7.334516867581215e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0020764596655175813, 'n_size': 900, 'n_batch': 225, 'duration': 221.03631401062012, 'duration_batch': 0.9823836178249783, 'duration_size': 0.24559590445624457, 'avg_pred_std': 0.10016421435814765}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Val loss {'avg_role_model_loss': 0.0025203956082178692, 'avg_role_model_std_loss': 2.2354346593865277, 'avg_role_model_mean_pred_loss': 2.0540576969837897e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0025203956082178692, 'n_size': 450, 'n_batch': 113, 'duration': 97.72019529342651, 'duration_batch': 0.8647804893223585, 'duration_size': 0.2171559895409478, 'avg_pred_std': 0.057214381701758014}\n",
"Epoch 11\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss {'avg_role_model_loss': 0.0019838308382329867, 'avg_role_model_std_loss': 0.26038020301704434, 'avg_role_model_mean_pred_loss': 7.2093371565769966e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0019838308382329867, 'n_size': 900, 'n_batch': 225, 'duration': 221.59594750404358, 'duration_batch': 0.9848708777957492, 'duration_size': 0.2462177194489373, 'avg_pred_std': 0.1044646823985709}\n",
"Time out: 3600.622545480728/3600\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Eval loss {'role_model': 'tvae', 'n_size': 1050, 'n_batch': 263, 'role_model_metrics': {'avg_loss': 0.0027671898508283663, 'avg_g_mag_loss': 0.06750814222799723, 'avg_g_cos_loss': 0.031209569611071075, 'pred_duration': 4.135732650756836, 'grad_duration': 12.466698169708252, 'total_duration': 16.602430820465088, 'pred_std': 0.10057252645492554, 'std_loss': 0.008725129999220371, 'mean_pred_loss': 2.4340070012840442e-05, 'pred_rmse': 0.05260408669710159, 'pred_mae': 0.03820032626390457, 'pred_mape': 0.1258891075849533, 'grad_rmse': 0.04161679372191429, 'grad_mae': 0.021082276478409767, 'grad_mape': 0.4825449585914612}, 'non_role_model_metrics': {'avg_loss': 0, 'avg_g_mag_loss': 0, 'avg_g_cos_loss': 0, 'avg_pred_duration': 0, 'avg_grad_duration': 0, 'avg_total_duration': 0, 'avg_pred_std': 0, 'avg_std_loss': 0, 'avg_mean_pred_loss': 0}, 'avg_metrics': {'avg_loss': 0.0027671898508283663, 'avg_g_mag_loss': 0.06750814222799723, 'avg_g_cos_loss': 0.031209569611071075, 'avg_pred_duration': 4.135732650756836, 'avg_grad_duration': 12.466698169708252, 'avg_total_duration': 16.602430820465088, 'avg_pred_std': 0.10057252645492554, 'avg_std_loss': 0.008725129999220371, 'avg_mean_pred_loss': 2.4340070012840442e-05}, 'min_metrics': {'avg_loss': 0.0027671898508283663, 'avg_g_mag_loss': 0.06750814222799723, 'avg_g_cos_loss': 0.031209569611071075, 'pred_duration': 4.135732650756836, 'grad_duration': 12.466698169708252, 'total_duration': 16.602430820465088, 'pred_std': 0.10057252645492554, 'std_loss': 0.008725129999220371, 'mean_pred_loss': 2.4340070012840442e-05, 'pred_rmse': 0.05260408669710159, 'pred_mae': 0.03820032626390457, 'pred_mape': 0.1258891075849533, 'grad_rmse': 0.04161679372191429, 'grad_mae': 0.021082276478409767, 'grad_mape': 0.4825449585914612}, 'model_metrics': {'tvae': {'avg_loss': 0.0027671898508283663, 'avg_g_mag_loss': 0.06750814222799723, 'avg_g_cos_loss': 0.031209569611071075, 'pred_duration': 4.135732650756836, 'grad_duration': 12.466698169708252, 'total_duration': 16.602430820465088, 'pred_std': 0.10057252645492554, 'std_loss': 0.008725129999220371, 'mean_pred_loss': 2.4340070012840442e-05, 'pred_rmse': 0.05260408669710159, 'pred_mae': 0.03820032626390457, 'pred_mape': 0.1258891075849533, 'grad_rmse': 0.04161679372191429, 'grad_mae': 0.021082276478409767, 'grad_mape': 0.4825449585914612}}}\n"
]
}
],
"source": [
"import torch\n",
"from ml_utility_loss.loss_learning.estimator.pipeline import train, train_2\n",
"from ml_utility_loss.loss_learning.estimator.process_simple import train_epoch, eval as _eval\n",
"from ml_utility_loss.params import GradientPenaltyMode\n",
"from ml_utility_loss.util import clear_memory\n",
"import time\n",
"#torch.autograd.set_detect_anomaly(True)\n",
"\n",
"del model\n",
"clear_memory()\n",
"\n",
"#opt = params[\"Optim\"](model.parameters())\n",
"loss = train_2(\n",
" [train_set, val_set, test_set],\n",
" preprocessor=preprocessor,\n",
" #whole_model=model,\n",
" #optim=opt,\n",
" log_dir=\"logs\",\n",
" checkpoint_dir=\"checkpoints\",\n",
" verbose=True,\n",
" allow_same_prediction=allow_same_prediction,\n",
" wandb=wandb if log_wandb else None,\n",
" study_name=study_name,\n",
" **params\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "9b514a07",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T18:16:21.365486Z",
"iopub.status.busy": "2024-03-22T18:16:21.365041Z",
"iopub.status.idle": "2024-03-22T18:16:21.370168Z",
"shell.execute_reply": "2024-03-22T18:16:21.369073Z"
},
"papermill": {
"duration": 0.028958,
"end_time": "2024-03-22T18:16:21.372457",
"exception": false,
"start_time": "2024-03-22T18:16:21.343499",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"model = loss[\"whole_model\"]\n",
"opt = loss[\"optim\"]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "331a49e1",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T18:16:21.409276Z",
"iopub.status.busy": "2024-03-22T18:16:21.408968Z",
"iopub.status.idle": "2024-03-22T18:16:21.506071Z",
"shell.execute_reply": "2024-03-22T18:16:21.505164Z"
},
"papermill": {
"duration": 0.117463,
"end_time": "2024-03-22T18:16:21.508351",
"exception": false,
"start_time": "2024-03-22T18:16:21.390888",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"import torch\n",
"from copy import deepcopy\n",
"\n",
"torch.save(deepcopy(model.state_dict()), \"model.pt\")\n",
"#torch.save(deepcopy(opt.state_dict()), \"optim.pt\")"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "123b4b17",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T18:16:21.544277Z",
"iopub.status.busy": "2024-03-22T18:16:21.543923Z",
"iopub.status.idle": "2024-03-22T18:16:21.818454Z",
"shell.execute_reply": "2024-03-22T18:16:21.817370Z"
},
"papermill": {
"duration": 0.295571,
"end_time": "2024-03-22T18:16:21.820671",
"exception": false,
"start_time": "2024-03-22T18:16:21.525100",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"<Axes: >"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 300x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"history = loss[\"history\"]\n",
"history.to_csv(\"history.csv\")\n",
"history[[\"avg_loss_train\", \"avg_loss_test\"]].plot()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "2586ba0a",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T18:16:21.858143Z",
"iopub.status.busy": "2024-03-22T18:16:21.857786Z",
"iopub.status.idle": "2024-03-22T18:20:40.954647Z",
"shell.execute_reply": "2024-03-22T18:20:40.953627Z"
},
"papermill": {
"duration": 259.118235,
"end_time": "2024-03-22T18:20:40.957108",
"exception": false,
"start_time": "2024-03-22T18:16:21.838873",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"\n",
"from ml_utility_loss.loss_learning.estimator.pipeline import eval\n",
"#eval_loss = loss[\"eval_loss\"]\n",
"\n",
"batch_size = params[\"batch_size_low\"] if \"batch_size_low\" in params else params[\"batch_size\"]\n",
"\n",
"eval_loss = eval(\n",
" test_set, model,\n",
" batch_size=batch_size,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "187137f6",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T18:20:40.994176Z",
"iopub.status.busy": "2024-03-22T18:20:40.993823Z",
"iopub.status.idle": "2024-03-22T18:20:41.015411Z",
"shell.execute_reply": "2024-03-22T18:20:41.014472Z"
},
"papermill": {
"duration": 0.04219,
"end_time": "2024-03-22T18:20:41.017464",
"exception": false,
"start_time": "2024-03-22T18:20:40.975274",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>avg_g_cos_loss</th>\n",
" <th>avg_g_mag_loss</th>\n",
" <th>avg_loss</th>\n",
" <th>grad_duration</th>\n",
" <th>grad_mae</th>\n",
" <th>grad_mape</th>\n",
" <th>grad_rmse</th>\n",
" <th>mean_pred_loss</th>\n",
" <th>pred_duration</th>\n",
" <th>pred_mae</th>\n",
" <th>pred_mape</th>\n",
" <th>pred_rmse</th>\n",
" <th>pred_std</th>\n",
" <th>std_loss</th>\n",
" <th>total_duration</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>tvae</th>\n",
" <td>0.030277</td>\n",
" <td>0.045636</td>\n",
" <td>0.002767</td>\n",
" <td>12.488235</td>\n",
" <td>0.021082</td>\n",
" <td>0.482545</td>\n",
" <td>0.041617</td>\n",
" <td>0.000024</td>\n",
" <td>4.115197</td>\n",
" <td>0.0382</td>\n",
" <td>0.125889</td>\n",
" <td>0.052604</td>\n",
" <td>0.100573</td>\n",
" <td>0.008725</td>\n",
" <td>16.603431</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" avg_g_cos_loss avg_g_mag_loss avg_loss grad_duration grad_mae \\\n",
"tvae 0.030277 0.045636 0.002767 12.488235 0.021082 \n",
"\n",
" grad_mape grad_rmse mean_pred_loss pred_duration pred_mae \\\n",
"tvae 0.482545 0.041617 0.000024 4.115197 0.0382 \n",
"\n",
" pred_mape pred_rmse pred_std std_loss total_duration \n",
"tvae 0.125889 0.052604 0.100573 0.008725 16.603431 "
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"metrics = pd.DataFrame(eval_loss[\"model_metrics\"]).T\n",
"metrics.to_csv(\"eval.csv\")\n",
"metrics"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "123d305b",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T18:20:41.051948Z",
"iopub.status.busy": "2024-03-22T18:20:41.051665Z",
"iopub.status.idle": "2024-03-22T18:20:41.424283Z",
"shell.execute_reply": "2024-03-22T18:20:41.423284Z"
},
"papermill": {
"duration": 0.392558,
"end_time": "2024-03-22T18:20:41.426797",
"exception": false,
"start_time": "2024-03-22T18:20:41.034239",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"from ml_utility_loss.util import clear_memory\n",
"clear_memory()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "a3eecc2a",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T18:20:41.464627Z",
"iopub.status.busy": "2024-03-22T18:20:41.464010Z",
"iopub.status.idle": "2024-03-22T18:25:10.980556Z",
"shell.execute_reply": "2024-03-22T18:25:10.979659Z"
},
"papermill": {
"duration": 269.538205,
"end_time": "2024-03-22T18:25:10.983174",
"exception": false,
"start_time": "2024-03-22T18:20:41.444969",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Caching in ../../../../contraceptive/_cache_aug_test/tvae/all inf False\n",
"Caching in ../../../../contraceptive/_cache_bs_test/tvae/all inf False\n",
"Caching in ../../../../contraceptive/_cache_synth_test/tvae/all inf False\n"
]
}
],
"source": [
"#\"\"\"\n",
"from ml_utility_loss.loss_learning.estimator.process import pred, pred_2\n",
"from ml_utility_loss.util import stack_samples\n",
"\n",
"#samples = test_set[list(range(len(test_set)))]\n",
"#y = {m: pred(model[m], s) for m, s in samples.items()}\n",
"y = pred_2(model, test_set, batch_size=batch_size)\n",
"#\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "6ab51db8",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T18:25:11.021429Z",
"iopub.status.busy": "2024-03-22T18:25:11.021082Z",
"iopub.status.idle": "2024-03-22T18:25:11.048484Z",
"shell.execute_reply": "2024-03-22T18:25:11.047685Z"
},
"papermill": {
"duration": 0.049077,
"end_time": "2024-03-22T18:25:11.050646",
"exception": false,
"start_time": "2024-03-22T18:25:11.001569",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"import pandas as pd\n",
"from ml_utility_loss.util import transpose_dict\n",
"\n",
"os.makedirs(\"pred\", exist_ok=True)\n",
"y2 = transpose_dict(y)\n",
"for k, v in y2.items():\n",
" df = pd.DataFrame(v)\n",
" df.to_csv(f\"pred/{k}.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "d81a30f1",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T18:25:11.085900Z",
"iopub.status.busy": "2024-03-22T18:25:11.085593Z",
"iopub.status.idle": "2024-03-22T18:25:11.091285Z",
"shell.execute_reply": "2024-03-22T18:25:11.090350Z"
},
"papermill": {
"duration": 0.025923,
"end_time": "2024-03-22T18:25:11.093469",
"exception": false,
"start_time": "2024-03-22T18:25:11.067546",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'tvae': 0.3965440729686192}\n"
]
}
],
"source": [
"print({k: sum(v[\"pred\"])/len(v[\"pred\"]) for k, v in y.items()})"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "3b3ff322",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T18:25:11.130756Z",
"iopub.status.busy": "2024-03-22T18:25:11.130471Z",
"iopub.status.idle": "2024-03-22T18:25:11.518772Z",
"shell.execute_reply": "2024-03-22T18:25:11.517773Z"
},
"papermill": {
"duration": 0.409639,
"end_time": "2024-03-22T18:25:11.520956",
"exception": false,
"start_time": "2024-03-22T18:25:11.111317",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 300x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from ml_utility_loss.loss_learning.visualization import plot_pred_density_2\n",
"\n",
"_ = plot_pred_density_2(y)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "e79e4b0f",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T18:25:11.558170Z",
"iopub.status.busy": "2024-03-22T18:25:11.557829Z",
"iopub.status.idle": "2024-03-22T18:25:11.910037Z",
"shell.execute_reply": "2024-03-22T18:25:11.908917Z"
},
"papermill": {
"duration": 0.37333,
"end_time": "2024-03-22T18:25:11.912313",
"exception": false,
"start_time": "2024-03-22T18:25:11.538983",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 300x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from ml_utility_loss.loss_learning.visualization import plot_density_3\n",
"\n",
"_ = plot_density_3(y2[\"pred\"], next(iter(y2[\"y\"].values())))"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "745adde1",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T18:25:11.954190Z",
"iopub.status.busy": "2024-03-22T18:25:11.953310Z",
"iopub.status.idle": "2024-03-22T18:25:12.111585Z",
"shell.execute_reply": "2024-03-22T18:25:12.110496Z"
},
"papermill": {
"duration": 0.182951,
"end_time": "2024-03-22T18:25:12.114578",
"exception": false,
"start_time": "2024-03-22T18:25:11.931627",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 300x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from ml_utility_loss.loss_learning.visualization import plot_box_3\n",
"\n",
"_ = plot_box_3(y2[\"pred\"], next(iter(y2[\"y\"].values())))"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "eabe1bab",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-22T18:25:12.154928Z",
"iopub.status.busy": "2024-03-22T18:25:12.154594Z",
"iopub.status.idle": "2024-03-22T18:25:12.439272Z",
"shell.execute_reply": "2024-03-22T18:25:12.438355Z"
},
"papermill": {
"duration": 0.306768,
"end_time": "2024-03-22T18:25:12.441359",
"exception": false,
"start_time": "2024-03-22T18:25:12.134591",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 300x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#\"\"\"\n",
"from ml_utility_loss.loss_learning.visualization import plot_grad, plot_grad_2, plot_grad_3\n",
"import matplotlib.pyplot as plt\n",
"\n",
"#plot_grad_2(y, model.models)\n",
"for m in model.models:\n",
" ym = y[m]\n",
" fig, ax = plt.subplots()\n",
" plot_grad_3(ym[\"error\"], ym[\"grad\"], name=f\"{m}_grad\", fig=fig, ax=ax)\n",
"#\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54c0e9f3",
"metadata": {
"papermill": {
"duration": 0.019727,
"end_time": "2024-03-22T18:25:12.481176",
"exception": false,
"start_time": "2024-03-22T18:25:12.461449",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"celltoolbar": "Tags",
"colab": {
"authorship_tag": "ABX9TyOOVfelovKP9fLGU7SvvRie",
"gpuType": "T4",
"mount_file_id": "17POSGAvge8y9DW9WGs2jLkibaRjToayg",
"provenance": []
},
"kaggle": {
"accelerator": "gpu",
"dataSources": [],
"dockerImageVersionId": 30648,
"isGpuEnabled": true,
"isInternetEnabled": true,
"language": "python",
"sourceType": "notebook"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
},
"papermill": {
"default_parameters": {},
"duration": 5438.254231,
"end_time": "2024-03-22T18:25:15.223561",
"environment_variables": {},
"exception": null,
"input_path": "eval/contraceptive/tvae/42/mlu-eval.ipynb",
"output_path": "eval/contraceptive/tvae/42/mlu-eval.ipynb",
"parameters": {
"allow_same_prediction": true,
"dataset": "contraceptive",
"dataset_name": "contraceptive",
"debug": false,
"folder": "eval",
"gp": false,
"gp_multiply": false,
"log_wandb": false,
"param_index": 0,
"path": "eval/contraceptive/tvae/42",
"path_prefix": "../../../../",
"random_seed": 42,
"single_model": "tvae"
},
"start_time": "2024-03-22T16:54:36.969330",
"version": "2.5.0"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
} |