{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "#### pickle file checking for AUPRC random lead" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 50\n", "\n", "----------------------\n", "epoch: \n", "model: \n", "train_auprc: \n", "valid_auprc: \n", "valid_targets: \n", "valid_outputs: \n", "-----------------------\n", "-----------------------\n", "[0.20795198881124255, 0.2924131615408049, 0.31194815399388126, 0.357671229080611, 0.3907590012977773, 0.39197022751675975, 0.39688932315376796, 0.41098642756821824, 0.4280303875603716, 0.4251116328825386, 0.41492397254078656, 0.44119503399957305, 0.42866565608661766, 0.42155615910506705, 0.4352771610735857, 0.4355309812927433, 0.4575302940022513, 0.4621060999031488, 0.4615244295921646, 0.4347042141353311, 0.4843673460502776, 0.49216570578173724, 0.49284316077316226, 0.4976730562122618, 0.4981241668777771, 0.4985906269863735, 0.5023674118168958, 0.5039947051779108, 0.5025596400291938, 0.501332454384853, 0.5017141509761979, 0.5033696471830942, 0.5035807094153067, 0.5044712423289812, 0.49912591150498187, 0.5036493639939076, 0.5073756144905568, 0.5066738446153692, 0.5041024684427422, 0.5061074251973712, 0.5079663458037375, 0.5080434717076571, 0.5071731389137064, 0.5066158069067092, 0.5059333249321385, 0.5078252460128987, 0.5081895157894929, 0.5079278975582764, 0.5073543066159428, 0.5078677916025073]\n", "0.5081895157894929 46\n" ] } ], "source": [ "import pickle\n", "import torch\n", "\n", "address = \"./model_output/model_group5/PROGRESS.pickle\"\n", "\n", "with open(address, 'rb') as file:\n", " data = pickle.load(file)\n", "\n", "print(type(data), len(data))\n", "# print(data[0])\n", "print(type(data[1]))\n", "print(\"----------------------\")\n", "for key, _ in data[1].items():\n", " print(f\"{key}: \")\n", "\n", "print(\"-----------------------\")\n", "AUPRC_list = []\n", "for i in range(len(data)):\n", " AUPRC_list.append(data[i][\"valid_auprc\"])\n", "\n", "print(\"-----------------------\") \n", "print(AUPRC_list)\n", "(\"-----------------------\")\n", "largest_number = max(AUPRC_list)\n", "index = AUPRC_list.index(largest_number)\n", "print(largest_number, index)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "group#1\n", "\n", "[0.24092945522182005, 0.3139675367502194, 0.3062163369752217, 0.32297163568130305, 0.3672050308180419, 0.3801609216698969, 0.3915211363523951, 0.4034875773118736, 0.41721359538446234, 0.41755420607909477, 0.4101699028342543, 0.42683222688245664, 0.4338339272938271, 0.4432706404963518, 0.4451886249025738, 0.4436839678451211, 0.46470292201596697, 0.4619959382638624, 0.4389299870874322, 0.4537386141609928, 0.4880276013143086, 0.48964141469390005, 0.49214694908533474, 0.49336784163926267, 0.4978899412041259, 0.4960868620495151, 0.4949812567178974, 0.49875221067947606, 0.4959535547710648, 0.49723019893878023, 0.49849758106937503, 0.5005045769636993, 0.4968324354226746, 0.4985954057932132, 0.4985684464062525, 0.4948398218890804, 0.5003443438290083, 0.49804674478254773, 0.5015115944170082, 0.5043099513157541, 0.5022930844045073, 0.502102123403741, 0.5025587387783707, 0.5026322695878688, 0.5028108420912678, 0.501853319716798, 0.5044486284061104, 0.5043333679462079, 0.503047975296802, 0.5021477867974229]\n", "\n", "0.5044486284061104, index: 46\n", "\n", "group #2\n", "\n", "[0.24668762844932296, 0.31123092790061574, 0.35728718371921886, 0.37858993755415526, 0.38325613445804607, 0.38183540019756823, 0.40688905625255206, 0.4050292403852287, 0.4103841963804383, 0.4288343036036706, 0.4293594683280219, 0.44373329349811874, 0.44694196761428867, 0.44516332505161516, 0.4570591656299683, 0.44925142278910385, 0.45783436251651694, 0.4512008966459152, 0.4628860929136446, 0.46190128250293605, 0.4891415053038087, 0.4933325648723347, 0.49795793473520533, 0.4989478549566136, 0.507199717375493, 0.5031777644234027, 0.5048360591023886, 0.5026344145441939, 0.5070084702134143, 0.50851780828997, 0.5013767142024679, 0.5077028354409389, 0.5073222030725629, 0.5103865617070087, 0.5070321372047399, 0.5069057373554984, 0.5054984338086199, 0.5052088211513525, 0.5085875776438461, 0.5015018579996042, 0.507983738986951, 0.506001318616706, 0.5078548999343991, 0.5084694227173217, 0.5081644743764611, 0.5070537320211395, 0.5072728550164887, 0.5084469401746737, 0.5081580384861908, 0.5092361778552277]\n", "\n", "0.5103865617070087, index: 33\n", "\n", "group #3\n", "[0.20546938178065813, 0.31056285598824596, 0.3521164077944065, 0.36566363279169545, 0.3649970330628938, 0.3816742095036071, 0.408841252427171, 0.4192963362391232, 0.419725128897165, 0.4009845215509139, 0.4221866024862177, 0.4383579336817017, 0.41634488480301257, 0.4394011015343916, 0.42674958918677536, 0.4484833626141604, 0.43733868299572076, 0.42813204282903494, 0.44362467579095183, 0.4525213211300688, 0.47993303563958817, 0.48221178536835363, 0.4832912567732829, 0.485964752652683, 0.4894140885779246, 0.49081305081555826, 0.4835906970652839, 0.4881328848995447, 0.49108874994886303, 0.49205732309554323, 0.4918174541861535, 0.49104602501641953, 0.49033495002806987, 0.49255438103140303, 0.4982302563540638, 0.4919847023325378, 0.49138268849817107, 0.49216471663752714, 0.49367968532436873, 0.49558690171904884, 0.4952242601993453, 0.49709259551176815, 0.4969043181087201, 0.49722348299821856, 0.49599951407363857, 0.49572421827303714, 0.49551046935516674, 0.4969339282495756, 0.49522481850002315, 0.4956301125397299]\n", "\n", "0.4982302563540638, index: 34\n", "\n", "group #4\n", "\n", "[0.16705442847351432, 0.2811237847091236, 0.3227277423619332, 0.3459164670019608, 0.3433205542817934, 0.38953865811323535, 0.40093825754134493, 0.4042482476980622, 0.4179255247142833, 0.42026119275049384, 0.415263850960453, 0.4326573070148512, 0.4284856196846552, 0.455811861263988, 0.44742754829379755, 0.4428520431746461, 0.4288860834282809, 0.43801462440444205, 0.441347802107846, 0.4560878428908129, 0.47952984096244766, 0.4859939647185739, 0.48291741623601653, 0.4863560035613435, 0.4879069301596515, 0.49283878286572264, 0.4925634321692941, 0.49296767067476266, 0.4925321693215088, 0.4930295366233496, 0.4927986378984127, 0.49612537918838245, 0.4992350455119594, 0.4951830005033058, 0.49014993853897326, 0.4924448141210762, 0.4945801109607605, 0.4971188401719394, 0.49753234729288465, 0.49315691206981155, 0.4963229926370793, 0.49660539254449804, 0.49752930191373473, 0.4983978705842285, 0.498218560630721, 0.49778016282127696, 0.4980937334749714, 0.4982398417549309, 0.49825272820647715, 0.4978916971990578]\n", "\n", "0.4992350455119594, index: 32\n", "\n", "group #5\n", "[0.20795198881124255, 0.2924131615408049, 0.31194815399388126, 0.357671229080611, 0.3907590012977773, 0.39197022751675975, 0.39688932315376796, 0.41098642756821824, 0.4280303875603716, 0.4251116328825386, 0.41492397254078656, 0.44119503399957305, 0.42866565608661766, 0.42155615910506705, 0.4352771610735857, 0.4355309812927433, 0.4575302940022513, 0.4621060999031488, 0.4615244295921646, 0.4347042141353311, 0.4843673460502776, 0.49216570578173724, 0.49284316077316226, 0.4976730562122618, 0.4981241668777771, 0.4985906269863735, 0.5023674118168958, 0.5039947051779108, 0.5025596400291938, 0.501332454384853, 0.5017141509761979, 0.5033696471830942, 0.5035807094153067, 0.5044712423289812, 0.49912591150498187, 0.5036493639939076, 0.5073756144905568, 0.5066738446153692, 0.5041024684427422, 0.5061074251973712, 0.5079663458037375, 0.5080434717076571, 0.5071731389137064, 0.5066158069067092, 0.5059333249321385, 0.5078252460128987, 0.5081895157894929, 0.5079278975582764, 0.5073543066159428, 0.5078677916025073]\n", "\n", "0.5081895157894929, index: 46\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "A0004.hea\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 17651/17651 [00:02<00:00, 6737.50it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "164889003 1051.0\n", "164890007 1675.0\n", "6374002 103.0\n", "426627000 59.0\n", "733534002 299.0\n", "713427006 963.0\n", "270492004 706.0\n", "713426002 372.0\n", "39732003 1526.0\n", "445118002 437.0\n", "164947007 74.0\n", "251146004 320.0\n", "111975006 382.0\n", "698252002 354.0\n", "426783006 5794.0\n", "284470004 653.0\n", "10370003 296.0\n", "365413008 120.0\n", "427172004 387.0\n", "164917005 415.0\n", "47665007 256.0\n", "427393009 758.0\n", "426177001 3784.0\n", "427084000 1932.0\n", "164934002 2343.0\n", "59931005 798.0\n", "dtype: float64\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/138 [00:00 0.5 else 0) for prob in probs] for probs in np.array(outputs)])\n", " f1 = f1_score(targets, outputs_f1, average='weighted')\n", " print(\"This is the auroc of testing:\", auroc)\n", " print(\"This is the f1 of testing:\", f1)\n", "\n", " return auprc, targets, outputs, auroc, f1\n", "\n", "file_address = \"../collection_of_all_datasets/\"\n", "data_directory = \"./csv-file/training_validation_testing/group5\"\n", "\n", "############ testing area #########################\n", "the_testing_address = data_directory + \"/testing_group\"+data_directory[-1]+\".csv\"\n", "df = pd.read_csv(the_testing_address)\n", "print(df['Name'][0])\n", "\n", "testing_header_files=[]\n", "\n", "for i in range(len(df['Name'])):\n", " each_header_file = file_address + df['Name'][i]\n", " testing_header_files.append(each_header_file)\n", " \n", "test_dataset = dataset(testing_header_files)\n", "print(test_dataset.summary('pandas'))\n", " \n", "\n", "test_dataset.num_leads = 12\n", "test_dataset.sample = True\n", "###################################################\n", "valid = DataLoader(dataset=test_dataset,\n", " batch_size=128,\n", " shuffle=False,\n", " num_workers=8,\n", " collate_fn=collate,\n", " pin_memory=True,\n", " drop_last=False)\n", "\n", "model = NN(nOUT=26).to(DEVICE)\n", "model.load_state_dict(new_state_dict)\n", "\n", "auprc, targets, outputs, auroc, f1 = valid_part(model, valid)\n", "print(\"============================================\")\n", "print(\"This is the auprc:\", auprc)\n", "print(\"This is the auroc:\", auroc)\n", "print(\"This is the f1: \", f1)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### AUPRC Checking for the 12-lead" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 70\n", "\n", "----------------------\n", "epoch: \n", "model: \n", "train_auprc: \n", "valid_auprc: \n", "valid_auroc: \n", "valid_targets: \n", "valid_outputs: \n", "-----------------------\n", "-----------------------\n", "[0.28910033219670556, 0.36661018186269884, 0.4050706458985195, 0.42316484282838984, 0.43350053497280694, 0.44553479020425846, 0.44136231989079977, 0.45122218897758415, 0.4644021743158985, 0.47293826752021145, 0.46926146531955976, 0.467247872269531, 0.4700967180951104, 0.4788817392392415, 0.467775651905869, 0.48255573911671956, 0.477212325777135, 0.4800058433950016, 0.4915499891655736, 0.48971557432198565, 0.5141404510781854, 0.5219609044964516, 0.5230429753848443, 0.5232866664000654, 0.5244318076853925, 0.5234401986426129, 0.5263752664655873, 0.5281336596938361, 0.528075972755201, 0.5302334312309755, 0.5307301559051676, 0.5304267506928821, 0.5318235204566354, 0.5306493348295291, 0.5350866957700305, 0.5345539436367677, 0.5360092534376171, 0.5331738336457497, 0.5317614713309697, 0.535248967523321, 0.5404416392354089, 0.540938736266318, 0.5408833722750755, 0.5412477678363515, 0.5417516390629749, 0.5412621901602691, 0.5420153503382734, 0.5413120998587926, 0.5421420766388778, 0.5429929117370075, 0.5426979188589047, 0.5427578748612708, 0.542525766866879, 0.5422344779496736, 0.542579324978226, 0.5435324050712601, 0.5427155565657983, 0.5432355489935663, 0.5430772835279998, 0.5435757517477243, 0.543233155585495, 0.5437444745165606, 0.5433336568324321, 0.543740825076081, 0.543418673099526, 0.5436833187473836, 0.5439772667518001, 0.5442283931779751, 0.5438852947464646, 0.5438798138010161]\n", "0.5442283931779751 67\n", "This is the largest auroc: 0.8730646472382316\n" ] } ], "source": [ "import pickle\n", "import torch\n", "\n", "address = \"./model_output_12_lead_without_attention/model_group5/PROGRESS.pickle\"\n", "\n", "with open(address, 'rb') as file:\n", " data = pickle.load(file)\n", "\n", "print(type(data), len(data))\n", "\n", "print(type(data[1]))\n", "print(\"----------------------\")\n", "for key, _ in data[1].items():\n", " print(f\"{key}: \")\n", "\n", "print(\"-----------------------\")\n", "AUPRC_list = []\n", "AUROC_list = []\n", "\n", "for i in range(len(data)):\n", " AUPRC_list.append(data[i][\"valid_auprc\"])\n", "\n", "for i in range(len(data)):\n", " AUROC_list.append(data[i][\"valid_auroc\"])\n", "\n", "print(\"-----------------------\") \n", "print(AUPRC_list)\n", "(\"-----------------------\")\n", "largest_number = max(AUPRC_list)\n", "index = AUPRC_list.index(largest_number)\n", "print(largest_number, index)\n", "print(\"This is the largest auroc:\", AUROC_list[index])" ] } ], "metadata": { "kernelspec": { "display_name": "testing", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }