rizqinur2010 commited on
Commit
83586b8
·
1 Parent(s): 7693c4d
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. contraceptive/lct_gan/eval.csv +2 -0
  2. contraceptive/lct_gan/history.csv +13 -0
  3. contraceptive/lct_gan/mlu-eval.ipynb +0 -0
  4. contraceptive/lct_gan/model.pt +3 -0
  5. contraceptive/lct_gan/params.json +1 -0
  6. contraceptive/realtabformer/eval.csv +2 -0
  7. contraceptive/realtabformer/history.csv +11 -0
  8. contraceptive/realtabformer/mlu-eval.ipynb +0 -0
  9. contraceptive/realtabformer/model.pt +3 -0
  10. contraceptive/realtabformer/params.json +1 -0
  11. contraceptive/tab_ddpm_concat/eval.csv +2 -0
  12. contraceptive/tab_ddpm_concat/history.csv +12 -0
  13. contraceptive/tab_ddpm_concat/mlu-eval.ipynb +0 -0
  14. contraceptive/tab_ddpm_concat/model.pt +3 -0
  15. contraceptive/tab_ddpm_concat/params.json +1 -0
  16. contraceptive/tvae/eval.csv +2 -0
  17. contraceptive/tvae/history.csv +12 -0
  18. contraceptive/tvae/mlu-eval.ipynb +0 -0
  19. contraceptive/tvae/model.pt +3 -0
  20. contraceptive/tvae/params.json +1 -0
  21. insurance/lct_gan/eval.csv +2 -0
  22. insurance/lct_gan/history.csv +21 -0
  23. insurance/lct_gan/mlu-eval.ipynb +0 -0
  24. insurance/lct_gan/model.pt +3 -0
  25. insurance/lct_gan/params.json +1 -0
  26. insurance/realtabformer/eval.csv +2 -0
  27. insurance/realtabformer/history.csv +17 -0
  28. insurance/realtabformer/mlu-eval.ipynb +0 -0
  29. insurance/realtabformer/model.pt +3 -0
  30. insurance/realtabformer/params.json +1 -0
  31. insurance/tab_ddpm_concat/eval.csv +2 -0
  32. insurance/tab_ddpm_concat/history.csv +23 -0
  33. insurance/tab_ddpm_concat/mlu-eval.ipynb +0 -0
  34. insurance/tab_ddpm_concat/model.pt +3 -0
  35. insurance/tab_ddpm_concat/params.json +1 -0
  36. insurance/tvae/eval.csv +2 -0
  37. insurance/tvae/history.csv +19 -0
  38. insurance/tvae/mlu-eval.ipynb +0 -0
  39. insurance/tvae/model.pt +3 -0
  40. insurance/tvae/params.json +1 -0
  41. treatment/lct_gan/eval.csv +2 -0
  42. treatment/lct_gan/history.csv +11 -0
  43. treatment/lct_gan/mlu-eval.ipynb +0 -0
  44. treatment/lct_gan/model.pt +3 -0
  45. treatment/lct_gan/params.json +1 -0
  46. treatment/realtabformer/eval.csv +2 -0
  47. treatment/realtabformer/history.csv +6 -0
  48. treatment/realtabformer/mlu-eval.ipynb +0 -0
  49. treatment/realtabformer/model.pt +3 -0
  50. treatment/realtabformer/params.json +1 -0
contraceptive/lct_gan/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ lct_gan,0.03109242501225145,0.09055740089827473,0.0031398469558362625,12.489086151123047,0.04181426391005516,1.1387050151824951,0.15034998953342438,3.8336263969540596e-05,3.991712808609009,0.04172290489077568,0.13467051088809967,0.05603433772921562,0.09302742779254913,0.02486800216138363,16.480798959732056
contraceptive/lct_gan/history.csv ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.017055697603565123,0.7837836365105371,0.000854024517475194,0.0,0.0,0.0,0.0,0.0,0.017055697603565123,900,225,209.5561990737915,0.9313608847724066,0.23284022119310166,0.11550631119145287,0.02723778669618898,0.35535829481002906,0.0011922009129072346,0.0,0.0,0.0,0.0,0.0,0.02723778669618898,450,113,91.00356817245483,0.8053413112606622,0.20223015149434406,0.12002328037391458
3
+ 1,0.007602974076354359,0.6694789405275434,0.00012108089975418718,0.0,0.0,0.0,0.0,0.0,0.007602974076354359,900,225,209.41460299491882,0.9307315688663059,0.23268289221657648,0.0941959698839734,0.004413039641092635,0.45001505491487465,4.256673724010495e-05,0.0,0.0,0.0,0.0,0.0,0.004413039641092635,450,113,91.49867558479309,0.8097227927857795,0.20333039018842908,0.06237377326902563
4
+ 2,0.005956742855616742,0.6394465912903183,6.889980071151411e-05,0.0,0.0,0.0,0.0,0.0,0.005956742855616742,900,225,209.05843949317932,0.9291486199696859,0.23228715499242147,0.09135436112475064,0.004399605713939915,1.215607570140407,2.922537710473547e-05,0.0,0.0,0.0,0.0,0.0,0.004399605713939915,450,113,88.99790143966675,0.7875920481386438,0.19777311431037056,0.04754088749589844
5
+ 3,0.005293925739824772,0.6930031799355281,5.091876555952298e-05,0.0,0.0,0.0,0.0,0.0,0.005293925739824772,900,225,209.8546106815338,0.9326871585845947,0.23317178964614868,0.09368091402575374,0.005122498869895935,0.6850604875349726,0.0001063596828682662,0.0,0.0,0.0,0.0,0.0,0.005122498869895935,450,113,90.19163846969604,0.7981560926521774,0.2004258632659912,0.06673125488749515
6
+ 4,0.004406307417407839,0.5391749190130491,3.865118440087487e-05,0.0,0.0,0.0,0.0,0.0,0.004406307417407839,900,225,208.6917781829834,0.9275190141465929,0.23187975353664822,0.0995840290437142,0.005340991177492671,0.6295711460246051,7.291253590576869e-05,0.0,0.0,0.0,0.0,0.0,0.005340991177492671,450,113,91.7313449382782,0.8117818136130814,0.20384743319617377,0.06550657832418132
7
+ 5,0.00469050889802424,0.9336842445089064,3.8414434089242215e-05,0.0,0.0,0.0,0.0,0.0,0.00469050889802424,900,225,208.08384490013123,0.9248170884450276,0.2312042721112569,0.09579447591263388,0.004810444195496125,0.3487526955594519,8.624510295827805e-05,0.0,0.0,0.0,0.0,0.0,0.004810444195496125,450,113,90.39807367324829,0.7999829528606044,0.20088460816277398,0.0735004077937487
8
+ 6,0.005119060436975107,0.6433215488478984,6.873547569881813e-05,0.0,0.0,0.0,0.0,0.0,0.005119060436975107,900,225,207.70152282714844,0.9231178792317708,0.2307794698079427,0.09771538318652245,0.004957494798161659,1.346057123426031,5.000207313663115e-05,0.0,0.0,0.0,0.0,0.0,0.004957494798161659,450,113,90.1600124835968,0.7978762166689982,0.20035558329688177,0.06382843967428249
9
+ 7,0.00532732381252572,0.4333305762650606,5.976084272669491e-05,0.0,0.0,0.0,0.0,0.0,0.00532732381252572,900,225,207.5520989894867,0.9224537732866075,0.23061344332165187,0.09995691900038058,0.0050655570465864405,1.4249071690357547,4.881459755913574e-05,0.0,0.0,0.0,0.0,0.0,0.0050655570465864405,450,113,90.09133553504944,0.7972684560623844,0.2002029678556654,0.05493936902612646
10
+ 8,0.004071665801651155,0.3260331416654134,2.1085142560956504e-05,0.0,0.0,0.0,0.0,0.0,0.004071665801651155,900,225,209.66889786720276,0.931861768298679,0.23296544207466974,0.10195785622629855,0.005279655439727422,0.9929788724251226,6.48734679797379e-05,0.0,0.0,0.0,0.0,0.0,0.005279655439727422,450,113,91.38696694374084,0.8087342207410694,0.20308214876386854,0.06704739362940984
11
+ 9,0.004601976428077453,0.3923278355509547,4.1437573981469074e-05,0.0,0.0,0.0,0.0,0.0,0.004601976428077453,900,225,202.2607822418213,0.8989368099636502,0.22473420249091255,0.1009762748144567,0.003612730009287285,1.519655740890833,3.1731855463335754e-05,0.0,0.0,0.0,0.0,0.0,0.003612730009287285,450,113,86.41371130943298,0.7647231089330353,0.19203046957651773,0.05825823241436805
12
+ 10,0.003639938447223459,0.2671018096537945,2.6972246297165055e-05,0.0,0.0,0.0,0.0,0.0,0.003639938447223459,900,225,204.09564805030823,0.907091769112481,0.22677294227812025,0.10696077811221281,0.002935796806381808,1.918275244656198,1.3822624714984066e-05,0.0,0.0,0.0,0.0,0.0,0.002935796806381808,450,113,87.83020257949829,0.7772584299070645,0.19517822795444065,0.04963276007591821
13
+ 11,0.0033298865797243907,0.48618020502341713,2.140108548165483e-05,0.0,0.0,0.0,0.0,0.0,0.0033298865797243907,900,225,201.81425046920776,0.89695222430759,0.2242380560768975,0.10155585827512874,0.002946255624992773,1.365417978676351,1.4400359725914503e-05,0.0,0.0,0.0,0.0,0.0,0.002946255624992773,450,113,87.90818929672241,0.777948577847101,0.19535153177049425,0.05489195802813577
contraceptive/lct_gan/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
contraceptive/lct_gan/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf30df9b8181b82f551f57bee82a280b3f49fc96eb848404aabd6765c121091e
3
+ size 47605515
contraceptive/lct_gan/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "none", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.73, "gradient_penalty_mode": "NONE", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.05, "head_activation": "softsign", "loss_balancer_beta": 0.67, "loss_balancer_r": 0.943, "tf_activation": "tanh", "dataset_size": 2048, "batch_size": 4, "epochs": 100, "lr_mul": 0.09, "n_warmup_steps": 100, "Optim": "amsgradw", "fixed_role_model": "lct_gan", "mse_mag": false, "mse_mag_target": 0.65, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "prelu", "tf_d_inner": 512, "tf_n_layers_enc": 3, "tf_n_head": 32, "tf_activation_final": "leakyhardtanh", "tf_num_inds": 128, "ada_d_hid": 1024, "ada_n_layers": 9, "ada_activation": "softsign", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 256, "head_n_layers": 9, "head_n_head": 32, "head_activation_final": "leakyhardsigmoid", "models": ["lct_gan"], "max_seconds": 3600}
contraceptive/realtabformer/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ realtabformer,0.01376279273286768,0.006844550730648704,0.002682759171099557,8.117686986923218,0.21929873526096344,5.288559913635254,0.37969863414764404,2.591257907624822e-05,8.684867143630981,0.037153489887714386,0.12090718746185303,0.0517953597009182,0.09103206545114517,0.03082490712404251,16.8025541305542
contraceptive/realtabformer/history.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.015436244466497252,0.6971071543745386,0.0005885604529329952,0.0,0.0,0.0,0.0,0.0,0.015436244466497252,900,225,233.85006093978882,1.0393336041768393,0.2598334010442098,0.1107698762230575,0.014223382426425814,0.39072058395795967,0.0007193372912492969,0.0,0.0,0.0,0.0,0.0,0.014223382426425814,450,113,101.00125312805176,0.8938163993632899,0.22444722917344834,0.1077716978992113
3
+ 1,0.006033640582399029,0.4650218232954108,0.00013023597284128328,0.0,0.0,0.0,0.0,0.0,0.006033640582399029,900,225,229.77366161346436,1.0212162738376194,0.25530406845940484,0.09818139906144804,0.0075027576179450585,0.3676717987151777,0.00021706114066001486,0.0,0.0,0.0,0.0,0.0,0.0075027576179450585,450,113,105.09904932975769,0.9300800825642274,0.2335534429550171,0.09515106879932954
4
+ 2,0.004512168372780757,0.38041619546195704,8.644590021956824e-05,0.0,0.0,0.0,0.0,0.0,0.004512168372780757,900,225,229.63529801368713,1.020601324505276,0.255150331126319,0.10015762750473288,0.0042034335448018585,1.679159187696908,0.00013376625657448635,0.0,0.0,0.0,0.0,0.0,0.0042034335448018585,450,113,100.8151388168335,0.8921693700604734,0.22403364181518554,0.06495650028253584
5
+ 3,0.0035539452421491863,0.47440502540713714,2.18277604056823e-05,0.0,0.0,0.0,0.0,0.0,0.0035539452421491863,900,225,229.67535948753357,1.0207793755001493,0.2551948438750373,0.10116369463089439,0.0038588620393743946,3.2039021881799323,5.9090333590049076e-05,0.0,0.0,0.0,0.0,0.0,0.0038588620393743946,450,113,102.2182068824768,0.904585901614839,0.22715157084994847,0.04514441881030056
6
+ 4,0.00319372646672289,0.37545427183122226,1.550665479323714e-05,0.0,0.0,0.0,0.0,0.0,0.00319372646672289,900,225,230.0189757347107,1.0223065588209364,0.2555766397052341,0.10247180342260334,0.0040023055252256905,3.8167529671252023,7.529611734713794e-05,0.0,0.0,0.0,0.0,0.0,0.0040023055252256905,450,113,101.59934687614441,0.8991092643906585,0.22577632639143203,0.04470131049484872
7
+ 5,0.0029433389956506693,0.4368025242733054,1.580498963409443e-05,0.0,0.0,0.0,0.0,0.0,0.0029433389956506693,900,225,229.80718541145325,1.0213652684953478,0.25534131712383695,0.1031194214626319,0.0028074885386094035,2.145103556606243,1.1650951160528297e-05,0.0,0.0,0.0,0.0,0.0,0.0028074885386094035,450,113,101.39577078819275,0.8973077060902013,0.22532393508487278,0.05228710767763576
8
+ 6,0.0026049669001885277,0.44792478884607234,1.0308605452903701e-05,0.0,0.0,0.0,0.0,0.0,0.0026049669001885277,900,225,230.53775358200073,1.0246122381422256,0.2561530595355564,0.10098279579128656,0.002687264719667534,4.189337836805609,1.241189197620803e-05,0.0,0.0,0.0,0.0,0.0,0.002687264719667534,450,113,100.64273428916931,0.8906436662758346,0.22365052064259847,0.042642332664979375
9
+ 7,0.002563537800257715,0.34648435719401377,9.35702638039536e-06,0.0,0.0,0.0,0.0,0.0,0.002563537800257715,900,225,227.69733333587646,1.011988148159451,0.25299703703986276,0.10545726788747642,0.0028965300656919784,3.1744629383202634,1.6787577942173726e-05,0.0,0.0,0.0,0.0,0.0,0.0028965300656919784,450,113,100.22019076347351,0.8869043430395885,0.22271153502994112,0.05538224283527517
10
+ 8,0.002383246352579186,0.3083862728209352,9.936972519990587e-06,0.0,0.0,0.0,0.0,0.0,0.002383246352579186,900,225,227.7537636756897,1.012238949669732,0.253059737417433,0.10515566083292166,0.0030194523966767723,3.0041199968275354,1.1394547741565071e-05,0.0,0.0,0.0,0.0,0.0,0.0030194523966767723,450,113,100.15035581588745,0.8862863346538713,0.22255634625752768,0.04854994764514432
11
+ 9,0.0021668040852819105,0.2028817977175666,7.674208641294594e-06,0.0,0.0,0.0,0.0,0.0,0.0021668040852819105,900,225,226.46019649505615,1.0064897622002495,0.2516224405500624,0.10634001894543568,0.0023035917821754184,3.5878168070930125,7.551760705609701e-06,0.0,0.0,0.0,0.0,0.0,0.0023035917821754184,450,113,100.57623744010925,0.8900551985850377,0.22350274986690946,0.05207151871749439
contraceptive/realtabformer/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
contraceptive/realtabformer/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fedab6ff24590104807dd04fe22f394cb4a5e57d14949b65f1a9438d41d21a1
3
+ size 50388737
contraceptive/realtabformer/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "none", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.73, "gradient_penalty_mode": "NONE", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.05, "head_activation": "softsign", "loss_balancer_beta": 0.67, "loss_balancer_r": 0.943, "tf_activation": "tanh", "dataset_size": 2048, "batch_size": 4, "epochs": 100, "lr_mul": 0.09, "n_warmup_steps": 100, "Optim": "amsgradw", "fixed_role_model": "realtabformer", "mse_mag": false, "mse_mag_target": 0.65, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "prelu", "tf_d_inner": 512, "tf_n_layers_enc": 3, "tf_n_head": 32, "tf_activation_final": "leakyhardtanh", "tf_num_inds": 128, "ada_d_hid": 1024, "ada_n_layers": 9, "ada_activation": "softsign", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 256, "head_n_layers": 9, "head_n_head": 32, "head_activation_final": "leakyhardsigmoid", "models": ["realtabformer"], "max_seconds": 3600}
contraceptive/tab_ddpm_concat/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ tab_ddpm_concat,0.009295055350909631,0.044223827542087174,0.0025627935946630756,12.520719766616821,0.03127303719520569,0.8461189270019531,0.043997641652822495,6.7261853473610245e-06,3.9416027069091797,0.037461522966623306,0.11356883496046066,0.05062403902411461,0.10244974493980408,0.00610949145630002,16.462322473526
contraceptive/tab_ddpm_concat/history.csv ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.02347481836991695,0.5124517322441693,0.002339297849103639,0.0,0.0,0.0,0.0,0.0,0.02347481836991695,900,225,223.80061292648315,0.9946693907843696,0.2486673476960924,0.13361252259876993,0.008576864742596323,0.43590039322429097,0.0002531669465378804,0.0,0.0,0.0,0.0,0.0,0.008576864742596323,450,113,98.30213046073914,0.8699303580596384,0.21844917880164252,0.06112736243490888
3
+ 1,0.008378146543917763,0.5896725166847361,0.00015907265293133425,0.0,0.0,0.0,0.0,0.0,0.008378146543917763,900,225,223.44736337661743,0.9930993927849664,0.2482748481962416,0.10096969366073609,0.004039558287994522,1.381628237529861,3.3886849553694544e-05,0.0,0.0,0.0,0.0,0.0,0.004039558287994522,450,113,98.5942611694336,0.8725155855702088,0.21909835815429687,0.052547984426857625
4
+ 2,0.010694745298888949,0.8260207763179106,0.00029481776721037526,0.0,0.0,0.0,0.0,0.0,0.010694745298888949,900,225,222.78780055046082,0.9901680024464925,0.24754200061162313,0.10044084888158573,0.006571635709972017,0.6343476063295598,7.651879237107392e-05,0.0,0.0,0.0,0.0,0.0,0.006571635709972017,450,113,98.61985659599304,0.8727420937698499,0.21915523687998453,0.0684763636147455
5
+ 3,0.007882197938549022,0.5442602433114256,0.00010187150954560703,0.0,0.0,0.0,0.0,0.0,0.007882197938549022,900,225,222.28374481201172,0.9879277547200521,0.24698193868001303,0.10185433081868622,0.010345627184336384,0.40530077661136166,0.00020547013780565572,0.0,0.0,0.0,0.0,0.0,0.010345627184336384,450,113,97.98167514801025,0.8670944703363739,0.21773705588446723,0.0684052016874528
6
+ 4,0.004765987225756463,0.33747415092830846,4.3614522179710367e-05,0.0,0.0,0.0,0.0,0.0,0.004765987225756463,900,225,222.1056261062622,0.987136116027832,0.246784029006958,0.10019376738617818,0.004264881536364555,1.3016880564829196,2.9635327498389935e-05,0.0,0.0,0.0,0.0,0.0,0.004264881536364555,450,113,99.50796246528625,0.880601437745896,0.2211288054784139,0.06006413273538042
7
+ 5,0.004931693592548577,0.49496390551068475,4.218088193112541e-05,0.0,0.0,0.0,0.0,0.0,0.004931693592548577,900,225,224.20910120010376,0.9964848942226834,0.24912122355567085,0.0999689629011684,0.004869184112176299,3.0701807380044164,3.4871858414750014e-05,0.0,0.0,0.0,0.0,0.0,0.004869184112176299,450,113,98.850257396698,0.8747810389088319,0.21966723865932888,0.057760195345082116
8
+ 6,0.004765996111996679,0.4421812962271803,3.558628672912029e-05,0.0,0.0,0.0,0.0,0.0,0.004765996111996679,900,225,223.28885769844055,0.9923949231041802,0.24809873077604505,0.10080746096455388,0.004297750624181289,2.2529975814067793,3.2225562565454145e-05,0.0,0.0,0.0,0.0,0.0,0.004297750624181289,450,113,98.77103543281555,0.8740799595824386,0.21949118985070123,0.05614519555768641
9
+ 7,0.0035699382384255943,0.3529214430276707,1.8523309853427525e-05,0.0,0.0,0.0,0.0,0.0,0.0035699382384255943,900,225,222.1509132385254,0.987337392171224,0.246834348042806,0.10173722859472037,0.0035803620515636996,2.1854055460992443,2.0855983310910675e-05,0.0,0.0,0.0,0.0,0.0,0.0035803620515636996,450,113,98.14665222167969,0.8685544444396432,0.21810367160373265,0.05454383128058923
10
+ 8,0.0035860408845150636,0.2396017117707293,2.2730216006070548e-05,0.0,0.0,0.0,0.0,0.0,0.0035860408845150636,900,225,221.9519350528717,0.9864530446794298,0.24661326116985746,0.10408703482813306,0.0036512146290624513,1.5390193092272952,3.486896375240715e-05,0.0,0.0,0.0,0.0,0.0,0.0036512146290624513,450,113,98.24018931388855,0.8693822063175978,0.2183115318086412,0.0619035135116016
11
+ 9,0.0029799208378406346,0.2575948480180224,1.6468858607504945e-05,0.0,0.0,0.0,0.0,0.0,0.0029799208378406346,900,225,222.18345999717712,0.9874820444318984,0.2468705111079746,0.10268474409770634,0.00386869330644711,2.897202919005004,3.095258879357379e-05,0.0,0.0,0.0,0.0,0.0,0.00386869330644711,450,113,98.27193331718445,0.8696631267007473,0.21838207403818766,0.06031067110495065
12
+ 10,0.002677944270108128,0.2724713624029512,1.7092457439067723e-05,0.0,0.0,0.0,0.0,0.0,0.002677944270108128,900,225,222.04818081855774,0.9868808036380344,0.2467202009095086,0.10342358542606235,0.002709867898762847,3.8843037776192246,8.260416360805687e-06,0.0,0.0,0.0,0.0,0.0,0.002709867898762847,450,113,98.23442721366882,0.8693312142802551,0.21829872714148627,0.045337541868489865
contraceptive/tab_ddpm_concat/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
contraceptive/tab_ddpm_concat/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd9a78df91e76837706914b0b3878144b373cc365fa95fcfaac8d1f26709426c
3
+ size 47482955
contraceptive/tab_ddpm_concat/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "none", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.73, "gradient_penalty_mode": "NONE", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.05, "head_activation": "softsign", "loss_balancer_beta": 0.67, "loss_balancer_r": 0.943, "tf_activation": "tanh", "dataset_size": 2048, "batch_size": 4, "epochs": 100, "lr_mul": 0.09, "n_warmup_steps": 100, "Optim": "amsgradw", "fixed_role_model": "tab_ddpm_concat", "mse_mag": false, "mse_mag_target": 0.65, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "prelu", "tf_d_inner": 512, "tf_n_layers_enc": 3, "tf_n_head": 32, "tf_activation_final": "leakyhardtanh", "tf_num_inds": 128, "ada_d_hid": 1024, "ada_n_layers": 9, "ada_activation": "softsign", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 256, "head_n_layers": 9, "head_n_head": 32, "head_activation_final": "leakyhardsigmoid", "models": ["tab_ddpm_concat"], "max_seconds": 3600}
contraceptive/tvae/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ tvae,0.030277189869137557,0.04563582819014448,0.0027671899940053535,12.488234519958496,0.021082285791635513,0.4825451970100403,0.04161680117249489,2.434007910778746e-05,4.115196704864502,0.03820033371448517,0.1258891224861145,0.05260408669710159,0.10057253390550613,0.008725106716156006,16.603431224822998
contraceptive/tvae/history.csv ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.020646917774962883,0.7029616862738138,0.0014198488189189598,0.0,0.0,0.0,0.0,0.0,0.020646917774962883,900,225,211.21505284309387,0.9387335681915283,0.23468339204788208,0.12668553197549448,0.00800067097414285,0.8323721770502058,0.00025833536566152146,0.0,0.0,0.0,0.0,0.0,0.00800067097414285,450,113,91.81465721130371,0.8125190903655196,0.20403257158067492,0.07484065717399384
3
+ 1,0.007335189075060447,0.5605360431658958,0.00016509135985034204,0.0,0.0,0.0,0.0,0.0,0.007335189075060447,900,225,211.18131518363953,0.9385836230383979,0.23464590575959948,0.09665914196934965,0.005770373049502572,1.328719814272311,0.0001455103024895009,0.0,0.0,0.0,0.0,0.0,0.005770373049502572,450,113,92.31423711776733,0.8169401514846667,0.20514274915059408,0.06159803802889269
4
+ 2,0.004101434572932905,0.40740934907574267,4.4437185304309564e-05,0.0,0.0,0.0,0.0,0.0,0.004101434572932905,900,225,211.43363165855408,0.9397050295935737,0.23492625739839343,0.10036693361898263,0.005793615489771279,2.93090469723272,7.996645985781352e-05,0.0,0.0,0.0,0.0,0.0,0.005793615489771279,450,113,92.21822166442871,0.8160904572073338,0.20492938147650824,0.0686845871407654
5
+ 3,0.0032525908350200753,0.45932424604433697,1.5064653623929553e-05,0.0,0.0,0.0,0.0,0.0,0.0032525908350200753,900,225,210.6352882385254,0.9361568366156684,0.2340392091539171,0.09941185830367937,0.0025840073977209006,2.9610207560058215,8.98458365437745e-06,0.0,0.0,0.0,0.0,0.0,0.0025840073977209006,450,113,91.6226761341095,0.81082014277973,0.20360594696468778,0.04679510168797147
6
+ 4,0.0028994390259807308,0.27792873476118807,1.2238399814356409e-05,0.0,0.0,0.0,0.0,0.0,0.0028994390259807308,900,225,210.90299940109253,0.9373466640048557,0.23433666600121392,0.10629830273903078,0.004226378290137897,3.139442252464839,2.8896792241464515e-05,0.0,0.0,0.0,0.0,0.0,0.004226378290137897,450,113,92.30420923233032,0.8168514091356666,0.20512046496073405,0.04041025609363167
7
+ 5,0.0029357373788823477,0.35013624048834924,1.5252040759503315e-05,0.0,0.0,0.0,0.0,0.0,0.0029357373788823477,900,225,210.8098328113556,0.9369325902726915,0.23423314756817287,0.10065761231092943,0.003366937771077371,2.7330855187934775,2.1526505657937356e-05,0.0,0.0,0.0,0.0,0.0,0.003366937771077371,450,113,92.45483732223511,0.8181844010817266,0.20545519404941134,0.048643345435137604
8
+ 6,0.0025804581021010463,0.33193543293685124,1.1130432750345e-05,0.0,0.0,0.0,0.0,0.0,0.0025804581021010463,900,225,210.82577991485596,0.9370034662882487,0.23425086657206218,0.10346181529677577,0.0029045950072920986,4.791847205054414,1.7926797626652248e-05,0.0,0.0,0.0,0.0,0.0,0.0029045950072920986,450,113,92.71966814994812,0.8205280367252046,0.20604370699988472,0.04583418705378066
9
+ 7,0.0024444010488999385,0.27924930442016427,1.0806890312169106e-05,0.0,0.0,0.0,0.0,0.0,0.0024444010488999385,900,225,211.59207558631897,0.9404092248280843,0.23510230620702108,0.10225464255238573,0.002720622533104486,3.2791891266000865,1.838483538128186e-05,0.0,0.0,0.0,0.0,0.0,0.002720622533104486,450,113,96.57148313522339,0.8546148950019768,0.21460329585605198,0.04418732333965435
10
+ 8,0.00242780985414154,0.25769682647126535,1.037415603605397e-05,0.0,0.0,0.0,0.0,0.0,0.00242780985414154,900,225,213.52410340309143,0.9489960151248508,0.2372490037812127,0.10494643683855732,0.002887863010659607,1.8791022815315657,2.555802672021092e-05,0.0,0.0,0.0,0.0,0.0,0.002887863010659607,450,113,92.99426054954529,0.8229580579605777,0.20665391233232286,0.059518284711418096
11
+ 9,0.002208985082106665,0.42546326303130916,7.684016633753033e-06,0.0,0.0,0.0,0.0,0.0,0.002208985082106665,900,225,213.6311333179474,0.9494717036353217,0.23736792590883043,0.10293637524772849,0.0024638745025731624,3.782287786237916,1.655159604749657e-05,0.0,0.0,0.0,0.0,0.0,0.0024638745025731624,450,113,97.28867506980896,0.8609617262814953,0.21619705571068656,0.04808065608046965
12
+ 10,0.0020764596655175813,0.3896139937922279,7.334516867581215e-06,0.0,0.0,0.0,0.0,0.0,0.0020764596655175813,900,225,221.03631401062012,0.9823836178249783,0.24559590445624457,0.10016421435814765,0.0025203956082178692,2.2354346593865277,2.0540576969837897e-05,0.0,0.0,0.0,0.0,0.0,0.0025203956082178692,450,113,97.72019529342651,0.8647804893223585,0.2171559895409478,0.057214381701758014
contraceptive/tvae/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
contraceptive/tvae/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fdb996dcc92c0d5229ad2651364afc4a47042431d9167333d68206233446423e
3
+ size 47629899
contraceptive/tvae/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "none", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.73, "gradient_penalty_mode": "NONE", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.05, "head_activation": "softsign", "loss_balancer_beta": 0.67, "loss_balancer_r": 0.943, "tf_activation": "tanh", "dataset_size": 2048, "batch_size": 4, "epochs": 100, "lr_mul": 0.09, "n_warmup_steps": 100, "Optim": "amsgradw", "fixed_role_model": "tvae", "mse_mag": false, "mse_mag_target": 0.65, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "prelu", "tf_d_inner": 512, "tf_n_layers_enc": 3, "tf_n_head": 32, "tf_activation_final": "leakyhardtanh", "tf_num_inds": 128, "ada_d_hid": 1024, "ada_n_layers": 9, "ada_activation": "softsign", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 256, "head_n_layers": 9, "head_n_head": 32, "head_activation_final": "leakyhardsigmoid", "models": ["tvae"], "max_seconds": 3600}
insurance/lct_gan/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ lct_gan,0.0014364243935034166,0.007374788956466156,0.0008153657148691958,6.472357988357544,0.0165903028100729,0.5433363318443298,0.05834271013736725,1.2863068832302815e-06,2.436617374420166,0.017312856391072273,0.7524448037147522,0.028554610908031464,0.1444527804851532,0.00022781474399380386,8.90897536277771
insurance/lct_gan/history.csv ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.03844140047931837,0.7188729810216106,0.013317291441673443,0.0,0.0,0.0,0.0,0.0,0.03844140047931837,900,113,110.54995560646057,0.9783181912076157,0.12283328400717841,0.11123340914800631,0.008303070723971663,2.3807693450168474,9.48858276775326e-05,0.0,0.0,0.0,0.0,0.0,0.008303070723971663,450,57,45.44606399536133,0.7972993683396724,0.10099125332302518,0.034816826579340716
3
+ 1,0.00653058495786455,0.9053095802807317,0.00012669656981902897,0.0,0.0,0.0,0.0,0.0,0.00653058495786455,900,113,110.65184760093689,0.9792198902737778,0.12294649733437432,0.08656557510087712,0.006352552234764315,0.8825915555632886,2.5333444609595783e-05,0.0,0.0,0.0,0.0,0.0,0.006352552234764315,450,57,45.84219408035278,0.8042490189535576,0.10187154240078396,0.04755824741400909
4
+ 2,0.005102309041556307,0.5227623938307345,0.0001850789159238102,0.0,0.0,0.0,0.0,0.0,0.005102309041556307,900,113,111.50907039642334,0.986805932711711,0.12389896710713705,0.09074666479651906,0.004312427961267531,0.17915711068516546,0.0001109809016019134,0.0,0.0,0.0,0.0,0.0,0.004312427961267531,450,57,45.79855394363403,0.8034834025198954,0.10177456431918674,0.07046655134150856
5
+ 3,0.004384625833740251,0.5408107707275669,0.0001490451962195263,0.0,0.0,0.0,0.0,0.0,0.004384625833740251,900,113,111.38601303100586,0.985716929477928,0.12376223670111762,0.08961118024205213,0.004137574709020555,0.30141546075263476,7.981840502843144e-05,0.0,0.0,0.0,0.0,0.0,0.004137574709020555,450,57,45.67794322967529,0.8013674250820226,0.10150654051038954,0.06743310128845144
6
+ 4,0.00368658812250942,0.46921445248119603,2.151566522561455e-05,0.0,0.0,0.0,0.0,0.0,0.00368658812250942,900,113,111.5419237613678,0.9870966704545823,0.12393547084596422,0.09260421899040189,0.0039208269826809155,0.39560824827086744,0.00012955872286843473,0.0,0.0,0.0,0.0,0.0,0.0039208269826809155,450,57,46.108590602874756,0.8089226421556974,0.10246353467305501,0.06722364254937131
7
+ 5,0.0037500500336237665,0.38660638477140624,8.546291791805232e-05,0.0,0.0,0.0,0.0,0.0,0.0037500500336237665,900,113,110.84228444099426,0.980905172044197,0.12315809382332696,0.0917343535847896,0.0035700482581887626,0.11842851034616386,0.0002132376919723574,0.0,0.0,0.0,0.0,0.0,0.0035700482581887626,450,57,48.897982120513916,0.8578593354476126,0.10866218249003093,0.0700296791874918
8
+ 6,0.0031003811301585906,0.2770809364794333,5.226455323332238e-05,0.0,0.0,0.0,0.0,0.0,0.0031003811301585906,900,113,113.76307344436646,1.006752862339526,0.12640341493818494,0.09047676906385253,0.0031285733293690203,0.25899296287125934,6.688645929435053e-05,0.0,0.0,0.0,0.0,0.0,0.0031285733293690203,450,57,45.767940282821655,0.8029463207512572,0.1017065339618259,0.07119475702117932
9
+ 7,0.002322382250438548,0.34646086611721244,1.220678437067165e-05,0.0,0.0,0.0,0.0,0.0,0.002322382250438548,900,113,109.93844747543335,0.9729066148268438,0.12215383052825927,0.09019952334991072,0.002956038533043789,0.4516964087274689,0.00018173267813042553,0.0,0.0,0.0,0.0,0.0,0.002956038533043789,450,57,46.75201344490051,0.8202107621912371,0.10389336321089003,0.07300392768637833
10
+ 8,0.0017050553576781467,0.2169923465751567,1.014655068901434e-05,0.0,0.0,0.0,0.0,0.0,0.0017050553576781467,900,113,114.35391974449158,1.0119815906592176,0.12705991082721285,0.09638774892793292,0.0030709211031595867,0.11738263850605092,0.0002604710703875041,0.0,0.0,0.0,0.0,0.0,0.0030709211031595867,450,57,45.301653146743774,0.7947658446797153,0.10067034032609727,0.07127821202831049
11
+ 9,0.0015025745798872474,0.20518340586134937,6.9952671160537674e-06,0.0,0.0,0.0,0.0,0.0,0.0015025745798872474,900,113,111.49351572990417,0.9866682807956122,0.12388168414433798,0.09498348522239027,0.0029102627569550857,0.35421860132333804,0.00010688639051289546,0.0,0.0,0.0,0.0,0.0,0.0029102627569550857,450,57,44.809088468551636,0.7861243590973971,0.09957575215233697,0.06238056484021639
12
+ 10,0.0015078903491828695,0.20325854836928797,1.2002418886963234e-05,0.0,0.0,0.0,0.0,0.0,0.0015078903491828695,900,113,110.71620535850525,0.9797894279513739,0.12301800595389473,0.09301089175638899,0.0021557584258051874,0.1434766953186761,6.070327296319708e-05,0.0,0.0,0.0,0.0,0.0,0.0021557584258051874,450,57,45.4062922000885,0.7966016175454123,0.10090287155575223,0.07540236165126164
13
+ 11,0.0009480594642486217,0.0989064481881752,1.0000245586616449e-06,0.0,0.0,0.0,0.0,0.0,0.0009480594642486217,900,113,110.84421229362488,0.9809222326869458,0.12316023588180541,0.09785350819274916,0.001960534858582024,0.5104609808862529,4.5300757508451246e-05,0.0,0.0,0.0,0.0,0.0,0.001960534858582024,450,57,45.625545501708984,0.8004481666966489,0.10139010111490886,0.06512660700255972
14
+ 12,0.0007876609438487018,0.10056789984725945,2.650213134493559e-06,0.0,0.0,0.0,0.0,0.0,0.0007876609438487018,900,113,111.02137565612793,0.9824900500542295,0.12335708406236437,0.09717012973156124,0.00240250650444068,0.20876344482377662,0.00015140544695302596,0.0,0.0,0.0,0.0,0.0,0.00240250650444068,450,57,45.895328760147095,0.8051812063183701,0.10198961946699354,0.06988455496499674
15
+ 13,0.001071412844466977,0.08979928175546043,2.8548677461531438e-05,0.0,0.0,0.0,0.0,0.0,0.001071412844466977,900,113,110.40404558181763,0.9770269520514834,0.12267116175757514,0.09673032805785141,0.002080558299649415,0.33820552927165826,1.7519274976611518e-05,0.0,0.0,0.0,0.0,0.0,0.002080558299649415,450,57,45.13530921936035,0.7918475301642167,0.10030068715413411,0.06275787556609302
16
+ 14,0.0009464439855623318,0.11049716685496896,1.5092843134298528e-06,0.0,0.0,0.0,0.0,0.0,0.0009464439855623318,900,113,110.10486841201782,0.9743793664780338,0.1223387426800198,0.0963555359306325,0.0019128535713470126,0.2881092154617132,9.992767801322522e-05,0.0,0.0,0.0,0.0,0.0,0.0019128535713470126,450,57,45.23128914833069,0.7935313885672051,0.10051397588517931,0.07075215538293776
17
+ 15,0.0007565623041591607,0.10640527475413109,6.627467923142204e-07,0.0,0.0,0.0,0.0,0.0,0.0007565623041591607,900,113,110.25560998916626,0.9757133627359846,0.12250623332129584,0.09287066381853239,0.0023708047673830558,0.21841555112681696,0.00019829534488743176,0.0,0.0,0.0,0.0,0.0,0.0023708047673830558,450,57,45.1703155040741,0.7924616755100719,0.10037847889794244,0.0748492696673789
18
+ 16,0.0013394151283298722,0.16507210116501242,1.1373421445702734e-05,0.0,0.0,0.0,0.0,0.0,0.0013394151283298722,900,113,110.16077661514282,0.9748741293375471,0.12240086290571425,0.0946754277148078,0.00335533164971922,0.19024144669157844,1.7741167042891626e-05,0.0,0.0,0.0,0.0,0.0,0.00335533164971922,450,57,45.20561957359314,0.7930810451507568,0.10045693238576253,0.0715425031161622
19
+ 17,0.0012095504338503816,0.17023994228466374,2.78441173262379e-06,0.0,0.0,0.0,0.0,0.0,0.0012095504338503816,900,113,110.36352491378784,0.9766683620689189,0.1226261387930976,0.09928367804505128,0.0016752594984912625,0.32005773657970704,1.843289076943494e-05,0.0,0.0,0.0,0.0,0.0,0.0016752594984912625,450,57,45.38660478591919,0.7962562243143717,0.10085912174648709,0.06659886517088141
20
+ 18,0.0006933588838890298,0.12216246798768343,7.663908191116571e-07,0.0,0.0,0.0,0.0,0.0,0.0006933588838890298,900,113,110.25478553771973,0.9757060667054843,0.12250531726413302,0.09808128885100637,0.0015123529919138592,0.2243189437115992,5.470234152319095e-05,0.0,0.0,0.0,0.0,0.0,0.0015123529919138592,450,57,45.31874370574951,0.795065679048237,0.10070831934611002,0.07162666100224382
21
+ 19,0.0004732113509251374,0.04918161084480092,2.0673360931155154e-07,0.0,0.0,0.0,0.0,0.0,0.0004732113509251374,900,113,110.03912854194641,0.9737975977163399,0.12226569837994046,0.09703757702908684,0.0011618334987562977,0.31888309370694506,6.430942659246202e-07,0.0,0.0,0.0,0.0,0.0,0.0011618334987562977,450,57,45.194119930267334,0.7928792970222339,0.1004313776228163,0.07293911375464839
insurance/lct_gan/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
insurance/lct_gan/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56f2e5d2620ac45edefaad1f1b2207c7976d90edb48ff23aaf22597d30cb38e0
3
+ size 38583573
insurance/lct_gan/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.7, "head_final_mul": "identity", "gradient_penalty_mode": "NONE", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.05, "loss_balancer_beta": 0.79, "loss_balancer_r": 0.95, "dataset_size": 2048, "batch_size": 8, "epochs": 100, "n_warmup_steps": 100, "Optim": "diffgrad", "fixed_role_model": "lct_gan", "mse_mag": false, "mse_mag_target": 0.1, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "leakyrelu", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "leakyhardsigmoid", "tf_activation_final": "leakyhardsigmoid", "tf_num_inds": 32, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "relu", "ada_activation_final": "softsign", "head_d_hid": 128, "head_n_layers": 9, "head_n_head": 64, "head_activation": "prelu", "head_activation_final": "softsign", "models": ["lct_gan"], "max_seconds": 3600}
insurance/realtabformer/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ realtabformer,0.004295671779534695,0.005862241469609823,0.0007627403576736382,4.4447386264801025,0.20265215635299683,9.470953941345215,0.4992390275001526,8.549841936655866e-07,5.64751410484314,0.01658162847161293,0.6326260566711426,0.02761775441467762,0.13868294656276703,0.002371334470808506,10.092252731323242
insurance/realtabformer/history.csv ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.03506428348624872,0.4538130360197343,0.00792232936242202,0.0,0.0,0.0,0.0,0.0,0.03506428348624872,900,113,125.6448929309845,1.111901707353845,0.13960543658998278,0.13160241616115106,0.01397167039879908,0.13293859043912098,0.002445040742022052,0.0,0.0,0.0,0.0,0.0,0.01397167039879908,450,57,51.65371060371399,0.9062054491879648,0.11478602356380886,0.0867383274241563
3
+ 1,0.00870141801217364,0.7070900697936303,0.0004758712758562291,0.0,0.0,0.0,0.0,0.0,0.00870141801217364,900,113,124.63241243362427,1.102941702952427,0.13848045825958252,0.09212413534060516,0.008885490597587907,0.16697826483946382,0.001734984528900567,0.0,0.0,0.0,0.0,0.0,0.008885490597587907,450,57,51.52860236167908,0.9040105677487558,0.11450800524817573,0.07432921211186208
4
+ 2,0.005305013003655606,0.5324694752190876,0.00016715914267360208,0.0,0.0,0.0,0.0,0.0,0.005305013003655606,900,113,124.62815642356873,1.1029040391466258,0.1384757293595208,0.09308656097381515,0.0029972506677990573,0.6092274434202413,5.252186918157018e-05,0.0,0.0,0.0,0.0,0.0,0.0029972506677990573,450,57,51.712756872177124,0.9072413486346864,0.11491723749372694,0.05699627291362144
5
+ 3,0.0040883604651834405,0.7858639722572959,3.0227560953076958e-05,0.0,0.0,0.0,0.0,0.0,0.0040883604651834405,900,113,124.75911545753479,1.1040629686507504,0.13862123939726087,0.08776520042622511,0.006447157069859612,0.8308172208510114,0.0002398531965743563,0.0,0.0,0.0,0.0,0.0,0.006447157069859612,450,57,51.98431992530823,0.9120056127247057,0.11552071094512939,0.05349372850175489
6
+ 4,0.002022715635701186,0.3040784713077333,9.828579557799773e-06,0.0,0.0,0.0,0.0,0.0,0.002022715635701186,900,113,125.08361649513245,1.1069346592489597,0.13898179610570272,0.09222022145656886,0.00254811349499505,0.07651601909258302,2.4787554458056445e-05,0.0,0.0,0.0,0.0,0.0,0.00254811349499505,450,57,51.83201622962952,0.9093336180636757,0.1151822582880656,0.0869178914775451
7
+ 5,0.001798461573101425,0.2503146630212105,1.3495003242444538e-05,0.0,0.0,0.0,0.0,0.0,0.001798461573101425,900,113,125.20910906791687,1.1080452129904148,0.13912123229768542,0.09624636015005872,0.00245836095231223,0.28769085691075147,1.3764144745344905e-05,0.0,0.0,0.0,0.0,0.0,0.00245836095231223,450,57,51.629480838775635,0.905780365592555,0.11473217964172364,0.07527359396529694
8
+ 6,0.0013259276232889129,0.23208431061799636,4.7836219041035795e-06,0.0,0.0,0.0,0.0,0.0,0.0013259276232889129,900,113,124.9141948223114,1.1054353524098355,0.13879354980256822,0.09697889212716733,0.0033668742331469225,0.8360707927999607,0.00017527166296598044,0.0,0.0,0.0,0.0,0.0,0.0033668742331469225,450,57,52.27232575416565,0.9170583465643096,0.11616072389814588,0.05769461206683334
9
+ 7,0.001323487046950807,0.2029421608147736,4.776099965856173e-06,0.0,0.0,0.0,0.0,0.0,0.001323487046950807,900,113,127.35284996032715,1.1270163713303287,0.14150316662258572,0.09650758934100118,0.0012376802931086989,0.1487765556306695,3.4437281075641283e-06,0.0,0.0,0.0,0.0,0.0,0.0012376802931086989,450,57,53.573572397232056,0.9398872350391588,0.11905238310496012,0.07214564077654168
10
+ 8,0.0009430775794155327,0.13147762805823807,2.152724269898497e-06,0.0,0.0,0.0,0.0,0.0,0.0009430775794155327,900,113,127.76970887184143,1.1307053882463844,0.14196634319093493,0.09628061390589032,0.0007872949669318688,0.18410880861784518,3.336931345282087e-06,0.0,0.0,0.0,0.0,0.0,0.0007872949669318688,450,57,53.49133658409119,0.938444501475284,0.11886963685353596,0.07019469502643404
11
+ 9,0.0009089046284659869,0.097098820553298,9.63799082052168e-06,0.0,0.0,0.0,0.0,0.0,0.0009089046284659869,900,113,127.74243187904358,1.1304639989295893,0.14193603542115954,0.09850719976609787,0.0009410815907176584,0.7538224018950778,2.0871374143320728e-06,0.0,0.0,0.0,0.0,0.0,0.0009410815907176584,450,57,53.4216194152832,0.9372213932505825,0.11871470981174045,0.06436136332616259
12
+ 10,0.0005724812648905855,0.12220282757398657,1.0156194488626337e-06,0.0,0.0,0.0,0.0,0.0,0.0005724812648905855,900,113,127.79642581939697,1.1309418214105926,0.14199602868821887,0.09549391159243815,0.0011512360008944396,0.48300685461102216,1.8063996093559748e-06,0.0,0.0,0.0,0.0,0.0,0.0011512360008944396,450,57,53.28672909736633,0.9348548964450234,0.11841495354970297,0.07293740793931902
13
+ 11,0.000575461219123099,0.07051163187836926,1.427012848514419e-06,0.0,0.0,0.0,0.0,0.0,0.000575461219123099,900,113,127.76422357559204,1.1306568458016995,0.1419602484173245,0.09717198909647697,0.0008690591479858591,0.8044250670493349,2.2709774584184216e-07,0.0,0.0,0.0,0.0,0.0,0.0008690591479858591,450,57,53.31593942642212,0.9353673583582828,0.11847986539204915,0.06398673398264994
14
+ 12,0.0005083383906619727,0.0760834188396614,3.377694875707717e-07,0.0,0.0,0.0,0.0,0.0,0.0005083383906619727,900,113,127.2828004360199,1.1263964640355744,0.1414253338177999,0.09947530912087027,0.0012337127990516214,0.6099953325571875,5.033288178477591e-06,0.0,0.0,0.0,0.0,0.0,0.0012337127990516214,450,57,53.24035835266113,0.93404137460809,0.11831190745035808,0.0727369972022675
15
+ 13,0.0008676275550452474,0.0859259017737159,6.595443268466302e-07,0.0,0.0,0.0,0.0,0.0,0.0008676275550452474,900,113,127.27700209617615,1.1263451512935942,0.1414188912179735,0.09762696997649901,0.0010123544972803857,0.5137522601167133,2.4046812298386368e-06,0.0,0.0,0.0,0.0,0.0,0.0010123544972803857,450,57,53.15213441848755,0.9324935862892553,0.11811585426330566,0.07203809639618716
16
+ 14,0.00046521230586222374,0.06125880159901201,3.3935850183466135e-07,0.0,0.0,0.0,0.0,0.0,0.00046521230586222374,900,113,127.70906496047974,1.1301687164644225,0.1418989610671997,0.09863226325045117,0.0008455607482270958,0.9153862733114074,1.1769104552178559e-05,0.0,0.0,0.0,0.0,0.0,0.0008455607482270958,450,57,53.55443787574768,0.9395515416797838,0.11900986194610595,0.07057310182503179
17
+ 15,0.0004456719228215257,0.04856209839710041,4.839545050809473e-07,0.0,0.0,0.0,0.0,0.0,0.0004456719228215257,900,113,127.83944082260132,1.131322485155764,0.14204382313622368,0.09787282459767518,0.000815138778561959,0.3977327043629637,1.0837310133676681e-05,0.0,0.0,0.0,0.0,0.0,0.000815138778561959,450,57,53.49911832809448,0.9385810232999032,0.11888692961798773,0.06857214915860248
insurance/realtabformer/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
insurance/realtabformer/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29e1d4cdba4270251db8b7f22a9c1763e7f4475a12e215667377096e4e7e0f64
3
+ size 43505805
insurance/realtabformer/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.7, "head_final_mul": "identity", "gradient_penalty_mode": "NONE", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.05, "loss_balancer_beta": 0.79, "loss_balancer_r": 0.95, "dataset_size": 2048, "batch_size": 8, "epochs": 100, "n_warmup_steps": 100, "Optim": "diffgrad", "fixed_role_model": "realtabformer", "mse_mag": false, "mse_mag_target": 0.1, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "leakyrelu", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "leakyhardsigmoid", "tf_activation_final": "leakyhardsigmoid", "tf_num_inds": 32, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "relu", "ada_activation_final": "softsign", "head_d_hid": 128, "head_n_layers": 9, "head_n_head": 64, "head_activation": "prelu", "head_activation_final": "softsign", "models": ["realtabformer"], "max_seconds": 3600}
insurance/tab_ddpm_concat/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ tab_ddpm_concat,0.010057446495336643,,0.018736911170362008,6.4715576171875,0.08312908560037613,0.9944403171539307,0.13684093952178955,9.326592589786742e-06,2.3459200859069824,0.08326641470193863,3.0670642852783203,0.13688284158706665,0.04763114079833031,0.8643301725387573,8.817477703094482
insurance/tab_ddpm_concat/history.csv ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.016726151166690722,6.681427996787429,0.0007995499490898637,0.0,0.0,0.0,0.0,0.0,0.016726151166690722,900,113,113.29283952713013,1.0025915002400896,0.12588093280792237,0.04266167516960243,0.013463321674304704,5.035531684387279,0.0003260195470640307,0.0,0.0,0.0,0.0,0.0,0.013463321674304704,450,57,46.18605709075928,0.810281703346654,0.10263568242390951,0.02743101706564949
3
+ 1,0.013343319840108355,7.374786443262831,0.00018221230446668457,0.0,0.0,0.0,0.0,0.0,0.013343319840108355,900,113,113.2004017829895,1.0017734671061018,0.12577822420332166,0.033394961601403435,0.013095296140398002,4.660537766719707,0.0003757634961128373,0.0,0.0,0.0,0.0,0.0,0.013095296140398002,450,57,46.44777750968933,0.8148732896436724,0.10321728335486519,0.0280298860416862
4
+ 2,0.013104651555832889,5.731545712810247,0.00016133929766812432,0.0,0.0,0.0,0.0,0.0,0.013104651555832889,900,113,113.58225321769714,1.0051526833424527,0.12620250357521906,0.03886226541568748,0.01287688842560682,3.2905811405824417,0.0003377204293245642,0.0,0.0,0.0,0.0,0.0,0.01287688842560682,450,57,47.43314838409424,0.8321604979665655,0.10540699640909831,0.03367136853436629
5
+ 3,0.013487551987895535,6.5134642504676314,0.0002293419825039046,0.0,0.0,0.0,0.0,0.0,0.013487551987895535,900,113,112.98276948928833,0.9998475176043216,0.1255364105436537,0.037405478042772916,0.013358833422470424,6.6747851123451865,0.000459600389062934,0.0,0.0,0.0,0.0,0.0,0.013358833422470424,450,57,46.22313857078552,0.8109322556278162,0.10271808571285672,0.020856579136626238
6
+ 4,0.013331625766845214,6.895425937681058,0.00015765312878396825,0.0,0.0,0.0,0.0,0.0,0.013331625766845214,900,113,112.3973479270935,0.9946667958149867,0.12488594214121501,0.03705255519104215,0.015463804398766823,5.035077505065904,0.0007592894264012608,0.0,0.0,0.0,0.0,0.0,0.015463804398766823,450,57,46.07612347602844,0.8083530434390955,0.10239138550228542,0.03493314252741504
7
+ 5,0.013444565045129921,6.88672233373744,0.0002211289928355282,0.0,0.0,0.0,0.0,0.0,0.013444565045129921,900,113,108.1660463809967,0.9572216493893514,0.12018449597888523,0.03778205175710991,0.013861719023229347,5.471305163596984,0.0005421772549703974,0.0,0.0,0.0,0.0,0.0,0.013861719023229347,450,57,44.389360666275024,0.7787607134434215,0.09864302370283339,0.025694214632701978
8
+ 6,0.013114380333055224,5.976081524419583,0.00027622638839541855,0.0,0.0,0.0,0.0,0.0,0.013114380333055224,900,113,111.54308128356934,0.987106914013888,0.12393675698174371,0.038547951167663644,0.012811201657168567,3.3985922191304563,0.00027678651687506,0.0,0.0,0.0,0.0,0.0,0.012811201657168567,450,57,45.941070795059204,0.8059836981589334,0.1020912684334649,0.03597862523441252
9
+ 7,0.013140132142644789,6.532687967344641,0.00014520771743971967,0.0,0.0,0.0,0.0,0.0,0.013140132142644789,900,113,108.82160568237305,0.9630230591360447,0.12091289520263672,0.03722985821520596,0.012903139407539533,3.578145878625659,0.0003883562163952724,0.0,0.0,0.0,0.0,0.0,0.012903139407539533,450,57,44.52135992050171,0.7810764898333633,0.09893635537889268,0.03204377959564067
10
+ 8,0.012976793174942335,5.416468192214529,0.00026929119924884524,0.0,0.0,0.0,0.0,0.0,0.012976793174942335,900,113,110.61598777770996,0.9789025467053979,0.1229066530863444,0.04178034744134783,0.013890320318751037,3.7336453318763363,0.0005463759208261207,0.0,0.0,0.0,0.0,0.0,0.013890320318751037,450,57,45.11232590675354,0.7914443141535709,0.10024961312611898,0.037292727285571266
11
+ 9,0.013011944204982783,6.452372519567247,0.0001708998166064858,0.0,0.0,0.0,0.0,0.0,0.013011944204982783,900,113,108.52205491065979,0.9603721673509716,0.12058006101184421,0.036126165387047604,0.01294027495249692,4.294494187321249,0.0004066406136260209,0.0,0.0,0.0,0.0,0.0,0.01294027495249692,450,57,45.50484013557434,0.7983305286942867,0.10112186696794298,0.029192425953959555
12
+ 10,0.01281262576735268,5.773519790146028,0.00017255405067423624,0.0,0.0,0.0,0.0,0.0,0.01281262576735268,900,113,110.79968285560608,0.9805281668637706,0.1231107587284512,0.04245836566721575,0.012781418984652394,2.9514776281277206,0.00034504475288883897,0.0,0.0,0.0,0.0,0.0,0.012781418984652394,450,57,46.79314422607422,0.8209323548434073,0.10398476494683159,0.03811844466907675
13
+ 11,0.013128257239651348,5.248821901264445,0.00028772900066602206,0.0,0.0,0.0,0.0,0.0,0.013128257239651348,900,113,112.41662859916687,0.9948374212315653,0.12490736511018541,0.039066291603762494,0.012926328503009345,4.4593022893954055,0.00044354169371757735,0.0,0.0,0.0,0.0,0.0,0.012926328503009345,450,57,46.57635951042175,0.8171291142179254,0.10350302113427057,0.02754177166998648
14
+ 12,0.013119876470623746,5.730330355907093,0.00015690460682182876,0.0,0.0,0.0,0.0,0.0,0.013119876470623746,900,113,113.1574444770813,1.0013933139564717,0.12573049386342366,0.04159025917142893,0.013785471472785705,4.6235486922847935,0.0005177017974133631,0.0,0.0,0.0,0.0,0.0,0.013785471472785705,450,57,45.30210542678833,0.7947737794173392,0.10067134539286296,0.029443472623825073
15
+ 13,0.012679316052235663,5.353459168564728,0.00018473560166641948,0.0,0.0,0.0,0.0,0.0,0.012679316052235663,900,113,109.81215143203735,0.9717889507259942,0.12201350159115261,0.03975286655358772,0.01452305714185867,6.82052209327945,0.0006187541751139886,0.0,0.0,0.0,0.0,0.0,0.01452305714185867,450,57,45.597567081451416,0.7999573172184459,0.10132792684766981,0.028516328191025217
16
+ 14,0.012685477107038929,4.885799459394998,9.410771968618429e-05,0.0,0.0,0.0,0.0,0.0,0.012685477107038929,900,113,112.43323254585266,0.9949843588128554,0.12492581393983629,0.043478165464723,0.013303691690218531,4.612162083723674,0.0004478314968633236,0.0,0.0,0.0,0.0,0.0,0.013303691690218531,450,57,46.16491627693176,0.8099108118759958,0.10258870283762614,0.030379774735162134
17
+ 15,0.012951436785774099,5.6180261968808605,0.0001260765955767044,0.0,0.0,0.0,0.0,0.0,0.012951436785774099,900,113,112.08748507499695,0.9919246466813889,0.12454165008332994,0.0389383625469904,0.013147126083624446,5.849595955204062,0.00047660569695545343,0.0,0.0,0.0,0.0,0.0,0.013147126083624446,450,57,46.19347643852234,0.8104118673424971,0.10265216986338298,0.024712568949581237
18
+ 16,0.013076940375483698,6.462012599512772,0.0003414776301815831,0.0,0.0,0.0,0.0,0.0,0.013076940375483698,900,113,112.6052417755127,0.9965065643850681,0.12511693530612522,0.03943786635764141,0.013706634990457031,2.6824540562413044,0.00021449473134149837,0.0,0.0,0.0,0.0,0.0,0.013706634990457031,450,57,46.32444953918457,0.8127096410383258,0.10294322119818794,0.04417712587797851
19
+ 17,0.013131085654927625,7.07163172544333,0.00013269016459129516,0.0,0.0,0.0,0.0,0.0,0.013131085654927625,900,113,112.05166149139404,0.9916076238176464,0.12450184610154894,0.03639758312332947,0.013306772463385843,3.5713896107044203,0.0004134395122505591,0.0,0.0,0.0,0.0,0.0,0.013306772463385843,450,57,45.87373447418213,0.8048023591961777,0.10194163216484918,0.03344645112622203
20
+ 18,0.013018570302778648,5.669214043441515,0.00020193976867753112,0.0,0.0,0.0,0.0,0.0,0.013018570302778648,900,113,111.77040791511536,0.9891186541160651,0.12418934212790596,0.03950189917752173,0.013145547265497347,4.102805095807446,0.00038455364744524763,0.0,0.0,0.0,0.0,0.0,0.013145547265497347,450,57,45.8624849319458,0.8046049988060667,0.10191663318210178,0.034161844544047325
21
+ 19,0.012822051873275389,5.499414050482283,0.00015386838247107923,0.0,0.0,0.0,0.0,0.0,0.012822051873275389,900,113,111.74085307121277,0.9888571068248918,0.12415650341245864,0.03887049661767957,0.012638136894804322,4.0768485525334475,0.0003349378748481513,0.0,0.0,0.0,0.0,0.0,0.012638136894804322,450,57,46.025766372680664,0.8074695854856256,0.10227948082817925,0.031167056632081146
22
+ 20,0.012691254431588783,5.174604149312652,5.650333046863428e-05,0.0,0.0,0.0,0.0,0.0,0.012691254431588783,900,113,111.79648971557617,0.9893494665095236,0.12421832190619575,0.041397298776866064,0.012594756139959726,3.085094394737736,0.0002528211896517658,0.0,0.0,0.0,0.0,0.0,0.012594756139959726,450,57,46.05099582672119,0.8079122074863367,0.10233554628160264,0.03782708470693283
23
+ 21,0.012717143062295185,5.476102162761738,0.00015562177650104382,0.0,0.0,0.0,0.0,0.0,0.012717143062295185,900,113,111.75997877120972,0.9890263608071657,0.12417775419023303,0.04065807265089413,0.012463197727791137,3.047649189417655,0.0003119605865624761,0.0,0.0,0.0,0.0,0.0,0.012463197727791137,450,57,45.863757610321045,0.8046273264968604,0.10191946135626899,0.03551346450848015
insurance/tab_ddpm_concat/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
insurance/tab_ddpm_concat/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eee2095aa46c24eef37ab46962474cd61a69f63df282cbec474316fad5cb6434
3
+ size 38514197
insurance/tab_ddpm_concat/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.7, "head_final_mul": "identity", "gradient_penalty_mode": "NONE", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.05, "loss_balancer_beta": 0.79, "loss_balancer_r": 0.95, "dataset_size": 2048, "batch_size": 8, "epochs": 100, "n_warmup_steps": 100, "Optim": "diffgrad", "fixed_role_model": "tab_ddpm_concat", "mse_mag": false, "mse_mag_target": 0.1, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "leakyrelu", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "leakyhardsigmoid", "tf_activation_final": "leakyhardsigmoid", "tf_num_inds": 32, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "relu", "ada_activation_final": "softsign", "head_d_hid": 128, "head_n_layers": 9, "head_n_head": 64, "head_activation": "prelu", "head_activation_final": "softsign", "models": ["tab_ddpm_concat"], "max_seconds": 3600}
insurance/tvae/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ tvae,0.006532281896756364,0.006128870656899121,0.0005013272912696092,6.598551273345947,0.004663586150854826,0.35133224725723267,0.013413011096417904,4.282439931557747e-06,2.3958842754364014,0.013242576271295547,0.5489071607589722,0.02239033952355385,0.14861759543418884,2.2368631107383408e-06,8.994435548782349
insurance/tvae/history.csv ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.043772892344924104,0.9691945454827052,0.020741946386080963,0.0,0.0,0.0,0.0,0.0,0.043772892344924104,900,113,111.33565783500671,0.9852713082743957,0.1237062864833408,0.11609506887276616,0.007104319949220452,0.48655429802302186,0.002825021496607657,0.0,0.0,0.0,0.0,0.0,0.007104319949220452,450,57,46.08293104171753,0.808472474416097,0.10240651342603896,0.062048352895337236
3
+ 1,0.005487369931975587,0.7738947147227733,0.0006779940362103653,0.0,0.0,0.0,0.0,0.0,0.005487369931975587,900,113,112.22859740257263,0.9931734283413507,0.12469844155841403,0.08416169975777116,0.006982278515998688,0.4736000236227578,0.0015233677555578614,0.0,0.0,0.0,0.0,0.0,0.006982278515998688,450,57,46.10599446296692,0.8088770958415249,0.10245776547325983,0.05901416842090456
4
+ 2,0.0028347509825188254,0.3354986571064682,1.5617633661304407e-05,0.0,0.0,0.0,0.0,0.0,0.0028347509825188254,900,113,111.97222852706909,0.9909046772307,0.124413587252299,0.09286229135221348,0.0037852605881117697,0.25047095601733244,0.0007742868338338995,0.0,0.0,0.0,0.0,0.0,0.0037852605881117697,450,57,46.205010652542114,0.810614221974423,0.10267780145009359,0.07840819853325293
5
+ 3,0.0019783554043776045,0.21749820665938444,1.8945457476570026e-05,0.0,0.0,0.0,0.0,0.0,0.0019783554043776045,900,113,111.77877640724182,0.9891927115685117,0.12419864045249092,0.09672643638224201,0.003601218523528789,0.548966435297753,0.0004845707171051153,0.0,0.0,0.0,0.0,0.0,0.003601218523528789,450,57,45.79237937927246,0.8033750768293414,0.10176084306504991,0.060725935640859235
6
+ 4,0.0015077938184711254,0.1864952617832508,1.1623385739693803e-05,0.0,0.0,0.0,0.0,0.0,0.0015077938184711254,900,113,111.0480124950409,0.9827257742923973,0.12338668055004544,0.09504026204215742,0.002632722834694303,0.21741927998951283,0.00043822755775758303,0.0,0.0,0.0,0.0,0.0,0.002632722834694303,450,57,45.50333523750305,0.7983041269737378,0.10111852275000678,0.07700717940583433
7
+ 5,0.0007533894496090296,0.0954851591936076,1.1907169007617899e-06,0.0,0.0,0.0,0.0,0.0,0.0007533894496090296,900,113,112.60472917556763,0.9965020281023684,0.1251163657506307,0.09669116178972531,0.0019194719764669167,0.1813253381316862,8.795860155298997e-05,0.0,0.0,0.0,0.0,0.0,0.0019194719764669167,450,57,47.84044623374939,0.8393060742763051,0.1063121027416653,0.07234944092322207
8
+ 6,0.0008143671194233724,0.07775253331788894,2.0004750451023264e-06,0.0,0.0,0.0,0.0,0.0,0.0008143671194233724,900,113,113.22330784797668,1.0019761756458114,0.12580367538664075,0.09869995380265523,0.0025437849643316843,0.1558871896347034,0.0004921607287944122,0.0,0.0,0.0,0.0,0.0,0.0025437849643316843,450,57,45.85625100135803,0.8044956316027725,0.10190278000301785,0.07156802042951121
9
+ 7,0.0006387245669288354,0.05651727137908936,2.687618409259153e-06,0.0,0.0,0.0,0.0,0.0,0.0006387245669288354,900,113,111.7732937335968,0.9891441923327151,0.12419254859288534,0.09920014948707766,0.0021676956941114947,0.39689580948561143,0.00020905914819549176,0.0,0.0,0.0,0.0,0.0,0.0021676956941114947,450,57,45.570988178253174,0.7994910206711083,0.10126886261834038,0.06560296467668786
10
+ 8,0.0005393729344682975,0.09357239947843243,5.010212360393345e-07,0.0,0.0,0.0,0.0,0.0,0.0005393729344682975,900,113,112.51995134353638,0.9957517818012068,0.12502216815948486,0.09689272767493287,0.0020265914760804866,0.585885123172634,0.00028776428296029374,0.0,0.0,0.0,0.0,0.0,0.0020265914760804866,450,57,47.63360095024109,0.8356772096533525,0.10585244655609131,0.07264613465757289
11
+ 9,0.00030413587897783147,0.03623591635612869,6.036034924474541e-08,0.0,0.0,0.0,0.0,0.0,0.00030413587897783147,900,113,113.07414937019348,1.0006561891167565,0.12563794374465942,0.09631216084271406,0.001967307192001802,0.5783613658454081,0.00021626046908912677,0.0,0.0,0.0,0.0,0.0,0.001967307192001802,450,57,45.948439598083496,0.8061129754049736,0.10210764355129666,0.0713601329470086
12
+ 10,0.0002241133607943387,0.02460616329780395,5.1179004124437636e-08,0.0,0.0,0.0,0.0,0.0,0.0002241133607943387,900,113,111.44915223121643,0.9862756834620923,0.12383239136801825,0.09834809370536719,0.001825745276728412,0.44179768500308064,0.00015850803496538857,0.0,0.0,0.0,0.0,0.0,0.001825745276728412,450,57,45.60394787788391,0.8000692610155072,0.10134210639529757,0.06884222721942422
13
+ 11,0.00033303717160985496,0.03369398813643545,2.4871881194393203e-07,0.0,0.0,0.0,0.0,0.0,0.00033303717160985496,900,113,111.49134683609009,0.9866490870450451,0.12387927426232231,0.10159270219768571,0.0023319287935414145,0.2600294268881457,0.00038547293633122223,0.0,0.0,0.0,0.0,0.0,0.0023319287935414145,450,57,45.5991792678833,0.7999856011909351,0.10133150948418511,0.07801991980522871
14
+ 12,0.0002562747773643221,0.028784445838175392,2.593463381772865e-08,0.0,0.0,0.0,0.0,0.0,0.0002562747773643221,900,113,111.06798338890076,0.9829025078663783,0.12340887043211195,0.09953929795430297,0.001840519470117417,0.3625226045031678,0.00015924821697381686,0.0,0.0,0.0,0.0,0.0,0.001840519470117417,450,57,45.293455839157104,0.7946220322659141,0.10065212408701579,0.07107936053969816
15
+ 13,0.0001757131654206508,0.018957749903747066,4.845866322408336e-08,0.0,0.0,0.0,0.0,0.0,0.0001757131654206508,900,113,110.83889031410217,0.9808751355230281,0.12315432257122463,0.09919166324281059,0.0019331971814648974,0.6399211409059746,0.00020584285959550467,0.0,0.0,0.0,0.0,0.0,0.0019331971814648974,450,57,45.47143840789795,0.7977445334718939,0.10104764090643989,0.07058916541363783
16
+ 14,0.00021455827955732174,0.031199176489886352,2.883131269934558e-08,0.0,0.0,0.0,0.0,0.0,0.00021455827955732174,900,113,110.77680969238281,0.9803257494901134,0.12308534410264757,0.10197576969466378,0.0018430328381752285,0.4276724044195193,0.00019247611021483608,0.0,0.0,0.0,0.0,0.0,0.0018430328381752285,450,57,45.94011998176575,0.8059670172239605,0.10208915551503499,0.07040622951707949
17
+ 15,0.0004361093971238006,0.05909316669737025,2.2801978405284285e-07,0.0,0.0,0.0,0.0,0.0,0.0004361093971238006,900,113,112.09613871574402,0.992001227572956,0.12455126523971558,0.1005768165882446,0.001964426144291388,0.25929170938355406,0.00015272240531687408,0.0,0.0,0.0,0.0,0.0,0.001964426144291388,450,57,45.62297987937927,0.8004031557785837,0.10138439973195394,0.06955495999570478
18
+ 16,0.0001958047746868235,0.016338927822493122,4.4443133083278144e-08,0.0,0.0,0.0,0.0,0.0,0.0001958047746868235,900,113,111.31650757789612,0.9851018369725321,0.12368500841988457,0.09745020749030915,0.002082800301011755,0.7195598722440445,0.0001506238370249388,0.0,0.0,0.0,0.0,0.0,0.002082800301011755,450,57,45.92464828491211,0.8056955839458265,0.10205477396647135,0.0661866773580346
19
+ 17,0.0003942243970474616,0.05046572757249398,1.0653629163880412e-07,0.0,0.0,0.0,0.0,0.0,0.0003942243970474616,900,113,112.87545800209045,0.9988978584255792,0.1254171755578783,0.09861106915086244,0.0017001370618042225,0.6304228919395235,0.0001222305167977231,0.0,0.0,0.0,0.0,0.0,0.0017001370618042225,450,57,48.029250144958496,0.8426184235957631,0.10673166698879666,0.07135208257424988
insurance/tvae/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
insurance/tvae/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63ada6dad85f692025cc8aa68c63fcd8c153ce0ce7eeb74552993acf519f6e26
3
+ size 38612117
insurance/tvae/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.7, "head_final_mul": "identity", "gradient_penalty_mode": "NONE", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.05, "loss_balancer_beta": 0.79, "loss_balancer_r": 0.95, "dataset_size": 2048, "batch_size": 8, "epochs": 100, "n_warmup_steps": 100, "Optim": "diffgrad", "fixed_role_model": "tvae", "mse_mag": false, "mse_mag_target": 0.1, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "leakyrelu", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "leakyhardsigmoid", "tf_activation_final": "leakyhardsigmoid", "tf_num_inds": 32, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "relu", "ada_activation_final": "softsign", "head_d_hid": 128, "head_n_layers": 9, "head_n_head": 64, "head_activation": "prelu", "head_activation_final": "softsign", "models": ["tvae"], "max_seconds": 3600}
treatment/lct_gan/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ lct_gan,0.0,0.013995862244785677,0.004479979896409315,11.700660228729248,0.10576117038726807,1.955125331878662,0.1900063008069992,5.562088335864246e-05,6.2541420459747314,0.04323010519146919,2659710.75,0.06693265587091446,0.24518649280071259,1.0283116580467322e-07,17.95480227470398
treatment/lct_gan/history.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.1568385149637389,35.06780496466763,0.06735080731399752,0.0,0.0,0.0,0.0,0.0,0.1568385149637389,900,225,251.36616969108582,1.1171829764048258,0.27929574410120644,0.17218083781136656,0.016657082374949443,0.6645107248224963,0.0009922695764642195,0.0,0.0,0.0,0.0,0.0,0.016657082374949443,450,113,88.97782921791077,0.7874144178576175,0.19772850937313505,0.12430944651547114
3
+ 1,0.013602297218005535,0.7166610674290971,0.000675762991130306,0.0,0.0,0.0,0.0,0.0,0.013602297218005535,900,225,252.238872051239,1.1210616535610622,0.28026541339026556,0.2211370967532922,0.010051749147562027,1.011745478109109,0.0004482628718359983,0.0,0.0,0.0,0.0,0.0,0.010051749147562027,450,113,89.63172817230225,0.7932011342681615,0.19918161816067165,0.13382232088393575
4
+ 2,0.009725581167071545,0.12862642994658147,0.00025997880217867837,0.0,0.0,0.0,0.0,0.0,0.009725581167071545,900,225,252.7475070953369,1.1233222537570529,0.2808305634392632,0.23383142546362554,0.009780957170489981,2.8031414232013194,0.0003623511504917325,0.0,0.0,0.0,0.0,0.0,0.009780957170489981,450,113,89.94277811050415,0.7959537885885323,0.1998728402455648,0.12263423703156696
5
+ 3,0.006879980324560569,0.1529390263434276,0.0001266732238959002,0.0,0.0,0.0,0.0,0.0,0.006879980324560569,900,225,253.08241868019104,1.124810749689738,0.2812026874224345,0.23675976120452913,0.009132593497263567,2.349483191307895,0.00031874165080905583,0.0,0.0,0.0,0.0,0.0,0.009132593497263567,450,113,89.89717721939087,0.7955502408795652,0.1997715049319797,0.12599473668269962
6
+ 4,0.005557837531780226,0.10897728779356435,0.00010824206081475525,0.0,0.0,0.0,0.0,0.0,0.005557837531780226,900,225,252.57941794395447,1.1225751908620198,0.28064379771550496,0.24095872966145787,0.007977272890635857,1.5098599609827625,0.00018786837228384554,0.0,0.0,0.0,0.0,0.0,0.007977272890635857,450,113,89.83993029594421,0.7950436309375594,0.1996442895465427,0.13417804435596584
7
+ 5,0.004936696493952898,0.09284011989876738,9.488825836738386e-05,0.0,0.0,0.0,0.0,0.0,0.004936696493952898,900,225,252.75963163375854,1.1233761405944824,0.2808440351486206,0.2406902328133583,0.009863585224850592,3.727565739812436,0.0004114330789984428,0.0,0.0,0.0,0.0,0.0,0.009863585224850592,450,113,89.957528591156,0.7960843238155398,0.1999056190914578,0.1339049061824502
8
+ 6,0.003923159539046234,0.060394366573882884,7.543663868940605e-05,0.0,0.0,0.0,0.0,0.0,0.003923159539046234,900,225,253.02440571784973,1.1245529143015545,0.2811382285753886,0.23062875824508308,0.006808710179325948,1.3504862622616833,0.00019859803310097395,0.0,0.0,0.0,0.0,0.0,0.006808710179325948,450,113,89.9857177734375,0.7963337856056416,0.19996826171875,0.1247912253736337
9
+ 7,0.0027492816115166838,0.13162874239411848,4.767503536376824e-05,0.0,0.0,0.0,0.0,0.0,0.0027492816115166838,900,225,253.01451897621155,1.1245089732276068,0.2811272433069017,0.23854128369026714,0.01000592881146531,3.7195748066041157,0.0004304353354714678,0.0,0.0,0.0,0.0,0.0,0.01000592881146531,450,113,90.18618249893188,0.7981078097250609,0.20041373888651529,0.1279042437419486
10
+ 8,0.0024273904216483466,0.11162616240413387,4.6507011236350724e-05,0.0,0.0,0.0,0.0,0.0,0.0024273904216483466,900,225,252.93700242042542,1.1241644552018908,0.2810411138004727,0.24135627153213135,0.006604909333373143,1.494413609870378,0.00017382205726341554,0.0,0.0,0.0,0.0,0.0,0.006604909333373143,450,113,90.14172339439392,0.79771436632207,0.20031494087643092,0.11940927416499832
11
+ 9,0.001791998714582203,0.02997249101643904,3.2543868669724134e-05,0.0,0.0,0.0,0.0,0.0,0.001791998714582203,900,225,253.0452380180359,1.1246455023023818,0.28116137557559545,0.2516257931456332,0.006831648460120555,2.186309788208537,0.0002196008244519274,0.0,0.0,0.0,0.0,0.0,0.006831648460120555,450,113,89.89717698097229,0.7955502387696662,0.19977150440216065,0.13539934268081083
treatment/lct_gan/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
treatment/lct_gan/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ecae4001da9abe16fe2126bd7970b74a51e8b787afd26d4400d57c64385c710
3
+ size 74778241
treatment/lct_gan/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "torch", "grad_clip": 0.8, "gradient_penalty_mode": "NONE", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.1, "loss_balancer_beta": 0.73, "loss_balancer_r": 0.94, "dataset_size": 2048, "batch_size": 4, "epochs": 100, "lr_mul": 0.04, "n_warmup_steps": 220, "Optim": "diffgrad", "fixed_role_model": "lct_gan", "mse_mag": false, "mse_mag_target": 0.2, "mse_mag_multiply": false, "d_model": 512, "attn_activation": "leakyhardsigmoid", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "leakyhardtanh", "tf_activation_final": "leakyhardtanh", "tf_num_inds": 64, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "selu", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 128, "head_n_layers": 8, "head_n_head": 64, "head_activation": "leakyhardsigmoid", "head_activation_final": "leakyhardsigmoid", "models": ["lct_gan"], "max_seconds": 3600}
treatment/realtabformer/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ realtabformer,1.3982927692787988e-05,0.396190471649171,0.004820992158222723,6.060643196105957,0.4483191967010498,7.328680992126465,1.377678632736206,6.148087413748726e-05,27.904374599456787,0.04260999336838722,5612931.0,0.06943336874246597,0.23635642230510712,0.0008501424454152584,33.965017795562744
treatment/realtabformer/history.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.07077005374701326,7.7690568529201345,0.028255754237470072,0.0,0.0,0.0,0.0,0.0,0.07077005374701326,900,450,490.3524549007416,1.0896721220016479,0.5448360610008239,0.15626739564745068,0.007018332518638089,5.231984866959922,0.00025027377600039385,0.0,0.0,0.0,0.0,0.0,0.007018332518638089,450,225,197.8156943321228,0.8791808636983236,0.4395904318491618,0.09444081523999052
3
+ 1,0.009379093581180642,2.6030105899689175,0.0004631821013883831,0.0,0.0,0.0,0.0,0.0,0.009379093581180642,900,450,485.68966460227966,1.0793103657828438,0.5396551828914219,0.2023497386896573,0.00616615310408106,7.329303962527855,0.00015382666007410638,0.0,0.0,0.0,0.0,0.0,0.00616615310408106,450,225,198.96484756469727,0.8842882113986545,0.44214410569932727,0.07584491885941583
4
+ 2,0.007078851991594096,2.0421946790576384,0.00021883253530301686,0.0,0.0,0.0,0.0,0.0,0.007078851991594096,900,450,487.61613154411316,1.0835914034313625,0.5417957017156813,0.1953078424528414,0.0064364255767268555,6.630073926702688,0.00046532210426971943,0.0,0.0,0.0,0.0,0.0,0.0064364255767268555,450,225,197.56161880493164,0.8780516391330295,0.43902581956651476,0.0723328961941095
5
+ 3,0.005932264231048244,2.1949213939262298,0.0001685581630674178,0.0,0.0,0.0,0.0,0.0,0.005932264231048244,900,450,485.7743980884552,1.0794986624187892,0.5397493312093946,0.1975122331425367,0.005354185296421948,6.1442645163203204,0.0003165367537867994,0.0,0.0,0.0,0.0,0.0,0.005354185296421948,450,225,197.5341498851776,0.8779295550452338,0.4389647775226169,0.07811297937988507
6
+ 4,0.005074872308783256,2.5619912489473453,9.861685678809986e-05,0.0,0.0,0.0,0.0,0.0,0.005074872308783256,900,450,481.7331359386444,1.0705180798636542,0.5352590399318271,0.18441032754652825,0.006421878464142566,6.356629352012609,0.0003719029411907014,0.0,0.0,0.0,0.0,0.0,0.006421878464142566,450,225,195.86327409744263,0.8705034404330784,0.4352517202165392,0.0808127932534593
treatment/realtabformer/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
treatment/realtabformer/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0acff5eb9c30e75b1ef15ed8122fe200e94020ff0583fdfc871bca2015d3ce43
3
+ size 78481207
treatment/realtabformer/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "torch", "grad_clip": 0.8, "gradient_penalty_mode": "NONE", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.1, "loss_balancer_beta": 0.73, "loss_balancer_r": 0.94, "dataset_size": 2048, "batch_size": 2, "epochs": 100, "lr_mul": 0.04, "n_warmup_steps": 220, "Optim": "diffgrad", "fixed_role_model": "realtabformer", "mse_mag": false, "mse_mag_target": 0.2, "mse_mag_multiply": false, "d_model": 512, "attn_activation": "leakyhardsigmoid", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "leakyhardtanh", "tf_activation_final": "leakyhardtanh", "tf_num_inds": 64, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "selu", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 128, "head_n_layers": 8, "head_n_head": 64, "head_activation": "leakyhardsigmoid", "head_activation_final": "leakyhardsigmoid", "models": ["realtabformer"], "max_seconds": 3600}