diff --git a/README.md b/README.md index ca8790b77e59afc477d012a405a099bcba2e398a..98c0da340b2215eb1f69193c2e4b84b173623819 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ --- -title: AioMedica -emoji: 👁 -colorFrom: yellow -colorTo: indigo +title: MedFormer +emoji: 🏃 +colorFrom: purple +colorTo: yellow sdk: streamlit sdk_version: 1.10.0 app_file: app.py diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..f344eb743c8f74603498ead630ab944cc30058e4 --- /dev/null +++ b/app.py @@ -0,0 +1,145 @@ +import streamlit as st +import openslide +import os +from streamlit_option_menu import option_menu +import torch + + +if torch.cuda.is_available(): + os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html") + os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html") + os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cu113.html") +else: + os.system("pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html") + os.system("pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html") + os.system("pip install torch-geometric -f https://pytorch-geometric.com/whl/torch-1.7.1+cpu.html") + +from predict import Predictor + + + +# environment variables for the inference api +os.environ['DATA_DIR'] = 'queries' +os.environ['PATCHES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'patches') +os.environ['SLIDES_DIR'] = os.path.join(os.environ['DATA_DIR'], 'slides') +os.environ['GRAPHCAM_DIR'] = os.path.join(os.environ['DATA_DIR'], 'graphcam_plots') +os.makedirs(os.environ['GRAPHCAM_DIR'], exist_ok=True) + + +# manually put the metadata in the metadata folder +os.environ['CLASS_METADATA'] ='metadata/label_map.pkl' + +# manually put the desired weights in the weights folder +os.environ['WEIGHTS_PATH'] = WEIGHTS_PATH='weights' +os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'feature_extractor', 'model.pth') +os.environ['GT_WEIGHT_PATH'] = os.path.join(os.environ['WEIGHTS_PATH'], 'graph_transformer', 'GraphCAM.pth') + + +st.set_page_config(page_title="",layout='wide') +predictor = Predictor() + + + + + +ABOUT_TEXT = "🤗 LastMinute Medical - Web diagnosis tool." +CONTACT_TEXT = """ +_Built by Christian Cancedda and LabLab lads with love_ ❤️ +[![Follow](https://img.shields.io/github/followers/Chris1nexus?style=social)](https://github.com/Chris1nexus) +[![Follow](https://img.shields.io/twitter/follow/chris_cancedda?style=social)](https://twitter.com/intent/follow?screen_name=chris_cancedda) +""" +VISUALIZE_TEXT = "Visualize WSI slide by uploading it on the provided window" +DETECT_TEXT = "Generate a preliminary diagnosis about the presence of pulmonary disease" + + + +with st.sidebar: + choice = option_menu("LastMinute - Diagnosis", + ["About", "Visualize WSI slide", "Cancer Detection", "Contact"], + icons=['house', 'upload', 'activity', 'person lines fill'], + menu_icon="app-indicator", default_index=0, + styles={ + # "container": {"padding": "5!important", "background-color": "#fafafa", }, + "container": {"border-radius": ".0rem"}, + # "icon": {"color": "orange", "font-size": "25px"}, + # "nav-link": {"font-size": "16px", "text-align": "left", "margin": "0px", + # "--hover-color": "#eee"}, + # "nav-link-selected": {"background-color": "#02ab21"}, + } + ) +st.sidebar.markdown( + """ + +

+ + + +

+ """, + unsafe_allow_html=True, +) + + + +if choice == "About": + st.title(choice) + + + +if choice == "Visualize WSI slide": + st.title(choice) + st.markdown(VISUALIZE_TEXT) + + uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)") + if uploaded_file is not None: + ori = openslide.OpenSlide(uploaded_file.name) + width, height = ori.dimensions + + REDUCTION_FACTOR = 20 + w, h = int(width/512), int(height/512) + w_r, h_r = int(width/20), int(height/20) + resized_img = ori.get_thumbnail((w_r,h_r)) + resized_img = resized_img.resize((w_r,h_r)) + ratio_w, ratio_h = width/resized_img.width, height/resized_img.height + #print('ratios ', ratio_w, ratio_h) + w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR) + st.image(resized_img, use_column_width='never') + +if choice == "Cancer Detection": + state = dict() + + st.title(choice) + st.markdown(DETECT_TEXT) + uploaded_file = st.file_uploader("Choose a WSI slide file to diagnose (.svs)") + if uploaded_file is not None: + # To read file as bytes: + #print(uploaded_file) + with open(os.path.join(uploaded_file.name),"wb") as f: + f.write(uploaded_file.getbuffer()) + with st.spinner(text="Computation is running"): + predicted_class, viz_dict = predictor.predict(uploaded_file.name) + st.info('Computation completed.') + st.header(f'Predicted to be: {predicted_class}') + st.text('Heatmap of the areas that show markers correlated with the disease.\nIncreasing red tones represent higher likelihood that the area is affected') + state['cur'] = predicted_class + mapper = {'ORI': predicted_class, predicted_class:'ORI'} + readable_mapper = {'ORI': 'Original', predicted_class :'Disease heatmap' } + #def fn(): + # st.image(viz_dict[mapper[state['cur']]], use_column_width='never', channels='BGR') + # state['cur'] = mapper[state['cur']] + # return + + #st.button(f'See {readable_mapper[mapper[state["cur"]] ]}', on_click=fn ) + #st.image(viz_dict[state['cur']], use_column_width='never', channels='BGR') + st.image([viz_dict[state['cur']],viz_dict['ORI']], caption=['Original', f'{predicted_class} heatmap'] ,channels='BGR' + # use_column_width='never', + ) + + +if choice == "Contact": + st.title(choice) + st.markdown(CONTACT_TEXT) \ No newline at end of file diff --git a/feature_extractor/.ipynb_checkpoints/weights_check-checkpoint.ipynb b/feature_extractor/.ipynb_checkpoints/weights_check-checkpoint.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..c91a0bd0b15954274975eaea110bd1110e834ff9 --- /dev/null +++ b/feature_extractor/.ipynb_checkpoints/weights_check-checkpoint.ipynb @@ -0,0 +1,3503 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "id": "11c4fe3c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "OrderedDict([('module.features.0.weight',\n", + " tensor([[[[ 2.3375e-02, 5.5993e-03, 4.2364e-02, ..., 1.2101e-02,\n", + " -2.6842e-02, -3.0364e-02],\n", + " [ 1.2243e-02, -9.1156e-03, -1.9976e-02, ..., -4.5601e-02,\n", + " -3.3681e-02, -4.0585e-03],\n", + " [-7.8155e-03, 1.4921e-02, 2.3364e-02, ..., 8.7603e-03,\n", + " -3.1223e-02, 7.4876e-03],\n", + " ...,\n", + " [-2.9299e-02, -2.6057e-02, -4.0052e-02, ..., 4.2710e-03,\n", + " 1.9729e-03, -1.2235e-02],\n", + " [ 1.6445e-02, 3.8652e-03, 4.2917e-03, ..., -8.5154e-03,\n", + " -1.3266e-02, -7.6143e-03],\n", + " [ 2.3999e-02, 8.4055e-03, -2.8751e-02, ..., -1.1365e-02,\n", + " 3.8881e-03, -1.7586e-02]],\n", + " \n", + " [[-1.4130e-02, -4.9244e-02, -1.6324e-02, ..., 4.1007e-02,\n", + " -4.0374e-02, -2.6552e-02],\n", + " [-3.2905e-02, -2.9145e-02, -5.5822e-03, ..., -1.6007e-02,\n", + " -1.5566e-03, -1.6690e-02],\n", + " [ 2.4724e-02, -2.8561e-02, 1.9321e-02, ..., -3.5075e-02,\n", + " -1.6752e-02, 2.1253e-02],\n", + " ...,\n", + " [-2.0854e-02, -1.6552e-02, -3.2742e-02, ..., 1.2465e-02,\n", + " 1.9453e-02, -4.9739e-02],\n", + " [-2.5184e-02, 3.3581e-02, 1.6366e-03, ..., -1.6559e-02,\n", + " -4.3148e-02, -8.8248e-03],\n", + " [-1.7976e-02, -1.0308e-02, 1.9864e-02, ..., -2.1598e-02,\n", + " 5.0608e-04, -2.4172e-02]],\n", + " \n", + " [[ 4.9666e-02, -1.2670e-02, 1.9931e-02, ..., 8.9254e-03,\n", + " 4.6066e-02, 4.8928e-02],\n", + " [ 1.5310e-02, -1.3443e-02, 2.6382e-02, ..., -3.9132e-03,\n", + " -1.9607e-03, -3.5969e-02],\n", + " [-1.9942e-02, -5.7225e-02, 1.8700e-02, ..., 3.6640e-02,\n", + " 3.5779e-03, 1.2500e-02],\n", + " ...,\n", + " [ 1.1875e-02, -3.3648e-03, -3.0441e-02, ..., -5.6659e-02,\n", + " 1.8092e-02, 4.2179e-02],\n", + " [-1.9221e-02, 8.7840e-03, 2.1695e-02, ..., 1.2839e-03,\n", + " -2.7966e-02, 5.1216e-03],\n", + " [-1.9038e-02, 9.0134e-04, 2.1077e-03, ..., 2.9699e-02,\n", + " 1.8513e-02, 3.3447e-02]]],\n", + " \n", + " \n", + " [[[ 3.7965e-02, -1.2022e-02, 2.2249e-02, ..., 1.6461e-02,\n", + " 1.6206e-02, -1.6585e-02],\n", + " [-2.7620e-02, -4.7865e-02, 1.3980e-02, ..., 3.9625e-02,\n", + " 1.2485e-02, -5.5151e-02],\n", + " [-4.1348e-04, -2.9432e-02, -1.4788e-02, ..., 2.3406e-02,\n", + " 1.6614e-02, 1.7552e-02],\n", + " ...,\n", + " [ 2.3246e-02, 2.1007e-02, -1.2156e-02, ..., -2.6140e-02,\n", + " 3.8020e-02, -3.0928e-02],\n", + " [ 2.1980e-02, -9.2860e-03, -5.9852e-03, ..., 2.2137e-02,\n", + " 4.7298e-03, 9.8544e-04],\n", + " [-2.5694e-02, -1.1514e-02, 5.8983e-02, ..., 3.6341e-02,\n", + " -2.8853e-02, -2.8959e-02]],\n", + " \n", + " [[-1.0707e-02, -2.0920e-02, -3.5067e-03, ..., 1.5692e-02,\n", + " -4.2022e-02, -5.8786e-02],\n", + " [-1.2049e-02, -6.1877e-03, -2.7395e-02, ..., 8.5530e-03,\n", + " 5.5608e-02, -3.1539e-02],\n", + " [-2.4982e-02, -1.3160e-02, 3.1325e-02, ..., -2.7327e-02,\n", + " -5.6841e-02, -2.9666e-03],\n", + " ...,\n", + " [ 2.1402e-03, 4.6426e-02, -2.4192e-02, ..., 5.4034e-03,\n", + " -5.9481e-02, -1.2253e-02],\n", + " [-1.2507e-02, -9.2855e-05, 3.8571e-03, ..., 4.7442e-02,\n", + " -1.3722e-02, 1.5930e-03],\n", + " [ 1.1638e-02, -4.4012e-02, 5.2303e-02, ..., 2.1680e-02,\n", + " 3.6332e-02, -2.9015e-02]],\n", + " \n", + " [[-6.1262e-03, 2.1868e-03, 2.9714e-03, ..., -3.9569e-03,\n", + " 8.6403e-03, -4.3809e-03],\n", + " [-5.1549e-03, -1.5469e-02, -1.4357e-03, ..., -2.9453e-02,\n", + " -7.2058e-04, 3.4153e-02],\n", + " [-3.5302e-04, -1.2305e-02, -4.1532e-02, ..., 7.4666e-03,\n", + " 3.7797e-02, 1.0348e-02],\n", + " ...,\n", + " [-3.3086e-02, 5.9402e-03, 5.4123e-02, ..., 1.9769e-02,\n", + " -6.2268e-02, 1.9798e-03],\n", + " [-5.6262e-02, -1.4047e-02, -1.7646e-02, ..., 1.5906e-03,\n", + " -7.5155e-03, -1.0734e-02],\n", + " [ 3.9025e-02, -4.2364e-02, 2.8937e-05, ..., -2.7328e-02,\n", + " 9.2807e-03, 9.3495e-04]]],\n", + " \n", + " \n", + " [[[ 2.3433e-03, -1.9035e-02, 1.0023e-02, ..., 1.1675e-02,\n", + " 4.3943e-05, 5.4307e-04],\n", + " [ 2.4110e-02, 2.1324e-02, 4.6116e-03, ..., -8.9188e-03,\n", + " 2.5617e-02, 7.1546e-03],\n", + " [-2.5087e-02, 2.3352e-03, 6.2174e-03, ..., 2.4089e-02,\n", + " -4.7804e-03, 5.0964e-03],\n", + " ...,\n", + " [-1.1212e-03, 1.2286e-02, 3.0641e-02, ..., 2.2648e-02,\n", + " 4.1326e-03, -3.0177e-02],\n", + " [ 3.4274e-03, 4.3688e-03, -3.4311e-02, ..., 8.7346e-03,\n", + " 2.9993e-02, 9.9230e-04],\n", + " [ 1.0400e-02, -4.4071e-03, 6.5103e-03, ..., -1.4547e-02,\n", + " -9.2856e-03, -2.8808e-03]],\n", + " \n", + " [[ 8.4968e-03, 3.1743e-02, 1.8266e-02, ..., -3.0564e-02,\n", + " 4.9976e-02, -1.1739e-02],\n", + " [-3.2356e-03, -1.9133e-02, 1.0271e-02, ..., 3.0369e-03,\n", + " 5.1930e-03, -2.6841e-03],\n", + " [-1.4563e-02, 7.1441e-03, -2.2971e-02, ..., -1.3364e-02,\n", + " -1.6875e-02, 5.2100e-02],\n", + " ...,\n", + " [-8.4915e-03, 8.6577e-03, 7.7203e-03, ..., 2.9735e-02,\n", + " 3.5839e-02, 6.2928e-03],\n", + " [ 1.7178e-02, 3.5923e-02, -8.4103e-02, ..., 2.5245e-02,\n", + " 2.9331e-02, -9.0751e-03],\n", + " [-4.7001e-02, 1.9169e-02, 6.7551e-03, ..., 4.2300e-02,\n", + " -3.5133e-02, -2.9424e-02]],\n", + " \n", + " [[ 4.2076e-02, 3.0278e-02, 9.7546e-03, ..., 2.6859e-02,\n", + " -3.9730e-03, 1.9414e-02],\n", + " [-1.0068e-03, 1.8244e-02, 2.1646e-02, ..., -7.3066e-03,\n", + " -2.0182e-02, 4.2991e-02],\n", + " [ 2.4365e-02, -1.9178e-02, 3.2335e-02, ..., -3.5166e-02,\n", + " 1.4258e-02, -4.2840e-02],\n", + " ...,\n", + " [-1.3621e-02, 2.2653e-02, 3.2518e-02, ..., -1.1237e-02,\n", + " 3.9021e-02, 4.3618e-03],\n", + " [ 4.7386e-03, -3.7146e-02, -1.8900e-02, ..., -5.6209e-03,\n", + " 3.0549e-02, 3.5451e-02],\n", + " [-2.3577e-02, 2.1167e-02, -4.5990e-02, ..., -6.4087e-03,\n", + " -2.0090e-02, 9.2213e-03]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[-1.7300e-02, -1.0968e-02, 2.3756e-02, ..., -2.8067e-02,\n", + " -1.2273e-02, -6.7963e-03],\n", + " [-1.9014e-02, 4.2994e-02, -2.9301e-02, ..., -1.5515e-02,\n", + " 3.0392e-02, 2.3771e-02],\n", + " [-3.1195e-02, 3.3641e-02, -4.9220e-02, ..., -1.7844e-02,\n", + " -1.1322e-02, -4.7037e-03],\n", + " ...,\n", + " [ 1.3981e-02, 5.4877e-02, -6.8753e-03, ..., 1.3635e-02,\n", + " -9.1420e-03, 3.1688e-02],\n", + " [ 1.3518e-02, 1.8627e-02, -2.9229e-02, ..., 8.1206e-03,\n", + " -6.5161e-03, -1.2963e-02],\n", + " [-4.1449e-02, 1.2236e-02, 1.1663e-02, ..., 3.8267e-02,\n", + " -1.9472e-02, 1.8405e-02]],\n", + " \n", + " [[-9.8533e-03, 4.9424e-02, 4.8071e-02, ..., -9.0419e-03,\n", + " 2.4288e-02, -4.3603e-03],\n", + " [ 5.8574e-04, 1.4928e-02, -2.0026e-02, ..., 1.6564e-02,\n", + " 7.6141e-03, 2.9512e-02],\n", + " [ 2.0338e-02, 1.0564e-02, -1.0914e-02, ..., -6.4865e-03,\n", + " 1.8478e-02, -2.3697e-03],\n", + " ...,\n", + " [ 5.6546e-03, 2.4127e-02, -5.7034e-02, ..., 3.6024e-02,\n", + " 2.8125e-02, -1.4412e-02],\n", + " [-1.4562e-03, 1.1494e-02, 2.0532e-02, ..., -3.4795e-02,\n", + " 1.6016e-02, 4.6263e-02],\n", + " [ 1.8855e-02, 2.8185e-02, 4.2835e-02, ..., -2.0183e-02,\n", + " 3.4963e-02, -5.9240e-03]],\n", + " \n", + " [[ 3.2398e-02, 2.7336e-02, -1.1512e-02, ..., 4.0024e-02,\n", + " 4.8417e-02, -8.8634e-03],\n", + " [-6.7418e-03, 6.6913e-03, 1.0604e-02, ..., 1.0133e-02,\n", + " -5.0271e-02, 2.7529e-02],\n", + " [ 4.2106e-02, 3.0234e-02, -3.6480e-03, ..., 1.3692e-02,\n", + " -1.0858e-02, -1.4118e-02],\n", + " ...,\n", + " [-2.4369e-02, -3.8789e-02, 4.2428e-03, ..., -7.5641e-03,\n", + " 4.2958e-02, 1.8423e-02],\n", + " [ 2.1679e-02, 2.9357e-02, 7.8422e-03, ..., 4.7591e-03,\n", + " 2.7958e-02, -3.0234e-03],\n", + " [ 5.8209e-03, 3.4338e-03, 2.4520e-02, ..., 2.5085e-03,\n", + " 5.5165e-02, 2.3223e-02]]],\n", + " \n", + " \n", + " [[[-6.9245e-02, -1.1918e-02, 3.0330e-02, ..., -7.0775e-03,\n", + " -1.3785e-02, -9.5928e-03],\n", + " [-1.1632e-02, -2.3685e-02, -3.9180e-02, ..., -5.9309e-02,\n", + " 2.6369e-02, 6.5659e-04],\n", + " [-1.1453e-03, 1.4085e-03, 9.6764e-03, ..., -3.5827e-02,\n", + " 5.9550e-03, 1.0719e-02],\n", + " ...,\n", + " [ 3.3185e-03, 2.2316e-02, -2.3351e-02, ..., -1.0927e-02,\n", + " -3.0209e-02, 1.1315e-02],\n", + " [-3.6965e-02, -2.6860e-02, -1.8028e-02, ..., -1.6357e-04,\n", + " 4.1140e-02, -1.8615e-03],\n", + " [ 1.7740e-02, 9.2312e-03, 4.3650e-03, ..., -3.0605e-02,\n", + " -1.1486e-02, 1.2793e-02]],\n", + " \n", + " [[-5.6697e-03, -1.4674e-02, -1.7249e-02, ..., -6.2317e-03,\n", + " 2.7349e-02, -3.9416e-02],\n", + " [-8.8279e-03, -2.5957e-03, 3.5875e-02, ..., 5.8308e-03,\n", + " -5.5580e-03, -2.8438e-02],\n", + " [-4.1452e-02, 2.2671e-02, -3.2239e-02, ..., 2.7244e-02,\n", + " -2.0010e-03, 5.7491e-02],\n", + " ...,\n", + " [-2.7492e-02, 5.1052e-02, -4.3853e-02, ..., -3.1139e-02,\n", + " 1.8314e-02, -4.4898e-03],\n", + " [-1.2398e-02, 3.1807e-02, 1.0428e-02, ..., -1.7304e-03,\n", + " 1.7393e-04, -2.6142e-02],\n", + " [-3.6152e-02, 2.2367e-02, -2.1544e-02, ..., 3.4823e-04,\n", + " -2.1448e-03, -4.5074e-03]],\n", + " \n", + " [[-1.9647e-02, -4.5944e-02, 1.5610e-02, ..., 1.8324e-02,\n", + " 1.8523e-02, -2.6029e-03],\n", + " [-5.0610e-02, 3.0383e-02, -1.2389e-02, ..., -1.4688e-02,\n", + " -3.4507e-03, -1.0137e-02],\n", + " [-2.5496e-02, -4.6650e-03, -3.2878e-02, ..., -4.2710e-02,\n", + " -4.7481e-03, 1.9729e-02],\n", + " ...,\n", + " [-2.5733e-03, 5.7768e-02, -1.2957e-04, ..., 8.6745e-03,\n", + " -2.5417e-02, -8.5791e-03],\n", + " [ 2.9283e-02, -8.3721e-03, -2.3964e-03, ..., -2.1602e-02,\n", + " -3.0959e-02, 4.2844e-02],\n", + " [ 2.9002e-02, -2.3411e-02, -3.9169e-02, ..., -3.9955e-02,\n", + " -2.9184e-02, 1.1949e-02]]],\n", + " \n", + " \n", + " [[[-6.4566e-03, 1.3145e-02, 1.3045e-02, ..., -3.0958e-02,\n", + " -2.2957e-02, -2.1346e-02],\n", + " [ 3.1348e-02, -2.1111e-02, 1.8779e-03, ..., -2.2084e-03,\n", + " 1.0736e-03, 6.5589e-03],\n", + " [-1.6862e-02, -1.4509e-02, -2.2391e-02, ..., 2.8254e-02,\n", + " 1.1151e-02, 3.5738e-02],\n", + " ...,\n", + " [ 3.2327e-02, -4.5994e-02, -7.5606e-03, ..., -1.7074e-02,\n", + " -6.4185e-03, -1.5941e-02],\n", + " [ 1.6623e-02, -1.3665e-02, 1.9817e-02, ..., -2.1725e-02,\n", + " -3.5567e-02, -2.4748e-02],\n", + " [-1.2258e-02, 3.5740e-02, -2.3850e-02, ..., -1.9130e-02,\n", + " -1.8982e-02, 9.1186e-03]],\n", + " \n", + " [[-4.1899e-02, 9.4187e-04, 4.6815e-02, ..., -1.1866e-02,\n", + " -9.6984e-03, -5.8996e-03],\n", + " [-2.1281e-02, 3.6341e-02, -1.4044e-02, ..., -6.0348e-04,\n", + " -1.1665e-02, -8.4142e-03],\n", + " [ 3.1403e-03, 2.9866e-02, 2.4069e-02, ..., -5.0233e-03,\n", + " -3.3209e-02, 1.8269e-02],\n", + " ...,\n", + " [-5.1381e-02, 3.3562e-02, 1.6398e-02, ..., -3.9078e-02,\n", + " -7.3412e-03, -2.5585e-03],\n", + " [-1.9666e-02, 1.8058e-02, -1.3134e-02, ..., 1.5786e-02,\n", + " 2.4163e-02, -1.1480e-02],\n", + " [-2.3845e-02, -1.3153e-02, 3.2843e-02, ..., -1.8643e-02,\n", + " 2.1755e-03, -2.5375e-02]],\n", + " \n", + " [[-1.9703e-02, 1.9801e-02, -3.0810e-02, ..., -1.6574e-02,\n", + " -2.3212e-02, -2.3294e-02],\n", + " [ 1.2184e-02, -3.8564e-02, 3.4919e-02, ..., -3.7129e-02,\n", + " 5.0054e-04, 1.6151e-03],\n", + " [ 4.5318e-02, -7.4944e-03, 2.0338e-02, ..., -1.8937e-02,\n", + " -2.8308e-02, -2.0223e-03],\n", + " ...,\n", + " [ 5.4708e-02, -2.0650e-02, -1.2043e-02, ..., 1.2570e-02,\n", + " -7.4001e-03, -3.8468e-02],\n", + " [ 3.1326e-05, -4.1838e-02, -2.3395e-03, ..., -8.4410e-03,\n", + " 3.6380e-02, 4.8269e-02],\n", + " [-9.6509e-03, -5.2517e-02, 3.6821e-02, ..., -2.9391e-02,\n", + " -3.9250e-02, 2.0744e-02]]]], device='cuda:0')),\n", + " ('module.features.4.0.conv1.weight',\n", + " tensor([[[[-9.2462e-02, -4.6187e-02, -6.3962e-02],\n", + " [ 3.6464e-02, 3.1752e-02, 5.5634e-02],\n", + " [ 5.3728e-02, 2.3950e-02, 5.7112e-02]],\n", + " \n", + " [[ 1.6550e-02, 3.7590e-02, -5.6057e-02],\n", + " [ 4.2331e-02, -4.6298e-02, -4.1679e-02],\n", + " [ 9.8096e-03, -2.4877e-02, -5.0601e-02]],\n", + " \n", + " [[ 1.6585e-02, 5.2531e-02, -1.1046e-02],\n", + " [-2.8249e-02, -1.3678e-02, -1.1900e-01],\n", + " [ 4.3205e-02, 1.1884e-01, 1.7970e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 1.5649e-02, 1.3387e-02, -1.0417e-01],\n", + " [-1.3216e-01, 3.2272e-03, -4.3647e-02],\n", + " [-3.5717e-02, 1.0049e-01, 3.0284e-02]],\n", + " \n", + " [[-4.5162e-02, 2.3759e-02, -1.8583e-02],\n", + " [-1.3453e-02, -5.3536e-02, 3.5570e-03],\n", + " [ 8.8063e-02, -2.8492e-02, -3.5059e-04]],\n", + " \n", + " [[ 6.3269e-02, -2.6969e-02, -7.0333e-03],\n", + " [ 1.4697e-02, -1.0806e-02, -8.7952e-02],\n", + " [-3.8361e-03, 1.5422e-04, -1.3386e-02]]],\n", + " \n", + " \n", + " [[[ 2.2461e-02, -1.3477e-01, -6.0933e-02],\n", + " [-7.8706e-02, -1.1187e-01, 7.2931e-02],\n", + " [-1.0034e-02, -1.3897e-02, 1.0557e-01]],\n", + " \n", + " [[ 2.0277e-02, -2.3068e-02, 9.4356e-02],\n", + " [-8.1401e-02, -1.4270e-01, -8.5097e-02],\n", + " [-6.9182e-02, 8.8395e-02, -3.0874e-02]],\n", + " \n", + " [[-4.1500e-02, -5.5379e-02, 3.7250e-02],\n", + " [-8.8914e-03, -1.9892e-02, -4.7787e-02],\n", + " [-9.7006e-02, -1.7777e-02, -4.0463e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 9.2270e-02, 5.8624e-03, 6.5071e-02],\n", + " [ 1.1497e-01, 6.8832e-02, 3.2386e-02],\n", + " [-3.6723e-03, -7.2844e-04, 1.0449e-01]],\n", + " \n", + " [[ 4.2105e-02, 1.7258e-02, 2.7150e-02],\n", + " [-1.3325e-02, -3.0972e-02, -2.2033e-03],\n", + " [ 9.3903e-04, 2.1462e-02, 4.3957e-02]],\n", + " \n", + " [[ 9.2400e-03, 7.4915e-02, 1.8164e-02],\n", + " [ 1.8690e-02, 2.0276e-02, -2.9706e-02],\n", + " [ 4.1231e-02, 6.7357e-02, -1.1762e-01]]],\n", + " \n", + " \n", + " [[[ 1.1161e-01, 6.0464e-02, 6.2415e-02],\n", + " [ 3.0668e-02, -1.3328e-01, 1.4706e-02],\n", + " [-6.8011e-02, -3.7102e-02, -8.1162e-02]],\n", + " \n", + " [[-3.5553e-02, 5.0089e-03, 2.1187e-02],\n", + " [-6.3591e-02, -4.6052e-02, 3.4658e-02],\n", + " [-1.2683e-01, 3.3427e-02, 1.3262e-01]],\n", + " \n", + " [[ 3.0757e-02, -2.3997e-02, -6.3890e-02],\n", + " [-4.0926e-03, 5.7332e-02, 1.7442e-02],\n", + " [-2.9423e-02, -5.4034e-03, -9.0974e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 4.0161e-02, -4.6695e-02, 4.1360e-03],\n", + " [ 7.7046e-03, -3.3115e-02, -4.3606e-02],\n", + " [-1.1876e-01, 2.6893e-03, -4.9484e-02]],\n", + " \n", + " [[-1.2849e-02, 1.6794e-01, -6.0573e-03],\n", + " [-2.2438e-02, 9.0174e-03, 5.0475e-03],\n", + " [ 3.6589e-02, -3.5933e-02, 1.2792e-02]],\n", + " \n", + " [[ 5.5376e-02, 7.7030e-02, 8.7771e-02],\n", + " [ 5.4986e-03, 1.0659e-01, 1.6883e-03],\n", + " [-3.4126e-02, 1.3669e-01, -2.8325e-02]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[-7.7317e-03, -4.6668e-02, 7.5650e-02],\n", + " [ 3.5873e-02, -4.3000e-02, -2.7691e-02],\n", + " [ 1.0248e-01, 1.3898e-02, -1.3904e-02]],\n", + " \n", + " [[-8.6134e-02, 6.3448e-02, -2.8724e-02],\n", + " [-4.2872e-02, 4.5522e-02, 1.0647e-01],\n", + " [-4.8477e-02, 1.3277e-02, 5.6465e-02]],\n", + " \n", + " [[-8.3595e-02, 6.6930e-02, -5.8604e-02],\n", + " [-6.2126e-02, -7.1881e-03, -1.2867e-02],\n", + " [-1.7585e-02, -8.0980e-03, -1.4380e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-4.1383e-02, 6.8068e-03, -9.0357e-03],\n", + " [-6.6095e-02, 1.3102e-02, -7.3603e-04],\n", + " [ 8.8705e-03, 4.9332e-02, -2.4792e-03]],\n", + " \n", + " [[ 3.5180e-03, 1.0274e-01, 2.6740e-02],\n", + " [ 5.2504e-02, -3.7355e-02, 3.9357e-02],\n", + " [-3.7058e-02, 1.9517e-02, 5.6316e-02]],\n", + " \n", + " [[-2.6484e-02, 9.8777e-02, -4.1960e-02],\n", + " [ 7.4271e-02, -1.0216e-02, -5.2095e-02],\n", + " [ 6.0615e-02, 7.6350e-02, -4.5450e-02]]],\n", + " \n", + " \n", + " [[[ 3.4035e-02, 1.0943e-01, -1.3764e-01],\n", + " [-7.1371e-02, 5.8172e-02, -2.2341e-02],\n", + " [-3.0836e-02, 5.2470e-02, -6.0414e-02]],\n", + " \n", + " [[-2.3613e-02, -4.2929e-03, 7.0351e-02],\n", + " [ 5.8288e-02, 6.5354e-02, 5.5242e-02],\n", + " [-1.0364e-02, 2.7791e-02, -2.2508e-02]],\n", + " \n", + " [[ 8.9985e-02, 3.6585e-02, 1.3680e-02],\n", + " [-5.1424e-02, 4.0535e-02, 1.3652e-01],\n", + " [-7.4183e-02, 7.6157e-02, -1.9116e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-9.7462e-03, -5.6896e-02, -9.0187e-02],\n", + " [-2.2037e-02, 1.3929e-01, -4.4938e-02],\n", + " [-4.9996e-02, 4.9984e-02, -6.0025e-02]],\n", + " \n", + " [[ 2.5050e-02, -4.9562e-02, 7.2298e-02],\n", + " [ 7.2273e-02, 1.2818e-02, -3.1320e-02],\n", + " [-3.0858e-02, 3.4074e-02, 6.9350e-02]],\n", + " \n", + " [[ 3.9936e-02, 6.4482e-03, -1.4794e-02],\n", + " [ 1.4991e-03, 8.6426e-02, 6.8410e-02],\n", + " [-4.8722e-03, -3.9833e-02, 4.4456e-02]]],\n", + " \n", + " \n", + " [[[-5.7302e-03, -1.8632e-01, -2.2117e-02],\n", + " [-1.2020e-01, -6.2009e-02, 9.3804e-02],\n", + " [ 5.9880e-02, -2.3881e-02, -1.2358e-02]],\n", + " \n", + " [[-3.5485e-02, -1.1522e-03, -8.2160e-02],\n", + " [ 8.2682e-02, -2.3291e-02, -2.8066e-02],\n", + " [-1.9191e-02, 1.2396e-02, -8.7263e-02]],\n", + " \n", + " [[ 3.0666e-02, 6.9301e-02, 2.3709e-02],\n", + " [-3.1250e-02, 8.0222e-02, 2.2501e-02],\n", + " [ 3.4345e-02, 5.6890e-02, 5.9732e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-1.2429e-01, 6.9400e-02, -7.0972e-02],\n", + " [-7.4741e-02, -8.7197e-03, 6.0359e-03],\n", + " [ 8.8969e-02, 2.9732e-02, -7.2561e-02]],\n", + " \n", + " [[-1.1770e-02, 5.1681e-02, 6.5577e-02],\n", + " [-1.3407e-01, -1.3068e-01, -5.4128e-02],\n", + " [ 8.3333e-02, 7.6867e-02, -1.1552e-02]],\n", + " \n", + " [[-1.3286e-01, 5.4414e-02, -3.6517e-02],\n", + " [-5.2994e-02, 3.9329e-02, 1.1094e-02],\n", + " [-9.7109e-02, 6.7629e-02, -8.7167e-02]]]], device='cuda:0')),\n", + " ('module.features.4.0.conv2.weight',\n", + " tensor([[[[ 1.1233e-01, -2.2981e-03, 3.6705e-02],\n", + " [-1.2030e-02, -1.5351e-02, 9.2952e-02],\n", + " [-4.1985e-02, 5.4107e-02, -5.9251e-02]],\n", + " \n", + " [[ 8.1965e-02, -1.0954e-01, 6.8691e-02],\n", + " [ 6.0726e-02, 1.0515e-03, 1.0493e-01],\n", + " [-5.5332e-02, -3.6784e-02, 7.3365e-02]],\n", + " \n", + " [[-8.5047e-02, -2.0464e-02, -8.1830e-02],\n", + " [-2.3363e-02, 1.0971e-01, 7.4004e-02],\n", + " [-7.5470e-02, 8.6039e-02, -7.4229e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[ 4.2190e-02, 1.4027e-01, 5.8237e-02],\n", + " [ 5.6310e-02, -6.7712e-02, -6.2967e-02],\n", + " [-3.5065e-02, 3.8621e-02, -2.5502e-03]],\n", + " \n", + " [[ 2.2199e-02, -7.8476e-02, 8.9634e-03],\n", + " [ 2.7594e-02, 9.7943e-02, 5.6846e-02],\n", + " [-6.5797e-02, 4.9289e-03, -2.2984e-02]],\n", + " \n", + " [[-3.6827e-02, -7.8728e-03, -4.4337e-02],\n", + " [ 8.1624e-02, 8.3161e-03, 5.9610e-02],\n", + " [ 5.9407e-02, -3.9335e-02, -7.0567e-02]]],\n", + " \n", + " \n", + " [[[-8.1664e-02, 7.6529e-03, -3.4560e-03],\n", + " [ 6.5344e-02, -1.4775e-01, -6.3836e-02],\n", + " [-3.1269e-03, -2.9126e-02, -1.3252e-01]],\n", + " \n", + " [[ 8.3339e-02, -7.4031e-02, 8.3576e-02],\n", + " [ 5.1510e-03, 7.2560e-02, -2.5081e-02],\n", + " [ 3.1496e-02, 1.3989e-02, -1.6677e-02]],\n", + " \n", + " [[-7.5721e-02, -2.4193e-02, 1.4763e-01],\n", + " [-2.1573e-02, -4.5769e-03, -5.6464e-02],\n", + " [-6.6852e-02, -2.0293e-02, -3.7721e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 1.2082e-02, -3.0667e-02, 3.9955e-02],\n", + " [-3.5713e-02, 7.0917e-03, 1.1693e-01],\n", + " [-1.1999e-01, -2.3686e-02, 1.0753e-01]],\n", + " \n", + " [[-8.8364e-02, 7.2671e-02, 1.1850e-02],\n", + " [ 4.0822e-02, 9.6776e-02, -7.8760e-02],\n", + " [ 1.0391e-01, 1.5550e-02, 1.0867e-01]],\n", + " \n", + " [[-4.8628e-02, -1.1506e-02, 5.5846e-02],\n", + " [-6.7405e-04, 1.8967e-02, 5.6449e-03],\n", + " [-2.2366e-02, 7.1308e-02, 1.0016e-02]]],\n", + " \n", + " \n", + " [[[-2.1073e-02, -7.4765e-02, 4.0197e-02],\n", + " [ 6.5436e-02, 6.8928e-02, -3.0027e-02],\n", + " [-1.6745e-02, -2.1730e-02, 8.4882e-02]],\n", + " \n", + " [[ 1.1282e-01, -2.6245e-02, 6.4570e-02],\n", + " [-1.1578e-01, -1.5154e-02, 2.7958e-02],\n", + " [ 3.0415e-03, 4.7246e-02, 1.2109e-02]],\n", + " \n", + " [[ 1.1064e-02, -1.1093e-01, -2.7692e-02],\n", + " [-4.2947e-02, 3.4327e-02, 3.7007e-02],\n", + " [-3.0342e-02, 7.2168e-03, 1.9143e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-6.0603e-02, 1.8067e-01, 9.0881e-03],\n", + " [-5.0757e-02, 3.4056e-04, 6.2487e-02],\n", + " [-3.6068e-02, 4.6166e-02, 8.6541e-02]],\n", + " \n", + " [[ 6.9447e-02, -3.0232e-03, -1.4447e-02],\n", + " [-1.1953e-01, 3.6767e-02, 2.4693e-02],\n", + " [-1.2821e-01, -6.6559e-03, -4.7528e-02]],\n", + " \n", + " [[-4.7991e-02, 3.7157e-02, 1.9292e-02],\n", + " [-3.7560e-03, -7.4758e-02, 6.9171e-03],\n", + " [ 5.4880e-03, -1.0589e-01, 5.9222e-02]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 7.4470e-02, -3.2481e-02, 4.1360e-02],\n", + " [-3.6666e-03, 2.7318e-02, -2.3968e-02],\n", + " [ 9.2488e-02, -3.6974e-02, 9.5300e-03]],\n", + " \n", + " [[ 7.9547e-02, 2.0034e-02, -1.9778e-02],\n", + " [ 7.4181e-03, 6.6158e-02, -2.6734e-02],\n", + " [ 1.8545e-02, -1.0150e-01, 4.9060e-02]],\n", + " \n", + " [[-4.1054e-02, -1.2121e-01, 6.2199e-02],\n", + " [ 2.2207e-02, -1.8837e-02, 1.0597e-01],\n", + " [-3.9223e-02, 7.7222e-02, -2.2536e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-2.1562e-02, -5.3268e-03, 1.2969e-01],\n", + " [ 3.4468e-02, -6.5299e-02, 9.2592e-02],\n", + " [ 1.6014e-02, 5.5946e-02, -2.9213e-02]],\n", + " \n", + " [[ 1.1376e-02, -1.4155e-02, -7.2439e-02],\n", + " [-1.5408e-02, -2.0305e-02, 5.3932e-02],\n", + " [ 6.8005e-02, 6.4583e-02, 1.0505e-01]],\n", + " \n", + " [[-6.2856e-03, -6.9690e-03, 7.1899e-02],\n", + " [-5.6182e-03, -6.7596e-02, 8.8580e-02],\n", + " [-5.1786e-02, 5.0984e-02, -4.0118e-02]]],\n", + " \n", + " \n", + " [[[-1.2869e-01, -9.6242e-03, -3.9101e-02],\n", + " [ 1.1690e-01, -2.0011e-02, 1.7888e-01],\n", + " [ 1.3289e-01, -3.4749e-02, -4.5241e-02]],\n", + " \n", + " [[ 4.1905e-02, -1.0684e-01, -4.4939e-02],\n", + " [-8.2466e-03, 3.0330e-02, 1.9460e-03],\n", + " [-1.0260e-01, -4.2263e-02, -5.0739e-02]],\n", + " \n", + " [[-1.4587e-02, -1.7936e-02, 2.9521e-02],\n", + " [-1.3464e-01, 4.0443e-02, 4.9810e-02],\n", + " [ 4.7797e-02, 1.4375e-02, 8.8259e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 9.2974e-03, 2.9998e-02, -3.2140e-02],\n", + " [ 8.5341e-02, -2.1467e-02, 5.0297e-02],\n", + " [ 7.9830e-02, 1.3020e-02, -3.3859e-02]],\n", + " \n", + " [[-3.1406e-02, -3.9117e-02, -1.1063e-01],\n", + " [ 2.8238e-02, -8.2110e-02, 2.4808e-02],\n", + " [ 4.7568e-02, -2.1383e-01, -9.1426e-03]],\n", + " \n", + " [[-2.0197e-02, -9.5517e-06, 4.0441e-02],\n", + " [ 6.8926e-02, 1.9768e-02, 2.5460e-02],\n", + " [ 4.4746e-02, 1.7082e-02, 3.0773e-02]]],\n", + " \n", + " \n", + " [[[ 6.1841e-02, -6.5545e-03, 5.5460e-03],\n", + " [-1.4118e-02, 1.2440e-02, -7.2500e-02],\n", + " [-1.3664e-03, -6.3462e-02, -5.2453e-02]],\n", + " \n", + " [[-2.1329e-02, 8.8623e-02, 1.7045e-02],\n", + " [ 2.5736e-02, 2.4333e-02, 9.1462e-02],\n", + " [ 1.6656e-02, -7.4451e-02, 6.4044e-02]],\n", + " \n", + " [[ 1.8426e-02, -2.6405e-02, 1.0379e-01],\n", + " [-1.7170e-02, 9.1147e-03, 3.7215e-02],\n", + " [-1.1100e-01, -4.4881e-02, 8.8977e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 2.5414e-02, -5.1767e-03, 1.1035e-02],\n", + " [ 7.9268e-02, -2.5036e-02, -4.6462e-02],\n", + " [-1.4646e-01, -2.6711e-02, -1.0542e-01]],\n", + " \n", + " [[-4.2877e-02, -8.2058e-02, -4.6482e-02],\n", + " [ 6.8054e-03, -1.4620e-02, -5.4781e-02],\n", + " [ 1.0607e-02, -8.7744e-03, -5.6498e-02]],\n", + " \n", + " [[-6.3612e-03, -2.1136e-02, 1.2030e-02],\n", + " [-1.0481e-02, -2.1997e-02, -9.0608e-03],\n", + " [-3.8871e-02, -6.0018e-04, 1.0755e-01]]]], device='cuda:0')),\n", + " ('module.features.4.1.conv1.weight',\n", + " tensor([[[[-2.9280e-02, 5.7228e-02, 3.5699e-02],\n", + " [-2.3619e-02, 1.8647e-02, 8.9333e-02],\n", + " [-1.3515e-02, -3.3410e-02, 6.6962e-02]],\n", + " \n", + " [[-2.7122e-02, -9.0241e-03, -8.4448e-02],\n", + " [ 1.0711e-02, 7.1210e-02, 2.8500e-02],\n", + " [-3.2766e-02, -1.5924e-02, 9.2216e-02]],\n", + " \n", + " [[ 4.7169e-02, -7.6823e-02, -7.7361e-02],\n", + " [-1.5131e-02, -8.4519e-03, 1.7770e-02],\n", + " [ 9.1017e-02, 1.6489e-01, -1.1881e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-2.6887e-02, 1.1245e-02, 4.5791e-03],\n", + " [-3.6761e-02, 1.2582e-01, 1.1945e-02],\n", + " [ 6.3552e-02, -4.0764e-02, 7.0398e-02]],\n", + " \n", + " [[ 6.4111e-03, 5.8965e-02, -6.6704e-02],\n", + " [ 7.0903e-03, -4.4084e-02, 5.8607e-03],\n", + " [-7.9157e-02, 8.0710e-02, 5.3255e-02]],\n", + " \n", + " [[ 4.9754e-02, -3.8833e-02, 5.9919e-02],\n", + " [ 9.8103e-03, -4.2643e-04, 3.0538e-02],\n", + " [ 9.3018e-02, -5.5808e-02, 1.5056e-02]]],\n", + " \n", + " \n", + " [[[ 8.1160e-04, -8.5566e-02, -3.4079e-02],\n", + " [ 3.5419e-02, -1.5424e-02, -5.0355e-02],\n", + " [-3.8926e-02, -6.8501e-02, 9.5525e-03]],\n", + " \n", + " [[ 1.3876e-02, -3.5668e-02, -1.5049e-02],\n", + " [ 4.1467e-04, -9.9467e-03, 4.6658e-02],\n", + " [-5.0346e-03, 3.1488e-02, -1.3610e-04]],\n", + " \n", + " [[ 3.4313e-02, 6.0405e-02, 7.3079e-02],\n", + " [ 3.0630e-02, 2.2675e-03, -1.3825e-02],\n", + " [ 1.9256e-02, -8.3649e-02, -1.9868e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-2.7544e-02, -7.5189e-02, 1.2101e-02],\n", + " [-5.2104e-02, -8.1663e-02, 2.3889e-02],\n", + " [-3.1997e-02, 1.0296e-01, 4.5187e-03]],\n", + " \n", + " [[ 4.1556e-02, 8.7092e-02, 1.0084e-02],\n", + " [ 2.0435e-02, 1.4916e-02, -5.1732e-02],\n", + " [ 5.2304e-02, 3.0570e-02, -2.0941e-02]],\n", + " \n", + " [[ 2.3169e-02, -4.5008e-02, -3.2341e-02],\n", + " [ 1.0699e-02, -8.6836e-02, 1.9169e-02],\n", + " [ 1.5688e-02, 1.6211e-01, 5.3303e-02]]],\n", + " \n", + " \n", + " [[[-6.5159e-03, 9.3245e-03, -1.1683e-02],\n", + " [-1.0310e-01, 7.4767e-02, 2.9678e-02],\n", + " [-3.9806e-03, -1.0650e-01, 8.1413e-02]],\n", + " \n", + " [[-5.1333e-02, 1.1462e-01, -9.6628e-02],\n", + " [ 1.0913e-01, -7.2096e-03, -8.1218e-02],\n", + " [ 2.6997e-02, -3.9076e-02, 1.6155e-03]],\n", + " \n", + " [[ 4.3165e-03, 6.9659e-02, -2.2012e-02],\n", + " [-7.1076e-03, 1.0487e-02, -1.6196e-02],\n", + " [-2.0985e-02, 5.7500e-02, 4.4786e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-4.9050e-02, 1.3593e-02, 1.6235e-01],\n", + " [ 9.1918e-02, -1.8695e-02, 4.6067e-02],\n", + " [-1.2878e-02, -3.5153e-02, -3.4204e-03]],\n", + " \n", + " [[ 9.2478e-02, -8.1305e-02, 1.1401e-01],\n", + " [-5.5514e-02, 3.0807e-04, 2.9226e-02],\n", + " [-3.5388e-03, -1.0554e-02, 4.9842e-02]],\n", + " \n", + " [[-1.6965e-03, 6.3461e-02, -4.4206e-04],\n", + " [-6.7280e-02, 2.2864e-02, -2.1471e-02],\n", + " [-2.7613e-02, 3.8291e-02, 1.5294e-01]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 3.5842e-02, 7.2023e-02, 4.1372e-03],\n", + " [-5.6176e-03, -2.5475e-02, 7.2281e-02],\n", + " [ 5.3528e-02, -3.8676e-02, 5.2703e-03]],\n", + " \n", + " [[ 1.8223e-02, 6.9698e-02, -1.7604e-02],\n", + " [ 6.9803e-03, -5.7443e-02, -3.3450e-02],\n", + " [-4.4618e-03, 1.8633e-02, 1.2111e-01]],\n", + " \n", + " [[-1.1872e-01, -1.0022e-01, 2.0097e-02],\n", + " [ 1.9121e-02, 3.3582e-02, -4.3687e-02],\n", + " [-1.7471e-02, 4.4938e-02, -6.1471e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 6.0367e-03, 7.5599e-04, 8.8230e-02],\n", + " [-2.7833e-02, 1.1672e-01, -5.7861e-03],\n", + " [ 5.0682e-02, 3.0452e-02, -1.0254e-01]],\n", + " \n", + " [[ 3.3197e-02, 5.2363e-02, 4.0486e-02],\n", + " [-1.4445e-02, -1.1716e-02, -1.8212e-02],\n", + " [-1.2584e-02, -6.3745e-02, 4.8277e-02]],\n", + " \n", + " [[ 3.4524e-02, 4.8264e-02, 3.4181e-02],\n", + " [-2.1468e-02, -6.5613e-02, -4.3188e-02],\n", + " [ 5.4996e-03, -3.5989e-02, 7.9056e-03]]],\n", + " \n", + " \n", + " [[[-1.5350e-02, -3.6039e-02, -1.0346e-01],\n", + " [ 1.0253e-01, 4.1605e-02, -4.7496e-02],\n", + " [-1.3337e-02, 2.6657e-04, -1.8195e-02]],\n", + " \n", + " [[-5.3061e-02, -3.5731e-02, 3.0896e-02],\n", + " [-7.0008e-02, 8.0442e-02, -5.3065e-02],\n", + " [ 3.8142e-02, -7.9275e-02, -5.3120e-02]],\n", + " \n", + " [[-8.0647e-02, -2.1549e-02, -3.9406e-02],\n", + " [-1.5421e-02, 2.9551e-03, -4.9613e-02],\n", + " [-4.2970e-02, -1.4718e-02, -1.8991e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 1.2678e-02, -5.3625e-02, 2.6030e-02],\n", + " [-5.3147e-02, 2.6828e-02, 6.6291e-02],\n", + " [ 9.3168e-02, 1.2636e-02, 2.5365e-02]],\n", + " \n", + " [[-3.3369e-02, 1.1124e-02, -1.9820e-02],\n", + " [ 2.1811e-02, 6.4112e-03, 4.1800e-02],\n", + " [ 6.0804e-02, 7.5496e-02, 2.2505e-02]],\n", + " \n", + " [[ 2.9828e-02, -1.1538e-01, 7.6337e-02],\n", + " [-1.5814e-01, -3.1391e-02, -1.1998e-02],\n", + " [-4.7611e-02, -1.5590e-02, 1.2594e-02]]],\n", + " \n", + " \n", + " [[[-4.9918e-02, 1.6541e-02, -2.1747e-02],\n", + " [-5.1807e-02, 7.8236e-02, 4.5203e-02],\n", + " [ 1.6146e-01, 3.4237e-02, 1.7932e-03]],\n", + " \n", + " [[-5.0599e-02, 6.9354e-02, 4.7455e-02],\n", + " [-5.7678e-02, -5.2270e-02, 6.2546e-02],\n", + " [-5.2623e-02, -5.8615e-02, -5.7776e-03]],\n", + " \n", + " [[ 2.8584e-02, -4.6263e-03, 3.0092e-02],\n", + " [-4.2295e-02, -1.5103e-01, 9.4677e-02],\n", + " [ 7.6677e-02, 1.8689e-02, 1.9354e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 8.5414e-02, -4.9813e-02, -1.5877e-02],\n", + " [-5.5522e-02, -5.0945e-02, -8.9750e-03],\n", + " [-8.4041e-02, 4.0588e-02, -6.0091e-02]],\n", + " \n", + " [[-1.2544e-01, -9.1728e-02, 1.3723e-02],\n", + " [-8.0662e-02, 1.0641e-01, 3.4712e-02],\n", + " [-1.5880e-02, -3.3900e-02, -5.5494e-02]],\n", + " \n", + " [[-5.7554e-02, -7.3517e-02, 2.2063e-02],\n", + " [-2.3512e-02, -4.2891e-02, -7.6788e-02],\n", + " [ 7.7499e-03, 6.6297e-02, 1.2341e-01]]]], device='cuda:0')),\n", + " ('module.features.4.1.conv2.weight',\n", + " tensor([[[[-0.0525, -0.0483, -0.0494],\n", + " [ 0.0247, -0.0368, -0.0704],\n", + " [-0.0005, 0.0078, -0.0628]],\n", + " \n", + " [[ 0.0117, 0.0538, -0.1291],\n", + " [ 0.0146, 0.0787, -0.0116],\n", + " [-0.0727, 0.0086, -0.0126]],\n", + " \n", + " [[-0.0127, -0.0116, 0.0342],\n", + " [ 0.0765, 0.0809, 0.0526],\n", + " [-0.0483, 0.0173, 0.0572]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0352, -0.0164, -0.0248],\n", + " [ 0.0715, 0.0104, 0.0451],\n", + " [-0.0418, 0.0218, -0.0910]],\n", + " \n", + " [[ 0.0437, 0.0141, -0.0259],\n", + " [ 0.0845, -0.0283, -0.0529],\n", + " [-0.0938, -0.0099, 0.0636]],\n", + " \n", + " [[-0.0853, -0.0144, -0.0787],\n", + " [-0.0351, 0.0832, -0.0776],\n", + " [ 0.1014, -0.0103, -0.1094]]],\n", + " \n", + " \n", + " [[[ 0.0730, -0.0990, 0.0356],\n", + " [-0.0599, 0.0117, -0.0096],\n", + " [ 0.0253, 0.1055, 0.1380]],\n", + " \n", + " [[ 0.0712, 0.1090, 0.0153],\n", + " [-0.0768, -0.0915, 0.0155],\n", + " [ 0.0987, -0.0024, -0.1088]],\n", + " \n", + " [[-0.0386, 0.0034, -0.0279],\n", + " [ 0.0505, -0.0153, 0.0605],\n", + " [-0.0701, 0.0598, 0.0538]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0071, 0.0587, -0.0383],\n", + " [-0.0250, -0.0264, 0.0615],\n", + " [ 0.1125, 0.0474, 0.0286]],\n", + " \n", + " [[-0.0069, 0.0343, -0.0296],\n", + " [ 0.0915, 0.0430, 0.1088],\n", + " [-0.1369, -0.0314, 0.0983]],\n", + " \n", + " [[-0.0319, 0.0417, -0.0538],\n", + " [-0.0282, 0.0577, -0.0480],\n", + " [ 0.0273, -0.0052, -0.0819]]],\n", + " \n", + " \n", + " [[[ 0.0803, -0.0615, 0.0542],\n", + " [ 0.0250, -0.0350, -0.0608],\n", + " [-0.0744, -0.0055, 0.0029]],\n", + " \n", + " [[ 0.0872, 0.0505, 0.0315],\n", + " [-0.0831, -0.0840, -0.0597],\n", + " [ 0.0237, -0.1176, -0.0849]],\n", + " \n", + " [[-0.0290, 0.0226, 0.0660],\n", + " [ 0.1088, -0.0424, -0.0558],\n", + " [ 0.0751, 0.0066, 0.0674]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.1043, -0.0803, -0.0286],\n", + " [-0.0090, 0.0550, -0.0403],\n", + " [-0.1595, -0.1149, -0.0509]],\n", + " \n", + " [[-0.1280, -0.0498, -0.0465],\n", + " [ 0.0072, -0.1675, 0.0868],\n", + " [-0.0774, 0.0304, 0.0038]],\n", + " \n", + " [[ 0.0237, 0.0396, -0.0633],\n", + " [ 0.1348, 0.0573, -0.0941],\n", + " [ 0.0403, -0.0493, -0.0018]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[-0.0720, -0.0048, -0.0738],\n", + " [ 0.0150, 0.0566, 0.0630],\n", + " [-0.0041, 0.0416, 0.0473]],\n", + " \n", + " [[ 0.0707, 0.0833, -0.0563],\n", + " [-0.0464, -0.0921, -0.0284],\n", + " [ 0.0428, 0.0079, 0.0133]],\n", + " \n", + " [[-0.0681, -0.0689, 0.0829],\n", + " [ 0.1170, 0.0167, 0.0201],\n", + " [ 0.0783, 0.0763, 0.0563]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0164, 0.0350, -0.0013],\n", + " [ 0.0690, 0.0359, -0.0513],\n", + " [ 0.0730, -0.0891, -0.0222]],\n", + " \n", + " [[-0.0428, -0.0946, 0.0147],\n", + " [-0.0037, -0.0407, -0.0290],\n", + " [ 0.0189, 0.0446, 0.0013]],\n", + " \n", + " [[-0.0965, -0.0170, 0.0327],\n", + " [-0.0066, 0.0590, 0.0330],\n", + " [-0.0344, -0.1080, -0.0929]]],\n", + " \n", + " \n", + " [[[ 0.0411, -0.0099, 0.0199],\n", + " [-0.0871, -0.0102, -0.2179],\n", + " [ 0.0105, -0.0675, 0.0096]],\n", + " \n", + " [[ 0.0683, -0.0130, -0.0562],\n", + " [ 0.0730, -0.0939, 0.0569],\n", + " [-0.0503, 0.0872, -0.0596]],\n", + " \n", + " [[-0.0970, 0.0281, 0.0215],\n", + " [ 0.0463, 0.0242, -0.0812],\n", + " [-0.0824, 0.0101, -0.0548]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0455, -0.0058, -0.0019],\n", + " [-0.0193, -0.0986, -0.0407],\n", + " [ 0.0216, -0.0313, 0.0442]],\n", + " \n", + " [[-0.0578, 0.0639, -0.0347],\n", + " [ 0.0483, 0.0167, 0.0356],\n", + " [-0.0884, -0.0625, -0.0573]],\n", + " \n", + " [[-0.0386, -0.0107, 0.0538],\n", + " [-0.0215, 0.0030, -0.0279],\n", + " [-0.0193, 0.1219, 0.0516]]],\n", + " \n", + " \n", + " [[[ 0.0699, 0.0497, -0.0102],\n", + " [ 0.0046, 0.0519, 0.0270],\n", + " [ 0.0369, 0.0953, -0.0288]],\n", + " \n", + " [[-0.0620, 0.0500, -0.1316],\n", + " [-0.0377, -0.0071, -0.0139],\n", + " [-0.0591, 0.0661, 0.2031]],\n", + " \n", + " [[-0.0262, 0.0128, 0.0796],\n", + " [ 0.0171, -0.0781, -0.0751],\n", + " [ 0.0560, -0.0993, -0.1257]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0683, -0.0576, -0.0644],\n", + " [-0.1106, -0.0743, -0.0878],\n", + " [ 0.0077, 0.0252, -0.0271]],\n", + " \n", + " [[ 0.1365, -0.0229, -0.0237],\n", + " [-0.0245, -0.0334, -0.0210],\n", + " [ 0.0896, 0.0498, 0.0945]],\n", + " \n", + " [[ 0.0300, 0.0274, -0.0963],\n", + " [-0.0513, 0.0832, -0.0052],\n", + " [-0.0037, -0.0797, -0.0482]]]], device='cuda:0')),\n", + " ('module.features.5.0.conv1.weight',\n", + " tensor([[[[-0.0792, 0.0218, -0.0899],\n", + " [-0.0803, -0.0315, 0.0240],\n", + " [-0.0841, -0.0110, -0.0109]],\n", + " \n", + " [[-0.0315, -0.0697, -0.0428],\n", + " [ 0.0572, 0.0261, -0.0217],\n", + " [ 0.0151, 0.0978, -0.0195]],\n", + " \n", + " [[ 0.0309, -0.0566, 0.0163],\n", + " [ 0.0194, -0.1011, -0.0228],\n", + " [-0.0361, -0.0042, -0.0763]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0801, 0.0180, 0.0183],\n", + " [-0.0507, -0.0176, -0.0653],\n", + " [-0.0196, 0.0873, -0.0575]],\n", + " \n", + " [[-0.0665, -0.0464, -0.0791],\n", + " [ 0.0303, -0.0349, 0.0325],\n", + " [ 0.0673, 0.0472, 0.0432]],\n", + " \n", + " [[ 0.0340, 0.0556, -0.0336],\n", + " [ 0.0384, 0.0019, -0.0620],\n", + " [-0.0209, 0.0068, 0.0490]]],\n", + " \n", + " \n", + " [[[-0.0167, -0.0249, -0.0185],\n", + " [-0.0041, -0.0665, -0.0203],\n", + " [ 0.0526, 0.0094, -0.0457]],\n", + " \n", + " [[ 0.0726, 0.0042, 0.0038],\n", + " [-0.0489, 0.0161, -0.0138],\n", + " [-0.0468, -0.0317, 0.0218]],\n", + " \n", + " [[-0.0506, 0.0394, 0.0849],\n", + " [-0.0749, 0.0147, 0.0096],\n", + " [ 0.0436, 0.0361, 0.0326]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0211, 0.0249, 0.0460],\n", + " [ 0.0015, -0.0474, -0.0434],\n", + " [ 0.0055, 0.0315, 0.0097]],\n", + " \n", + " [[-0.0690, 0.0033, 0.0164],\n", + " [-0.0765, 0.0196, -0.0821],\n", + " [ 0.0495, -0.0271, -0.0655]],\n", + " \n", + " [[ 0.0337, -0.0879, 0.0041],\n", + " [-0.0020, -0.0018, 0.0143],\n", + " [-0.0319, -0.0122, -0.0044]]],\n", + " \n", + " \n", + " [[[ 0.0778, -0.0654, 0.0032],\n", + " [ 0.0187, -0.0885, -0.0015],\n", + " [ 0.0671, 0.1071, -0.0319]],\n", + " \n", + " [[ 0.1421, -0.0081, -0.0021],\n", + " [-0.0644, -0.0045, 0.0067],\n", + " [-0.0229, -0.0169, 0.0082]],\n", + " \n", + " [[-0.0393, 0.0062, -0.0046],\n", + " [ 0.0598, -0.0135, -0.0100],\n", + " [ 0.0283, -0.0384, 0.0068]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0349, -0.0346, 0.0055],\n", + " [-0.0319, -0.0370, -0.0279],\n", + " [ 0.0458, 0.0806, 0.0086]],\n", + " \n", + " [[ 0.0148, 0.0105, -0.0311],\n", + " [-0.0032, 0.0306, 0.0248],\n", + " [-0.0445, -0.0083, 0.0084]],\n", + " \n", + " [[ 0.0659, -0.0195, -0.0173],\n", + " [-0.0862, 0.0108, 0.0530],\n", + " [ 0.0390, 0.0169, -0.0187]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[-0.0533, 0.0452, -0.0518],\n", + " [ 0.0509, 0.0094, -0.0395],\n", + " [ 0.0027, -0.0050, -0.0940]],\n", + " \n", + " [[ 0.0047, -0.0107, 0.0322],\n", + " [ 0.0044, -0.0472, 0.0117],\n", + " [ 0.0032, -0.0558, 0.0117]],\n", + " \n", + " [[ 0.0715, -0.0191, -0.0130],\n", + " [-0.0261, -0.0300, 0.0227],\n", + " [-0.0532, 0.0113, 0.0065]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0201, -0.0292, -0.0152],\n", + " [ 0.0303, -0.0559, -0.0149],\n", + " [-0.0054, 0.0347, 0.1112]],\n", + " \n", + " [[ 0.0415, 0.0173, 0.0301],\n", + " [-0.0307, 0.0392, -0.0117],\n", + " [ 0.0257, 0.0229, 0.0593]],\n", + " \n", + " [[-0.0990, -0.0523, -0.0409],\n", + " [-0.0661, -0.0069, 0.0011],\n", + " [-0.0444, 0.0089, 0.0445]]],\n", + " \n", + " \n", + " [[[ 0.0045, 0.0149, 0.0122],\n", + " [ 0.0050, -0.0330, 0.0177],\n", + " [-0.0690, -0.0824, -0.0246]],\n", + " \n", + " [[-0.0881, -0.0018, -0.0066],\n", + " [ 0.0561, -0.0397, -0.0240],\n", + " [-0.0405, 0.0066, -0.0229]],\n", + " \n", + " [[ 0.0256, 0.0282, 0.0244],\n", + " [-0.0167, -0.0100, -0.0101],\n", + " [-0.0276, 0.0515, -0.0074]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0582, 0.0027, -0.0263],\n", + " [ 0.0084, 0.0272, 0.0462],\n", + " [ 0.0278, -0.0378, -0.0446]],\n", + " \n", + " [[ 0.0310, 0.0112, -0.0012],\n", + " [ 0.0115, 0.0438, -0.0174],\n", + " [-0.1078, -0.0189, -0.0014]],\n", + " \n", + " [[-0.0506, -0.0164, 0.0329],\n", + " [ 0.0816, -0.0127, 0.0256],\n", + " [-0.0311, -0.0202, -0.0431]]],\n", + " \n", + " \n", + " [[[ 0.0184, 0.0127, 0.0492],\n", + " [-0.0172, -0.0385, -0.0424],\n", + " [-0.0818, -0.0270, -0.0116]],\n", + " \n", + " [[ 0.0077, 0.0117, -0.0061],\n", + " [ 0.0166, 0.0163, 0.0286],\n", + " [-0.0163, -0.0531, 0.0770]],\n", + " \n", + " [[ 0.0229, -0.0362, -0.0435],\n", + " [ 0.0539, 0.0568, 0.0706],\n", + " [-0.0477, 0.0183, 0.0310]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0473, -0.0633, 0.0155],\n", + " [ 0.0071, -0.0229, -0.0209],\n", + " [-0.0374, -0.0606, -0.0541]],\n", + " \n", + " [[ 0.0055, -0.0027, -0.0049],\n", + " [ 0.0064, 0.0350, -0.0610],\n", + " [ 0.0301, 0.0102, -0.0355]],\n", + " \n", + " [[-0.0422, -0.0496, 0.0068],\n", + " [-0.0090, -0.0634, -0.0383],\n", + " [-0.0983, -0.0244, -0.0193]]]], device='cuda:0')),\n", + " ('module.features.5.0.conv2.weight',\n", + " tensor([[[[-0.0185, -0.0055, 0.0204],\n", + " [-0.0669, 0.0242, 0.0155],\n", + " [-0.0176, 0.0231, 0.0618]],\n", + " \n", + " [[ 0.0329, -0.0596, 0.0462],\n", + " [ 0.0019, 0.0363, 0.0510],\n", + " [ 0.0255, -0.0271, 0.0377]],\n", + " \n", + " [[-0.0379, -0.0469, 0.0030],\n", + " [-0.0445, -0.0152, -0.0425],\n", + " [-0.0020, 0.0046, 0.0034]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0956, -0.0019, -0.0946],\n", + " [ 0.0631, -0.0316, -0.0266],\n", + " [-0.0110, -0.1127, -0.0501]],\n", + " \n", + " [[ 0.0361, 0.0123, -0.0237],\n", + " [-0.0586, -0.0240, 0.0603],\n", + " [-0.0407, -0.0967, -0.0087]],\n", + " \n", + " [[-0.0953, -0.0075, 0.0781],\n", + " [-0.0586, -0.0116, 0.0293],\n", + " [ 0.0420, -0.0406, -0.0262]]],\n", + " \n", + " \n", + " [[[-0.0963, -0.0383, 0.0114],\n", + " [-0.0260, 0.0145, -0.0099],\n", + " [ 0.0706, 0.0572, -0.0383]],\n", + " \n", + " [[-0.0160, -0.0009, -0.0596],\n", + " [ 0.0447, 0.0360, 0.0334],\n", + " [-0.0005, 0.0293, -0.0442]],\n", + " \n", + " [[-0.0596, -0.0388, 0.0065],\n", + " [-0.0359, 0.0717, -0.0545],\n", + " [ 0.0501, -0.0117, -0.0065]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0024, -0.0862, 0.0692],\n", + " [ 0.0315, 0.0387, -0.0095],\n", + " [ 0.0111, -0.0315, 0.0742]],\n", + " \n", + " [[-0.0501, -0.0397, 0.0452],\n", + " [-0.0079, 0.0929, 0.0956],\n", + " [-0.1150, 0.0330, -0.0026]],\n", + " \n", + " [[-0.0165, -0.0450, 0.0579],\n", + " [-0.0096, 0.0519, 0.0432],\n", + " [-0.0026, -0.0358, 0.0526]]],\n", + " \n", + " \n", + " [[[-0.0022, -0.0217, -0.0499],\n", + " [-0.0221, 0.0151, -0.0570],\n", + " [ 0.0224, 0.0505, 0.0402]],\n", + " \n", + " [[ 0.0356, 0.0084, 0.0051],\n", + " [-0.0006, -0.0410, -0.0303],\n", + " [ 0.0270, 0.0788, -0.0720]],\n", + " \n", + " [[ 0.0262, -0.0168, -0.0006],\n", + " [ 0.0143, 0.0763, 0.0362],\n", + " [ 0.0824, 0.0376, -0.0052]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0246, 0.0417, -0.0386],\n", + " [ 0.0018, -0.0701, 0.0177],\n", + " [ 0.0582, 0.0484, 0.0029]],\n", + " \n", + " [[-0.0829, 0.0297, 0.1046],\n", + " [ 0.0008, 0.0256, -0.0059],\n", + " [-0.0159, 0.0485, 0.0155]],\n", + " \n", + " [[ 0.0113, 0.0179, -0.0300],\n", + " [-0.0117, -0.0168, -0.0579],\n", + " [ 0.0297, -0.0137, -0.0320]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 0.0665, -0.0067, -0.0355],\n", + " [-0.0112, -0.0227, 0.0079],\n", + " [-0.0444, -0.0311, -0.0343]],\n", + " \n", + " [[-0.0311, 0.1055, 0.0788],\n", + " [ 0.0565, -0.0003, -0.0352],\n", + " [-0.0467, 0.0318, -0.0082]],\n", + " \n", + " [[ 0.0322, -0.0605, -0.0607],\n", + " [ 0.0802, 0.0164, -0.0120],\n", + " [-0.0080, 0.0134, -0.0655]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0505, -0.0788, 0.1130],\n", + " [ 0.0141, 0.0362, -0.0042],\n", + " [-0.0597, 0.0082, 0.0053]],\n", + " \n", + " [[ 0.0281, -0.0423, 0.0122],\n", + " [-0.0923, -0.0106, 0.0446],\n", + " [-0.0557, -0.0728, -0.0367]],\n", + " \n", + " [[-0.0470, 0.0243, 0.0581],\n", + " [ 0.0270, -0.0034, -0.0219],\n", + " [ 0.0516, -0.0335, 0.0021]]],\n", + " \n", + " \n", + " [[[ 0.0429, -0.0169, -0.0063],\n", + " [ 0.0552, -0.0109, -0.0119],\n", + " [-0.0203, -0.0496, -0.0205]],\n", + " \n", + " [[-0.0178, 0.0129, 0.0213],\n", + " [-0.0029, -0.0199, -0.0105],\n", + " [ 0.0424, -0.0083, 0.0296]],\n", + " \n", + " [[-0.0600, 0.0293, -0.0311],\n", + " [-0.0541, -0.0449, -0.0008],\n", + " [ 0.0119, 0.0300, -0.0152]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0118, 0.0368, 0.0228],\n", + " [ 0.0384, 0.0141, -0.0213],\n", + " [-0.0364, -0.0394, 0.0516]],\n", + " \n", + " [[-0.0128, 0.0606, -0.0741],\n", + " [ 0.0360, 0.0348, 0.0741],\n", + " [-0.0049, -0.0374, 0.0287]],\n", + " \n", + " [[-0.0020, -0.0368, -0.0418],\n", + " [-0.0443, 0.0536, -0.0359],\n", + " [-0.0287, -0.0068, 0.0364]]],\n", + " \n", + " \n", + " [[[ 0.0482, 0.0189, -0.0050],\n", + " [-0.0643, 0.0241, 0.0150],\n", + " [-0.0293, 0.0479, -0.0457]],\n", + " \n", + " [[-0.0449, -0.0111, -0.0052],\n", + " [ 0.0568, -0.0459, 0.0137],\n", + " [-0.0959, 0.0218, -0.0872]],\n", + " \n", + " [[ 0.0153, -0.0173, -0.0511],\n", + " [-0.0229, -0.0133, 0.0028],\n", + " [-0.0202, 0.0880, -0.0106]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0198, 0.0526, 0.0231],\n", + " [-0.0164, 0.0388, -0.0761],\n", + " [-0.0426, 0.1007, -0.0563]],\n", + " \n", + " [[-0.0545, -0.0352, -0.0286],\n", + " [-0.0113, 0.0061, -0.0081],\n", + " [ 0.0563, -0.0457, 0.0216]],\n", + " \n", + " [[ 0.0377, 0.0722, 0.0403],\n", + " [ 0.0199, 0.0028, 0.0053],\n", + " [-0.0022, 0.0155, 0.0596]]]], device='cuda:0')),\n", + " ('module.features.5.0.downsample.0.weight',\n", + " tensor([[[[-0.1062]],\n", + " \n", + " [[ 0.0730]],\n", + " \n", + " [[-0.1620]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0529]],\n", + " \n", + " [[-0.1904]],\n", + " \n", + " [[-0.0081]]],\n", + " \n", + " \n", + " [[[-0.0272]],\n", + " \n", + " [[-0.0740]],\n", + " \n", + " [[-0.1000]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.1445]],\n", + " \n", + " [[-0.1923]],\n", + " \n", + " [[ 0.0117]]],\n", + " \n", + " \n", + " [[[ 0.0210]],\n", + " \n", + " [[-0.0334]],\n", + " \n", + " [[-0.0200]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0884]],\n", + " \n", + " [[-0.1663]],\n", + " \n", + " [[ 0.0249]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 0.0371]],\n", + " \n", + " [[-0.1020]],\n", + " \n", + " [[-0.1673]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0238]],\n", + " \n", + " [[-0.1666]],\n", + " \n", + " [[-0.0730]]],\n", + " \n", + " \n", + " [[[ 0.0192]],\n", + " \n", + " [[ 0.0836]],\n", + " \n", + " [[-0.2289]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.1308]],\n", + " \n", + " [[ 0.1366]],\n", + " \n", + " [[-0.0892]]],\n", + " \n", + " \n", + " [[[ 0.0063]],\n", + " \n", + " [[-0.0660]],\n", + " \n", + " [[-0.0632]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0167]],\n", + " \n", + " [[ 0.1447]],\n", + " \n", + " [[-0.2353]]]], device='cuda:0')),\n", + " ('module.features.5.1.conv1.weight',\n", + " tensor([[[[ 0.0524, 0.0397, 0.0304],\n", + " [-0.0118, 0.0106, -0.0503],\n", + " [-0.0191, 0.0437, -0.0011]],\n", + " \n", + " [[ 0.0521, -0.0468, -0.0494],\n", + " [-0.0680, 0.0069, 0.0577],\n", + " [ 0.0327, 0.0409, 0.0039]],\n", + " \n", + " [[-0.0154, 0.0447, 0.0069],\n", + " [-0.0919, -0.0604, -0.0296],\n", + " [ 0.0298, 0.0329, 0.0491]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0175, 0.0182, -0.0473],\n", + " [ 0.0262, 0.0171, -0.0327],\n", + " [-0.0018, -0.0248, 0.0211]],\n", + " \n", + " [[ 0.0266, 0.0194, 0.0980],\n", + " [-0.0007, 0.0087, -0.0767],\n", + " [-0.0205, -0.0228, 0.0293]],\n", + " \n", + " [[-0.0330, 0.0093, 0.0342],\n", + " [-0.0393, 0.0319, -0.0072],\n", + " [-0.0048, -0.0731, 0.0250]]],\n", + " \n", + " \n", + " [[[-0.0200, 0.0151, 0.0073],\n", + " [-0.0070, -0.0006, 0.0224],\n", + " [-0.0144, -0.0088, -0.0227]],\n", + " \n", + " [[-0.0728, 0.0609, -0.0212],\n", + " [ 0.0061, -0.0018, -0.0444],\n", + " [ 0.0704, -0.0230, 0.0283]],\n", + " \n", + " [[-0.0296, -0.0179, -0.0159],\n", + " [-0.0356, 0.0155, -0.0227],\n", + " [ 0.0409, -0.0235, -0.0307]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0399, -0.0184, -0.1100],\n", + " [-0.0362, -0.0091, -0.0157],\n", + " [ 0.0774, -0.0126, 0.0616]],\n", + " \n", + " [[ 0.0709, 0.0422, 0.0342],\n", + " [ 0.0445, 0.0818, -0.0921],\n", + " [ 0.0237, 0.0218, 0.0722]],\n", + " \n", + " [[-0.0024, -0.0132, 0.0135],\n", + " [-0.0443, -0.0957, 0.0015],\n", + " [-0.0534, -0.0437, -0.0510]]],\n", + " \n", + " \n", + " [[[ 0.0410, 0.0319, 0.0060],\n", + " [ 0.0066, 0.0273, -0.0037],\n", + " [-0.0130, -0.0355, -0.0566]],\n", + " \n", + " [[ 0.0547, 0.0008, 0.0495],\n", + " [ 0.0512, -0.0130, 0.0659],\n", + " [ 0.0532, 0.0160, 0.0034]],\n", + " \n", + " [[-0.0293, -0.0005, 0.0227],\n", + " [ 0.0014, -0.0253, -0.0078],\n", + " [-0.0341, 0.0066, -0.0061]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0146, 0.0126, -0.0179],\n", + " [-0.0398, -0.0328, 0.0236],\n", + " [ 0.0281, 0.0066, -0.0820]],\n", + " \n", + " [[ 0.0164, -0.0245, -0.0377],\n", + " [-0.0083, 0.0010, 0.0083],\n", + " [ 0.0375, -0.0253, -0.0360]],\n", + " \n", + " [[ 0.0161, -0.0290, 0.0348],\n", + " [ 0.0022, 0.0219, 0.0387],\n", + " [ 0.0141, -0.0822, 0.0453]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[-0.0518, 0.0440, 0.0333],\n", + " [ 0.0233, 0.0102, -0.0029],\n", + " [-0.0399, -0.0019, -0.0124]],\n", + " \n", + " [[ 0.0093, -0.0079, 0.0802],\n", + " [ 0.0376, 0.1255, 0.0040],\n", + " [-0.0386, -0.0161, 0.0193]],\n", + " \n", + " [[-0.1355, 0.0206, 0.0577],\n", + " [-0.0023, 0.0065, -0.0870],\n", + " [-0.0181, -0.0022, 0.0204]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0522, 0.0307, 0.0033],\n", + " [-0.0065, 0.0160, 0.0701],\n", + " [-0.0301, -0.0474, 0.0469]],\n", + " \n", + " [[-0.0221, -0.1105, 0.0512],\n", + " [ 0.0373, -0.0592, -0.0210],\n", + " [-0.0781, -0.0006, -0.0032]],\n", + " \n", + " [[ 0.0072, 0.0435, -0.0008],\n", + " [-0.0510, -0.0071, 0.0293],\n", + " [ 0.0355, 0.0196, 0.0561]]],\n", + " \n", + " \n", + " [[[ 0.0175, -0.0339, 0.0416],\n", + " [ 0.0016, -0.0096, -0.0434],\n", + " [-0.0544, -0.0349, 0.0061]],\n", + " \n", + " [[ 0.0194, -0.0126, -0.0320],\n", + " [-0.0303, 0.0254, 0.0027],\n", + " [ 0.0055, -0.0507, -0.0462]],\n", + " \n", + " [[-0.0646, 0.0249, -0.0058],\n", + " [-0.0026, 0.0673, -0.0211],\n", + " [ 0.0457, 0.0309, -0.0090]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0143, 0.0254, -0.0236],\n", + " [-0.0102, 0.0430, -0.0116],\n", + " [-0.0336, -0.0753, 0.0151]],\n", + " \n", + " [[-0.0260, 0.0030, 0.0509],\n", + " [ 0.0197, 0.0323, 0.0497],\n", + " [-0.0455, -0.0137, 0.0477]],\n", + " \n", + " [[ 0.0073, 0.0444, -0.0131],\n", + " [ 0.0492, 0.0232, -0.0728],\n", + " [ 0.0272, -0.0433, 0.0159]]],\n", + " \n", + " \n", + " [[[ 0.0218, 0.0135, -0.0812],\n", + " [-0.0456, 0.0107, -0.0395],\n", + " [-0.0329, 0.0774, -0.0346]],\n", + " \n", + " [[ 0.0271, 0.0287, 0.0262],\n", + " [-0.0314, 0.0923, 0.0007],\n", + " [ 0.0322, 0.0266, 0.0109]],\n", + " \n", + " [[-0.0257, -0.0564, 0.0011],\n", + " [ 0.1216, -0.0470, 0.0248],\n", + " [ 0.0016, -0.0504, 0.0032]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0228, 0.0184, -0.0226],\n", + " [ 0.0335, -0.0399, -0.0486],\n", + " [ 0.0249, 0.0200, -0.0146]],\n", + " \n", + " [[-0.0255, 0.0459, -0.0260],\n", + " [ 0.0220, -0.0481, 0.0306],\n", + " [-0.0272, -0.1053, -0.0381]],\n", + " \n", + " [[-0.0147, -0.0229, -0.0310],\n", + " [ 0.0103, 0.0320, -0.1259],\n", + " [-0.0652, -0.0111, -0.0101]]]], device='cuda:0')),\n", + " ('module.features.5.1.conv2.weight',\n", + " tensor([[[[-6.3954e-02, 4.1020e-02, -1.6813e-02],\n", + " [ 1.9134e-02, 3.3471e-02, -4.8216e-03],\n", + " [-5.1764e-02, 5.2896e-02, -3.0942e-03]],\n", + " \n", + " [[-4.1595e-02, 3.2443e-02, -6.2537e-03],\n", + " [ 8.8430e-03, -6.9640e-03, -5.4212e-04],\n", + " [-9.1833e-03, 4.6340e-02, -5.6299e-02]],\n", + " \n", + " [[-2.6058e-02, 1.7892e-02, 7.1810e-02],\n", + " [ 2.3703e-02, 5.2424e-02, 3.3864e-02],\n", + " [ 4.2138e-02, -2.6445e-02, -3.3645e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 3.0310e-03, 4.2958e-02, 2.0308e-02],\n", + " [-1.2983e-02, -9.3949e-03, 5.3939e-02],\n", + " [ 1.3553e-02, -4.4959e-05, 2.0351e-02]],\n", + " \n", + " [[ 3.3951e-02, 3.9457e-02, -8.2942e-02],\n", + " [-3.9516e-03, 3.5715e-02, -7.5140e-02],\n", + " [ 2.2566e-02, 5.9749e-02, 2.4178e-03]],\n", + " \n", + " [[-5.9489e-02, 1.2264e-02, -1.7448e-02],\n", + " [-4.6808e-02, -1.0983e-01, 5.9044e-03],\n", + " [-4.2222e-02, 3.5628e-02, -1.8010e-02]]],\n", + " \n", + " \n", + " [[[-5.9920e-02, -5.8497e-02, 1.9054e-02],\n", + " [ 1.8326e-02, 3.2683e-02, -1.1116e-01],\n", + " [ 3.9147e-02, 1.1690e-02, -2.0936e-02]],\n", + " \n", + " [[-1.5919e-02, -5.8673e-02, -7.0556e-02],\n", + " [-2.0246e-02, 2.8831e-03, -6.7587e-03],\n", + " [ 6.8233e-02, 4.5227e-02, -1.0763e-02]],\n", + " \n", + " [[ 1.0176e-03, -5.5830e-02, 2.1531e-03],\n", + " [-2.0116e-03, -5.8720e-02, -1.6871e-02],\n", + " [ 2.0126e-02, -2.0070e-02, -2.5659e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-6.4194e-03, 7.0518e-02, 9.0191e-04],\n", + " [-3.5949e-02, 9.2871e-02, 5.3642e-02],\n", + " [-2.0992e-02, -2.6704e-02, 5.2092e-02]],\n", + " \n", + " [[ 5.4812e-02, -1.6752e-03, -8.8176e-02],\n", + " [ 1.5927e-02, 1.6110e-03, 3.7639e-02],\n", + " [ 1.7232e-02, 4.7434e-02, -3.0740e-02]],\n", + " \n", + " [[-5.2121e-02, 3.7098e-02, 2.0256e-02],\n", + " [-6.0424e-02, 3.3092e-02, 5.1734e-02],\n", + " [-7.9362e-03, -3.2565e-03, -2.9208e-02]]],\n", + " \n", + " \n", + " [[[ 3.1005e-02, -1.9540e-02, 3.3992e-02],\n", + " [ 2.2688e-02, 5.2265e-02, -5.0763e-02],\n", + " [-6.6935e-03, -1.2542e-02, -8.2880e-03]],\n", + " \n", + " [[-2.8486e-02, 1.2028e-02, 3.1694e-02],\n", + " [ 2.8941e-02, 1.2840e-02, -5.0390e-02],\n", + " [ 7.9660e-02, 1.5672e-02, -3.8056e-02]],\n", + " \n", + " [[ 1.0606e-02, -4.1412e-02, -1.9782e-02],\n", + " [-3.5188e-02, 5.0918e-03, -4.2261e-02],\n", + " [-8.7299e-02, 6.5995e-02, -2.9643e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 7.1659e-02, -6.0586e-02, -9.8954e-03],\n", + " [ 5.6954e-02, 5.5151e-02, 3.5438e-02],\n", + " [-3.1127e-02, -6.1179e-02, -7.7983e-02]],\n", + " \n", + " [[ 6.6632e-02, 4.3427e-02, -4.0689e-02],\n", + " [-1.5278e-02, -4.2361e-02, 3.6805e-02],\n", + " [ 3.2855e-02, -4.2530e-02, -4.5892e-02]],\n", + " \n", + " [[ 2.3422e-02, 5.8044e-02, -7.6857e-03],\n", + " [-2.8058e-02, -1.5949e-02, -4.0950e-02],\n", + " [ 4.8697e-03, -2.4890e-02, 2.8388e-02]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 1.2142e-02, 7.3265e-02, -9.9022e-03],\n", + " [ 2.5017e-02, 5.8642e-02, -6.8056e-03],\n", + " [-9.9244e-02, -9.1136e-02, -7.3883e-02]],\n", + " \n", + " [[-1.9512e-02, 4.1599e-02, -1.2142e-01],\n", + " [ 1.7323e-02, 1.3393e-02, 3.3512e-02],\n", + " [ 1.1051e-02, -1.4824e-02, 8.3638e-03]],\n", + " \n", + " [[ 1.8719e-02, 6.2494e-02, 8.6195e-03],\n", + " [-1.0191e-02, 4.2297e-03, -4.0402e-02],\n", + " [ 4.0748e-02, 4.5447e-02, -2.4630e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 8.1352e-02, -1.6838e-02, 2.7775e-02],\n", + " [ 1.6260e-02, 2.2381e-02, 4.7176e-02],\n", + " [ 6.5905e-03, -2.0240e-03, 5.0318e-02]],\n", + " \n", + " [[ 3.5941e-02, -1.3546e-02, 2.5559e-02],\n", + " [ 8.7063e-02, 5.7019e-03, 8.7749e-02],\n", + " [ 3.1105e-02, 3.0270e-02, 7.4364e-02]],\n", + " \n", + " [[ 2.4678e-02, -1.4076e-02, 2.7644e-02],\n", + " [-7.6359e-02, 3.7058e-02, -2.3741e-02],\n", + " [-2.0144e-02, 9.5933e-02, 4.5360e-02]]],\n", + " \n", + " \n", + " [[[ 6.5741e-02, 4.9462e-02, -6.5293e-02],\n", + " [-4.4282e-02, -2.1336e-02, -5.6478e-02],\n", + " [ 1.1655e-02, -1.0429e-02, -3.8329e-02]],\n", + " \n", + " [[-4.5852e-02, 4.2555e-02, -3.2961e-02],\n", + " [-3.5738e-02, 4.2727e-02, 3.7960e-02],\n", + " [-5.9692e-02, -3.7927e-02, -6.3316e-02]],\n", + " \n", + " [[-3.1069e-02, -4.6978e-02, -6.3886e-02],\n", + " [-7.9316e-02, 6.1180e-03, -1.3502e-02],\n", + " [ 1.7732e-02, 2.7328e-02, -3.5424e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-5.1502e-02, 1.1618e-02, 4.4762e-02],\n", + " [ 1.3293e-02, -2.8515e-02, 2.3477e-02],\n", + " [ 1.0963e-02, 6.3643e-03, 6.7758e-03]],\n", + " \n", + " [[ 4.9450e-02, -1.1083e-02, 4.1855e-02],\n", + " [ 2.1341e-02, 5.4715e-02, 5.2060e-02],\n", + " [ 1.9619e-02, 8.7005e-03, -1.0872e-02]],\n", + " \n", + " [[ 1.6183e-03, 1.7625e-02, -7.1724e-02],\n", + " [ 5.6331e-02, 1.0753e-01, 1.2052e-02],\n", + " [ 2.0485e-02, 1.2266e-02, 6.9537e-03]]],\n", + " \n", + " \n", + " [[[ 1.4428e-02, 7.6514e-03, 2.5245e-02],\n", + " [-3.3150e-02, 4.3417e-02, 1.9428e-02],\n", + " [ 4.7255e-02, -3.4982e-02, -4.3335e-02]],\n", + " \n", + " [[-1.1168e-02, -4.7031e-02, -3.6381e-02],\n", + " [ 4.1839e-02, 3.1655e-02, -1.7358e-02],\n", + " [-1.9284e-02, -3.4662e-03, -4.0539e-02]],\n", + " \n", + " [[ 5.0484e-02, -7.0366e-02, 7.1903e-02],\n", + " [ 2.5097e-02, 1.6200e-02, 2.6682e-02],\n", + " [-2.6516e-02, -4.4044e-02, -1.1757e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 2.7789e-02, -2.4881e-02, -1.3087e-02],\n", + " [-2.1391e-02, 1.0493e-02, -5.0449e-02],\n", + " [-3.7873e-02, 2.2887e-02, 2.3353e-02]],\n", + " \n", + " [[-2.7992e-02, -9.8221e-03, -2.6284e-02],\n", + " [ 5.5802e-02, -1.9979e-02, -5.3550e-02],\n", + " [-8.4603e-02, -4.5323e-02, -5.7951e-02]],\n", + " \n", + " [[ 1.6301e-02, 6.3142e-02, 9.8293e-02],\n", + " [ 4.3805e-03, 3.1449e-02, 5.8051e-02],\n", + " [-1.9028e-02, -1.3127e-02, -1.3450e-02]]]], device='cuda:0')),\n", + " ('module.features.6.0.conv1.weight',\n", + " tensor([[[[ 0.0240, -0.0227, -0.0309],\n", + " [-0.0247, 0.0033, 0.0432],\n", + " [ 0.0297, -0.0698, 0.0268]],\n", + " \n", + " [[-0.0087, -0.0099, -0.0024],\n", + " [ 0.0354, -0.0023, 0.0237],\n", + " [-0.0012, 0.0157, 0.0071]],\n", + " \n", + " [[ 0.0058, -0.0504, -0.0078],\n", + " [-0.0192, 0.0444, -0.0287],\n", + " [ 0.0058, -0.0301, 0.0103]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0248, 0.0194, -0.0008],\n", + " [-0.0152, 0.0511, 0.0284],\n", + " [-0.0017, -0.0710, -0.0117]],\n", + " \n", + " [[ 0.0231, -0.0283, 0.0355],\n", + " [ 0.0168, -0.0039, -0.0019],\n", + " [ 0.0184, -0.0179, -0.0213]],\n", + " \n", + " [[ 0.0283, -0.0127, -0.0369],\n", + " [ 0.0271, -0.0027, -0.0016],\n", + " [ 0.0173, -0.0237, 0.0033]]],\n", + " \n", + " \n", + " [[[ 0.0057, -0.0025, 0.0064],\n", + " [ 0.0161, -0.0475, 0.0262],\n", + " [ 0.0325, -0.0095, -0.0054]],\n", + " \n", + " [[-0.0346, -0.0113, 0.0190],\n", + " [-0.0364, 0.0418, 0.0298],\n", + " [ 0.0088, 0.0369, 0.0125]],\n", + " \n", + " [[ 0.0016, 0.0035, 0.0206],\n", + " [ 0.0233, 0.0029, -0.0143],\n", + " [-0.0005, -0.0228, 0.0294]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0209, -0.0253, 0.0257],\n", + " [-0.0404, -0.0363, 0.0393],\n", + " [-0.0203, 0.0136, 0.0041]],\n", + " \n", + " [[-0.0301, -0.0184, -0.0200],\n", + " [-0.0319, -0.0195, -0.0332],\n", + " [-0.0231, 0.0407, -0.0170]],\n", + " \n", + " [[-0.0198, 0.0130, -0.0241],\n", + " [-0.0224, 0.0324, -0.0427],\n", + " [ 0.0380, -0.0024, 0.0291]]],\n", + " \n", + " \n", + " [[[-0.0466, -0.0423, -0.0005],\n", + " [-0.0131, -0.0138, -0.0047],\n", + " [-0.0888, 0.0136, 0.0141]],\n", + " \n", + " [[-0.0107, -0.0105, -0.0020],\n", + " [ 0.0112, -0.0028, -0.0252],\n", + " [ 0.0050, 0.0019, 0.0032]],\n", + " \n", + " [[-0.0417, 0.0396, -0.0192],\n", + " [-0.0302, -0.0239, -0.0329],\n", + " [-0.0322, -0.0501, 0.0183]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0651, -0.0232, -0.0549],\n", + " [-0.0417, -0.0100, 0.0518],\n", + " [-0.0001, 0.0329, -0.0203]],\n", + " \n", + " [[ 0.0449, 0.0181, -0.0199],\n", + " [-0.0355, -0.0602, 0.0449],\n", + " [-0.0516, -0.0057, 0.0202]],\n", + " \n", + " [[ 0.0048, -0.0382, -0.0481],\n", + " [ 0.0446, -0.0534, 0.0196],\n", + " [ 0.0420, -0.0488, -0.0025]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 0.0370, 0.0267, -0.0290],\n", + " [ 0.0086, 0.0268, 0.0252],\n", + " [ 0.0014, 0.0049, 0.0030]],\n", + " \n", + " [[-0.0172, 0.0166, -0.0116],\n", + " [ 0.0058, 0.0025, 0.0096],\n", + " [ 0.0088, -0.0045, 0.0109]],\n", + " \n", + " [[-0.0301, -0.0272, -0.0277],\n", + " [-0.0049, 0.0645, 0.0125],\n", + " [-0.0520, -0.0035, -0.0437]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0150, -0.0118, -0.0402],\n", + " [-0.0030, 0.0220, -0.0568],\n", + " [-0.0604, -0.0319, 0.0031]],\n", + " \n", + " [[ 0.0017, -0.0345, 0.0604],\n", + " [ 0.0126, -0.0026, 0.0381],\n", + " [-0.0294, 0.0549, -0.0118]],\n", + " \n", + " [[-0.0526, -0.0483, 0.0265],\n", + " [ 0.0209, -0.0104, 0.0020],\n", + " [ 0.0164, 0.0114, 0.0161]]],\n", + " \n", + " \n", + " [[[-0.0355, 0.0480, -0.0118],\n", + " [ 0.0319, -0.0545, 0.0136],\n", + " [ 0.0028, -0.0574, 0.0331]],\n", + " \n", + " [[ 0.0020, 0.0231, -0.0033],\n", + " [ 0.0219, 0.0202, -0.0166],\n", + " [-0.0336, 0.0322, -0.0502]],\n", + " \n", + " [[ 0.0146, -0.0157, 0.0381],\n", + " [ 0.0105, 0.0060, -0.0070],\n", + " [-0.0114, 0.0013, -0.0418]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0059, 0.0099, 0.0060],\n", + " [-0.0165, 0.0038, -0.0169],\n", + " [ 0.0343, -0.0244, -0.0167]],\n", + " \n", + " [[ 0.0479, 0.0010, 0.0205],\n", + " [-0.0019, 0.0591, -0.0672],\n", + " [-0.0070, -0.0364, 0.0232]],\n", + " \n", + " [[-0.0080, 0.0137, -0.0213],\n", + " [-0.0774, 0.0105, -0.0237],\n", + " [ 0.0721, -0.0167, 0.0276]]],\n", + " \n", + " \n", + " [[[-0.0220, -0.0279, 0.0012],\n", + " [-0.0209, 0.0293, -0.0029],\n", + " [-0.0057, 0.0015, 0.0172]],\n", + " \n", + " [[-0.0106, -0.0302, -0.0075],\n", + " [ 0.0008, -0.0210, 0.0442],\n", + " [-0.0106, 0.0031, -0.0311]],\n", + " \n", + " [[-0.0157, -0.0408, 0.0793],\n", + " [-0.0214, 0.0764, -0.0372],\n", + " [ 0.0025, 0.0271, -0.0315]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0563, 0.0262, -0.0354],\n", + " [ 0.0141, -0.0001, -0.0292],\n", + " [-0.0230, 0.0063, -0.0463]],\n", + " \n", + " [[-0.0258, 0.0125, -0.0095],\n", + " [ 0.0223, -0.0436, 0.0133],\n", + " [ 0.0052, -0.0080, -0.0041]],\n", + " \n", + " [[-0.0072, 0.0499, -0.0308],\n", + " [-0.0131, -0.0604, 0.0236],\n", + " [-0.0735, 0.0252, -0.0268]]]], device='cuda:0')),\n", + " ('module.features.6.0.conv2.weight',\n", + " tensor([[[[ 0.0456, 0.0217, -0.0140],\n", + " [ 0.0218, -0.0242, -0.0408],\n", + " [-0.0064, 0.0213, -0.0625]],\n", + " \n", + " [[-0.0043, -0.0048, 0.0029],\n", + " [-0.0304, -0.0281, 0.0088],\n", + " [ 0.0585, 0.0351, 0.0145]],\n", + " \n", + " [[-0.0666, -0.0023, 0.0278],\n", + " [-0.0346, -0.0102, 0.0055],\n", + " [-0.0033, -0.0292, 0.0276]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0203, 0.0189, 0.0237],\n", + " [ 0.0598, -0.0031, 0.0022],\n", + " [-0.0101, 0.0254, -0.0024]],\n", + " \n", + " [[ 0.0243, -0.0321, 0.0007],\n", + " [ 0.0324, -0.0349, 0.0275],\n", + " [-0.0017, 0.0128, 0.0202]],\n", + " \n", + " [[-0.0222, 0.0059, -0.0872],\n", + " [-0.0068, -0.0591, 0.0200],\n", + " [ 0.0156, 0.0124, 0.0116]]],\n", + " \n", + " \n", + " [[[ 0.0435, 0.0144, -0.0088],\n", + " [-0.0421, -0.0291, 0.0273],\n", + " [ 0.0186, -0.0065, -0.0051]],\n", + " \n", + " [[-0.0224, -0.0085, -0.0016],\n", + " [-0.0155, -0.0116, 0.0089],\n", + " [ 0.0052, -0.0223, 0.0146]],\n", + " \n", + " [[ 0.0664, 0.0152, 0.0241],\n", + " [ 0.0502, -0.0051, 0.0327],\n", + " [ 0.0381, -0.0349, -0.0250]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0974, 0.0089, -0.0157],\n", + " [ 0.0427, 0.0091, -0.0036],\n", + " [-0.0220, -0.0030, -0.0207]],\n", + " \n", + " [[ 0.0463, -0.0679, 0.0149],\n", + " [-0.0382, -0.0128, -0.0297],\n", + " [ 0.0492, 0.0189, -0.0443]],\n", + " \n", + " [[ 0.0432, -0.0122, -0.0390],\n", + " [-0.0299, 0.0153, 0.0116],\n", + " [ 0.0074, 0.0139, 0.0156]]],\n", + " \n", + " \n", + " [[[-0.0378, -0.0024, 0.0227],\n", + " [-0.0338, 0.0147, 0.0021],\n", + " [ 0.0113, 0.0399, -0.0064]],\n", + " \n", + " [[-0.0599, -0.0307, -0.0259],\n", + " [ 0.0257, 0.0076, 0.0498],\n", + " [-0.0048, -0.0039, -0.0475]],\n", + " \n", + " [[ 0.0055, -0.0252, -0.0048],\n", + " [ 0.0249, -0.0032, 0.0166],\n", + " [-0.0380, 0.0109, 0.0167]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0421, -0.0173, -0.0114],\n", + " [ 0.0343, -0.0060, 0.0394],\n", + " [-0.0232, -0.0279, 0.0052]],\n", + " \n", + " [[ 0.0079, -0.0103, -0.0056],\n", + " [ 0.0265, 0.0216, -0.0492],\n", + " [ 0.0082, 0.0359, -0.0071]],\n", + " \n", + " [[-0.0195, 0.0216, -0.0235],\n", + " [ 0.0362, 0.0314, 0.0027],\n", + " [ 0.0388, 0.0462, -0.0083]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 0.0158, 0.0483, 0.0320],\n", + " [ 0.0072, 0.0244, 0.0348],\n", + " [ 0.0027, 0.0525, 0.0169]],\n", + " \n", + " [[-0.0024, 0.0197, 0.0254],\n", + " [-0.0142, 0.0002, -0.0231],\n", + " [-0.0176, 0.0244, 0.0119]],\n", + " \n", + " [[-0.0151, 0.0136, 0.0562],\n", + " [ 0.0227, -0.0010, -0.0176],\n", + " [ 0.0216, 0.0213, -0.0401]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0267, -0.0435, 0.0050],\n", + " [-0.0382, -0.0170, 0.0197],\n", + " [-0.0110, -0.0479, 0.0207]],\n", + " \n", + " [[ 0.0468, 0.0084, -0.0388],\n", + " [-0.0114, 0.0255, -0.0155],\n", + " [-0.0160, -0.0051, -0.0084]],\n", + " \n", + " [[-0.0102, 0.0076, 0.0167],\n", + " [-0.0097, 0.0808, 0.0072],\n", + " [-0.0422, -0.0090, 0.0205]]],\n", + " \n", + " \n", + " [[[ 0.0127, -0.0120, -0.0155],\n", + " [-0.0229, -0.0039, -0.0077],\n", + " [ 0.0269, 0.0339, 0.0376]],\n", + " \n", + " [[ 0.0109, -0.0058, -0.0114],\n", + " [ 0.0051, 0.0078, 0.0334],\n", + " [ 0.0142, 0.0040, -0.0676]],\n", + " \n", + " [[ 0.0029, -0.0156, 0.0024],\n", + " [-0.0088, 0.0022, 0.0056],\n", + " [ 0.0235, -0.0165, -0.0713]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0219, -0.0182, 0.0225],\n", + " [ 0.0015, -0.0273, -0.0245],\n", + " [ 0.0080, -0.0202, -0.0027]],\n", + " \n", + " [[ 0.0078, 0.0203, -0.0222],\n", + " [ 0.0223, 0.0386, 0.0078],\n", + " [ 0.0222, 0.0257, -0.0107]],\n", + " \n", + " [[ 0.0005, -0.0625, 0.0093],\n", + " [ 0.0340, -0.0411, -0.0146],\n", + " [ 0.0081, 0.0240, 0.0127]]],\n", + " \n", + " \n", + " [[[-0.0027, -0.0428, -0.0054],\n", + " [-0.0341, -0.0059, -0.0241],\n", + " [-0.0252, -0.0168, 0.0242]],\n", + " \n", + " [[-0.0126, -0.0144, -0.0143],\n", + " [ 0.0255, 0.0032, -0.0261],\n", + " [-0.0114, 0.0082, -0.0139]],\n", + " \n", + " [[-0.0032, -0.0282, 0.0255],\n", + " [-0.0109, -0.0130, 0.0422],\n", + " [ 0.0156, 0.0132, -0.0362]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0280, -0.0159, 0.0197],\n", + " [-0.0053, 0.0227, 0.0105],\n", + " [ 0.0252, -0.0133, 0.0017]],\n", + " \n", + " [[ 0.0088, -0.0180, -0.0219],\n", + " [-0.0258, -0.0302, 0.0063],\n", + " [-0.0330, 0.0104, 0.0190]],\n", + " \n", + " [[-0.0133, 0.0672, -0.0083],\n", + " [-0.0084, -0.0127, 0.0435],\n", + " [-0.0250, 0.0217, -0.0196]]]], device='cuda:0')),\n", + " ('module.features.6.0.downsample.0.weight',\n", + " tensor([[[[ 0.0167]],\n", + " \n", + " [[ 0.0906]],\n", + " \n", + " [[-0.0019]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0699]],\n", + " \n", + " [[-0.0978]],\n", + " \n", + " [[ 0.0776]]],\n", + " \n", + " \n", + " [[[ 0.0425]],\n", + " \n", + " [[ 0.0029]],\n", + " \n", + " [[-0.0786]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0315]],\n", + " \n", + " [[-0.0649]],\n", + " \n", + " [[ 0.0829]]],\n", + " \n", + " \n", + " [[[ 0.1541]],\n", + " \n", + " [[-0.1738]],\n", + " \n", + " [[-0.0216]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0183]],\n", + " \n", + " [[-0.0233]],\n", + " \n", + " [[-0.0739]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 0.0135]],\n", + " \n", + " [[-0.1001]],\n", + " \n", + " [[-0.0375]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0250]],\n", + " \n", + " [[ 0.1578]],\n", + " \n", + " [[-0.0817]]],\n", + " \n", + " \n", + " [[[-0.1361]],\n", + " \n", + " [[-0.1274]],\n", + " \n", + " [[ 0.1172]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0530]],\n", + " \n", + " [[-0.0503]],\n", + " \n", + " [[ 0.1324]]],\n", + " \n", + " \n", + " [[[ 0.0932]],\n", + " \n", + " [[ 0.2297]],\n", + " \n", + " [[-0.1584]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0520]],\n", + " \n", + " [[-0.0518]],\n", + " \n", + " [[ 0.1083]]]], device='cuda:0')),\n", + " ('module.features.6.1.conv1.weight',\n", + " tensor([[[[-1.0663e-02, -6.4592e-02, -1.7643e-02],\n", + " [-1.9915e-02, -2.3672e-02, -2.6418e-02],\n", + " [-3.7928e-03, 1.4260e-02, 8.1776e-04]],\n", + " \n", + " [[-1.3990e-02, 4.2493e-02, -4.5754e-02],\n", + " [ 6.5685e-02, -9.7300e-03, -4.2949e-03],\n", + " [ 5.9361e-03, 2.8047e-02, -7.9141e-03]],\n", + " \n", + " [[ 7.7116e-03, 5.7022e-02, -3.6133e-02],\n", + " [-1.7684e-02, -3.1758e-02, -3.3286e-02],\n", + " [-4.7406e-02, 2.8379e-02, -2.8364e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-5.6726e-04, 1.2521e-02, -2.0426e-02],\n", + " [ 4.0585e-02, -4.6295e-02, 1.5691e-02],\n", + " [-2.1484e-02, -3.0863e-03, -4.5604e-02]],\n", + " \n", + " [[ 3.8378e-03, -3.6587e-03, 1.3426e-02],\n", + " [ 1.0853e-03, 8.8641e-03, -4.1009e-02],\n", + " [ 3.7223e-03, 9.7395e-03, -5.4633e-03]],\n", + " \n", + " [[-1.3050e-02, -3.5338e-02, 2.1541e-03],\n", + " [ 1.8484e-02, -1.7885e-02, 3.4567e-02],\n", + " [ 1.0993e-02, -2.7531e-02, 4.5742e-03]]],\n", + " \n", + " \n", + " [[[ 7.5872e-02, -3.7636e-02, 1.7796e-02],\n", + " [ 9.2064e-03, 4.1637e-03, 1.0417e-02],\n", + " [-3.1397e-02, -3.2680e-03, -4.1829e-02]],\n", + " \n", + " [[ 6.7461e-02, -9.0697e-03, -1.5819e-02],\n", + " [ 2.0520e-02, 1.0564e-02, 1.5427e-02],\n", + " [-3.1541e-02, -1.0761e-01, 1.0140e-02]],\n", + " \n", + " [[ 2.0647e-02, 3.4147e-03, 5.1170e-03],\n", + " [ 6.2912e-02, -2.5438e-02, -7.0194e-03],\n", + " [-5.7876e-03, -2.5348e-02, 6.9223e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[-1.8312e-02, -1.5583e-02, 1.9101e-02],\n", + " [ 1.1937e-02, -1.1945e-02, 3.0733e-02],\n", + " [ 8.7930e-02, -8.5641e-03, -3.9050e-02]],\n", + " \n", + " [[-4.3650e-03, 3.8903e-02, -7.1692e-02],\n", + " [ 2.3118e-02, 7.5008e-03, -2.0700e-02],\n", + " [-5.0258e-02, 1.1372e-02, -1.7813e-02]],\n", + " \n", + " [[-8.8078e-04, 3.8180e-02, -2.6889e-02],\n", + " [ 1.2815e-02, 2.0097e-02, -3.1433e-02],\n", + " [-1.8129e-02, -1.6484e-02, -1.1331e-02]]],\n", + " \n", + " \n", + " [[[ 1.1774e-02, -4.3228e-02, -2.4902e-02],\n", + " [ 3.0643e-02, -7.8646e-03, 7.8888e-03],\n", + " [ 2.7589e-02, 3.7252e-03, -1.6763e-02]],\n", + " \n", + " [[ 3.5904e-02, 9.2736e-03, 1.5518e-03],\n", + " [-2.8347e-02, 7.6164e-03, -2.6277e-02],\n", + " [-2.5395e-03, 5.1337e-02, 3.6514e-02]],\n", + " \n", + " [[-8.9175e-03, 2.8995e-02, 3.6913e-03],\n", + " [ 4.7412e-02, -2.0663e-02, -7.5921e-02],\n", + " [ 4.4291e-02, 2.7490e-02, 3.7849e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-4.9216e-02, 2.1112e-02, 2.7770e-02],\n", + " [-5.3211e-03, -1.9268e-02, 5.1660e-03],\n", + " [ 5.1669e-03, 9.7053e-03, -3.8086e-02]],\n", + " \n", + " [[-2.4785e-02, 1.6300e-02, -1.5496e-02],\n", + " [ 1.5307e-02, -2.9707e-02, 2.7436e-02],\n", + " [-1.9306e-02, 7.0717e-02, -1.0173e-02]],\n", + " \n", + " [[ 3.0778e-02, 4.9702e-04, 3.1833e-02],\n", + " [-2.5441e-02, -1.4679e-02, 3.4316e-02],\n", + " [-2.3471e-03, -3.3221e-03, -2.5022e-02]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 2.0778e-04, -2.3316e-03, -2.8760e-02],\n", + " [-4.4530e-02, -2.4329e-02, -3.9114e-02],\n", + " [-1.5464e-02, 1.0834e-02, -8.8251e-03]],\n", + " \n", + " [[ 1.0525e-04, 6.5905e-02, 3.1255e-02],\n", + " [ 4.8631e-02, 1.9551e-02, -1.5353e-02],\n", + " [-1.9567e-03, 2.6212e-04, -3.2542e-02]],\n", + " \n", + " [[-2.9586e-02, 3.8948e-02, 1.4704e-02],\n", + " [ 7.8863e-03, -2.3816e-02, 9.8579e-03],\n", + " [-2.5092e-02, 6.2358e-03, -6.0422e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[-1.6740e-02, 2.9511e-02, -1.9156e-02],\n", + " [ 1.6318e-02, -3.4487e-04, -4.0301e-03],\n", + " [ 1.0853e-02, 2.3323e-02, -2.9527e-03]],\n", + " \n", + " [[ 1.6024e-02, -1.4094e-02, -1.2549e-02],\n", + " [ 5.7991e-03, 3.7044e-03, -3.6310e-02],\n", + " [-5.9241e-02, -2.2843e-02, 3.7137e-02]],\n", + " \n", + " [[-1.7307e-02, -3.0526e-02, -6.5269e-03],\n", + " [ 3.6899e-02, 3.5466e-02, -3.6425e-02],\n", + " [ 1.6369e-02, 2.1150e-03, 2.2940e-02]]],\n", + " \n", + " \n", + " [[[-4.8313e-02, -1.7220e-02, 2.5480e-03],\n", + " [-7.8973e-04, 1.3135e-02, 8.3401e-03],\n", + " [-3.0900e-03, -2.5085e-02, 2.9801e-02]],\n", + " \n", + " [[ 4.8227e-02, 1.4280e-02, 1.7196e-02],\n", + " [-4.0403e-03, -5.7785e-03, -3.2870e-02],\n", + " [-1.3931e-02, -7.0277e-02, -3.6799e-02]],\n", + " \n", + " [[ 1.6242e-02, -1.8689e-02, 1.0376e-03],\n", + " [ 3.9966e-02, 2.1135e-02, 2.2993e-02],\n", + " [-9.3555e-03, 5.1732e-02, -3.1082e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-1.3077e-02, 2.8440e-02, 4.4365e-02],\n", + " [-1.3948e-02, -6.9563e-04, 2.8177e-02],\n", + " [ 3.3243e-02, -3.3940e-02, -5.2344e-02]],\n", + " \n", + " [[-3.1947e-02, -9.0099e-03, 4.4178e-02],\n", + " [ 8.0787e-03, 7.7619e-02, 5.6659e-04],\n", + " [-1.0522e-02, -6.2181e-03, -4.4128e-02]],\n", + " \n", + " [[ 4.1340e-03, -3.2932e-02, -6.5304e-02],\n", + " [-2.3988e-02, -2.4525e-03, 9.7465e-03],\n", + " [-3.9158e-03, -4.7578e-02, -1.2476e-02]]],\n", + " \n", + " \n", + " [[[ 3.4115e-02, 2.7232e-02, -4.2150e-02],\n", + " [-1.9656e-02, -2.1527e-02, -3.7997e-02],\n", + " [ 4.6947e-02, -4.2265e-02, 1.0091e-02]],\n", + " \n", + " [[ 2.7583e-03, -7.0453e-04, 2.0036e-02],\n", + " [ 2.1210e-02, 1.0332e-02, 3.0095e-02],\n", + " [-8.7497e-03, 2.3021e-02, 3.6112e-02]],\n", + " \n", + " [[ 2.3547e-03, -2.3150e-02, -3.4876e-02],\n", + " [ 8.3291e-03, 8.2892e-03, -2.3248e-02],\n", + " [-1.4488e-02, -2.4721e-02, 5.8742e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[-8.3425e-02, -2.4922e-02, -6.0979e-03],\n", + " [ 2.5739e-02, 5.1047e-02, -1.5323e-02],\n", + " [ 9.1807e-03, 2.0354e-03, -3.6563e-04]],\n", + " \n", + " [[-1.8426e-02, 3.0250e-02, -3.0054e-03],\n", + " [ 2.6835e-02, -1.9758e-02, 2.7468e-02],\n", + " [ 3.3452e-02, -1.6044e-02, -1.7455e-02]],\n", + " \n", + " [[-5.7080e-03, 6.3050e-02, 3.1424e-02],\n", + " [ 9.9518e-03, 1.4794e-02, -1.4960e-02],\n", + " [ 1.0899e-02, 4.6267e-03, 1.7051e-02]]]], device='cuda:0')),\n", + " ('module.features.6.1.conv2.weight',\n", + " tensor([[[[ 3.4787e-02, 3.8060e-02, -1.5990e-02],\n", + " [ 1.7950e-02, -8.4159e-03, -6.6534e-02],\n", + " [-1.4080e-03, -3.2043e-03, 2.0568e-02]],\n", + " \n", + " [[-3.1033e-02, -2.5325e-02, 1.4106e-02],\n", + " [-2.0858e-02, -2.9413e-02, -1.6169e-02],\n", + " [ 4.9200e-03, 2.5553e-02, 1.8689e-02]],\n", + " \n", + " [[-7.8053e-03, -5.6583e-02, 1.0117e-02],\n", + " [ 2.1717e-02, -3.9589e-02, 8.9629e-03],\n", + " [ 7.4012e-03, 5.4506e-02, -2.3082e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-6.1332e-02, 2.4637e-02, 8.5247e-03],\n", + " [-7.8043e-03, -5.0832e-02, -3.5482e-02],\n", + " [-1.0316e-02, -2.4707e-02, -5.4953e-02]],\n", + " \n", + " [[-3.7195e-02, -2.7308e-02, -9.6095e-04],\n", + " [-1.1219e-02, 5.7739e-02, -2.0120e-02],\n", + " [ 4.1952e-02, 3.4357e-02, -2.2346e-02]],\n", + " \n", + " [[ 6.2356e-02, 5.6113e-02, -1.4165e-02],\n", + " [ 3.5778e-03, -1.5435e-02, 2.8328e-02],\n", + " [ 1.1567e-02, 1.4654e-02, 8.2448e-03]]],\n", + " \n", + " \n", + " [[[-1.2194e-02, -5.0170e-03, 3.5250e-02],\n", + " [ 1.9931e-03, 6.0486e-03, 9.3887e-03],\n", + " [ 2.1655e-04, 1.3419e-02, 1.5752e-02]],\n", + " \n", + " [[ 5.1833e-02, -4.1004e-02, 1.4427e-02],\n", + " [ 5.7461e-03, 3.6509e-02, 4.9289e-02],\n", + " [ 6.3431e-02, 2.2885e-02, 2.8438e-02]],\n", + " \n", + " [[ 1.4997e-02, 5.6978e-02, -2.3945e-02],\n", + " [ 4.4951e-02, 3.1703e-02, -1.1541e-02],\n", + " [ 3.0593e-02, 1.4636e-02, 3.1538e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-2.0998e-02, -2.3779e-02, -2.7013e-02],\n", + " [ 4.0677e-03, 1.0402e-02, 2.4015e-03],\n", + " [ 8.1995e-02, 4.0297e-02, -5.9132e-02]],\n", + " \n", + " [[-1.5711e-02, 2.0702e-02, 1.6739e-02],\n", + " [-1.7031e-02, -8.0594e-03, 8.3916e-03],\n", + " [-3.0904e-02, -1.6253e-02, -3.8727e-02]],\n", + " \n", + " [[-4.1036e-03, 2.9179e-02, -1.2883e-02],\n", + " [-6.3859e-03, 2.3857e-03, 5.9913e-03],\n", + " [ 3.0187e-02, -3.4250e-02, 4.2737e-02]]],\n", + " \n", + " \n", + " [[[-1.1786e-02, 2.8669e-02, 3.1567e-02],\n", + " [ 4.9293e-02, -4.7991e-03, -6.0374e-02],\n", + " [ 1.3316e-02, 6.5624e-02, -3.3356e-02]],\n", + " \n", + " [[-9.2111e-03, 1.4936e-02, 1.5386e-02],\n", + " [-3.3757e-02, 1.1189e-02, -3.3450e-02],\n", + " [ 3.7834e-02, -1.4629e-02, 1.1069e-02]],\n", + " \n", + " [[-5.7206e-03, -1.9249e-02, -7.7542e-03],\n", + " [ 2.9950e-02, 2.9959e-02, 2.7115e-02],\n", + " [-3.3033e-02, -4.4068e-02, -1.1164e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 2.6510e-03, -1.7953e-02, -1.1683e-02],\n", + " [ 3.0550e-03, 2.4446e-02, 1.6834e-02],\n", + " [ 4.1480e-02, -4.5584e-02, -5.1714e-02]],\n", + " \n", + " [[-2.7675e-02, 5.1731e-02, 2.0568e-02],\n", + " [ 4.1835e-02, 1.6490e-03, -3.9423e-02],\n", + " [ 1.4456e-02, 4.6297e-02, -2.3125e-02]],\n", + " \n", + " [[-5.5972e-02, 7.2672e-03, -4.7699e-02],\n", + " [ 1.6370e-02, 5.0866e-02, 3.3071e-02],\n", + " [-3.0254e-02, -4.6505e-03, -1.3263e-02]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[-5.5671e-02, -1.4973e-02, 2.4598e-02],\n", + " [-1.5751e-02, -1.9332e-02, -1.1173e-02],\n", + " [-1.1446e-02, 3.4020e-02, -5.6328e-03]],\n", + " \n", + " [[ 1.4628e-02, 3.8786e-02, -1.1804e-02],\n", + " [ 1.1848e-02, 1.8123e-02, -1.0171e-02],\n", + " [ 6.0197e-02, 4.3300e-02, 5.8398e-02]],\n", + " \n", + " [[-6.3173e-02, -2.6803e-02, 1.3401e-03],\n", + " [ 3.0209e-02, 3.8472e-02, 3.5204e-02],\n", + " [-1.4885e-02, 3.1834e-02, -7.7356e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[ 1.1209e-02, -5.0194e-03, 4.2764e-02],\n", + " [-2.4707e-02, 3.6524e-04, -5.5653e-03],\n", + " [ 1.7874e-02, 1.0252e-02, 6.7133e-02]],\n", + " \n", + " [[ 1.4919e-02, -2.0242e-03, 1.3058e-02],\n", + " [-3.0284e-03, 4.6720e-02, 5.9795e-02],\n", + " [ 2.8785e-02, -1.5592e-02, 1.6045e-02]],\n", + " \n", + " [[-2.5472e-02, -8.5856e-02, -3.7504e-02],\n", + " [-3.0099e-02, -2.3069e-02, 1.2823e-02],\n", + " [-4.1428e-02, 1.5843e-02, 1.4451e-02]]],\n", + " \n", + " \n", + " [[[ 7.7966e-03, 1.6178e-02, 2.6000e-02],\n", + " [-6.4233e-02, -1.7636e-02, 1.2902e-03],\n", + " [-4.1026e-04, 2.8500e-02, -3.6673e-02]],\n", + " \n", + " [[-3.1332e-02, -1.7827e-02, 4.2891e-02],\n", + " [-2.4205e-03, -2.4863e-02, 2.2896e-03],\n", + " [ 1.0987e-02, -3.0397e-02, -4.4000e-02]],\n", + " \n", + " [[ 5.1481e-03, 7.3660e-04, 3.3303e-03],\n", + " [ 3.6612e-04, -5.4450e-02, -8.3123e-03],\n", + " [ 2.0301e-03, -2.1761e-02, 3.3226e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 1.2331e-03, 8.0546e-03, -3.6006e-02],\n", + " [ 3.8699e-02, -2.6648e-02, -1.8826e-02],\n", + " [ 2.7367e-02, 1.4244e-02, 5.2926e-04]],\n", + " \n", + " [[-6.9364e-02, -1.0560e-02, 9.4717e-03],\n", + " [ 5.0586e-02, 1.5329e-03, 3.8197e-02],\n", + " [ 2.4806e-02, 7.5918e-02, -1.8269e-02]],\n", + " \n", + " [[ 3.2278e-02, -1.0672e-03, 6.8488e-03],\n", + " [-4.4616e-02, -3.5674e-02, 9.5346e-04],\n", + " [-1.3379e-02, 6.8442e-03, 9.0560e-03]]],\n", + " \n", + " \n", + " [[[ 2.0115e-02, -5.3358e-03, -3.5381e-02],\n", + " [ 7.1608e-04, -6.9627e-03, -1.9737e-02],\n", + " [-8.2062e-03, -3.7454e-02, -7.4117e-02]],\n", + " \n", + " [[ 1.4927e-02, 7.1709e-02, 1.1718e-02],\n", + " [ 8.2372e-02, 5.8646e-03, 3.4174e-03],\n", + " [ 1.0936e-03, 3.0345e-02, -1.7796e-02]],\n", + " \n", + " [[-5.0297e-02, 2.2410e-02, 5.7437e-03],\n", + " [ 3.7350e-02, 1.3494e-02, -2.9290e-04],\n", + " [-6.8438e-03, 3.8460e-05, -9.2413e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[ 9.5568e-03, 2.5887e-03, 1.0262e-02],\n", + " [-3.2448e-03, -1.7702e-04, 1.8214e-02],\n", + " [-8.0327e-03, 6.6512e-04, -1.5375e-02]],\n", + " \n", + " [[ 3.9076e-02, 1.3856e-03, 1.0307e-02],\n", + " [ 2.3322e-02, 4.0026e-03, 3.5763e-02],\n", + " [ 1.3618e-02, 3.7627e-02, 1.0824e-02]],\n", + " \n", + " [[-1.6677e-03, -2.2723e-02, -5.5824e-02],\n", + " [-1.3569e-02, -1.2928e-02, -6.1438e-03],\n", + " [ 6.2071e-02, 1.5035e-03, -7.6897e-02]]]], device='cuda:0')),\n", + " ('module.features.7.0.conv1.weight',\n", + " tensor([[[[-3.4441e-03, 1.0994e-02, 1.4760e-03],\n", + " [ 1.9767e-02, -1.8906e-02, -2.6181e-02],\n", + " [ 1.4694e-02, 1.9673e-02, -1.7175e-03]],\n", + " \n", + " [[-3.7813e-03, -3.9056e-02, 2.6098e-02],\n", + " [-3.7201e-03, -4.2006e-03, -9.9549e-03],\n", + " [ 4.3169e-02, -1.2532e-02, 5.1302e-03]],\n", + " \n", + " [[-4.2043e-02, 1.4136e-02, -9.0424e-03],\n", + " [-2.1283e-03, -8.4711e-03, -2.1926e-03],\n", + " [-6.4922e-03, 7.2286e-03, 8.8581e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[ 1.8439e-02, -1.1141e-03, 2.0971e-02],\n", + " [-4.0137e-02, 1.4446e-02, -1.3450e-03],\n", + " [ 1.5027e-02, -7.5354e-03, -3.5068e-04]],\n", + " \n", + " [[ 4.2114e-03, 1.1973e-02, -4.3907e-03],\n", + " [ 7.0268e-03, 1.9686e-02, -2.6050e-02],\n", + " [ 1.4946e-03, -7.8286e-03, 9.1873e-03]],\n", + " \n", + " [[ 3.9229e-02, -4.6599e-03, -2.6167e-02],\n", + " [-1.6986e-02, -4.5638e-02, 3.0573e-03],\n", + " [-1.8197e-02, -2.7810e-02, -8.9829e-03]]],\n", + " \n", + " \n", + " [[[-2.4756e-02, 1.4825e-02, 1.5259e-02],\n", + " [-1.4752e-02, 1.0032e-02, -4.7665e-02],\n", + " [ 1.0060e-02, 8.5348e-03, 3.7074e-02]],\n", + " \n", + " [[ 1.5102e-03, -2.1007e-02, 7.4820e-03],\n", + " [-1.0579e-03, -9.3815e-03, -6.6800e-04],\n", + " [-2.1778e-02, 2.0677e-02, -1.1825e-02]],\n", + " \n", + " [[-1.0584e-02, 3.6972e-03, -4.1202e-03],\n", + " [ 1.5519e-02, -4.3264e-03, 6.3549e-03],\n", + " [ 2.9458e-03, 1.1394e-02, -9.6818e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[ 4.2069e-03, -8.4145e-03, -1.2281e-02],\n", + " [-1.1881e-02, 2.5911e-02, 3.5466e-02],\n", + " [ 2.7593e-02, -4.0577e-02, 1.0283e-02]],\n", + " \n", + " [[-3.7335e-02, 3.5848e-02, -2.9818e-02],\n", + " [ 1.6793e-03, 6.0743e-04, -1.1339e-02],\n", + " [-5.2570e-02, -7.8037e-03, -3.1148e-02]],\n", + " \n", + " [[-4.5868e-02, 1.4489e-02, 5.0106e-03],\n", + " [ 1.0891e-02, -1.3956e-02, -7.5098e-03],\n", + " [-1.5168e-02, -6.3514e-04, -7.5874e-03]]],\n", + " \n", + " \n", + " [[[ 1.2017e-02, -3.1987e-03, -1.2760e-02],\n", + " [ 9.7807e-03, -5.9038e-03, -3.6333e-02],\n", + " [ 1.5559e-02, -1.4835e-02, 1.2544e-02]],\n", + " \n", + " [[-1.5986e-03, -3.9330e-03, -1.9158e-02],\n", + " [ 1.5867e-02, -2.2501e-03, 1.3388e-02],\n", + " [ 1.9915e-02, 1.5831e-02, 8.9993e-03]],\n", + " \n", + " [[-2.7614e-02, 3.2708e-02, -2.1841e-02],\n", + " [-9.4685e-03, -2.4966e-02, 1.2511e-02],\n", + " [-1.9887e-03, -4.3608e-02, 3.9580e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[ 9.3413e-03, -5.7094e-03, 4.3542e-03],\n", + " [-6.8274e-03, 7.4210e-03, 1.1561e-02],\n", + " [ 6.4388e-03, 1.9781e-02, 4.2751e-03]],\n", + " \n", + " [[ 5.0592e-03, 2.7730e-03, -2.4055e-02],\n", + " [ 2.4276e-03, 8.3152e-04, 3.7686e-02],\n", + " [ 2.7420e-02, -9.2692e-03, 2.3494e-02]],\n", + " \n", + " [[ 5.9059e-03, -4.4244e-02, -1.3735e-03],\n", + " [-2.5886e-02, -3.9441e-02, 6.0072e-03],\n", + " [-1.1696e-02, -4.1958e-03, 6.9575e-03]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 3.8198e-02, -1.2132e-02, -1.1683e-02],\n", + " [-2.4794e-03, 2.5955e-02, 1.8615e-02],\n", + " [-1.7031e-03, 4.1369e-02, -7.3895e-03]],\n", + " \n", + " [[ 1.4308e-03, 4.2879e-03, -1.3985e-02],\n", + " [ 1.5767e-02, 4.8289e-03, -3.0731e-02],\n", + " [ 1.2513e-02, 5.6250e-02, 4.5197e-04]],\n", + " \n", + " [[ 2.4120e-02, -5.4435e-03, 4.5873e-03],\n", + " [ 2.0246e-03, 2.8319e-02, 9.0150e-03],\n", + " [-1.1607e-02, 1.8807e-02, 2.4154e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-3.2622e-03, 1.1339e-02, 4.2190e-02],\n", + " [ 2.8931e-02, 1.8660e-02, -3.6523e-02],\n", + " [-9.0465e-03, 3.1880e-02, 3.1114e-02]],\n", + " \n", + " [[-4.3627e-03, 2.1465e-02, 6.3580e-03],\n", + " [ 3.1705e-03, 1.1819e-02, 3.9138e-02],\n", + " [ 2.9341e-03, 1.3085e-02, 2.7232e-02]],\n", + " \n", + " [[-1.6039e-02, -2.7102e-02, -3.2196e-02],\n", + " [-1.0371e-02, 3.2571e-03, 4.9135e-03],\n", + " [-4.6609e-04, -2.5075e-03, 2.4381e-02]]],\n", + " \n", + " \n", + " [[[ 1.0021e-02, 5.6412e-03, -2.6135e-02],\n", + " [-2.0356e-02, -1.7683e-04, 3.3079e-03],\n", + " [-1.4637e-02, 6.8626e-02, -4.9217e-02]],\n", + " \n", + " [[ 1.9138e-03, -2.6581e-02, -1.5232e-03],\n", + " [-1.0672e-02, 5.2147e-03, 1.7318e-02],\n", + " [ 2.1448e-03, -1.5667e-02, 6.1177e-03]],\n", + " \n", + " [[-2.5615e-02, 2.5841e-02, -2.5348e-03],\n", + " [-3.3380e-03, 1.4092e-02, -1.3205e-02],\n", + " [-6.7538e-03, 3.6098e-03, -1.2386e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-1.8762e-03, -1.0771e-02, -2.6408e-02],\n", + " [ 2.6240e-02, 2.8887e-02, 2.0198e-02],\n", + " [ 6.9128e-05, -1.2740e-02, -1.1935e-03]],\n", + " \n", + " [[-1.0645e-02, -2.0356e-02, 3.9404e-02],\n", + " [ 6.0474e-03, 1.1074e-02, 2.7843e-02],\n", + " [ 7.6534e-03, -2.0945e-02, 3.9055e-02]],\n", + " \n", + " [[-2.9758e-02, -9.6599e-03, 1.1365e-02],\n", + " [ 6.4649e-03, 1.7953e-02, -2.0368e-02],\n", + " [-9.9998e-03, -4.2007e-03, 8.8341e-03]]],\n", + " \n", + " \n", + " [[[ 3.6774e-02, 6.3172e-03, 2.2152e-02],\n", + " [-2.1888e-03, 1.4285e-02, 1.0440e-02],\n", + " [ 1.6240e-02, -2.2892e-02, -1.0568e-02]],\n", + " \n", + " [[-1.1015e-02, 2.5463e-02, 9.1917e-04],\n", + " [ 3.3877e-02, 2.0663e-03, 1.8499e-02],\n", + " [ 4.4987e-02, -5.6371e-02, -1.6414e-02]],\n", + " \n", + " [[ 3.9439e-04, -3.3985e-02, -1.5881e-02],\n", + " [ 5.9595e-03, 3.1137e-02, -1.7259e-03],\n", + " [-2.1612e-02, -1.1537e-02, 6.6468e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[ 2.9433e-02, -1.1377e-02, 8.4307e-03],\n", + " [ 4.2296e-03, 1.1831e-02, 1.6850e-02],\n", + " [-1.2929e-02, -7.1755e-03, 1.8415e-02]],\n", + " \n", + " [[-1.6618e-02, 7.8405e-03, 4.0864e-02],\n", + " [ 4.8490e-02, 4.2489e-02, 6.0160e-03],\n", + " [ 3.7563e-02, 1.0776e-02, -6.1331e-03]],\n", + " \n", + " [[ 2.6469e-02, -6.6585e-04, -6.7010e-04],\n", + " [ 3.8410e-03, -1.4263e-02, 6.1104e-03],\n", + " [-1.0586e-02, 2.6340e-02, -3.0821e-02]]]], device='cuda:0')),\n", + " ('module.features.7.0.conv2.weight',\n", + " tensor([[[[-1.2200e-02, 2.4671e-02, -9.8238e-03],\n", + " [-1.2586e-02, -1.3236e-02, 2.5329e-03],\n", + " [-1.6542e-03, 1.4174e-03, 8.2470e-03]],\n", + " \n", + " [[-4.4317e-02, 1.0416e-02, 3.3988e-02],\n", + " [-3.3332e-02, 8.5336e-03, -4.7323e-02],\n", + " [ 5.7513e-03, -1.8317e-02, -2.3702e-02]],\n", + " \n", + " [[-2.7567e-02, 3.0532e-02, 3.4585e-02],\n", + " [-1.6679e-02, -2.7876e-02, -5.6227e-03],\n", + " [-6.7172e-03, -9.3788e-03, 3.3215e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[-2.6134e-02, -3.7077e-03, -2.7437e-02],\n", + " [ 8.5012e-06, 2.5855e-02, 7.1096e-03],\n", + " [-5.5645e-03, 5.7466e-03, -2.4518e-02]],\n", + " \n", + " [[-4.4125e-02, 2.1510e-03, -3.1742e-04],\n", + " [ 1.0338e-02, -1.1641e-02, 5.6042e-02],\n", + " [-8.5308e-03, -2.9704e-02, 2.9754e-02]],\n", + " \n", + " [[ 1.4525e-02, 7.8045e-03, 1.7892e-02],\n", + " [ 8.9423e-03, -2.1311e-02, 1.8905e-02],\n", + " [ 1.4926e-03, 5.7177e-02, 4.0705e-02]]],\n", + " \n", + " \n", + " [[[ 1.4293e-02, 2.2090e-02, 3.3967e-03],\n", + " [-3.1053e-02, -4.6200e-03, 1.2056e-02],\n", + " [-7.9522e-03, -4.3661e-02, -4.2358e-03]],\n", + " \n", + " [[ 4.4703e-02, -2.4620e-03, 3.6090e-03],\n", + " [-5.4236e-04, -3.4849e-02, 6.2990e-03],\n", + " [ 1.1801e-02, -6.0881e-03, -6.4435e-03]],\n", + " \n", + " [[-1.1111e-02, -4.8533e-03, -4.5175e-02],\n", + " [-1.0993e-02, 2.3395e-02, -2.0765e-02],\n", + " [-1.9418e-02, -1.3892e-03, -1.1269e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 7.2273e-03, -1.7110e-02, 5.6988e-03],\n", + " [ 2.1792e-02, -9.0769e-03, -1.5590e-03],\n", + " [ 2.9187e-02, 2.6378e-02, -4.0534e-03]],\n", + " \n", + " [[ 8.8414e-03, 2.6818e-02, 1.2076e-04],\n", + " [-8.8425e-04, 1.2134e-02, -1.3035e-02],\n", + " [ 1.3764e-02, 4.9568e-02, 7.7859e-03]],\n", + " \n", + " [[-3.3327e-02, -2.3628e-02, 3.3143e-02],\n", + " [ 2.0608e-02, 6.5762e-03, 8.5704e-03],\n", + " [ 4.0431e-02, -6.4119e-03, -2.8803e-02]]],\n", + " \n", + " \n", + " [[[ 3.0359e-03, 2.2854e-02, 3.4083e-02],\n", + " [-2.1899e-02, -4.9271e-03, 2.5522e-02],\n", + " [-2.8607e-02, -1.9181e-02, 3.9501e-03]],\n", + " \n", + " [[ 1.4994e-02, -2.1828e-02, 6.4722e-03],\n", + " [ 1.9912e-02, 4.4057e-03, 1.0549e-03],\n", + " [-2.5813e-02, -2.5785e-02, 1.4741e-02]],\n", + " \n", + " [[-2.3405e-03, -5.0771e-03, -1.5741e-02],\n", + " [-7.1172e-03, -2.1527e-02, 1.0617e-02],\n", + " [ 2.0363e-02, -1.7201e-02, 5.4293e-04]],\n", + " \n", + " ...,\n", + " \n", + " [[-8.7831e-03, 4.9798e-03, 2.3290e-02],\n", + " [-2.3432e-02, 3.7738e-03, -1.0715e-03],\n", + " [ 2.0324e-02, -1.2834e-02, 2.3588e-02]],\n", + " \n", + " [[-9.9962e-03, -2.8735e-02, -5.6838e-03],\n", + " [ 2.4581e-02, -7.0371e-03, -1.3003e-02],\n", + " [-4.5345e-02, -1.1756e-02, -2.2176e-02]],\n", + " \n", + " [[-3.6717e-02, 1.0350e-02, 1.0031e-02],\n", + " [ 1.9931e-02, 1.5897e-02, 8.8945e-04],\n", + " [-3.0208e-02, 2.8018e-02, 3.4711e-03]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[ 6.2676e-04, -5.8382e-03, 4.8320e-04],\n", + " [ 1.6859e-02, 2.4176e-02, 9.4217e-03],\n", + " [-3.6895e-03, 5.9570e-03, 3.2383e-02]],\n", + " \n", + " [[-5.2046e-02, 3.5914e-02, -2.2961e-03],\n", + " [-3.3788e-02, 1.8930e-02, -1.9586e-02],\n", + " [-2.6040e-02, -1.1494e-02, -1.3791e-02]],\n", + " \n", + " [[-2.6943e-02, 2.2764e-02, -8.1082e-03],\n", + " [ 1.9537e-02, 2.3907e-03, 2.4450e-02],\n", + " [ 1.1038e-02, -1.3544e-03, 8.8689e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[-2.9265e-02, 2.9907e-03, -1.9701e-02],\n", + " [ 7.6547e-03, 2.3468e-02, 3.0149e-03],\n", + " [-5.6129e-02, -5.5228e-02, -1.4498e-02]],\n", + " \n", + " [[-7.0285e-03, -1.4398e-02, -8.0260e-03],\n", + " [ 8.6078e-03, 1.7272e-02, 3.7922e-03],\n", + " [ 1.7967e-02, -5.8396e-03, 1.2388e-02]],\n", + " \n", + " [[-1.9122e-02, 8.0995e-03, 1.4047e-02],\n", + " [-2.0732e-02, -1.7510e-03, -7.0565e-03],\n", + " [-8.9632e-03, 2.5373e-02, -2.1325e-03]]],\n", + " \n", + " \n", + " [[[-7.2139e-03, 7.5723e-04, 1.5858e-02],\n", + " [ 4.4614e-03, 2.5813e-03, 7.7735e-03],\n", + " [-3.8430e-02, -1.2453e-02, -1.4860e-03]],\n", + " \n", + " [[ 1.5345e-02, -1.1651e-02, -4.7176e-02],\n", + " [-1.6323e-02, -3.9558e-02, -6.2256e-02],\n", + " [-2.7249e-02, -3.9055e-03, 3.3146e-02]],\n", + " \n", + " [[ 2.7576e-03, -1.1254e-02, 1.5793e-02],\n", + " [ 9.0961e-04, -1.9929e-02, -3.9376e-02],\n", + " [-4.5218e-02, 6.8002e-03, 2.4895e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-3.2264e-02, -4.5119e-02, -2.3231e-02],\n", + " [ 4.5156e-02, 1.1634e-02, -1.9521e-02],\n", + " [ 8.7846e-03, 1.1269e-02, -2.7752e-02]],\n", + " \n", + " [[-2.6495e-02, 2.2894e-02, 1.4803e-02],\n", + " [ 4.4075e-03, 2.3500e-02, 3.7441e-03],\n", + " [ 9.3967e-03, 1.2344e-02, 3.9956e-02]],\n", + " \n", + " [[ 4.6042e-04, -9.5107e-04, 5.8999e-02],\n", + " [ 1.6742e-03, 8.3862e-03, -1.2621e-02],\n", + " [-9.1627e-03, 3.0100e-02, 1.9895e-02]]],\n", + " \n", + " \n", + " [[[ 2.0569e-02, -1.0795e-02, -8.7080e-03],\n", + " [ 2.4528e-02, 1.6483e-02, 9.9702e-04],\n", + " [ 3.6408e-03, -7.5199e-04, -5.9227e-02]],\n", + " \n", + " [[ 1.1502e-02, 2.2854e-02, -4.8116e-03],\n", + " [-1.7260e-02, -6.9348e-04, 8.0554e-03],\n", + " [ 8.5994e-03, -2.4679e-02, -3.9365e-02]],\n", + " \n", + " [[ 4.1262e-02, 1.6001e-02, 1.1125e-02],\n", + " [ 1.9232e-02, -4.0470e-02, -4.2124e-03],\n", + " [-3.1845e-02, -2.8374e-03, 1.2675e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-2.6239e-02, -3.9554e-02, 1.6393e-02],\n", + " [-6.0580e-03, 2.7392e-02, 2.6700e-02],\n", + " [-9.8317e-03, 2.6180e-02, -1.0239e-02]],\n", + " \n", + " [[-2.6586e-02, 3.0612e-02, -1.3597e-03],\n", + " [ 4.8483e-02, -1.3060e-02, -1.8707e-02],\n", + " [-6.5954e-03, 1.6304e-02, -2.2056e-02]],\n", + " \n", + " [[-2.0903e-03, -2.5995e-02, 4.7070e-02],\n", + " [ 1.9098e-02, -1.4131e-02, 1.0743e-02],\n", + " [-7.3639e-03, 3.8980e-02, -1.8740e-02]]]], device='cuda:0')),\n", + " ('module.features.7.0.downsample.0.weight',\n", + " tensor([[[[ 0.0280]],\n", + " \n", + " [[ 0.1087]],\n", + " \n", + " [[ 0.0847]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0122]],\n", + " \n", + " [[-0.0237]],\n", + " \n", + " [[ 0.0892]]],\n", + " \n", + " \n", + " [[[ 0.1502]],\n", + " \n", + " [[-0.0710]],\n", + " \n", + " [[-0.0160]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0163]],\n", + " \n", + " [[ 0.0837]],\n", + " \n", + " [[-0.0358]]],\n", + " \n", + " \n", + " [[[ 0.0068]],\n", + " \n", + " [[ 0.0029]],\n", + " \n", + " [[ 0.0566]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0893]],\n", + " \n", + " [[ 0.0697]],\n", + " \n", + " [[ 0.0071]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[-0.1321]],\n", + " \n", + " [[-0.0198]],\n", + " \n", + " [[-0.0812]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0576]],\n", + " \n", + " [[ 0.0471]],\n", + " \n", + " [[-0.0246]]],\n", + " \n", + " \n", + " [[[ 0.1082]],\n", + " \n", + " [[ 0.0021]],\n", + " \n", + " [[ 0.0406]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0061]],\n", + " \n", + " [[-0.0029]],\n", + " \n", + " [[ 0.0266]]],\n", + " \n", + " \n", + " [[[ 0.1127]],\n", + " \n", + " [[-0.0219]],\n", + " \n", + " [[-0.0168]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0372]],\n", + " \n", + " [[-0.0544]],\n", + " \n", + " [[-0.0213]]]], device='cuda:0')),\n", + " ('module.features.7.1.conv1.weight',\n", + " tensor([[[[-0.0412, 0.0398, 0.0164],\n", + " [ 0.0117, 0.0329, 0.0236],\n", + " [ 0.0188, -0.0364, -0.0361]],\n", + " \n", + " [[-0.0310, -0.0105, -0.0407],\n", + " [-0.0084, -0.0072, 0.0195],\n", + " [ 0.0076, 0.0183, 0.0473]],\n", + " \n", + " [[ 0.0089, 0.0133, -0.0262],\n", + " [-0.0185, -0.0104, -0.0033],\n", + " [ 0.0133, 0.0122, -0.0572]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0086, 0.0117, -0.0214],\n", + " [ 0.0113, -0.0197, 0.0212],\n", + " [-0.0350, 0.0330, -0.0123]],\n", + " \n", + " [[ 0.0009, 0.0155, 0.0001],\n", + " [-0.0108, 0.0034, -0.0038],\n", + " [ 0.0087, 0.0137, -0.0251]],\n", + " \n", + " [[-0.0321, -0.0180, 0.0369],\n", + " [ 0.0048, 0.0178, 0.0107],\n", + " [-0.0123, 0.0056, 0.0049]]],\n", + " \n", + " \n", + " [[[-0.0111, -0.0211, -0.0223],\n", + " [-0.0191, -0.0136, 0.0010],\n", + " [-0.0029, 0.0223, 0.0629]],\n", + " \n", + " [[-0.0196, 0.0049, -0.0111],\n", + " [ 0.0151, -0.0025, -0.0261],\n", + " [ 0.0095, -0.0071, 0.0136]],\n", + " \n", + " [[ 0.0367, 0.0324, -0.0228],\n", + " [-0.0251, 0.0110, -0.0151],\n", + " [ 0.0457, -0.0145, 0.0162]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0022, -0.0046, -0.0142],\n", + " [ 0.0224, 0.0348, 0.0015],\n", + " [-0.0036, 0.0051, -0.0091]],\n", + " \n", + " [[ 0.0116, 0.0166, 0.0132],\n", + " [-0.0150, 0.0014, 0.0032],\n", + " [-0.0103, 0.0153, -0.0235]],\n", + " \n", + " [[-0.0142, -0.0050, -0.0244],\n", + " [ 0.0185, -0.0016, 0.0012],\n", + " [-0.0144, 0.0141, -0.0072]]],\n", + " \n", + " \n", + " [[[-0.0123, 0.0130, -0.0075],\n", + " [-0.0071, -0.0082, -0.0034],\n", + " [ 0.0143, 0.0157, -0.0249]],\n", + " \n", + " [[-0.0159, 0.0027, -0.0388],\n", + " [ 0.0021, 0.0024, -0.0288],\n", + " [-0.0225, -0.0067, 0.0192]],\n", + " \n", + " [[-0.0156, -0.0163, 0.0312],\n", + " [ 0.0416, 0.0107, -0.0375],\n", + " [ 0.0049, -0.0461, -0.0219]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0265, -0.0160, -0.0259],\n", + " [-0.0189, -0.0011, -0.0063],\n", + " [ 0.0012, 0.0453, -0.0172]],\n", + " \n", + " [[ 0.0210, 0.0088, -0.0411],\n", + " [-0.0089, 0.0210, -0.0035],\n", + " [ 0.0183, -0.0061, 0.0316]],\n", + " \n", + " [[-0.0117, -0.0214, -0.0032],\n", + " [-0.0275, -0.0606, -0.0150],\n", + " [ 0.0040, -0.0216, -0.0044]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[-0.0001, 0.0213, -0.0171],\n", + " [ 0.0011, -0.0231, 0.0122],\n", + " [-0.0306, 0.0143, 0.0080]],\n", + " \n", + " [[-0.0027, 0.0032, -0.0124],\n", + " [ 0.0044, -0.0284, 0.0182],\n", + " [ 0.0099, 0.0445, 0.0078]],\n", + " \n", + " [[-0.0175, 0.0329, -0.0161],\n", + " [-0.0447, -0.0261, -0.0158],\n", + " [-0.0196, 0.0309, -0.0215]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0195, 0.0048, -0.0415],\n", + " [-0.0131, -0.0159, -0.0024],\n", + " [ 0.0295, 0.0165, -0.0192]],\n", + " \n", + " [[ 0.0355, 0.0102, 0.0040],\n", + " [ 0.0007, 0.0181, -0.0201],\n", + " [-0.0097, 0.0173, -0.0389]],\n", + " \n", + " [[ 0.0024, -0.0168, -0.0181],\n", + " [ 0.0059, 0.0004, -0.0522],\n", + " [ 0.0194, 0.0180, -0.0293]]],\n", + " \n", + " \n", + " [[[ 0.0062, 0.0105, -0.0109],\n", + " [-0.0213, 0.0117, 0.0115],\n", + " [-0.0232, -0.0389, 0.0185]],\n", + " \n", + " [[-0.0121, -0.0008, -0.0179],\n", + " [-0.0042, -0.0131, 0.0363],\n", + " [-0.0141, 0.0162, 0.0122]],\n", + " \n", + " [[-0.0172, 0.0188, -0.0150],\n", + " [-0.0093, 0.0270, 0.0506],\n", + " [-0.0337, -0.0070, -0.0267]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0118, -0.0035, 0.0332],\n", + " [-0.0185, -0.0073, -0.0262],\n", + " [-0.0110, -0.0086, 0.0152]],\n", + " \n", + " [[ 0.0105, -0.0218, -0.0302],\n", + " [-0.0114, 0.0171, -0.0384],\n", + " [ 0.0255, -0.0277, 0.0629]],\n", + " \n", + " [[ 0.0015, 0.0425, -0.0147],\n", + " [-0.0056, -0.0232, -0.0242],\n", + " [ 0.0015, 0.0060, -0.0036]]],\n", + " \n", + " \n", + " [[[ 0.0189, 0.0194, 0.0064],\n", + " [-0.0075, -0.0141, -0.0050],\n", + " [ 0.0002, -0.0243, -0.0248]],\n", + " \n", + " [[ 0.0173, 0.0066, -0.0278],\n", + " [-0.0158, -0.0061, 0.0161],\n", + " [-0.0176, -0.0237, -0.0293]],\n", + " \n", + " [[ 0.0067, -0.0371, 0.0001],\n", + " [-0.0122, 0.0012, -0.0346],\n", + " [-0.0239, -0.0195, 0.0066]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0004, -0.0357, -0.0282],\n", + " [-0.0071, -0.0012, -0.0346],\n", + " [-0.0103, -0.0152, -0.0183]],\n", + " \n", + " [[-0.0087, 0.0358, 0.0211],\n", + " [ 0.0090, -0.0186, 0.0573],\n", + " [-0.0072, 0.0191, -0.0075]],\n", + " \n", + " [[-0.0091, 0.0155, 0.0092],\n", + " [ 0.0244, 0.0233, 0.0293],\n", + " [ 0.0218, -0.0016, -0.0111]]]], device='cuda:0')),\n", + " ('module.features.7.1.conv2.weight',\n", + " tensor([[[[ 2.7669e-02, 1.0102e-02, -1.5389e-02],\n", + " [-7.7896e-03, 1.7454e-02, -5.2838e-03],\n", + " [ 3.2319e-02, -9.5478e-03, 2.1955e-02]],\n", + " \n", + " [[ 1.3091e-02, 1.2744e-02, -6.0428e-03],\n", + " [ 7.9706e-04, 4.0532e-03, -2.7187e-03],\n", + " [ 5.6271e-03, -2.0450e-02, -8.3630e-04]],\n", + " \n", + " [[ 2.6977e-02, -1.3292e-02, 9.1527e-03],\n", + " [-2.3563e-02, -2.7924e-02, -1.6096e-02],\n", + " [ 5.5006e-04, 1.0982e-02, 1.1167e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 3.2128e-03, -1.1398e-02, -1.7091e-02],\n", + " [-5.3630e-02, -1.8246e-02, -7.6490e-03],\n", + " [-1.8577e-02, 9.7755e-03, -2.1825e-02]],\n", + " \n", + " [[ 2.5462e-02, -4.8757e-03, 1.7404e-03],\n", + " [ 6.9207e-03, 2.9026e-02, 1.9597e-02],\n", + " [-2.6660e-03, -1.3143e-02, 1.3166e-02]],\n", + " \n", + " [[-9.2363e-03, 1.4657e-02, -1.7753e-02],\n", + " [-1.5085e-02, -8.5969e-03, 1.7411e-02],\n", + " [ 3.1844e-02, -1.2665e-02, 3.8360e-02]]],\n", + " \n", + " \n", + " [[[ 1.6494e-02, 4.0510e-02, -1.1034e-02],\n", + " [ 1.0801e-02, -4.6805e-02, 3.1894e-02],\n", + " [ 3.1285e-02, -2.7874e-02, 3.1512e-03]],\n", + " \n", + " [[ 1.3249e-02, -2.4612e-02, 8.0424e-03],\n", + " [ 1.4087e-02, 8.3113e-03, -1.6492e-02],\n", + " [-1.6472e-02, 1.4578e-02, -1.0045e-02]],\n", + " \n", + " [[-2.6437e-02, -1.5198e-02, 8.9365e-03],\n", + " [ 4.9480e-03, 3.0334e-02, 7.5042e-03],\n", + " [ 1.6577e-02, -5.5669e-03, -8.8701e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[-3.6576e-02, -2.9337e-02, 3.2167e-03],\n", + " [-8.1815e-04, -1.7954e-02, -1.2874e-02],\n", + " [ 8.4388e-05, -2.0595e-02, -4.4706e-02]],\n", + " \n", + " [[-1.6041e-02, 1.0156e-02, -9.1008e-03],\n", + " [ 4.5301e-02, 4.7988e-04, 8.6943e-03],\n", + " [-1.1701e-02, 4.7244e-02, 1.7367e-02]],\n", + " \n", + " [[ 1.3813e-02, -4.0753e-02, -2.4185e-03],\n", + " [-2.4735e-02, -2.6664e-02, 9.9198e-04],\n", + " [ 1.4143e-02, -9.8058e-03, 2.7278e-03]]],\n", + " \n", + " \n", + " [[[-3.9752e-03, 2.7291e-03, 2.9844e-02],\n", + " [-8.5228e-03, -6.5568e-03, -1.9087e-02],\n", + " [ 5.2401e-03, 5.0350e-04, 1.7048e-02]],\n", + " \n", + " [[-9.2543e-03, 1.8186e-02, -1.7019e-02],\n", + " [-3.3054e-02, 2.0240e-03, 5.5067e-03],\n", + " [-4.1597e-03, 9.8459e-03, -8.0485e-03]],\n", + " \n", + " [[-3.1671e-02, -3.4744e-03, -2.2173e-02],\n", + " [ 1.9146e-02, 5.9635e-03, 2.6835e-02],\n", + " [-2.7219e-02, -1.4855e-02, 1.3716e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-1.3893e-03, -2.3909e-03, 3.7731e-03],\n", + " [ 2.5951e-02, 2.2350e-02, 3.3649e-02],\n", + " [ 8.8714e-03, -3.2832e-03, 3.0970e-02]],\n", + " \n", + " [[ 1.6667e-02, -7.0636e-03, -8.0601e-05],\n", + " [-3.8572e-03, -2.2319e-02, -2.1308e-02],\n", + " [-2.4768e-02, -1.5440e-02, -1.5375e-02]],\n", + " \n", + " [[ 4.0461e-04, -1.9128e-02, 2.8521e-02],\n", + " [ 2.7569e-03, 1.2434e-02, 1.1612e-02],\n", + " [ 4.9260e-02, 2.2169e-02, -2.4459e-02]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[-2.1874e-02, 1.4048e-03, 2.3160e-02],\n", + " [-2.2029e-02, 1.3565e-02, -1.1423e-02],\n", + " [ 1.0922e-02, 1.6705e-02, 4.6132e-03]],\n", + " \n", + " [[ 1.5474e-03, 4.6071e-02, 9.5328e-03],\n", + " [ 4.8634e-03, 3.2872e-02, -4.5178e-04],\n", + " [-2.0327e-02, 5.2714e-03, 2.2405e-02]],\n", + " \n", + " [[-1.3570e-02, -5.1997e-02, -2.5401e-02],\n", + " [ 2.3371e-02, -2.0209e-02, -1.3148e-02],\n", + " [ 1.3217e-02, -9.4160e-03, -1.7271e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[-9.3603e-03, -1.0252e-02, -2.1926e-03],\n", + " [ 8.3288e-03, 1.4618e-02, 2.2519e-02],\n", + " [ 4.3726e-03, 1.5235e-02, 2.4148e-02]],\n", + " \n", + " [[ 2.1585e-02, -2.3678e-03, 1.3339e-03],\n", + " [-1.2002e-02, -1.8972e-02, -8.1394e-03],\n", + " [ 4.2480e-03, -1.7132e-02, 2.2052e-02]],\n", + " \n", + " [[ 1.0347e-02, 2.2348e-02, -9.3872e-03],\n", + " [-1.3332e-02, 1.9236e-02, -5.8983e-03],\n", + " [-2.6036e-02, -1.4431e-02, -4.9069e-02]]],\n", + " \n", + " \n", + " [[[ 3.4349e-03, 3.0972e-02, -2.8677e-02],\n", + " [ 7.1175e-03, 2.1691e-02, 4.0088e-03],\n", + " [ 2.0570e-02, 4.5587e-03, 5.5360e-03]],\n", + " \n", + " [[-2.6303e-02, -6.6955e-03, -3.8755e-02],\n", + " [-1.7761e-02, -9.3154e-03, 1.6708e-02],\n", + " [-2.8337e-02, 1.5408e-02, -1.2784e-02]],\n", + " \n", + " [[-1.0890e-02, 1.0057e-03, 4.9444e-03],\n", + " [ 1.5475e-02, 2.7684e-02, 3.2346e-03],\n", + " [-1.2827e-02, -8.9223e-03, 1.1571e-02]],\n", + " \n", + " ...,\n", + " \n", + " [[ 2.9467e-02, -1.7513e-02, 2.1790e-02],\n", + " [ 9.5623e-03, 1.5912e-03, -2.8656e-02],\n", + " [-2.0841e-02, 3.4052e-02, 1.4477e-02]],\n", + " \n", + " [[ 2.5121e-02, 1.3353e-02, 1.7004e-02],\n", + " [-2.2956e-02, -7.9317e-03, -7.7189e-03],\n", + " [-2.2805e-02, 5.5981e-03, -9.2481e-03]],\n", + " \n", + " [[-7.4896e-03, 1.5322e-03, -1.4587e-02],\n", + " [ 4.1217e-02, 1.2760e-02, -3.7466e-02],\n", + " [ 4.6187e-02, 3.8541e-03, -3.0564e-02]]],\n", + " \n", + " \n", + " [[[ 2.1435e-02, 4.5467e-02, 4.5000e-03],\n", + " [ 7.6490e-03, -2.5208e-03, 2.3527e-02],\n", + " [ 2.5580e-02, 1.6698e-02, -1.8860e-02]],\n", + " \n", + " [[ 1.4487e-02, -3.4432e-03, -7.5392e-03],\n", + " [ 1.8938e-02, 4.7559e-02, -2.5380e-02],\n", + " [ 2.1595e-02, -8.7880e-03, 2.2249e-03]],\n", + " \n", + " [[ 1.0577e-02, 2.2048e-02, -1.5811e-03],\n", + " [-3.6731e-02, -4.3263e-03, 9.8155e-03],\n", + " [ 5.3485e-03, 3.3231e-02, 4.9549e-03]],\n", + " \n", + " ...,\n", + " \n", + " [[-1.4723e-02, -1.8430e-02, 2.0312e-02],\n", + " [-4.0386e-02, 1.3468e-02, -9.7934e-03],\n", + " [-7.6757e-03, 5.5895e-02, -6.2044e-03]],\n", + " \n", + " [[-4.6726e-03, -1.3059e-02, -6.3235e-03],\n", + " [ 2.0474e-02, 9.5724e-03, -1.0297e-02],\n", + " [-2.9894e-02, -2.8968e-02, 1.5460e-02]],\n", + " \n", + " [[ 3.6078e-03, 4.0985e-03, -1.9945e-02],\n", + " [ 8.5199e-03, -1.5623e-02, 1.6469e-02],\n", + " [-4.2940e-03, 3.3834e-03, -2.3444e-02]]]], device='cuda:0')),\n", + " ('module.l1.weight',\n", + " tensor([[-2.8849e-02, -1.9119e-02, 3.8300e-02, ..., 4.1686e-02,\n", + " 3.6721e-02, 1.5111e-03],\n", + " [-2.2924e-03, -7.3991e-05, -4.6714e-04, ..., 2.8410e-04,\n", + " -7.1595e-05, 1.2524e-03],\n", + " [-4.3433e-02, -1.6126e-02, 2.4300e-02, ..., 3.1861e-02,\n", + " -4.4353e-03, 3.3997e-02],\n", + " ...,\n", + " [-4.2116e-02, 9.1577e-03, -2.8979e-03, ..., 8.2516e-03,\n", + " -2.0367e-02, -2.6846e-02],\n", + " [-3.9404e-02, 3.1663e-03, -3.1503e-02, ..., -5.0674e-03,\n", + " -1.2334e-02, 2.4472e-02],\n", + " [ 4.0532e-02, 2.3001e-02, -4.4442e-02, ..., 4.0665e-02,\n", + " -1.5093e-03, -3.1281e-02]], device='cuda:0')),\n", + " ('module.l1.bias',\n", + " tensor([ 3.0996e-02, -1.0540e-04, -1.4748e-02, 2.5318e-04, 4.0875e-03,\n", + " -2.7277e-02, 2.8669e-02, -2.7196e-15, -1.9390e-02, -6.0616e-03,\n", + " 2.7244e-02, -7.6942e-16, 7.5572e-03, -4.1416e-02, 3.1923e-02,\n", + " -4.3156e-02, -7.3542e-23, -1.2017e-24, -3.8204e-02, -2.5722e-02,\n", + " -3.6925e-15, -3.7085e-02, -5.6505e-15, -1.9316e-02, -5.5478e-03,\n", + " 4.0829e-02, -7.1897e-04, 4.1314e-02, -1.0092e-02, -5.3813e-04,\n", + " 2.4123e-02, -1.0446e-02, -3.5741e-18, -2.5170e-04, -3.2334e-02,\n", + " -7.4074e-16, -2.1480e-04, -9.5165e-03, 2.5351e-02, -8.0323e-03,\n", + " 2.2315e-02, -1.6049e-03, -9.0869e-03, -3.8037e-02, 3.1691e-02,\n", + " 3.4157e-02, -1.7944e-03, -1.9633e-02, 3.0051e-02, -1.3012e-02,\n", + " -3.7214e-02, 2.4073e-02, -9.3066e-03, 3.7976e-03, -2.2379e-03,\n", + " -3.3308e-39, -1.3419e-04, -2.2215e-02, 6.3669e-03, 7.4496e-03,\n", + " -2.0330e-03, -2.2457e-02, -2.8783e-04, -4.1772e-02, 3.2072e-03,\n", + " 1.4834e-02, 1.6824e-02, -2.8914e-24, 6.9507e-03, 4.3442e-02,\n", + " 3.6441e-02, 2.9667e-02, 3.8556e-02, -6.1635e-04, -4.2663e-02,\n", + " 1.3519e-02, 1.9096e-02, 1.7169e-02, 8.0511e-03, -9.5487e-03,\n", + " -3.1515e-02, 5.2946e-03, 8.7159e-03, 3.3859e-02, 2.5782e-02,\n", + " 3.0781e-02, -2.5431e-21, 2.2940e-03, -3.4296e-02, -3.2956e-04,\n", + " -3.8724e-02, -9.5437e-03, 6.4339e-03, 1.9079e-03, -4.8774e-40,\n", + " -1.2229e-02, -3.6624e-02, 1.6433e-03, 3.0857e-02, 5.2444e-03,\n", + " 2.5248e-02, -2.3934e-04, -5.5150e-22, 2.6404e-02, 1.1247e-02,\n", + " 2.0062e-02, -1.3908e-02, 2.9855e-02, -3.4427e-02, -1.1697e-02,\n", + " 2.0255e-02, -3.1341e-12, -4.0415e-03, -3.6996e-02, -3.4850e-03,\n", + " 5.8550e-03, 2.1957e-02, -8.1380e-30, 1.9003e-02, -3.1575e-02,\n", + " -2.5251e-02, 8.6628e-04, -3.8475e-04, -2.3658e-02, -3.6147e-40,\n", + " 3.5547e-02, 3.0704e-03, -2.3029e-22, -2.0700e-02, 1.6150e-02,\n", + " -5.2000e-04, 1.8059e-02, 1.1933e-03, 2.8345e-02, -3.0557e-03,\n", + " -1.0855e-13, -3.1424e-02, 3.4007e-02, 3.1265e-02, 2.9189e-02,\n", + " -1.8984e-02, 1.7491e-02, -2.1464e-02, -2.1183e-03, 9.4477e-03,\n", + " 1.9714e-02, 3.4079e-03, 4.1416e-02, -4.2228e-02, -5.4038e-03,\n", + " 6.3732e-03, 4.0005e-02, 9.1947e-03, 3.6073e-02, 1.7447e-02,\n", + " 8.5820e-04, 4.7003e-03, 1.5726e-02, 8.5460e-03, 1.0651e-02,\n", + " 1.8213e-02, 2.6473e-02, 2.8870e-02, 4.2648e-02, -3.7402e-02,\n", + " 2.4007e-02, 1.4670e-04, 1.2810e-02, 1.7543e-02, 1.9343e-02,\n", + " 2.0540e-02, 3.8007e-02, 4.3428e-03, 2.8722e-02, -5.9784e-04,\n", + " -3.5587e-16, -6.5072e-31, 1.9575e-02, -2.1262e-02, 3.3649e-02,\n", + " -3.7233e-02, 1.5873e-02, 1.8498e-02, 2.4411e-02, 2.9294e-02,\n", + " 1.4994e-02, -4.3407e-02, 1.4391e-02, -3.2271e-18, 1.4919e-02,\n", + " 8.2647e-03, -5.9837e-03, 1.9369e-02, 4.9382e-03, 8.6709e-04,\n", + " 3.7271e-02, -1.5391e-04, -1.3788e-04, 3.3446e-02, -7.8908e-04,\n", + " -2.7301e-02, 2.3324e-02, -1.2863e-04, -1.4473e-20, -2.2201e-02,\n", + " -6.5220e-03, 2.2164e-02, 2.2754e-02, -6.6216e-03, 2.3274e-02,\n", + " -4.5658e-04, 4.2457e-02, 2.4065e-02, -1.7021e-02, -6.6629e-03,\n", + " -2.7511e-15, 2.0085e-02, 2.0448e-02, -1.1625e-13, -3.0990e-02,\n", + " 1.5328e-02, -9.4063e-13, 3.2097e-02, 2.0661e-02, -1.9480e-02,\n", + " -4.0583e-03, -8.1608e-05, 2.4092e-02, -1.0616e-02, -1.0273e-02,\n", + " -4.5386e-24, 2.0436e-02, 1.3322e-02, -1.9890e-02, -4.9095e-22,\n", + " -6.6272e-03, -9.4388e-03, -2.7694e-03, 1.4548e-02, 1.6013e-02,\n", + " 2.4510e-02, 3.4344e-02, 3.7955e-02, -2.3670e-03, -5.5752e-04,\n", + " 4.0415e-02, -2.0931e-02, 4.2837e-02, 9.9007e-03, 2.6497e-02,\n", + " -2.4462e-02, -2.5526e-24, -4.6201e-04, 2.1199e-02, 1.5242e-02,\n", + " -3.6141e-02, -2.2361e-04, -1.5856e-40, 9.2270e-03, 3.1342e-02,\n", + " 1.3890e-02, -2.0049e-03, -1.7684e-03, 8.6534e-03, 1.0998e-02,\n", + " 2.7369e-02, -2.3437e-02, -2.9719e-06, -1.1690e-17, -1.8020e-06,\n", + " -4.4680e-02, 4.0719e-02, 1.7397e-02, 2.3354e-02, -3.3420e-02,\n", + " 1.0529e-02, 1.4536e-02, -9.0702e-04, 6.8670e-03, -1.8068e-02,\n", + " 6.0072e-03, 2.9915e-02, 1.5068e-02, 1.9626e-02, 1.1432e-02,\n", + " 2.7300e-02, 1.0894e-02, 3.3583e-02, 3.8062e-02, -4.0905e-02,\n", + " 2.9605e-02, 9.2011e-03, -1.2439e-02, -3.1910e-02, 2.0235e-02,\n", + " -2.3785e-02, 1.9151e-02, -2.7005e-12, 1.7179e-02, 1.5289e-03,\n", + " 1.0438e-02, -1.0442e-05, 1.4552e-02, 1.6959e-04, -2.5779e-02,\n", + " -7.5984e-03, -1.0531e-02, 1.1304e-02, 1.5698e-02, 4.2287e-02,\n", + " -1.8908e-04, 2.2671e-02, 2.6370e-02, 2.6532e-02, -3.1829e-02,\n", + " 1.7723e-02, 1.1734e-02, 4.3042e-02, -8.5903e-03, -1.0594e-02,\n", + " 3.6971e-02, -1.4430e-14, 2.7222e-02, 3.6229e-02, 3.7119e-02,\n", + " 1.8926e-02, 1.0625e-03, 9.2381e-04, -8.2238e-03, -6.4543e-14,\n", + " -1.7204e-02, -1.7061e-02, -2.6419e-02, 2.8717e-02, 4.7878e-03,\n", + " 1.5822e-02, -2.7476e-02, -3.3203e-02, -2.4455e-02, -6.5348e-03,\n", + " 2.9319e-02, -5.7435e-03, 2.5551e-02, 7.4299e-04, -7.3322e-23,\n", + " 2.6626e-02, -3.7447e-04, 5.5226e-03, -3.3814e-02, -4.4083e-02,\n", + " -1.4452e-38, 3.1598e-02, -2.7289e-02, 3.3592e-02, 3.0028e-02,\n", + " -4.3492e-03, -5.6239e-04, 4.2002e-02, 4.2847e-02, 2.1402e-02,\n", + " 3.6151e-02, 2.9313e-02, 3.4036e-02, 1.0820e-02, 8.1017e-03,\n", + " -3.2201e-02, -1.3255e-02, -2.0437e-19, -4.4835e-02, -1.6591e-15,\n", + " 7.3390e-03, -4.1098e-02, 1.0173e-03, -8.9903e-04, -1.2006e-02,\n", + " -1.7443e-02, 2.8147e-02, 2.3075e-02, 4.2859e-03, 1.2978e-02,\n", + " -1.5794e-02, 3.9542e-02, -1.5094e-06, -1.2039e-02, -4.7866e-04,\n", + " -7.7794e-04, -1.9704e-02, 1.4743e-02, 1.8525e-02, -9.4513e-03,\n", + " 2.2987e-02, 1.7009e-03, -1.9632e-02, -1.8319e-18, -7.6008e-04,\n", + " -1.8262e-02, 1.7942e-02, -3.4788e-02, 9.3117e-03, -1.2839e-03,\n", + " -1.9524e-13, -7.4299e-03, -2.1009e-12, 1.9936e-02, -2.3543e-07,\n", + " -3.7481e-12, -3.2467e-02, -2.0639e-04, -2.5341e-02, -2.8937e-39,\n", + " 1.9073e-02, -2.2175e-02, -7.8619e-13, 3.0309e-02, -1.4175e-05,\n", + " 3.3793e-02, -2.2715e-02, 2.4584e-02, 1.6416e-02, -4.0398e-02,\n", + " -2.8118e-03, 5.8474e-40, -1.3411e-04, -3.5482e-02, -2.5144e-02,\n", + " -2.7105e-18, 1.3947e-02, -2.8604e-18, -8.5239e-05, 2.1395e-02,\n", + " 8.6881e-03, -1.2376e-03, -6.1344e-04, 2.9223e-02, 3.6894e-02,\n", + " 4.1455e-02, -4.3520e-02, -7.9448e-17, -3.3451e-03, 3.6124e-02,\n", + " -1.7476e-02, 2.6451e-02, -3.3266e-03, -3.2752e-02, -1.3644e-02,\n", + " -3.3310e-03, -1.6137e-02, -2.0225e-17, 2.7003e-02, 2.9479e-02,\n", + " -2.6810e-02, -2.4228e-02, 8.1781e-03, 3.0183e-02, -1.3654e-02,\n", + " -2.8101e-02, 3.2361e-02, 2.3272e-02, 3.7588e-02, -1.9659e-37,\n", + " -7.2061e-19, 3.4653e-02, 1.1160e-02, -4.4758e-04, 1.0306e-02,\n", + " 6.9661e-03, -1.3581e-02, 4.0198e-03, -3.7042e-02, -2.3607e-07,\n", + " -1.2145e-37, 1.8899e-02, 4.1548e-02, 4.4042e-02, 2.8042e-02,\n", + " -1.1759e-02, 2.5871e-02, -2.8453e-21, 2.2797e-02, 8.1754e-04,\n", + " 9.6130e-03, 4.2659e-03, -3.3586e-02, 1.9345e-03, -7.2296e-03,\n", + " -5.0466e-13, -5.1403e-04, 2.8400e-02, -1.2116e-02, -4.2314e-02,\n", + " -4.1556e-02, 5.7938e-05, -1.3978e-03, -1.7389e-05, 3.2640e-02,\n", + " 2.0089e-02, 1.3020e-02, 8.1609e-03, -1.7976e-22, -8.7996e-03,\n", + " 1.0226e-02, 3.5521e-03, 4.3961e-02, -5.3368e-03, 3.8782e-02,\n", + " -4.5596e-04, -2.2731e-03, -2.3455e-02, 3.5500e-02, 3.9022e-02,\n", + " -2.9015e-02, 2.9233e-02], device='cuda:0')),\n", + " ('module.l2.weight',\n", + " tensor([[-1.9599e-02, 4.3520e-14, 3.3635e-02, ..., -5.6260e-03,\n", + " -1.0022e-02, 2.9938e-02],\n", + " [-3.7372e-02, -3.6296e-06, -4.1800e-02, ..., -2.4365e-03,\n", + " 1.1096e-02, -2.5845e-03],\n", + " [-1.3492e-02, 2.6660e-06, 4.1717e-02, ..., 3.4630e-02,\n", + " 3.4364e-02, 3.4099e-02],\n", + " ...,\n", + " [ 3.8297e-02, 1.7021e-07, 7.2816e-03, ..., -2.4668e-02,\n", + " -2.1351e-02, 2.6505e-02],\n", + " [ 3.9132e-02, -7.4625e-09, -3.8427e-02, ..., 2.9803e-02,\n", + " -4.0681e-02, -1.4239e-02],\n", + " [ 1.4177e-02, 5.8990e-06, 2.3557e-02, ..., -2.7589e-02,\n", + " 1.6648e-02, 4.2929e-03]], device='cuda:0')),\n", + " ('module.l2.bias',\n", + " tensor([ 2.9559e-02, -2.3026e-02, -3.2657e-02, -6.7057e-03, -3.2271e-02,\n", + " 3.4164e-02, 3.5729e-02, 2.8985e-02, -2.8901e-02, -4.7665e-03,\n", + " 2.8278e-02, 7.5709e-03, -2.6814e-02, 3.5719e-02, -5.4381e-03,\n", + " -3.8308e-02, 2.4235e-02, -1.6882e-03, 1.4637e-02, 1.0824e-03,\n", + " 2.4541e-02, 1.3036e-02, 2.0802e-02, -2.2542e-02, -2.9280e-02,\n", + " 1.1447e-02, -2.1490e-02, -3.9629e-02, 1.2565e-02, -2.7861e-03,\n", + " -4.3228e-02, -5.9347e-03, -5.7878e-03, -2.1625e-02, -2.8541e-02,\n", + " -1.0963e-02, -3.7994e-02, 1.5290e-02, -3.9319e-02, -2.2262e-03,\n", + " -1.2133e-02, 2.7171e-03, -3.4644e-02, -3.4570e-02, -1.1267e-02,\n", + " 3.2676e-02, 1.0274e-02, -2.3055e-02, -2.1700e-02, -4.1650e-02,\n", + " 1.4503e-02, 5.6783e-03, 5.8394e-03, -2.7246e-02, 1.2473e-02,\n", + " -3.5907e-02, 1.3339e-02, 3.1142e-02, -2.4338e-02, -1.5809e-02,\n", + " 1.3111e-02, 4.3001e-02, -7.0815e-04, 3.5192e-02, -2.2486e-02,\n", + " -3.7531e-03, 3.5100e-02, 1.3077e-02, -4.2701e-02, 1.2880e-02,\n", + " 1.6414e-03, -1.6710e-02, -3.9001e-02, -3.1517e-02, 9.5927e-03,\n", + " -1.2958e-02, 3.0725e-02, -3.7749e-02, 2.9525e-02, -2.4572e-02,\n", + " -1.8708e-02, -4.0700e-02, 2.5859e-02, -2.1719e-02, 2.3005e-02,\n", + " 3.6742e-02, -3.3517e-02, 3.1454e-02, -1.9672e-02, -7.4635e-03,\n", + " 4.1330e-02, 3.0112e-05, -4.4783e-02, -3.4733e-02, 1.2247e-02,\n", + " -2.9102e-02, 2.9429e-02, 3.7014e-02, 1.0661e-02, -3.4314e-02,\n", + " -3.4389e-02, -2.8635e-02, -3.3195e-02, 1.9766e-04, 6.7751e-03,\n", + " -4.0058e-02, 2.0636e-02, -3.0467e-02, -1.2753e-02, 3.1120e-03,\n", + " 2.5857e-02, 3.7302e-02, 7.2018e-04, 1.0489e-02, -2.2780e-02,\n", + " -2.2742e-02, -3.0345e-02, 4.4343e-02, -1.5053e-02, 6.2518e-03,\n", + " -2.8297e-02, 2.6441e-02, 1.0819e-02, 2.1973e-02, 3.7255e-02,\n", + " -2.0747e-03, 1.2554e-02, -3.6957e-02, -2.7261e-02, -3.3418e-02,\n", + " -3.4205e-02, 2.3845e-02, 3.0916e-02, -2.5197e-02, -1.6039e-02,\n", + " -2.7497e-02, -1.0755e-02, -1.5479e-02, -1.9881e-02, -9.5323e-04,\n", + " -7.9155e-03, -3.4123e-02, 1.5674e-02, -1.6659e-02, -1.4317e-02,\n", + " 1.5828e-02, 2.3558e-02, 3.0177e-02, 3.2466e-02, 4.0365e-02,\n", + " -2.6484e-02, -2.2293e-03, -3.8066e-02, -3.3669e-02, -3.0378e-02,\n", + " 6.4637e-03, 1.6316e-02, 2.5219e-02, 5.6825e-03, -3.4366e-02,\n", + " 1.3237e-02, 2.9095e-02, -3.5179e-02, 3.7836e-04, 1.5759e-03,\n", + " -2.0528e-02, 8.4841e-03, -1.6437e-02, 3.4154e-02, 2.6474e-02,\n", + " -1.2459e-02, 3.8338e-02, 4.1170e-02, -3.9659e-02, -1.1147e-02,\n", + " 1.8834e-02, 2.1028e-02, 1.8429e-02, -3.9750e-02, -2.1570e-02,\n", + " -3.3184e-02, -9.8474e-03, 4.1417e-02, 2.0477e-02, 3.0310e-02,\n", + " -1.6439e-02, -4.2043e-03, -1.9652e-02, 1.5055e-02, 1.3996e-02,\n", + " 2.5340e-02, -1.0289e-02, 2.1606e-02, 1.3181e-02, 1.4767e-02,\n", + " -1.7636e-02, -8.3252e-03, 2.8213e-02, -7.6369e-03, 3.5022e-02,\n", + " 2.7302e-02, 7.9103e-03, 1.9150e-02, -5.3064e-03, -2.1429e-02,\n", + " -3.9241e-03, 1.6526e-02, 3.0762e-02, -3.2304e-02, 4.2036e-02,\n", + " 3.4973e-02, -3.2242e-02, 1.4682e-02, 3.1378e-02, 1.3040e-02,\n", + " 1.3286e-02, 7.1529e-03, 4.0689e-02, 1.8221e-02, -8.5823e-04,\n", + " -3.9397e-02, -1.7632e-02, 3.8854e-02, 7.2799e-03, -8.8435e-03,\n", + " -3.7702e-02, 2.9500e-02, 1.3317e-02, 3.1820e-02, -2.5771e-02,\n", + " -2.0684e-02, 9.7548e-03, 1.1657e-02, 3.6020e-02, 3.8275e-03,\n", + " 2.5878e-02, -2.6804e-02, -2.7374e-02, -4.1425e-03, -8.4578e-03,\n", + " 3.6086e-02, -9.2384e-03, 3.6645e-02, 2.2945e-02, -3.1541e-02,\n", + " -3.6963e-02, -8.4283e-03, -1.2568e-02, -1.0610e-02, 4.6122e-03,\n", + " 8.0923e-03, 1.9675e-02, 1.6277e-02, 8.4814e-03, -1.7338e-02,\n", + " -2.2944e-03, -1.4254e-02, -1.1278e-02, -4.1493e-02, -1.1478e-02,\n", + " -4.1900e-03, -3.9914e-02, 3.6327e-02, 2.1176e-02, 2.1221e-02,\n", + " 2.3638e-03, -1.8431e-02, -1.6513e-02, 2.7354e-02, -2.7289e-02,\n", + " 5.6979e-03, -3.3893e-02, 3.4217e-02, 1.1927e-02, -3.7365e-02,\n", + " -4.3069e-02, 1.1643e-02, -3.8766e-02, 2.7033e-02, -3.3769e-02,\n", + " -8.2054e-03, 3.8923e-02, -1.3608e-02, -1.6946e-02, -5.1545e-03,\n", + " -3.2235e-02, -1.8647e-02, -2.6441e-02, 3.2820e-02, -6.6160e-03,\n", + " -1.0563e-02, -1.8934e-02, 1.0979e-02, 3.1983e-02, -3.0583e-02,\n", + " -5.1670e-03, -1.3826e-02, -4.7620e-03, -2.3506e-03, 2.0033e-02,\n", + " -2.9167e-02, 1.3401e-02, 3.6858e-05, -4.0258e-02, 3.0535e-03,\n", + " 1.0325e-02, -4.1490e-02, 3.5403e-02, -2.3996e-02, -2.3905e-02,\n", + " -2.4046e-02, 3.5478e-02, -2.8242e-02, 3.0915e-02, 1.1743e-02,\n", + " 2.9823e-02, 4.1825e-03, 2.7585e-02, 3.1861e-02, 3.5130e-02,\n", + " 1.4075e-02, 1.3824e-02, -3.5370e-02, 9.8739e-03, -4.4559e-02,\n", + " -3.6603e-03, -3.4666e-02, 7.4369e-03, -4.0257e-02, 1.5170e-02,\n", + " 3.8140e-02, -1.3233e-02, -4.7628e-03, 1.3612e-02, 5.4920e-03,\n", + " -3.4651e-02, 1.6557e-02, -3.7007e-03, -2.2401e-02, -3.1753e-02,\n", + " 1.8287e-02, 4.0354e-02, 1.9965e-02, 2.8328e-02, 3.8909e-02,\n", + " -2.5958e-02, 3.0594e-02, -2.6624e-02, 2.8449e-02, -3.4423e-02,\n", + " -1.9999e-02, 1.2986e-02, 1.7698e-04, 3.8171e-02, -1.8423e-02,\n", + " -4.0395e-02, 2.8430e-02, 3.3913e-02, -3.7508e-02, -2.3408e-02,\n", + " 2.0060e-02, 3.2080e-03, -4.0210e-02, -1.7192e-04, 4.2094e-02,\n", + " -4.1802e-02, -2.7460e-02, -2.8787e-02, -2.3993e-02, -2.0099e-02,\n", + " -5.7462e-03, -3.6673e-02, 2.0087e-02, 3.5864e-02, -3.5259e-02,\n", + " -1.8713e-02, -3.9394e-02, 2.8474e-02, -3.4466e-02, -2.5969e-02,\n", + " -2.6257e-02, -2.6545e-02, -8.0948e-03, 8.3270e-03, 8.8027e-03,\n", + " -1.9652e-02, 1.1662e-02, 2.7739e-02, -4.2084e-02, 2.5227e-02,\n", + " -4.2397e-02, -1.9872e-02, 2.2378e-02, -2.2397e-02, -1.0248e-03,\n", + " 2.3392e-02, 2.4300e-02, -3.3800e-03, -4.1069e-02, 1.6576e-02,\n", + " 2.9552e-02, 3.2363e-02, 3.5240e-02, -2.3632e-02, 3.7093e-02,\n", + " 3.5138e-02, -3.9624e-02, 4.2746e-02, -1.5768e-02, -2.0072e-02,\n", + " 9.3175e-03, 4.1433e-02, -3.0851e-03, 2.7544e-02, -2.1298e-02,\n", + " 4.0739e-02, 4.3524e-03, -1.0058e-02, -4.0087e-02, -1.2815e-02,\n", + " -4.3053e-02, 2.8493e-02, -2.3377e-02, -3.6020e-02, 2.6309e-02,\n", + " -1.2762e-02, -2.2052e-02, -1.7514e-02, 1.6878e-02, -2.3183e-02,\n", + " -4.4345e-02, -1.7994e-03, -1.6895e-03, 6.9794e-03, 4.4503e-02,\n", + " 3.6033e-02, -3.2598e-02, -3.0241e-03, 1.9046e-03, 2.1675e-02,\n", + " -3.1861e-02, 4.2725e-02, -3.1347e-02, -3.8474e-03, 2.6504e-02,\n", + " 1.1765e-02, -2.2374e-02, -1.1619e-02, -3.1221e-02, -6.7740e-03,\n", + " -3.4773e-03, -3.6647e-02, -3.2517e-02, -1.8775e-02, 4.6997e-03,\n", + " -3.5332e-02, 4.1909e-02, -1.8737e-02, -7.7933e-03, -2.4771e-02,\n", + " -2.5742e-02, -4.4500e-02, 3.2834e-02, 3.8324e-02, 1.2405e-02,\n", + " -3.8426e-02, 2.9064e-02, -3.6016e-02, 4.3653e-03, 3.8132e-02,\n", + " -4.2313e-02, -4.1442e-03, -3.4012e-02, 1.9081e-02, 7.1228e-03,\n", + " 1.9582e-02, -3.8668e-02, -3.8291e-02, -3.3956e-02, 9.5759e-04,\n", + " 2.0652e-02, -2.1983e-02, 1.7222e-02, 6.5477e-03, -4.0682e-02,\n", + " -2.4376e-02, 2.8711e-02, -3.0863e-02, 3.5675e-03, 3.1138e-03,\n", + " 2.8312e-03, 3.6704e-02, -1.4013e-02, 1.4549e-02, 4.7681e-03,\n", + " -1.5461e-03, -2.5192e-02, -2.2219e-02, -1.7596e-02, -3.3960e-03,\n", + " -3.1382e-02, 1.3545e-02, 8.4428e-04, 1.7972e-02, -7.7820e-03,\n", + " 2.1294e-02, 1.6943e-03, -3.1114e-02, 2.1719e-02, -2.6632e-02,\n", + " 4.6229e-03, 1.4209e-02], device='cuda:0'))])" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "\n", + "\n", + "torch.load('papers/tmi2022/feature_extractor/runs/Oct29_16-15-55_xrh1/checkpoints/model.pth')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "931f65c8", + "metadata": {}, + "outputs": [], + "source": [ + "import cl as cl\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import DataLoader\n", + "import torchvision.models as models\n", + "import torchvision.transforms.functional as VF\n", + "from torchvision import transforms\n", + "\n", + "import sys, argparse, os, glob\n", + "import pandas as pd\n", + "import numpy as np\n", + "from PIL import Image\n", + "from collections import OrderedDict\n", + "from easydict import EasyDict as edict\n", + "\n", + "\n", + "edict({'backbone':'resnet18',\n", + " 'weights':})\n", + "\n", + "\n", + "if args.backbone == 'resnet18':\n", + " resnet = models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d)\n", + " num_feats = 512\n", + "if args.backbone == 'resnet34':\n", + " resnet = models.resnet34(pretrained=False, norm_layer=nn.InstanceNorm2d)\n", + " num_feats = 512\n", + "if args.backbone == 'resnet50':\n", + " resnet = models.resnet50(pretrained=False, norm_layer=nn.InstanceNorm2d)\n", + " num_feats = 2048\n", + "if args.backbone == 'resnet101':\n", + " resnet = models.resnet101(pretrained=False, norm_layer=nn.InstanceNorm2d)\n", + " num_feats = 2048\n", + "for param in resnet.parameters():\n", + " param.requires_grad = False\n", + "resnet.fc = nn.Identity()\n", + "i_classifier = cl.IClassifier(resnet, num_feats, output_class=args.num_classes).cuda()\n", + "\n", + "# load feature extractor\n", + "if args.weights is None:\n", + " print('No feature extractor')\n", + " return\n", + "state_dict_weights = torch.load(args.weights)\n", + "print(state_dict_weights)\n", + "state_dict_init = i_classifier.state_dict()\n", + "new_state_dict = OrderedDict()\n", + "for (k, v), (k_0, v_0) in zip(state_dict_weights.items(), state_dict_init.items()):\n", + " name = k_0\n", + " new_state_dict[name] = v\n", + "i_classifier.load_state_dict(new_state_dict, strict=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/feature_extractor/__init__.py b/feature_extractor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/feature_extractor/__pycache__/__init__.cpython-38.pyc b/feature_extractor/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b556d42af77e9d7b94633e305322f0d57a69d37 Binary files /dev/null and b/feature_extractor/__pycache__/__init__.cpython-38.pyc differ diff --git a/feature_extractor/__pycache__/build_graph_utils.cpython-38.pyc b/feature_extractor/__pycache__/build_graph_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b44ae714b0972cd17a1d00c342e8a82cbf7f290 Binary files /dev/null and b/feature_extractor/__pycache__/build_graph_utils.cpython-38.pyc differ diff --git a/feature_extractor/__pycache__/build_graphs.cpython-38.pyc b/feature_extractor/__pycache__/build_graphs.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a83ffb9324aac47a3725f7017826b06bb41aaf13 Binary files /dev/null and b/feature_extractor/__pycache__/build_graphs.cpython-38.pyc differ diff --git a/feature_extractor/__pycache__/cl.cpython-38.pyc b/feature_extractor/__pycache__/cl.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e42920b563025ba18624dd86f37a0b02bc8b2848 Binary files /dev/null and b/feature_extractor/__pycache__/cl.cpython-38.pyc differ diff --git a/feature_extractor/__pycache__/simclr.cpython-36.pyc b/feature_extractor/__pycache__/simclr.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b22ef85b25820aa4f3bb208e1f4f05ae4b4d2040 Binary files /dev/null and b/feature_extractor/__pycache__/simclr.cpython-36.pyc differ diff --git a/feature_extractor/__pycache__/simclr.cpython-38.pyc b/feature_extractor/__pycache__/simclr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d71b3c2c74d8b382691a15fc69feebd0bf73b1b3 Binary files /dev/null and b/feature_extractor/__pycache__/simclr.cpython-38.pyc differ diff --git a/feature_extractor/build_graph_utils.py b/feature_extractor/build_graph_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2b7b79b4af4974f364e81153ce78b4215120050e --- /dev/null +++ b/feature_extractor/build_graph_utils.py @@ -0,0 +1,85 @@ + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import torchvision.models as models +import torchvision.transforms.functional as VF +from torchvision import transforms + +import sys, argparse, os, glob +import pandas as pd +import numpy as np +from PIL import Image +from collections import OrderedDict + +class ToPIL(object): + def __call__(self, sample): + img = sample + img = transforms.functional.to_pil_image(img) + return img + +class BagDataset(): + def __init__(self, csv_file, transform=None): + self.files_list = csv_file + self.transform = transform + def __len__(self): + return len(self.files_list) + def __getitem__(self, idx): + temp_path = self.files_list[idx] + img = os.path.join(temp_path) + img = Image.open(img) + img = img.resize((224, 224)) + sample = {'input': img} + + if self.transform: + sample = self.transform(sample) + return sample + +class ToTensor(object): + def __call__(self, sample): + img = sample['input'] + img = VF.to_tensor(img) + return {'input': img} + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img): + for t in self.transforms: + img = t(img) + return img + +def save_coords(txt_file, csv_file_path): + for path in csv_file_path: + x, y = path.split('/')[-1].split('.')[0].split('_') + txt_file.writelines(str(x) + '\t' + str(y) + '\n') + txt_file.close() + +def adj_matrix(csv_file_path, output, device='cpu'): + total = len(csv_file_path) + adj_s = np.zeros((total, total)) + + for i in range(total-1): + path_i = csv_file_path[i] + x_i, y_i = path_i.split('/')[-1].split('.')[0].split('_') + for j in range(i+1, total): + # sptial + path_j = csv_file_path[j] + x_j, y_j = path_j.split('/')[-1].split('.')[0].split('_') + if abs(int(x_i)-int(x_j)) <=1 and abs(int(y_i)-int(y_j)) <= 1: + adj_s[i][j] = 1 + adj_s[j][i] = 1 + + adj_s = torch.from_numpy(adj_s) + adj_s = adj_s.to(device) + + return adj_s + +def bag_dataset(args, csv_file_path): + transformed_dataset = BagDataset(csv_file=csv_file_path, + transform=Compose([ + ToTensor() + ])) + dataloader = DataLoader(transformed_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False) + return dataloader, len(transformed_dataset) \ No newline at end of file diff --git a/feature_extractor/build_graphs.py b/feature_extractor/build_graphs.py new file mode 100644 index 0000000000000000000000000000000000000000..64620387d1a607b32b7239e18739dc0e80f92567 --- /dev/null +++ b/feature_extractor/build_graphs.py @@ -0,0 +1,114 @@ + +from cl import IClassifier +from build_graph_utils import * +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import torchvision.models as models +import torchvision.transforms.functional as VF +from torchvision import transforms + +import sys, argparse, os, glob +import pandas as pd +import numpy as np +from PIL import Image +from collections import OrderedDict + + + +def compute_feats(args, bags_list, i_classifier, device, save_path=None, whole_slide_path=None): + num_bags = len(bags_list) + Tensor = torch.FloatTensor + for i in range(0, num_bags): + feats_list = [] + if args.magnification == '20x': + glob_path = os.path.join(bags_list[i], '*.jpeg') + csv_file_path = glob.glob(glob_path) + # line below was in the original version, commented due to errror with current version + #file_name = bags_list[i].split('/')[-3].split('_')[0] + + file_name = glob_path.split('/')[-3].split('_')[0] + + if args.magnification == '5x' or args.magnification == '10x': + csv_file_path = glob.glob(os.path.join(bags_list[i], '*.jpg')) + + dataloader, bag_size = bag_dataset(args, csv_file_path) + print('{} files to be processed: {}'.format(len(csv_file_path), file_name)) + + if os.path.isdir(os.path.join(save_path, 'simclr_files', file_name)) or len(csv_file_path) < 1: + print('alreday exists') + continue + with torch.no_grad(): + for iteration, batch in enumerate(dataloader): + patches = batch['input'].float().to(device) + feats, classes = i_classifier(patches) + #feats = feats.cpu().numpy() + feats_list.extend(feats) + + os.makedirs(os.path.join(save_path, 'simclr_files', file_name), exist_ok=True) + + txt_file = open(os.path.join(save_path, 'simclr_files', file_name, 'c_idx.txt'), "w+") + save_coords(txt_file, csv_file_path) + # save node features + output = torch.stack(feats_list, dim=0).to(device) + torch.save(output, os.path.join(save_path, 'simclr_files', file_name, 'features.pt')) + # save adjacent matrix + adj_s = adj_matrix(csv_file_path, output, device=device) + torch.save(adj_s, os.path.join(save_path, 'simclr_files', file_name, 'adj_s.pt')) + + print('\r Computed: {}/{}'.format(i+1, num_bags)) + + +def main(): + parser = argparse.ArgumentParser(description='Compute TCGA features from SimCLR embedder') + parser.add_argument('--num_classes', default=2, type=int, help='Number of output classes') + parser.add_argument('--num_feats', default=512, type=int, help='Feature size') + parser.add_argument('--batch_size', default=128, type=int, help='Batch size of dataloader') + parser.add_argument('--num_workers', default=0, type=int, help='Number of threads for datalodaer') + parser.add_argument('--dataset', default=None, type=str, help='path to patches') + parser.add_argument('--backbone', default='resnet18', type=str, help='Embedder backbone') + parser.add_argument('--magnification', default='20x', type=str, help='Magnification to compute features') + parser.add_argument('--weights', default=None, type=str, help='path to the pretrained weights') + parser.add_argument('--output', default=None, type=str, help='path to the output graph folder') + args = parser.parse_args() + + if args.backbone == 'resnet18': + resnet = models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d) + num_feats = 512 + if args.backbone == 'resnet34': + resnet = models.resnet34(pretrained=False, norm_layer=nn.InstanceNorm2d) + num_feats = 512 + if args.backbone == 'resnet50': + resnet = models.resnet50(pretrained=False, norm_layer=nn.InstanceNorm2d) + num_feats = 2048 + if args.backbone == 'resnet101': + resnet = models.resnet101(pretrained=False, norm_layer=nn.InstanceNorm2d) + num_feats = 2048 + for param in resnet.parameters(): + param.requires_grad = False + resnet.fc = nn.Identity() + device = 'cuda' if torch.cuda.is_available() else 'cpu' + print("Running on:", device) + i_classifier = IClassifier(resnet, num_feats, output_class=args.num_classes).to(device) + + # load feature extractor + if args.weights is None: + print('No feature extractor') + return + state_dict_weights = torch.load(args.weights) + state_dict_init = i_classifier.state_dict() + new_state_dict = OrderedDict() + for (k, v), (k_0, v_0) in zip(state_dict_weights.items(), state_dict_init.items()): + if 'features' not in k: + continue + name = k_0 + new_state_dict[name] = v + i_classifier.load_state_dict(new_state_dict, strict=False) + + os.makedirs(args.output, exist_ok=True) + bags_list = glob.glob(args.dataset) + print(bags_list) + compute_feats(args, bags_list, i_classifier, device, args.output) + +if __name__ == '__main__': + main() diff --git a/feature_extractor/cl.py b/feature_extractor/cl.py new file mode 100644 index 0000000000000000000000000000000000000000..6de9ef291a50dcbe870185a1ec62a63ecbd4f161 --- /dev/null +++ b/feature_extractor/cl.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable + +class FCLayer(nn.Module): + def __init__(self, in_size, out_size=1): + super(FCLayer, self).__init__() + self.fc = nn.Sequential(nn.Linear(in_size, out_size)) + def forward(self, feats): + x = self.fc(feats) + return feats, x + +class IClassifier(nn.Module): + def __init__(self, feature_extractor, feature_size, output_class): + super(IClassifier, self).__init__() + + self.feature_extractor = feature_extractor + self.fc = nn.Linear(feature_size, output_class) + + + def forward(self, x): + device = x.device + feats = self.feature_extractor(x) # N x K + c = self.fc(feats.view(feats.shape[0], -1)) # N x C + return feats.view(feats.shape[0], -1), c + +class BClassifier(nn.Module): + def __init__(self, input_size, output_class, dropout_v=0.0): # K, L, N + super(BClassifier, self).__init__() + self.q = nn.Linear(input_size, 128) + self.v = nn.Sequential( + nn.Dropout(dropout_v), + nn.Linear(input_size, input_size) + ) + + ### 1D convolutional layer that can handle multiple class (including binary) + self.fcc = nn.Conv1d(output_class, output_class, kernel_size=input_size) + + def forward(self, feats, c): # N x K, N x C + device = feats.device + V = self.v(feats) # N x V, unsorted + Q = self.q(feats).view(feats.shape[0], -1) # N x Q, unsorted + + # handle multiple classes without for loop + _, m_indices = torch.sort(c, 0, descending=True) # sort class scores along the instance dimension, m_indices in shape N x C + m_feats = torch.index_select(feats, dim=0, index=m_indices[0, :]) # select critical instances, m_feats in shape C x K + q_max = self.q(m_feats) # compute queries of critical instances, q_max in shape C x Q + A = torch.mm(Q, q_max.transpose(0, 1)) # compute inner product of Q to each entry of q_max, A in shape N x C, each column contains unnormalized attention scores + A = F.softmax( A / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)), 0) # normalize attention scores, A in shape N x C, + B = torch.mm(A.transpose(0, 1), V) # compute bag representation, B in shape C x V + + +# for i in range(c.shape[1]): +# _, indices = torch.sort(c[:, i], 0, True) +# feats = torch.index_select(feats, 0, indices) # N x K, sorted +# q_max = self.q(feats[0].view(1, -1)) # 1 x 1 x Q +# temp = torch.mm(Q, q_max.view(-1, 1)) / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)) +# if i == 0: +# A = F.softmax(temp, 0) # N x 1 +# B = torch.sum(torch.mul(A, V), 0).view(1, -1) # 1 x V +# else: +# temp = F.softmax(temp, 0) # N x 1 +# A = torch.cat((A, temp), 1) # N x C +# B = torch.cat((B, torch.sum(torch.mul(temp, V), 0).view(1, -1)), 0) # C x V -> 1 x C x V + + B = B.view(1, B.shape[0], B.shape[1]) # 1 x C x V + C = self.fcc(B) # 1 x C x 1 + C = C.view(1, -1) + return C, A, B + +class MILNet(nn.Module): + def __init__(self, i_classifier, b_classifier): + super(MILNet, self).__init__() + self.i_classifier = i_classifier + self.b_classifier = b_classifier + + def forward(self, x): + feats, classes = self.i_classifier(x) + prediction_bag, A, B = self.b_classifier(feats, classes) + + return classes, prediction_bag, A, B + \ No newline at end of file diff --git a/feature_extractor/config.yaml b/feature_extractor/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4c8f4309e6cbefa7270b1beb7c639d9551b325a8 --- /dev/null +++ b/feature_extractor/config.yaml @@ -0,0 +1,23 @@ +batch_size: 256 +epochs: 20 +eval_every_n_epochs: 1 +fine_tune_from: '' +log_every_n_steps: 25 +weight_decay: 10e-6 +fp16_precision: False +n_gpu: 2 +gpu_ids: (0,1) + +model: + out_dim: 512 + base_model: "resnet18" + +dataset: + s: 1 + input_shape: (224,224,3) + num_workers: 10 + valid_size: 0.1 + +loss: + temperature: 0.5 + use_cosine_similarity: True diff --git a/feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-36.pyc b/feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0693e648c7284afcca6e210918d3a23633b446f Binary files /dev/null and b/feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-36.pyc differ diff --git a/feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-38.pyc b/feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46f9dc9314fbe1d8839cab934ce38e1cbee3428a Binary files /dev/null and b/feature_extractor/data_aug/__pycache__/dataset_wrapper.cpython-38.pyc differ diff --git a/feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-36.pyc b/feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26111ce71915c4cf020bea014d5349d67314036c Binary files /dev/null and b/feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-36.pyc differ diff --git a/feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-38.pyc b/feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e800e41ae6afff480e1435bb55d62e55c608915b Binary files /dev/null and b/feature_extractor/data_aug/__pycache__/gaussian_blur.cpython-38.pyc differ diff --git a/feature_extractor/data_aug/dataset_wrapper.py b/feature_extractor/data_aug/dataset_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..8c2ad19f9ee1de487027b9db55516a89531aa484 --- /dev/null +++ b/feature_extractor/data_aug/dataset_wrapper.py @@ -0,0 +1,93 @@ +import numpy as np +from torch.utils.data import DataLoader +from torch.utils.data.sampler import SubsetRandomSampler +import torchvision.transforms as transforms +from data_aug.gaussian_blur import GaussianBlur +from torchvision import datasets +import pandas as pd +from PIL import Image +from skimage import io, img_as_ubyte + +np.random.seed(0) + +class Dataset(): + def __init__(self, csv_file, transform=None): + lines = [] + with open(csv_file) as f: + for line in f: + line = line.rstrip().strip() + lines.append(line) + self.files_list = lines#pd.read_csv(csv_file) + self.transform = transform + def __len__(self): + return len(self.files_list) + def __getitem__(self, idx): + temp_path = self.files_list[idx]# self.files_list.iloc[idx, 0] + img = Image.open(temp_path) + img = transforms.functional.to_tensor(img) + if self.transform: + sample = self.transform(img) + return sample + +class ToPIL(object): + def __call__(self, sample): + img = sample + img = transforms.functional.to_pil_image(img) + return img + +class DataSetWrapper(object): + + def __init__(self, batch_size, num_workers, valid_size, input_shape, s): + self.batch_size = batch_size + self.num_workers = num_workers + self.valid_size = valid_size + self.s = s + self.input_shape = eval(input_shape) + + def get_data_loaders(self): + data_augment = self._get_simclr_pipeline_transform() + train_dataset = Dataset(csv_file='all_patches.csv', transform=SimCLRDataTransform(data_augment)) + train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset) + return train_loader, valid_loader + + def _get_simclr_pipeline_transform(self): + # get a set of data augmentation transformations as described in the SimCLR paper. + color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s) + data_transforms = transforms.Compose([ToPIL(), + # transforms.RandomResizedCrop(size=self.input_shape[0]), + transforms.Resize((self.input_shape[0],self.input_shape[1])), + transforms.RandomHorizontalFlip(), + transforms.RandomApply([color_jitter], p=0.8), + transforms.RandomGrayscale(p=0.2), + GaussianBlur(kernel_size=int(0.06 * self.input_shape[0])), + transforms.ToTensor()]) + return data_transforms + + def get_train_validation_data_loaders(self, train_dataset): + # obtain training indices that will be used for validation + num_train = len(train_dataset) + indices = list(range(num_train)) + np.random.shuffle(indices) + + split = int(np.floor(self.valid_size * num_train)) + train_idx, valid_idx = indices[split:], indices[:split] + + # define samplers for obtaining training and validation batches + train_sampler = SubsetRandomSampler(train_idx) + valid_sampler = SubsetRandomSampler(valid_idx) + + train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler, + num_workers=self.num_workers, drop_last=True, shuffle=False) + valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler, + num_workers=self.num_workers, drop_last=True) + return train_loader, valid_loader + + +class SimCLRDataTransform(object): + def __init__(self, transform): + self.transform = transform + + def __call__(self, sample): + xi = self.transform(sample) + xj = self.transform(sample) + return xi, xj diff --git a/feature_extractor/data_aug/gaussian_blur.py b/feature_extractor/data_aug/gaussian_blur.py new file mode 100644 index 0000000000000000000000000000000000000000..19669769637750ecc021e553483d71da3256174c --- /dev/null +++ b/feature_extractor/data_aug/gaussian_blur.py @@ -0,0 +1,26 @@ +import cv2 +import numpy as np + +np.random.seed(0) + + +class GaussianBlur(object): + # Implements Gaussian blur as described in the SimCLR paper + def __init__(self, kernel_size, min=0.1, max=2.0): + self.min = min + self.max = max + # kernel size is set to be 10% of the image height/width + self.kernel_size = kernel_size + + def __call__(self, sample): + sample = np.array(sample) + + # blur the image with a 50% chance + prob = np.random.random_sample() + + if prob < 0.5: +# print(self.kernel_size) + sigma = (self.max - self.min) * np.random.random_sample() + self.min + sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma) + + return sample diff --git a/feature_extractor/load_patches.py b/feature_extractor/load_patches.py new file mode 100644 index 0000000000000000000000000000000000000000..0418cdbc185ef8a2d9c2870062b5ce18bcc347e7 --- /dev/null +++ b/feature_extractor/load_patches.py @@ -0,0 +1,37 @@ + +import os, glob +import argparse + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--data_path', type=str) + args = parser.parse_args() + + wsi_slides_paths = [] + + + def r(dirpath): + for file in os.listdir(dirpath): + path = os.path.join(dirpath, file) + if os.path.isfile(path) and file.endswith(".svs"): + wsi_slides_paths.append(path) + elif os.path.isdir(path): + r(path) + def r(dirpath): + for path in glob.glob(os.path.join(dirpath, '*','*.svs') ):#os.listdir(dirpath): + if os.path.isfile(path): + wsi_slides_paths.append(path) + def r(dirpath): + for path in glob.glob(os.path.join(dirpath, '*', '*', '*.jpeg') ):#os.listdir(dirpath): + if os.path.isfile(path): + wsi_slides_paths.append(path) + r(args.data_path) + with open('all_patches.csv', 'w') as f: + for filepath in wsi_slides_paths: + f.write(f'{filepath}\n') + + + + +if __name__ == "__main__": + main() diff --git a/feature_extractor/loss/__pycache__/nt_xent.cpython-36.pyc b/feature_extractor/loss/__pycache__/nt_xent.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f9b816d47e6570a705d9ed13c3962fbc3f04d39 Binary files /dev/null and b/feature_extractor/loss/__pycache__/nt_xent.cpython-36.pyc differ diff --git a/feature_extractor/loss/__pycache__/nt_xent.cpython-38.pyc b/feature_extractor/loss/__pycache__/nt_xent.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd661bb4c3737477f5da9b20be4bdfd94d22e595 Binary files /dev/null and b/feature_extractor/loss/__pycache__/nt_xent.cpython-38.pyc differ diff --git a/feature_extractor/loss/nt_xent.py b/feature_extractor/loss/nt_xent.py new file mode 100644 index 0000000000000000000000000000000000000000..ff2baff1d67613797c333b27be0cd29756f89bbe --- /dev/null +++ b/feature_extractor/loss/nt_xent.py @@ -0,0 +1,65 @@ +import torch +import numpy as np + + +class NTXentLoss(torch.nn.Module): + + def __init__(self, device, batch_size, temperature, use_cosine_similarity): + super(NTXentLoss, self).__init__() + self.batch_size = batch_size + self.temperature = temperature + self.device = device + self.softmax = torch.nn.Softmax(dim=-1) + self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool) + self.similarity_function = self._get_similarity_function(use_cosine_similarity) + self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") + + def _get_similarity_function(self, use_cosine_similarity): + if use_cosine_similarity: + self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) + return self._cosine_simililarity + else: + return self._dot_simililarity + + def _get_correlated_mask(self): + diag = np.eye(2 * self.batch_size) + l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size) + l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size) + mask = torch.from_numpy((diag + l1 + l2)) + mask = (1 - mask).type(torch.bool) + return mask.to(self.device) + + @staticmethod + def _dot_simililarity(x, y): + v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) + # x shape: (N, 1, C) + # y shape: (1, C, 2N) + # v shape: (N, 2N) + return v + + def _cosine_simililarity(self, x, y): + # x shape: (N, 1, C) + # y shape: (1, 2N, C) + # v shape: (N, 2N) + v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) + return v + + def forward(self, zis, zjs): + representations = torch.cat([zjs, zis], dim=0) + + similarity_matrix = self.similarity_function(representations, representations) + + # filter out the scores from the positive samples + l_pos = torch.diag(similarity_matrix, self.batch_size) + r_pos = torch.diag(similarity_matrix, -self.batch_size) + positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1) + + negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1) + + logits = torch.cat((positives, negatives), dim=1) + logits /= self.temperature + + labels = torch.zeros(2 * self.batch_size).to(self.device).long() + loss = self.criterion(logits, labels) + + return loss / (2 * self.batch_size) diff --git a/feature_extractor/models/__init__.py b/feature_extractor/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/feature_extractor/models/__pycache__/__init__.cpython-38.pyc b/feature_extractor/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ed96a932d406396d34df0b7ef0d78679b2ac52f Binary files /dev/null and b/feature_extractor/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/feature_extractor/models/__pycache__/resnet_simclr.cpython-36.pyc b/feature_extractor/models/__pycache__/resnet_simclr.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcc3d0e9ab9c84d08d1d299105dcce4c10c8f9c1 Binary files /dev/null and b/feature_extractor/models/__pycache__/resnet_simclr.cpython-36.pyc differ diff --git a/feature_extractor/models/__pycache__/resnet_simclr.cpython-38.pyc b/feature_extractor/models/__pycache__/resnet_simclr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb91e49d76dc61d79b70295ce9ff335321500ac7 Binary files /dev/null and b/feature_extractor/models/__pycache__/resnet_simclr.cpython-38.pyc differ diff --git a/feature_extractor/models/baseline_encoder.py b/feature_extractor/models/baseline_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..87b9b931c34d5a91dcefa65d9f838bbd30707009 --- /dev/null +++ b/feature_extractor/models/baseline_encoder.py @@ -0,0 +1,43 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + + +class Encoder(nn.Module): + def __init__(self, out_dim=64): + super(Encoder, self).__init__() + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) + self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.pool = nn.MaxPool2d(2, 2) + + # projection MLP + self.l1 = nn.Linear(64, 64) + self.l2 = nn.Linear(64, out_dim) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.pool(x) + + x = self.conv2(x) + x = F.relu(x) + x = self.pool(x) + + x = self.conv3(x) + x = F.relu(x) + x = self.pool(x) + + x = self.conv4(x) + x = F.relu(x) + x = self.pool(x) + + h = torch.mean(x, dim=[2, 3]) + + x = self.l1(h) + x = F.relu(x) + x = self.l2(x) + + return h, x diff --git a/feature_extractor/models/resnet_simclr.py b/feature_extractor/models/resnet_simclr.py new file mode 100644 index 0000000000000000000000000000000000000000..957d2611229c5a452ccd73a62630e420cf1e2e70 --- /dev/null +++ b/feature_extractor/models/resnet_simclr.py @@ -0,0 +1,37 @@ +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + + +class ResNetSimCLR(nn.Module): + + def __init__(self, base_model, out_dim): + super(ResNetSimCLR, self).__init__() + self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d), + "resnet50": models.resnet50(pretrained=False)} + + resnet = self._get_basemodel(base_model) + num_ftrs = resnet.fc.in_features + + self.features = nn.Sequential(*list(resnet.children())[:-1]) + + # projection MLP + self.l1 = nn.Linear(num_ftrs, num_ftrs) + self.l2 = nn.Linear(num_ftrs, out_dim) + + def _get_basemodel(self, model_name): + try: + model = self.resnet_dict[model_name] + print("Feature extractor:", model_name) + return model + except: + raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50") + + def forward(self, x): + h = self.features(x) + h = h.squeeze() + + x = self.l1(h) + x = F.relu(x) + x = self.l2(x) + return h, x diff --git a/feature_extractor/run.py b/feature_extractor/run.py new file mode 100644 index 0000000000000000000000000000000000000000..50d357b15d364b8064f69d5ecc1cca9f670e4987 --- /dev/null +++ b/feature_extractor/run.py @@ -0,0 +1,21 @@ +from simclr import SimCLR +import yaml +from data_aug.dataset_wrapper import DataSetWrapper +import os, glob +import pandas as pd +import argparse + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--magnification', type=str, default='20x') + parser.add_argument('--dest_weights', type=str) + args = parser.parse_args() + config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader) + dataset = DataSetWrapper(config['batch_size'], **config['dataset']) + + simclr = SimCLR(dataset, config, args) + simclr.train() + + +if __name__ == "__main__": + main() diff --git a/feature_extractor/simclr.py b/feature_extractor/simclr.py new file mode 100644 index 0000000000000000000000000000000000000000..4165108714d9b8f677d2bb7d7de77fc7c11ad151 --- /dev/null +++ b/feature_extractor/simclr.py @@ -0,0 +1,165 @@ +import torch +from models.resnet_simclr import ResNetSimCLR +from torch.utils.tensorboard import SummaryWriter +import torch.nn.functional as F +from loss.nt_xent import NTXentLoss +import os +import shutil +import sys + +apex_support = False +try: + sys.path.append('./apex') + from apex import amp + + apex_support = True +except: + print("Please install apex for mixed precision training from: https://github.com/NVIDIA/apex") + apex_support = False + +import numpy as np + +torch.manual_seed(0) + + +def _save_config_file(model_checkpoints_folder): + if not os.path.exists(model_checkpoints_folder): + os.makedirs(model_checkpoints_folder) + shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml')) + + +class SimCLR(object): + + def __init__(self, dataset, config, args=None): + self.config = config + self.device = self._get_device() + self.writer = SummaryWriter() + self.dataset = dataset + self.nt_xent_criterion = NTXentLoss(self.device, config['batch_size'], **config['loss']) + self.args = args + def _get_device(self): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + print("Running on:", device) + return device + + def _step(self, model, xis, xjs, n_iter): + + # get the representations and the projections + ris, zis = model(xis) # [N,C] + + # get the representations and the projections + rjs, zjs = model(xjs) # [N,C] + + # normalize projection feature vectors + zis = F.normalize(zis, dim=1) + zjs = F.normalize(zjs, dim=1) + + loss = self.nt_xent_criterion(zis, zjs) + return loss + + def train(self): + + train_loader, valid_loader = self.dataset.get_data_loaders() + + model = ResNetSimCLR(**self.config["model"])# .to(self.device) + if self.config['n_gpu'] > 1: + model = torch.nn.DataParallel(model, device_ids=eval(self.config['gpu_ids'])) + model = self._load_pre_trained_weights(model) + model = model.to(self.device) + + + optimizer = torch.optim.Adam(model.parameters(), 1e-5, weight_decay=eval(self.config['weight_decay'])) + +# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0, +# last_epoch=-1) + + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.config['epochs'], eta_min=0, + last_epoch=-1) + + + if apex_support and self.config['fp16_precision']: + model, optimizer = amp.initialize(model, optimizer, + opt_level='O2', + keep_batchnorm_fp32=True) + + if self.args is None: + model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints') + else: + model_checkpoints_folder = self.args.dest_weights#os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH'] + model_checkpoints_folder = os.path.dirname(model_checkpoints_folder) + # save config file + _save_config_file(model_checkpoints_folder) + + n_iter = 0 + valid_n_iter = 0 + best_valid_loss = np.inf + + for epoch_counter in range(self.config['epochs']): + for (xis, xjs) in train_loader: + optimizer.zero_grad() + xis = xis.to(self.device) + xjs = xjs.to(self.device) + + loss = self._step(model, xis, xjs, n_iter) + + if n_iter % self.config['log_every_n_steps'] == 0: + self.writer.add_scalar('train_loss', loss, global_step=n_iter) + print("[%d/%d] step: %d train_loss: %.3f" % (epoch_counter, self.config['epochs'], n_iter, loss)) + + if apex_support and self.config['fp16_precision']: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + optimizer.step() + n_iter += 1 + + # validate the model if requested + if epoch_counter % self.config['eval_every_n_epochs'] == 0: + valid_loss = self._validate(model, valid_loader) + print("[%d/%d] val_loss: %.3f" % (epoch_counter, self.config['epochs'], valid_loss)) + if valid_loss < best_valid_loss: + # save the model weights + best_valid_loss = valid_loss + torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth')) + print('saved') + + self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter) + valid_n_iter += 1 + + # warmup for the first 10 epochs + if epoch_counter >= 10: + scheduler.step() + self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter) + + def _load_pre_trained_weights(self, model): + try: + checkpoints_folder = os.path.join('./runs', self.config['fine_tune_from'], 'checkpoints') + state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth')) + model.load_state_dict(state_dict) + print("Loaded pre-trained model with success.") + except FileNotFoundError: + print("Pre-trained weights not found. Training from scratch.") + + return model + + def _validate(self, model, valid_loader): + + # validation steps + with torch.no_grad(): + model.eval() + + valid_loss = 0.0 + counter = 0 + + for (xis, xjs) in valid_loader: + xis = xis.to(self.device) + xjs = xjs.to(self.device) + + loss = self._step(model, xis, xjs, counter) + valid_loss += loss.item() + counter += 1 + valid_loss /= counter + model.train() + return valid_loss diff --git a/feature_extractor/viewer.py b/feature_extractor/viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..4b4ca901d07808dc1186efba948af0bd5e763559 --- /dev/null +++ b/feature_extractor/viewer.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python +# +# deepzoom_server - Example web application for serving whole-slide images +# +# Copyright (c) 2010-2015 Carnegie Mellon University +# +# This library is free software; you can redistribute it and/or modify it +# under the terms of version 2.1 of the GNU Lesser General Public License +# as published by the Free Software Foundation. +# +# This library is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY +# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library; if not, write to the Free Software Foundation, +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# + +from io import BytesIO +from optparse import OptionParser +import os +import re +from unicodedata import normalize + +from flask import Flask, abort, make_response, render_template, url_for + +if os.name == 'nt': + _dll_path = os.getenv('OPENSLIDE_PATH') + if _dll_path is not None: + if hasattr(os, 'add_dll_directory'): + # Python >= 3.8 + with os.add_dll_directory(_dll_path): + import openslide + else: + # Python < 3.8 + _orig_path = os.environ.get('PATH', '') + os.environ['PATH'] = _orig_path + ';' + _dll_path + import openslide + + os.environ['PATH'] = _orig_path +else: + import openslide + +from openslide import ImageSlide, open_slide +from openslide.deepzoom import DeepZoomGenerator + +DEEPZOOM_SLIDE = None +DEEPZOOM_FORMAT = 'jpeg' +DEEPZOOM_TILE_SIZE = 254 +DEEPZOOM_OVERLAP = 1 +DEEPZOOM_LIMIT_BOUNDS = True +DEEPZOOM_TILE_QUALITY = 75 +SLIDE_NAME = 'slide' + +app = Flask(__name__) +app.config.from_object(__name__) +app.config.from_envvar('DEEPZOOM_TILER_SETTINGS', silent=True) + + +@app.before_first_request +def load_slide(): + slidefile = app.config['DEEPZOOM_SLIDE'] + if slidefile is None: + raise ValueError('No slide file specified') + config_map = { + 'DEEPZOOM_TILE_SIZE': 'tile_size', + 'DEEPZOOM_OVERLAP': 'overlap', + 'DEEPZOOM_LIMIT_BOUNDS': 'limit_bounds', + } + opts = {v: app.config[k] for k, v in config_map.items()} + slide = open_slide(slidefile) + app.slides = {SLIDE_NAME: DeepZoomGenerator(slide, **opts)} + app.associated_images = [] + app.slide_properties = slide.properties + for name, image in slide.associated_images.items(): + app.associated_images.append(name) + slug = slugify(name) + app.slides[slug] = DeepZoomGenerator(ImageSlide(image), **opts) + try: + mpp_x = slide.properties[openslide.PROPERTY_NAME_MPP_X] + mpp_y = slide.properties[openslide.PROPERTY_NAME_MPP_Y] + app.slide_mpp = (float(mpp_x) + float(mpp_y)) / 2 + except (KeyError, ValueError): + app.slide_mpp = 0 + + +@app.route('/') +def index(): + slide_url = url_for('dzi', slug=SLIDE_NAME) + associated_urls = { + name: url_for('dzi', slug=slugify(name)) for name in app.associated_images + } + return render_template( + 'slide-multipane.html', + slide_url=slide_url, + associated=associated_urls, + properties=app.slide_properties, + slide_mpp=app.slide_mpp, + ) + + +@app.route('/.dzi') +def dzi(slug): + format = app.config['DEEPZOOM_FORMAT'] + try: + resp = make_response(app.slides[slug].get_dzi(format)) + resp.mimetype = 'application/xml' + return resp + except KeyError: + # Unknown slug + abort(404) + + +@app.route('/_files//_.') +def tile(slug, level, col, row, format): + format = format.lower() + if format != 'jpeg' and format != 'png': + # Not supported by Deep Zoom + abort(404) + try: + tile = app.slides[slug].get_tile(level, (col, row)) + except KeyError: + # Unknown slug + abort(404) + except ValueError: + # Invalid level or coordinates + abort(404) + buf = BytesIO() + tile.save(buf, format, quality=app.config['DEEPZOOM_TILE_QUALITY']) + resp = make_response(buf.getvalue()) + resp.mimetype = 'image/%s' % format + return resp + + +def slugify(text): + text = normalize('NFKD', text.lower()).encode('ascii', 'ignore').decode() + return re.sub('[^a-z0-9]+', '-', text) + + +if __name__ == '__main__': + parser = OptionParser(usage='Usage: %prog [options] [slide]') + parser.add_option( + '-B', + '--ignore-bounds', + dest='DEEPZOOM_LIMIT_BOUNDS', + default=True, + action='store_false', + help='display entire scan area', + ) + parser.add_option( + '-c', '--config', metavar='FILE', dest='config', help='config file' + ) + parser.add_option( + '-d', + '--debug', + dest='DEBUG', + action='store_true', + help='run in debugging mode (insecure)', + ) + parser.add_option( + '-e', + '--overlap', + metavar='PIXELS', + dest='DEEPZOOM_OVERLAP', + type='int', + help='overlap of adjacent tiles [1]', + ) + parser.add_option( + '-f', + '--format', + metavar='{jpeg|png}', + dest='DEEPZOOM_FORMAT', + help='image format for tiles [jpeg]', + ) + parser.add_option( + '-l', + '--listen', + metavar='ADDRESS', + dest='host', + default='127.0.0.1', + help='address to listen on [127.0.0.1]', + ) + parser.add_option( + '-p', + '--port', + metavar='PORT', + dest='port', + type='int', + default=5000, + help='port to listen on [5000]', + ) + parser.add_option( + '-Q', + '--quality', + metavar='QUALITY', + dest='DEEPZOOM_TILE_QUALITY', + type='int', + help='JPEG compression quality [75]', + ) + parser.add_option( + '-s', + '--size', + metavar='PIXELS', + dest='DEEPZOOM_TILE_SIZE', + type='int', + help='tile size [254]', + ) + + (opts, args) = parser.parse_args() + # Load config file if specified + if opts.config is not None: + app.config.from_pyfile(opts.config) + # Overwrite only those settings specified on the command line + for k in dir(opts): + if not k.startswith('_') and getattr(opts, k) is None: + delattr(opts, k) + app.config.from_object(opts) + # Set slide file + try: + app.config['DEEPZOOM_SLIDE'] = args[0] + except IndexError: + if app.config['DEEPZOOM_SLIDE'] is None: + parser.error('No slide file specified') + + app.run(host=opts.host, port=opts.port, threaded=True) \ No newline at end of file diff --git a/helper.py b/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..374c1a7345d298bc315051785be1727b470e454a --- /dev/null +++ b/helper.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python +# coding: utf-8 + +from __future__ import absolute_import, division, print_function + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from torchvision import transforms +from utils.metrics import ConfusionMatrix +from PIL import Image +import os + +# torch.cuda.synchronize() +# torch.backends.cudnn.benchmark = True +torch.backends.cudnn.deterministic = True + +def collate(batch): + image = [ b['image'] for b in batch ] # w, h + label = [ b['label'] for b in batch ] + id = [ b['id'] for b in batch ] + adj_s = [ b['adj_s'] for b in batch ] + return {'image': image, 'label': label, 'id': id, 'adj_s': adj_s} + +def preparefeatureLabel(batch_graph, batch_label, batch_adjs, device='cpu'): + batch_size = len(batch_graph) + labels = torch.LongTensor(batch_size) + max_node_num = 0 + + for i in range(batch_size): + labels[i] = batch_label[i] + max_node_num = max(max_node_num, batch_graph[i].shape[0]) + + masks = torch.zeros(batch_size, max_node_num) + adjs = torch.zeros(batch_size, max_node_num, max_node_num) + batch_node_feat = torch.zeros(batch_size, max_node_num, 512) + + for i in range(batch_size): + cur_node_num = batch_graph[i].shape[0] + #node attribute feature + tmp_node_fea = batch_graph[i] + batch_node_feat[i, 0:cur_node_num] = tmp_node_fea + + #adjs + adjs[i, 0:cur_node_num, 0:cur_node_num] = batch_adjs[i] + + #masks + masks[i,0:cur_node_num] = 1 + + node_feat = batch_node_feat.to(device) + labels = labels.to(device) + adjs = adjs.to(device) + masks = masks.to(device) + + return node_feat, labels, adjs, masks + +class Trainer(object): + def __init__(self, n_class): + self.metrics = ConfusionMatrix(n_class) + + def get_scores(self): + acc = self.metrics.get_scores() + + return acc + + def reset_metrics(self): + self.metrics.reset() + + def plot_cm(self): + self.metrics.plotcm() + + def train(self, sample, model): + node_feat, labels, adjs, masks = preparefeatureLabel(sample['image'], sample['label'], sample['adj_s']) + pred,labels,loss = model.forward(node_feat, labels, adjs, masks) + + return pred,labels,loss + +class Evaluator(object): + def __init__(self, n_class): + self.metrics = ConfusionMatrix(n_class) + + def get_scores(self): + acc = self.metrics.get_scores() + + return acc + + def reset_metrics(self): + self.metrics.reset() + + def plot_cm(self): + self.metrics.plotcm() + + def eval_test(self, sample, model, graphcam_flag=False): + node_feat, labels, adjs, masks = preparefeatureLabel(sample['image'], sample['label'], sample['adj_s']) + if not graphcam_flag: + with torch.no_grad(): + pred,labels,loss = model.forward(node_feat, labels, adjs, masks) + else: + torch.set_grad_enabled(True) + pred,labels,loss= model.forward(node_feat, labels, adjs, masks, graphcam_flag=graphcam_flag) + return pred,labels,loss \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..9b1949847e2f6ae66707c93c86ecf72ccdb7b445 --- /dev/null +++ b/main.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +# coding: utf-8 + +from __future__ import absolute_import, division, print_function + +import os +import numpy as np +import torch +import torch.nn as nn +from torchvision import transforms + +from utils.dataset import GraphDataset +from utils.lr_scheduler import LR_Scheduler +from tensorboardX import SummaryWriter +from helper import Trainer, Evaluator, collate +from option import Options + +from models.GraphTransformer import Classifier +from models.weight_init import weight_init +import pickle +args = Options().parse() + +label_map = pickle.load(open(os.path.join(args.dataset_metadata_path, 'label_map.pkl'), 'rb')) + +n_class = len(label_map) + +torch.cuda.synchronize() +torch.backends.cudnn.deterministic = True + +data_path = args.data_path +model_path = args.model_path +if not os.path.isdir(model_path): os.mkdir(model_path) +log_path = args.log_path +if not os.path.isdir(log_path): os.mkdir(log_path) +task_name = args.task_name + +print(task_name) +################################### +train = args.train +test = args.test +graphcam = args.graphcam +print("train:", train, "test:", test, "graphcam:", graphcam) + +##### Load datasets +print("preparing datasets and dataloaders......") +batch_size = args.batch_size + +if train: + ids_train = open(args.train_set).readlines() + dataset_train = GraphDataset(os.path.join(data_path, ""), ids_train, args.dataset_metadata_path) + dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=batch_size, num_workers=10, collate_fn=collate, shuffle=True, pin_memory=True, drop_last=True) + total_train_num = len(dataloader_train) * batch_size + +ids_val = open(args.val_set).readlines() +dataset_val = GraphDataset(os.path.join(data_path, ""), ids_val, args.dataset_metadata_path) +dataloader_val = torch.utils.data.DataLoader(dataset=dataset_val, batch_size=batch_size, num_workers=10, collate_fn=collate, shuffle=False, pin_memory=True) +total_val_num = len(dataloader_val) * batch_size + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +##### creating models ############# +print("creating models......") + +num_epochs = args.num_epochs +learning_rate = args.lr + +model = Classifier(n_class) +model = nn.DataParallel(model) +if args.resume: + print('load model{}'.format(args.resume)) + model.load_state_dict(torch.load(args.resume)) + +if torch.cuda.is_available(): + model = model.cuda() +#model.apply(weight_init) + +optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate, weight_decay = 5e-4) # best:5e-4, 4e-3 +scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20,100], gamma=0.1) # gamma=0.3 # 30,90,130 # 20,90,130 -> 150 + +################################## + +criterion = nn.CrossEntropyLoss() + +if not test: + writer = SummaryWriter(log_dir=log_path + task_name) + f_log = open(log_path + task_name + ".log", 'w') + +trainer = Trainer(n_class) +evaluator = Evaluator(n_class) + +best_pred = 0.0 +for epoch in range(num_epochs): + # optimizer.zero_grad() + model.train() + train_loss = 0. + total = 0. + + current_lr = optimizer.param_groups[0]['lr'] + print('\n=>Epoches %i, learning rate = %.7f, previous best = %.4f' % (epoch+1, current_lr, best_pred)) + + if train: + for i_batch, sample_batched in enumerate(dataloader_train): + scheduler.step(epoch) + + preds,labels,loss = trainer.train(sample_batched, model) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + train_loss += loss + total += len(labels) + + trainer.metrics.update(labels, preds) + if (i_batch + 1) % args.log_interval_local == 0: + print("[%d/%d] train loss: %.3f; agg acc: %.3f" % (total, total_train_num, train_loss / total, trainer.get_scores())) + trainer.plot_cm() + + if not test: + print("[%d/%d] train loss: %.3f; agg acc: %.3f" % (total_train_num, total_train_num, train_loss / total, trainer.get_scores())) + trainer.plot_cm() + + + if epoch % 1 == 0: + with torch.no_grad(): + model.eval() + print("evaluating...") + + total = 0. + batch_idx = 0 + + for i_batch, sample_batched in enumerate(dataloader_val): + preds, labels, _ = evaluator.eval_test(sample_batched, model, graphcam) + + total += len(labels) + + evaluator.metrics.update(labels, preds) + + if (i_batch + 1) % args.log_interval_local == 0: + print('[%d/%d] val agg acc: %.3f' % (total, total_val_num, evaluator.get_scores())) + evaluator.plot_cm() + + print('[%d/%d] val agg acc: %.3f' % (total_val_num, total_val_num, evaluator.get_scores())) + evaluator.plot_cm() + + # torch.cuda.empty_cache() + + val_acc = evaluator.get_scores() + if val_acc > best_pred: + best_pred = val_acc + if not test: + print("saving model...") + torch.save(model.state_dict(), model_path + task_name + ".pth") + + log = "" + log = log + 'epoch [{}/{}] ------ acc: train = {:.4f}, val = {:.4f}'.format(epoch+1, num_epochs, trainer.get_scores(), evaluator.get_scores()) + "\n" + + log += "================================\n" + print(log) + if test: break + + f_log.write(log) + f_log.flush() + + writer.add_scalars('accuracy', {'train acc': trainer.get_scores(), 'val acc': evaluator.get_scores()}, epoch+1) + + trainer.reset_metrics() + evaluator.reset_metrics() + +if not test: f_log.close() \ No newline at end of file diff --git a/metadata/label_map.pkl b/metadata/label_map.pkl new file mode 100644 index 0000000000000000000000000000000000000000..d1fd4d98f76c037472de0a292a6da3586c7736ae --- /dev/null +++ b/metadata/label_map.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce5be416a8667c9379502eaf8407e6d07bbae03749085190be630bd3b026eb52 +size 34 diff --git a/models/.gitkeep b/models/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/models/.gitkeep @@ -0,0 +1 @@ + diff --git a/models/GraphTransformer.py b/models/GraphTransformer.py new file mode 100644 index 0000000000000000000000000000000000000000..34ecdf561e21dace9a68c81623dbddca2f37475b --- /dev/null +++ b/models/GraphTransformer.py @@ -0,0 +1,123 @@ +import sys +import os +import torch +import random +import numpy as np + +from torch.autograd import Variable +from torch.nn.parameter import Parameter +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +from .ViT import * +from .gcn import GCNBlock + +from torch_geometric.nn import GCNConv, DenseGraphConv, dense_mincut_pool +from torch.nn import Linear +class Classifier(nn.Module): + def __init__(self, n_class): + super(Classifier, self).__init__() + + self.n_class = n_class + self.embed_dim = 64 + self.num_layers = 3 + self.node_cluster_num = 100 + + self.transformer = VisionTransformer(num_classes=n_class, embed_dim=self.embed_dim) + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + self.criterion = nn.CrossEntropyLoss() + + self.bn = 1 + self.add_self = 1 + self.normalize_embedding = 1 + self.conv1 = GCNBlock(512,self.embed_dim,self.bn,self.add_self,self.normalize_embedding,0.,0) # 64->128 + self.pool1 = Linear(self.embed_dim, self.node_cluster_num) # 100-> 20 + + + def forward(self,node_feat,labels,adj,mask,is_print=False, graphcam_flag=False, to_file=True): + # node_feat, labels = self.PrepareFeatureLabel(batch_graph) + cls_loss=node_feat.new_zeros(self.num_layers) + rank_loss=node_feat.new_zeros(self.num_layers-1) + X=node_feat + p_t=[] + pred_logits=0 + visualize_tools=[] + if labels is not None: + visualize_tools1=[labels.cpu()] + embeds=0 + concats=[] + + layer_acc=[] + + X=mask.unsqueeze(2)*X + X = self.conv1(X, adj, mask) + s = self.pool1(X) + + + graphcam_tensors = {} + + if graphcam_flag: + s_matrix = torch.argmax(s[0], dim=1) + if to_file: + from os import path + os.makedirs('graphcam', exist_ok=True) + torch.save(s_matrix, 'graphcam/s_matrix.pt') + torch.save(s[0], 'graphcam/s_matrix_ori.pt') + + if path.exists('graphcam/att_1.pt'): + os.remove('graphcam/att_1.pt') + os.remove('graphcam/att_2.pt') + os.remove('graphcam/att_3.pt') + + if not to_file: + graphcam_tensors['s_matrix'] = s_matrix + graphcam_tensors['s_matrix_ori'] = s[0] + + + X, adj, mc1, o1 = dense_mincut_pool(X, adj, s, mask) + b, _, _ = X.shape + cls_token = self.cls_token.repeat(b, 1, 1) + X = torch.cat([cls_token, X], dim=1) + + out = self.transformer(X) + + loss = None + if labels is not None: + # loss + loss = self.criterion(out, labels) + loss = loss + mc1 + o1 + # pred + pred = out.data.max(1)[1] + + if graphcam_flag: + #print('GraphCAM enabled') + #print(out.shape) + p = F.softmax(out) + #print(p.shape) + if to_file: + torch.save(p, 'graphcam/prob.pt') + if not to_file: + graphcam_tensors['prob'] = p + index = np.argmax(out.cpu().data.numpy(), axis=-1) + + for index_ in range(self.n_class): + one_hot = np.zeros((1, out.size()[-1]), dtype=np.float32) + one_hot[0, index_] = out[0][index_] + one_hot_vector = one_hot + one_hot = torch.from_numpy(one_hot).requires_grad_(True) + one_hot = torch.sum(one_hot.to( 'cuda' if torch.cuda.is_available() else 'cpu') * out) #!!!!!!!!!!!!!!!!!!!!out-->p + self.transformer.zero_grad() + one_hot.backward(retain_graph=True) + + kwargs = {"alpha": 1} + cam = self.transformer.relprop(torch.tensor(one_hot_vector).to(X.device), method="transformer_attribution", is_ablation=False, + start_layer=0, **kwargs) + if to_file: + torch.save(cam, 'graphcam/cam_{}.pt'.format(index_)) + if not to_file: + graphcam_tensors[f'cam_{index_}'] = cam + + if not to_file: + return pred,labels,loss, graphcam_tensors + return pred,labels,loss diff --git a/models/ViT.py b/models/ViT.py new file mode 100644 index 0000000000000000000000000000000000000000..1e07347410897293d32e697d426531b521cc38b5 --- /dev/null +++ b/models/ViT.py @@ -0,0 +1,415 @@ +""" Vision Transformer (ViT) in PyTorch +""" +import torch +import torch.nn as nn +from einops import rearrange +from .layers import * +import math + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models + 'vit_small_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', + ), + 'vit_base_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + ), + 'vit_large_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), +} + +def compute_rollout_attention(all_layer_matrices, start_layer=0): + # adding residual consideration + num_tokens = all_layer_matrices[0].shape[1] + batch_size = all_layer_matrices[0].shape[0] + eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device) + all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))] + # all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True) + # for i in range(len(all_layer_matrices))] + joint_attention = all_layer_matrices[start_layer] + for i in range(start_layer+1, len(all_layer_matrices)): + joint_attention = all_layer_matrices[i].bmm(joint_attention) + return joint_attention + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = Linear(in_features, hidden_features) + self.act = GELU() + self.fc2 = Linear(hidden_features, out_features) + self.drop = Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + def relprop(self, cam, **kwargs): + cam = self.drop.relprop(cam, **kwargs) + cam = self.fc2.relprop(cam, **kwargs) + cam = self.act.relprop(cam, **kwargs) + cam = self.fc1.relprop(cam, **kwargs) + return cam + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False,attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = head_dim ** -0.5 + + # A = Q*K^T + self.matmul1 = einsum('bhid,bhjd->bhij') + # attn = A*V + self.matmul2 = einsum('bhij,bhjd->bhid') + + self.qkv = Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = Dropout(attn_drop) + self.proj = Linear(dim, dim) + self.proj_drop = Dropout(proj_drop) + self.softmax = Softmax(dim=-1) + + self.attn_cam = None + self.attn = None + self.v = None + self.v_cam = None + self.attn_gradients = None + + def get_attn(self): + return self.attn + + def save_attn(self, attn): + self.attn = attn + + def save_attn_cam(self, cam): + self.attn_cam = cam + + def get_attn_cam(self): + return self.attn_cam + + def get_v(self): + return self.v + + def save_v(self, v): + self.v = v + + def save_v_cam(self, cam): + self.v_cam = cam + + def get_v_cam(self): + return self.v_cam + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def forward(self, x): + b, n, _, h = *x.shape, self.num_heads + qkv = self.qkv(x) + q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h) + + self.save_v(v) + + dots = self.matmul1([q, k]) * self.scale + + attn = self.softmax(dots) + attn = self.attn_drop(attn) + + # Get attention + if False: + from os import path + if not path.exists('att_1.pt'): + torch.save(attn, 'att_1.pt') + elif not path.exists('att_2.pt'): + torch.save(attn, 'att_2.pt') + else: + torch.save(attn, 'att_3.pt') + + #comment in training + if x.requires_grad: + self.save_attn(attn) + attn.register_hook(self.save_attn_gradients) + + out = self.matmul2([attn, v]) + out = rearrange(out, 'b h n d -> b n (h d)') + + out = self.proj(out) + out = self.proj_drop(out) + return out + + def relprop(self, cam, **kwargs): + cam = self.proj_drop.relprop(cam, **kwargs) + cam = self.proj.relprop(cam, **kwargs) + cam = rearrange(cam, 'b n (h d) -> b h n d', h=self.num_heads) + + # attn = A*V + (cam1, cam_v)= self.matmul2.relprop(cam, **kwargs) + cam1 /= 2 + cam_v /= 2 + + self.save_v_cam(cam_v) + self.save_attn_cam(cam1) + + cam1 = self.attn_drop.relprop(cam1, **kwargs) + cam1 = self.softmax.relprop(cam1, **kwargs) + + # A = Q*K^T + (cam_q, cam_k) = self.matmul1.relprop(cam1, **kwargs) + cam_q /= 2 + cam_k /= 2 + + cam_qkv = rearrange([cam_q, cam_k, cam_v], 'qkv b h n d -> b n (qkv h d)', qkv=3, h=self.num_heads) + + return self.qkv.relprop(cam_qkv, **kwargs) + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.): + super().__init__() + self.norm1 = LayerNorm(dim, eps=1e-6) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.norm2 = LayerNorm(dim, eps=1e-6) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop) + + self.add1 = Add() + self.add2 = Add() + self.clone1 = Clone() + self.clone2 = Clone() + + def forward(self, x): + x1, x2 = self.clone1(x, 2) + x = self.add1([x1, self.attn(self.norm1(x2))]) + x1, x2 = self.clone2(x, 2) + x = self.add2([x1, self.mlp(self.norm2(x2))]) + return x + + def relprop(self, cam, **kwargs): + (cam1, cam2) = self.add2.relprop(cam, **kwargs) + cam2 = self.mlp.relprop(cam2, **kwargs) + cam2 = self.norm2.relprop(cam2, **kwargs) + cam = self.clone2.relprop((cam1, cam2), **kwargs) + + (cam1, cam2) = self.add1.relprop(cam, **kwargs) + cam2 = self.attn.relprop(cam2, **kwargs) + cam2 = self.norm1.relprop(cam2, **kwargs) + cam = self.clone1.relprop((cam1, cam2), **kwargs) + return cam + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, num_classes=2, embed_dim=64, depth=3, + num_heads=8, mlp_ratio=2., qkv_bias=False, mlp_head=False, drop_rate=0., attn_drop_rate=0.): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate) + for i in range(depth)]) + + self.norm = LayerNorm(embed_dim) + if mlp_head: + # paper diagram suggests 'MLP head', but results in 4M extra parameters vs paper + self.head = Mlp(embed_dim, int(embed_dim * mlp_ratio), num_classes) + else: + # with a single Linear layer as head, the param count within rounding of paper + self.head = Linear(embed_dim, num_classes) + + #self.apply(self._init_weights) + + self.pool = IndexSelect() + self.add = Add() + + self.inp_grad = None + + def save_inp_grad(self,grad): + self.inp_grad = grad + + def get_inp_grad(self): + return self.inp_grad + + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @property + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward(self, x): + if x.requires_grad: + x.register_hook(self.save_inp_grad) #comment it in train + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device)) + x = x.squeeze(1) + x = self.head(x) + return x + + def relprop(self, cam=None,method="transformer_attribution", is_ablation=False, start_layer=0, **kwargs): + # print(kwargs) + # print("conservation 1", cam.sum()) + cam = self.head.relprop(cam, **kwargs) + cam = cam.unsqueeze(1) + cam = self.pool.relprop(cam, **kwargs) + cam = self.norm.relprop(cam, **kwargs) + for blk in reversed(self.blocks): + cam = blk.relprop(cam, **kwargs) + + # print("conservation 2", cam.sum()) + # print("min", cam.min()) + + if method == "full": + (cam, _) = self.add.relprop(cam, **kwargs) + cam = cam[:, 1:] + cam = self.patch_embed.relprop(cam, **kwargs) + # sum on channels + cam = cam.sum(dim=1) + return cam + + elif method == "rollout": + # cam rollout + attn_cams = [] + for blk in self.blocks: + attn_heads = blk.attn.get_attn_cam().clamp(min=0) + avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach() + attn_cams.append(avg_heads) + cam = compute_rollout_attention(attn_cams, start_layer=start_layer) + cam = cam[:, 0, 1:] + return cam + + # our method, method name grad is legacy + elif method == "transformer_attribution" or method == "grad": + cams = [] + for blk in self.blocks: + grad = blk.attn.get_attn_gradients() + cam = blk.attn.get_attn_cam() + cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1]) + grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1]) + cam = grad * cam + cam = cam.clamp(min=0).mean(dim=0) + cams.append(cam.unsqueeze(0)) + rollout = compute_rollout_attention(cams, start_layer=start_layer) + cam = rollout[:, 0, 1:] + return cam + + elif method == "last_layer": + cam = self.blocks[-1].attn.get_attn_cam() + cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1]) + if is_ablation: + grad = self.blocks[-1].attn.get_attn_gradients() + grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1]) + cam = grad * cam + cam = cam.clamp(min=0).mean(dim=0) + cam = cam[0, 1:] + return cam + + elif method == "last_layer_attn": + cam = self.blocks[-1].attn.get_attn() + cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1]) + cam = cam.clamp(min=0).mean(dim=0) + cam = cam[0, 1:] + return cam + + elif method == "second_layer": + cam = self.blocks[1].attn.get_attn_cam() + cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1]) + if is_ablation: + grad = self.blocks[1].attn.get_attn_gradients() + grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1]) + cam = grad * cam + cam = cam.clamp(min=0).mean(dim=0) + cam = cam[0, 1:] + return cam \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/__pycache__/GraphTransformer.cpython-38.pyc b/models/__pycache__/GraphTransformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b3d5ab870b19f73a63551535ccbf5902969cfd3 Binary files /dev/null and b/models/__pycache__/GraphTransformer.cpython-38.pyc differ diff --git a/models/__pycache__/ViT.cpython-38.pyc b/models/__pycache__/ViT.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcfb6a6c7807778d0c1b2e2aed83292c8583cb50 Binary files /dev/null and b/models/__pycache__/ViT.cpython-38.pyc differ diff --git a/models/__pycache__/__init__.cpython-38.pyc b/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..759c53064228f849d841cb166a83d30ba0ff1580 Binary files /dev/null and b/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/models/__pycache__/gcn.cpython-38.pyc b/models/__pycache__/gcn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cb6cff43d3474ff6fd1af86eed4198d6bff692f Binary files /dev/null and b/models/__pycache__/gcn.cpython-38.pyc differ diff --git a/models/__pycache__/layers.cpython-38.pyc b/models/__pycache__/layers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2b4eddb6fc10f1acab2da50966ee2f016ec5ad5 Binary files /dev/null and b/models/__pycache__/layers.cpython-38.pyc differ diff --git a/models/__pycache__/weight_init.cpython-38.pyc b/models/__pycache__/weight_init.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..410a7e0d7f64a4cbd2b1fc31df416e65fd63e1b5 Binary files /dev/null and b/models/__pycache__/weight_init.cpython-38.pyc differ diff --git a/models/gcn.py b/models/gcn.py new file mode 100644 index 0000000000000000000000000000000000000000..9d52daa8171c0dadfb06b6a8dde411ef778188e6 --- /dev/null +++ b/models/gcn.py @@ -0,0 +1,420 @@ +import torch +import torch.nn as nn +from torch.nn import init +import torch.nn.functional as F +import math + +import numpy as np + +torch.set_printoptions(precision=2,threshold=float('inf')) + +class AGCNBlock(nn.Module): + def __init__(self,input_dim,hidden_dim,gcn_layer=2,dropout=0.0,relu=0): + super(AGCNBlock,self).__init__() + if dropout > 0.001: + self.dropout_layer = nn.Dropout(p=dropout) + self.sort = 'sort' + self.model='agcn' + self.gcns=nn.ModuleList() + self.bn = 0 + self.add_self = 1 + self.normalize_embedding = 1 + self.gcns.append(GCNBlock(input_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,relu)) + self.pool = 'mean' + self.tau = 1. + self.lamda = 1. + + for i in range(gcn_layer-1): + if i==gcn_layer-2 and (not 1): + self.gcns.append(GCNBlock(hidden_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,0)) + else: + self.gcns.append(GCNBlock(hidden_dim,hidden_dim,self.bn,self.add_self,self.normalize_embedding,dropout,relu)) + + if self.model=='diffpool': + self.pool_gcns=nn.ModuleList() + tmp=input_dim + self.diffpool_k=200 + for i in range(3): + self.pool_gcns.append(GCNBlock(tmp,200,0,0,0,dropout,relu)) + tmp=200 + + self.w_a=nn.Parameter(torch.zeros(1,hidden_dim,1)) + self.w_b=nn.Parameter(torch.zeros(1,hidden_dim,1)) + torch.nn.init.normal_(self.w_a) + torch.nn.init.uniform_(self.w_b,-1,1) + + self.pass_dim=hidden_dim + + if self.pool=='mean': + self.pool=self.mean_pool + elif self.pool=='max': + self.pool=self.max_pool + elif self.pool=='sum': + self.pool=self.sum_pool + + self.softmax='global' + if self.softmax=='gcn': + self.att_gcn=GCNBlock(2,1,0,0,dropout,relu) + self.khop=1 + self.adj_norm='none' + + self.filt_percent=0.25 #default 0.5 + self.eps=1e-10 + + self.tau_config=1 + if 1==-1.: + self.tau=nn.Parameter(torch.tensor(1),requires_grad=False) + elif 1==-2.: + self.tau_fc=nn.Linear(hidden_dim,1) + torch.nn.init.constant_(self.tau_fc.bias,1) + torch.nn.init.xavier_normal_(self.tau_fc.weight.t()) + else: + self.tau=nn.Parameter(torch.tensor(self.tau)) + self.lamda1=nn.Parameter(torch.tensor(self.lamda)) + self.lamda2=nn.Parameter(torch.tensor(self.lamda)) + + self.att_norm=0 + + self.dnorm=0 + self.dnorm_coe=1 + + self.att_out=0 + self.single_att=0 + + + def forward(self,X,adj,mask,is_print=False): + ''' + input: + X: node input features , [batch,node_num,input_dim],dtype=float + adj: adj matrix, [batch,node_num,node_num], dtype=float + mask: mask for nodes, [batch,node_num] + outputs: + out:unormalized classification prob, [batch,hidden_dim] + H: batch of node hidden features, [batch,node_num,pass_dim] + new_adj: pooled new adj matrix, [batch, k_max, k_max] + new_mask: [batch, k_max] + ''' + hidden=X + #adj = adj.float() + # print('input size:') + # print(hidden.shape) + + is_print1=is_print2=is_print + if adj.shape[-1]>100: + is_print1=False + + for gcn in self.gcns: + hidden=gcn(hidden,adj,mask) + # print('gcn:') + # print(hidden.shape) + # print('mask:') + # print(mask.unsqueeze(2).shape) + # print(mask.sum(dim=1)) + + hidden=mask.unsqueeze(2)*hidden + # print(hidden[0][0]) + # print(hidden[0][-1]) + + if self.model=='unet': + att=torch.matmul(hidden,self.w_a).squeeze() + att=att/torch.sqrt((self.w_a.squeeze(2)**2).sum(dim=1,keepdim=True)) + elif self.model=='agcn': + if self.softmax=='global' or self.softmax=='mix': + if False: + dgree_w = torch.sum(adj, dim=2) / torch.sum(adj, dim=2).max(1, keepdim=True)[0] + att_a=torch.matmul(hidden,self.w_a).squeeze()*dgree_w+(mask-1)*1e10 + else: + att_a=torch.matmul(hidden,self.w_a).squeeze()+(mask-1)*1e10 + # print(att_a[0][:10]) + # print(att_a[0][-10:-1]) + att_a_1=att_a=torch.nn.functional.softmax(att_a,dim=1) + # print(att_a[0][:10]) + # print(att_a[0][-10:-1]) + + if self.dnorm: + scale=mask.sum(dim=1,keepdim=True)/self.dnorm_coe + att_a=scale*att_a + if self.softmax=='neibor' or self.softmax=='mix': + att_b=torch.matmul(hidden,self.w_b).squeeze()+(mask-1)*1e10 + att_b_max,_=att_b.max(dim=1,keepdim=True) + if self.tau_config!=-2: + att_b=torch.exp((att_b-att_b_max)*torch.abs(self.tau)) + else: + att_b=torch.exp((att_b-att_b_max)*torch.abs(self.tau_fc(self.pool(hidden,mask)))) + denom=att_b.unsqueeze(2) + for _ in range(self.khop): + denom=torch.matmul(adj,denom) + denom=denom.squeeze()+self.eps + att_b=(att_b*torch.diagonal(adj,0,1,2))/denom + if self.dnorm: + if self.adj_norm=='diag': + diag_scale=mask/(torch.diagonal(adj,0,1,2)+self.eps) + elif self.adj_norm=='none': + diag_scale=adj.sum(dim=1) + att_b=att_b*diag_scale + att_b=att_b*mask + + if self.softmax=='global': + att=att_a + elif self.softmax=='neibor' or self.softmax=='hardnei': + att=att_b + elif self.softmax=='mix': + att=att_a*torch.abs(self.lamda1)+att_b*torch.abs(self.lamda2) + # print('att:') + # print(att.shape) + Z=hidden + + if self.model=='unet': + Z=torch.tanh(att.unsqueeze(2))*Z + elif self.model=='agcn': + if self.single_att: + Z=Z + else: + Z=att.unsqueeze(2)*Z + # print('Z shape') + # print(Z.shape) + k_max=int(math.ceil(self.filt_percent*adj.shape[-1])) + # print('k_max') + # print(k_max) + if self.model=='diffpool': + k_max=min(k_max,self.diffpool_k) + + k_list=[int(math.ceil(self.filt_percent*x)) for x in mask.sum(dim=1).tolist()] + # print('k_list') + # print(k_list) + if self.model!='diffpool': + if self.sort=='sample': + att_samp = att * mask + att_samp = (att_samp/att_samp.sum(1)).detach().cpu().numpy() + top_index = () + for i in range(att.size(0)): + top_index = (torch.LongTensor(np.random.choice(att_samp.size(1), k_max, att_samp[i])) ,) + top_index = torch.stack(top_index,1) + elif self.sort=='random_sample': + top_index = torch.LongTensor(att.size(0), k_max)*0 + for i in range(att.size(0)): + top_index[i,0:k_list[i]] = torch.randperm(int(mask[i].sum().item()))[0:k_list[i]] + else: #sort + _,top_index=torch.topk(att,k_max,dim=1) + # print('top_index') + # print(top_index) + # print(len(top_index[0])) + new_mask=X.new_zeros(X.shape[0],k_max) + # print('new_mask') + # print(new_mask.shape) + visualize_tools=None + if self.model=='unet': + for i,k in enumerate(k_list): + for j in range(int(k),k_max): + top_index[i][j]=adj.shape[-1]-1 + new_mask[i][j]=-1. + new_mask=new_mask+1 + top_index,_=torch.sort(top_index,dim=1) + assign_m=X.new_zeros(X.shape[0],k_max,adj.shape[-1]) + for i,x in enumerate(top_index): + assign_m[i]=torch.index_select(adj[i],0,x) + new_adj=X.new_zeros(X.shape[0],k_max,k_max) + H=Z.new_zeros(Z.shape[0],k_max,Z.shape[-1]) + for i,x in enumerate(top_index): + new_adj[i]=torch.index_select(assign_m[i],1,x) + H[i]=torch.index_select(Z[i],0,x) + + elif self.model=='agcn': + assign_m=X.new_zeros(X.shape[0],k_max,adj.shape[-1]) + # print('assign_m.shape') + # print(assign_m.shape) + for i,k in enumerate(k_list): + #print('top_index[i][j]') + for j in range(int(k)): + #print(str(top_index[i][j].item())+' ', end='') + assign_m[i][j]=adj[i][top_index[i][j]] + #print(assign_m[i][j]) + new_mask[i][j]=1. + + assign_m=assign_m/(assign_m.sum(dim=1,keepdim=True)+self.eps) + H=torch.matmul(assign_m,Z) + # print('H') + # print(H.shape) + new_adj=torch.matmul(torch.matmul(assign_m,adj),torch.transpose(assign_m,1,2)) + # print(torch.matmul(assign_m,adj).shape) + # print('new_adj:') + # print(new_adj.shape) + + elif self.model=='diffpool': + hidden1=X + for gcn in self.pool_gcns: + hidden1=gcn(hidden1,adj,mask) + assign_m=X.new_ones(X.shape[0],X.shape[1],k_max)*(-100000000.) + for i,x in enumerate(hidden1): + k=min(k_list[i],k_max) + assign_m[i,:,0:k]=hidden1[i,:,0:k] + for j in range(int(k)): + new_mask[i][j]=1. + + assign_m=torch.nn.functional.softmax(assign_m,dim=2)*mask.unsqueeze(2) + assign_m_t=torch.transpose(assign_m,1,2) + new_adj=torch.matmul(torch.matmul(assign_m_t,adj),assign_m) + H=torch.matmul(assign_m_t,Z) + # print('pool') + if self.att_out and self.model=='agcn': + if self.softmax=='global': + out=self.pool(att_a_1.unsqueeze(2)*hidden,mask) + elif self.softmax=='neibor': + att_b_sum=att_b.sum(dim=1,keepdim=True) + out=self.pool((att_b/(att_b_sum+self.eps)).unsqueeze(2)*hidden,mask) + else: + # print('hidden.shape') + # print(hidden.shape) + out=self.pool(hidden,mask) + # print('out shape') + # print(out.shape) + + if self.adj_norm=='tanh' or self.adj_norm=='mix': + new_adj=torch.tanh(new_adj) + elif self.adj_norm=='diag' or self.adj_norm=='mix': + diag_elem=torch.pow(new_adj.sum(dim=2)+self.eps,-0.5) + diag=new_adj.new_zeros(new_adj.shape) + for i,x in enumerate(diag_elem): + diag[i]=torch.diagflat(x) + new_adj=torch.matmul(torch.matmul(diag,new_adj),diag) + + visualize_tools=[] + ''' + if (not self.training) and is_print1: + print('**********************************') + print('node_feat:',X.type(),X.shape) + print(X) + if self.model!='diffpool': + print('**********************************') + print('att:',att.type(),att.shape) + print(att) + print('**********************************') + print('top_index:',top_index.type(),top_index.shape) + print(top_index) + print('**********************************') + print('adj:',adj.type(),adj.shape) + print(adj) + print('**********************************') + print('assign_m:',assign_m.type(),assign_m.shape) + print(assign_m) + print('**********************************') + print('new_adj:',new_adj.type(),new_adj.shape) + print(new_adj) + print('**********************************') + print('new_mask:',new_mask.type(),new_mask.shape) + print(new_mask) + ''' + #visualization + from os import path + if not path.exists('att_1.pt'): + torch.save(att[0], 'att_1.pt') + torch.save(top_index[0], 'att_ind1.pt') + elif not path.exists('att_2.pt'): + torch.save(att[0], 'att_2.pt') + torch.save(top_index[0], 'att_ind2.pt') + else: + torch.save(att[0], 'att_3.pt') + torch.save(top_index[0], 'att_ind3.pt') + + if (not self.training) and is_print2: + if self.model!='diffpool': + visualize_tools.append(att[0]) + visualize_tools.append(top_index[0]) + visualize_tools.append(new_adj[0]) + visualize_tools.append(new_mask.sum()) + # print('**********************************') + return out,H,new_adj,new_mask,visualize_tools + + def mean_pool(self,x,mask): + return x.sum(dim=1)/(self.eps+mask.sum(dim=1,keepdim=True)) + + def sum_pool(self,x,mask): + return x.sum(dim=1) + + @staticmethod + def max_pool(x,mask): + #output: [batch,x.shape[2]] + m=(mask-1)*1e10 + r,_=(x+m.unsqueeze(2)).max(dim=1) + return r +# GCN basic operation +class GCNBlock(nn.Module): + def __init__(self, input_dim, output_dim, bn=0,add_self=0, normalize_embedding=0, + dropout=0.0,relu=0, bias=True): + super(GCNBlock,self).__init__() + self.add_self = add_self + self.dropout = dropout + self.relu=relu + self.bn=bn + if dropout > 0.001: + self.dropout_layer = nn.Dropout(p=dropout) + if self.bn: + self.bn_layer = torch.nn.BatchNorm1d(output_dim) + + self.normalize_embedding = normalize_embedding + self.input_dim = input_dim + self.output_dim = output_dim + + self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).to( 'cuda' if torch.cuda.is_available() else 'cpu') ) + torch.nn.init.xavier_normal_(self.weight) + if bias: + self.bias = nn.Parameter(torch.zeros(output_dim).to( 'cuda' if torch.cuda.is_available() else 'cpu') ) + else: + self.bias = None + + def forward(self, x, adj, mask): + y = torch.matmul(adj, x) + if self.add_self: + y += x + y = torch.matmul(y,self.weight) + if self.bias is not None: + y = y + self.bias + if self.normalize_embedding: + y = F.normalize(y, p=2, dim=2) + if self.bn: + index=mask.sum(dim=1).long().tolist() + bn_tensor_bf=mask.new_zeros((sum(index),y.shape[2])) + bn_tensor_af=mask.new_zeros(*y.shape) + start_index=[] + ssum=0 + for i in range(x.shape[0]): + start_index.append(ssum) + ssum+=index[i] + start_index.append(ssum) + for i in range(x.shape[0]): + bn_tensor_bf[start_index[i]:start_index[i+1]]=y[i,0:index[i]] + bn_tensor_bf=self.bn_layer(bn_tensor_bf) + for i in range(x.shape[0]): + bn_tensor_af[i,0:index[i]]=bn_tensor_bf[start_index[i]:start_index[i+1]] + y=bn_tensor_af + if self.dropout > 0.001: + y = self.dropout_layer(y) + if self.relu=='relu': + y=torch.nn.functional.relu(y) + print('hahah') + elif self.relu=='lrelu': + y=torch.nn.functional.leaky_relu(y,0.1) + return y + +#experimental function, untested +class masked_batchnorm(nn.Module): + def __init__(self,feat_dim,epsilon=1e-10): + super().__init__() + self.alpha=nn.Parameter(torch.ones(feat_dim)) + self.beta=nn.Parameter(torch.zeros(feat_dim)) + self.eps=epsilon + + def forward(self,x,mask): + ''' + x: node feat, [batch,node_num,feat_dim] + mask: [batch,node_num] + ''' + mask1 = mask.unsqueeze(2) + mask_sum = mask.sum() + mean = x.sum(dim=(0,1),keepdim=True)/(self.eps+mask_sum) + temp = (x - mean)**2 + temp = temp*mask1 + var = temp.sum(dim=(0,1),keepdim=True)/(self.eps+mask_sum) + rstd = torch.rsqrt(var+self.eps) + x=(x-mean)*rstd + return ((x*self.alpha) + self.beta)*mask1 \ No newline at end of file diff --git a/models/layers.py b/models/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..66703b11cdd30c0bba249dc7a376f3638a14a253 --- /dev/null +++ b/models/layers.py @@ -0,0 +1,280 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d', + 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect', + 'LayerNorm', 'AddEye'] + + +def safe_divide(a, b): + den = b.clamp(min=1e-9) + b.clamp(max=1e-9) + den = den + den.eq(0).type(den.type()) * 1e-9 + return a / den * b.ne(0).type(b.type()) + + +def forward_hook(self, input, output): + if type(input[0]) in (list, tuple): + self.X = [] + for i in input[0]: + x = i.detach() + x.requires_grad = True + self.X.append(x) + else: + self.X = input[0].detach() + self.X.requires_grad = True + + self.Y = output + + +def backward_hook(self, grad_input, grad_output): + self.grad_input = grad_input + self.grad_output = grad_output + + +class RelProp(nn.Module): + def __init__(self): + super(RelProp, self).__init__() + # if not self.training: + self.register_forward_hook(forward_hook) + + def gradprop(self, Z, X, S): + C = torch.autograd.grad(Z, X, S, retain_graph=True) + return C + + def relprop(self, R, alpha): + return R + +class RelPropSimple(RelProp): + def relprop(self, R, alpha): + Z = self.forward(self.X) + S = safe_divide(R, Z) + C = self.gradprop(Z, self.X, S) + + if torch.is_tensor(self.X) == False: + outputs = [] + outputs.append(self.X[0] * C[0]) + outputs.append(self.X[1] * C[1]) + else: + outputs = self.X * (C[0]) + return outputs + +class AddEye(RelPropSimple): + # input of shape B, C, seq_len, seq_len + def forward(self, input): + return input + torch.eye(input.shape[2]).expand_as(input).to(input.device) + +class ReLU(nn.ReLU, RelProp): + pass + +class GELU(nn.GELU, RelProp): + pass + +class Softmax(nn.Softmax, RelProp): + pass + +class LayerNorm(nn.LayerNorm, RelProp): + pass + +class Dropout(nn.Dropout, RelProp): + pass + + +class MaxPool2d(nn.MaxPool2d, RelPropSimple): + pass + +class LayerNorm(nn.LayerNorm, RelProp): + pass + +class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple): + pass + + +class AvgPool2d(nn.AvgPool2d, RelPropSimple): + pass + + +class Add(RelPropSimple): + def forward(self, inputs): + return torch.add(*inputs) + + def relprop(self, R, alpha): + Z = self.forward(self.X) + S = safe_divide(R, Z) + C = self.gradprop(Z, self.X, S) + + a = self.X[0] * C[0] + b = self.X[1] * C[1] + + a_sum = a.sum() + b_sum = b.sum() + + a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum() + b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum() + + a = a * safe_divide(a_fact, a.sum()) + b = b * safe_divide(b_fact, b.sum()) + + outputs = [a, b] + + return outputs + +class einsum(RelPropSimple): + def __init__(self, equation): + super().__init__() + self.equation = equation + def forward(self, *operands): + return torch.einsum(self.equation, *operands) + +class IndexSelect(RelProp): + def forward(self, inputs, dim, indices): + self.__setattr__('dim', dim) + self.__setattr__('indices', indices) + + return torch.index_select(inputs, dim, indices) + + def relprop(self, R, alpha): + Z = self.forward(self.X, self.dim, self.indices) + S = safe_divide(R, Z) + C = self.gradprop(Z, self.X, S) + + if torch.is_tensor(self.X) == False: + outputs = [] + outputs.append(self.X[0] * C[0]) + outputs.append(self.X[1] * C[1]) + else: + outputs = self.X * (C[0]) + return outputs + + + +class Clone(RelProp): + def forward(self, input, num): + self.__setattr__('num', num) + outputs = [] + for _ in range(num): + outputs.append(input) + + return outputs + + def relprop(self, R, alpha): + Z = [] + for _ in range(self.num): + Z.append(self.X) + S = [safe_divide(r, z) for r, z in zip(R, Z)] + C = self.gradprop(Z, self.X, S)[0] + + R = self.X * C + + return R + +class Cat(RelProp): + def forward(self, inputs, dim): + self.__setattr__('dim', dim) + return torch.cat(inputs, dim) + + def relprop(self, R, alpha): + Z = self.forward(self.X, self.dim) + S = safe_divide(R, Z) + C = self.gradprop(Z, self.X, S) + + outputs = [] + for x, c in zip(self.X, C): + outputs.append(x * c) + + return outputs + + +class Sequential(nn.Sequential): + def relprop(self, R, alpha): + for m in reversed(self._modules.values()): + R = m.relprop(R, alpha) + return R + +class BatchNorm2d(nn.BatchNorm2d, RelProp): + def relprop(self, R, alpha): + X = self.X + beta = 1 - alpha + weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / ( + (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5)) + Z = X * weight + 1e-9 + S = R / Z + Ca = S * weight + R = self.X * (Ca) + return R + + +class Linear(nn.Linear, RelProp): + def relprop(self, R, alpha): + beta = alpha - 1 + pw = torch.clamp(self.weight, min=0) + nw = torch.clamp(self.weight, max=0) + px = torch.clamp(self.X, min=0) + nx = torch.clamp(self.X, max=0) + + def f(w1, w2, x1, x2): + Z1 = F.linear(x1, w1) + Z2 = F.linear(x2, w2) + S1 = safe_divide(R, Z1 + Z2) + S2 = safe_divide(R, Z1 + Z2) + C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0] + C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0] + + return C1 + C2 + + activator_relevances = f(pw, nw, px, nx) + inhibitor_relevances = f(nw, pw, px, nx) + + R = alpha * activator_relevances - beta * inhibitor_relevances + + return R + + +class Conv2d(nn.Conv2d, RelProp): + def gradprop2(self, DY, weight): + Z = self.forward(self.X) + + output_padding = self.X.size()[2] - ( + (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0]) + + return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding) + + def relprop(self, R, alpha): + if self.X.shape[1] == 3: + pw = torch.clamp(self.weight, min=0) + nw = torch.clamp(self.weight, max=0) + X = self.X + L = self.X * 0 + \ + torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, + keepdim=True)[0] + H = self.X * 0 + \ + torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, + keepdim=True)[0] + Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \ + torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \ + torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9 + + S = R / Za + C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw) + R = C + else: + beta = alpha - 1 + pw = torch.clamp(self.weight, min=0) + nw = torch.clamp(self.weight, max=0) + px = torch.clamp(self.X, min=0) + nx = torch.clamp(self.X, max=0) + + def f(w1, w2, x1, x2): + Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding) + Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding) + S1 = safe_divide(R, Z1) + S2 = safe_divide(R, Z2) + C1 = x1 * self.gradprop(Z1, x1, S1)[0] + C2 = x2 * self.gradprop(Z2, x2, S2)[0] + return C1 + C2 + + activator_relevances = f(pw, nw, px, nx) + inhibitor_relevances = f(nw, pw, px, nx) + + R = alpha * activator_relevances - beta * inhibitor_relevances + return R \ No newline at end of file diff --git a/models/weight_init.py b/models/weight_init.py new file mode 100644 index 0000000000000000000000000000000000000000..aa71c04254105ef5ba0c89bb730270328fc49bb1 --- /dev/null +++ b/models/weight_init.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python +# -*- coding:UTF-8 -*- + +import torch +import torch.nn as nn +import torch.nn.init as init + + +def weight_init(m): + ''' + Usage: + model = Model() + model.apply(weight_init) + ''' + if isinstance(m, nn.Conv1d): + init.normal_(m.weight.data) + if m.bias is not None: + init.normal_(m.bias.data) + elif isinstance(m, nn.Conv2d): + init.xavier_normal_(m.weight.data) + if m.bias is not None: + init.normal_(m.bias.data) + elif isinstance(m, nn.Conv3d): + init.xavier_normal_(m.weight.data) + if m.bias is not None: + init.normal_(m.bias.data) + elif isinstance(m, nn.ConvTranspose1d): + init.normal_(m.weight.data) + if m.bias is not None: + init.normal_(m.bias.data) + elif isinstance(m, nn.ConvTranspose2d): + init.xavier_normal_(m.weight.data) + if m.bias is not None: + init.normal_(m.bias.data) + elif isinstance(m, nn.ConvTranspose3d): + init.xavier_normal_(m.weight.data) + if m.bias is not None: + init.normal_(m.bias.data) + elif isinstance(m, nn.BatchNorm1d): + init.normal_(m.weight.data, mean=1, std=0.02) + init.constant_(m.bias.data, 0) + elif isinstance(m, nn.BatchNorm2d): + init.normal_(m.weight.data, mean=1, std=0.02) + init.constant_(m.bias.data, 0) + elif isinstance(m, nn.BatchNorm3d): + init.normal_(m.weight.data, mean=1, std=0.02) + init.constant_(m.bias.data, 0) + elif isinstance(m, nn.Linear): + init.xavier_normal_(m.weight.data) + init.normal_(m.bias.data) + elif isinstance(m, nn.LSTM): + for param in m.parameters(): + if len(param.shape) >= 2: + init.orthogonal_(param.data) + else: + init.normal_(param.data) + elif isinstance(m, nn.LSTMCell): + for param in m.parameters(): + if len(param.shape) >= 2: + init.orthogonal_(param.data) + else: + init.normal_(param.data) + elif isinstance(m, nn.GRU): + for param in m.parameters(): + if len(param.shape) >= 2: + init.orthogonal_(param.data) + else: + init.normal_(param.data) + elif isinstance(m, nn.GRUCell): + for param in m.parameters(): + if len(param.shape) >= 2: + init.orthogonal_(param.data) + else: + init.normal_(param.data) + + +if __name__ == '__main__': + pass \ No newline at end of file diff --git a/option.py b/option.py new file mode 100644 index 0000000000000000000000000000000000000000..35af9d6c88465d28aaff5e5472c439b554b3c728 --- /dev/null +++ b/option.py @@ -0,0 +1,41 @@ +########################################################################### +# Created by: YI ZHENG +# Email: yizheng@bu.edu +# Copyright (c) 2020 +########################################################################### + +import os +import argparse +import torch + +class Options(): + def __init__(self): + parser = argparse.ArgumentParser(description='PyTorch Classification') + parser.add_argument('--data_path', type=str, help='path to dataset where images store') + parser.add_argument('--train_set', type=str, help='train') + parser.add_argument('--val_set', type=str, help='validation') + parser.add_argument('--model_path', type=str, help='path to trained model') + parser.add_argument('--log_path', type=str, help='path to log files') + parser.add_argument('--task_name', type=str, help='task name for naming saved model files and log files') + parser.add_argument('--train', action='store_true', default=False, help='train only') + parser.add_argument('--test', action='store_true', default=False, help='test only') + parser.add_argument('--batch_size', type=int, default=6, help='batch size for origin global image (without downsampling)') + parser.add_argument('--log_interval_local', type=int, default=10, help='classification classes') + parser.add_argument('--resume', type=str, default="", help='path for model') + parser.add_argument('--graphcam', action='store_true', default=False, help='GraphCAM') + parser.add_argument('--dataset_metadata_path', type=str, help='Location of the metadata associated with the created dataset: label mapping, splits and so on') + + + # the parser + self.parser = parser + + def parse(self): + args = self.parser.parse_args() + # default settings for epochs and lr + + args.num_epochs = 120 + args.lr = 1e-3 + + if args.test: + args.num_epochs = 1 + return args diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..717b1fa2b0f030c5b959f0ad5ae718867f2b564f --- /dev/null +++ b/packages.txt @@ -0,0 +1,3 @@ +openslide-tools +python3-openslide +python3-opencv \ No newline at end of file diff --git a/predict.py b/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..a16eea20541f1975db22befd5c8490786768b7af --- /dev/null +++ b/predict.py @@ -0,0 +1,305 @@ + +from __future__ import absolute_import, division, print_function + +import os +import numpy as np +import torch +import torch.nn as nn +from torchvision import transforms + +import torchvision.models as models +from feature_extractor import cl +from models.GraphTransformer import Classifier +from models.weight_init import weight_init +from feature_extractor.build_graph_utils import ToTensor, Compose, bag_dataset, adj_matrix +import torchvision.transforms.functional as VF +from src.vis_graphcam import show_cam_on_image,cam_to_mask +from easydict import EasyDict as edict +from models.GraphTransformer import Classifier +from slide_tiling import save_tiles +import pickle +from collections import OrderedDict +import glob +import openslide +import numpy as np +import skimage.transform +import cv2 + + +class Predictor: + + def __init__(self): + self.classdict = pickle.load(open(os.environ['CLASS_METADATA'], 'rb' )) + self.label_map_inv = dict() + for label_name, label_id in self.classdict.items(): + self.label_map_inv[label_id] = label_name + + iclf_weights = os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH'] + graph_transformer_weights = os.environ['GT_WEIGHT_PATH'] + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + + self.__init_iclf(iclf_weights, backbone='resnet18') + self.__init_graph_transformer(graph_transformer_weights) + + def predict(self, slide_path): + + # get tiles for a given WSI slide + save_tiles(slide_path) + + filename = os.path.basename(slide_path) + FILEID = filename.rsplit('.', maxsplit=1)[0] + patches_glob_path = os.path.join(os.environ['PATCHES_DIR'], f'{FILEID}_files', '*', '*.jpeg') + patches_paths = glob.glob(patches_glob_path) + + sample = self.iclf_predict(patches_paths) + + + torch.set_grad_enabled(True) + node_feat, adjs, masks = Predictor.preparefeatureLabel(sample['image'], sample['adj_s'], self.device) + pred,labels,loss,graphcam_tensors = self.model.forward(node_feat=node_feat, labels=None, adj=adjs, mask=masks, graphcam_flag=True, to_file=False) + + patches_coords = sample['c_idx'][0] + viz_dict = self.get_graphcams(graphcam_tensors, patches_coords, slide_path, FILEID) + return self.label_map_inv[pred.item()], viz_dict + + def iclf_predict(self, patches_paths): + feats_list = [] + + batch_size = 128 + num_workers = 0 + args = edict({'batch_size':batch_size, 'num_workers':num_workers} ) + dataloader, bag_size = bag_dataset(args, patches_paths) + + with torch.no_grad(): + for iteration, batch in enumerate(dataloader): + patches = batch['input'].float().to(self.device) + feats, classes = self.i_classifier(patches) + #feats = feats.cpu().numpy() + feats_list.extend(feats) + output = torch.stack(feats_list, dim=0).to(self.device) + # save adjacent matrix + adj_s = adj_matrix(patches_paths, output) + + + patch_infos = [] + for path in patches_paths: + x, y = path.split('/')[-1].split('.')[0].split('_') + patch_infos.append((x,y)) + + preds = {'image': [output], + 'adj_s': [adj_s], + 'c_idx': [patch_infos]} + return preds + + + + def get_graphcams(self, graphcam_tensors, patches_coords, slide_path, FILEID): + label_map = self.classdict + label_name_from_id = self.label_map_inv + + n_class = len(label_map) + + p = graphcam_tensors['prob'].cpu().detach().numpy()[0] + ori = openslide.OpenSlide(slide_path) + width, height = ori.dimensions + + REDUCTION_FACTOR = 20 + w, h = int(width/512), int(height/512) + w_r, h_r = int(width/20), int(height/20) + resized_img = ori.get_thumbnail((width,height))#ori.get_thumbnail((w_r,h_r)) + resized_img = resized_img.resize((w_r,h_r)) + ratio_w, ratio_h = width/resized_img.width, height/resized_img.height + #print('ratios ', ratio_w, ratio_h) + w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR) + + patches = [] + xmax, ymax = 0, 0 + for patch_coords in patches_coords: + x, y = patch_coords + if xmax < int(x): xmax = int(x) + if ymax < int(y): ymax = int(y) + patches.append('{}_{}.jpeg'.format(x,y)) + + + + output_img = np.asarray(resized_img)[:,:,::-1].copy() + #-----------------------------------------------------------------------------------------------------# + # GraphCAM + #print('visulize GraphCAM') + assign_matrix = graphcam_tensors['s_matrix_ori'] + m = nn.Softmax(dim=1) + assign_matrix = m(assign_matrix) + + # Thresholding for better visualization + p = np.clip(p, 0.4, 1) + + + + output_img_copy =np.copy(output_img) + gray = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + image_transformer_attribution = (output_img_copy - output_img_copy.min()) / (output_img_copy.max() - output_img_copy.min()) + cam_matrices = [] + masks = [] + visualizations = [] + viz_dict = dict() + + SAMPLE_VIZ_DIR = os.path.join(os.environ['GRAPHCAM_DIR'], + FILEID) + os.makedirs(SAMPLE_VIZ_DIR, exist_ok=True) + + for class_i in range(n_class): + + # Load graphcam for each class + cam_matrix = graphcam_tensors[f'cam_{class_i}'] + cam_matrix = torch.mm(assign_matrix, cam_matrix.transpose(1,0)) + cam_matrix = cam_matrix.cpu() + + # Normalize the graphcam + cam_matrix = (cam_matrix - cam_matrix.min()) / (cam_matrix.max() - cam_matrix.min()) + cam_matrix = cam_matrix.detach().numpy() + cam_matrix = p[class_i] * cam_matrix + cam_matrix = np.clip(cam_matrix, 0, 1) + + + mask = cam_to_mask(gray, patches, cam_matrix, w, h, w_s, h_s) + + vis = show_cam_on_image(image_transformer_attribution, mask) + vis = np.uint8(255 * vis) + + cam_matrices.append(cam_matrix) + masks.append(mask) + visualizations.append(vis) + viz_dict['{}'.format(label_name_from_id[class_i]) ] = vis + cv2.imwrite(os.path.join( + SAMPLE_VIZ_DIR, + '{}_all_types_cam_{}.png'.format(FILEID, label_name_from_id[class_i] ) + ), vis) + h, w, _ = output_img.shape + if h > w: + vis_merge = cv2.hconcat([output_img] + visualizations) + else: + vis_merge = cv2.vconcat([output_img] + visualizations) + + + cv2.imwrite(os.path.join( + SAMPLE_VIZ_DIR, + '{}_all_types_cam_all.png'.format(FILEID)), + vis_merge) + viz_dict['ALL'] = vis_merge + cv2.imwrite(os.path.join( + SAMPLE_VIZ_DIR, + '{}_all_types_ori.png'.format(FILEID ) + ), + output_img) + viz_dict['ORI'] = output_img + return viz_dict + + + + + + + def preparefeatureLabel(batch_graph, batch_adjs, device='cpu'): + batch_size = len(batch_graph) + max_node_num = 0 + + for i in range(batch_size): + max_node_num = max(max_node_num, batch_graph[i].shape[0]) + + masks = torch.zeros(batch_size, max_node_num) + adjs = torch.zeros(batch_size, max_node_num, max_node_num) + batch_node_feat = torch.zeros(batch_size, max_node_num, 512) + + for i in range(batch_size): + cur_node_num = batch_graph[i].shape[0] + #node attribute feature + tmp_node_fea = batch_graph[i] + batch_node_feat[i, 0:cur_node_num] = tmp_node_fea + + #adjs + adjs[i, 0:cur_node_num, 0:cur_node_num] = batch_adjs[i] + + #masks + masks[i,0:cur_node_num] = 1 + + node_feat = batch_node_feat.to() + adjs = adjs.to(device) + masks = masks.to(device) + + return node_feat, adjs, masks + + def __init_graph_transformer(self, graph_transformer_weights): + n_class = len(self.classdict) + model = Classifier(n_class) + model = nn.DataParallel(model) + model.load_state_dict(torch.load(graph_transformer_weights, + map_location=torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) )) + if torch.cuda.is_available(): + model = model.cuda() + self.model = model + + + def __init_iclf(self, iclf_weights, backbone='resnet18'): + if backbone == 'resnet18': + resnet = models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d) + num_feats = 512 + if backbone == 'resnet34': + resnet = models.resnet34(pretrained=False, norm_layer=nn.InstanceNorm2d) + num_feats = 512 + if backbone == 'resnet50': + resnet = models.resnet50(pretrained=False, norm_layer=nn.InstanceNorm2d) + num_feats = 2048 + if backbone == 'resnet101': + resnet = models.resnet101(pretrained=False, norm_layer=nn.InstanceNorm2d) + num_feats = 2048 + for param in resnet.parameters(): + param.requires_grad = False + resnet.fc = nn.Identity() + i_classifier = cl.IClassifier(resnet, num_feats, output_class=2).to(self.device) + + # load feature extractor + + state_dict_weights = torch.load(iclf_weights, map_location=torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )) + state_dict_init = i_classifier.state_dict() + new_state_dict = OrderedDict() + for (k, v), (k_0, v_0) in zip(state_dict_weights.items(), state_dict_init.items()): + if 'features' not in k: + continue + name = k_0 + new_state_dict[name] = v + i_classifier.load_state_dict(new_state_dict, strict=False) + + self.i_classifier = i_classifier + + + + + +#0 load metadata dicitonary for class names +#1 TILE THE IMAGE +#2 FEED IT TO FEATURE EXTRACTOR +#3 PRODUCE GRAPH +#4 predict graphcams +import subprocess +import argparse +import os +import shutil + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='PyTorch Classification') + parser.add_argument('--slide_path', type=str, help='path to the WSI slide') + args = parser.parse_args() + predictor = Predictor() + + predicted_class, viz_dict = predictor.predict(args.slide_path) + print('Class prediction is: ', predicted_class) + + + + + + + + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..42faad6293158915bc700a5366b49d2ae707a2d3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,25 @@ +streamlit + +#-f https://download.pytorch.org/whl/cpu/torch_stable.html +#-f https://data.pyg.org/whl/torch-1.7.1+cpu.html +#torch==1.7.1+cpu +-f https://download.pytorch.org/whl/torch_stable.html +-f https://data.pyg.org/whl/torch-1.10.0+cu113.html +torch==1.10.0+cu113 +torchvision +#torch-scatter +#torch-sparse + +einops +streamlit-option-menu +numpy +pandas +scikit-image +opencv-python +PyYAML +tqdm +scipy +imageio +easydict +openslide-python +pydicom \ No newline at end of file diff --git a/set_env.sh b/set_env.sh new file mode 100755 index 0000000000000000000000000000000000000000..06021e440e5f23a62c786e5b3f85143f9ab345f9 --- /dev/null +++ b/set_env.sh @@ -0,0 +1,20 @@ + +# environment variables for model training + + + +# environment variables for the inference api +export DATA_DIR=queries +export PATCHES_DIR=${DATA_DIR}/patches +export SLIDES_DIR=${DATA_DIR}/slides +export GRAPHCAM_DIR=${DATA_DIR}/graphcam_plots +mkdir $GRAPHCAM_DIR -p + + +# manually put the metadata in the metadata folder +export CLASS_METADATA='metadata/label_map.pkl' + +# manually put the desired weights in the weights folder +export WEIGHTS_PATH='weights' +export FEATURE_EXTRACTOR_WEIGHT_PATH=${WEIGHTS_PATH}/feature_extractor/model.pth +export GT_WEIGHT_PATH=${WEIGHTS_PATH}/graph_transformer/GraphCAM.pth \ No newline at end of file diff --git a/slide_tiling.py b/slide_tiling.py new file mode 100644 index 0000000000000000000000000000000000000000..8736ce0d43bab3cc3825ae52747cd7fbe433a15f --- /dev/null +++ b/slide_tiling.py @@ -0,0 +1,41 @@ +import subprocess +import argparse +import os +import shutil + + +def save_tiles(slide_path): + + filename = os.path.basename(slide_path) + FILEID = filename.rsplit('.', maxsplit=1)[0] + PATCHES_DIR = os.environ['PATCHES_DIR'] + SLIDES_DIR = os.environ['SLIDES_DIR'] + os.makedirs(PATCHES_DIR, exist_ok=True) + os.makedirs(SLIDES_DIR, exist_ok=True) + shutil.copy(slide_path, SLIDES_DIR) + + INPUT_PATH = os.path.join(SLIDES_DIR, filename) + CMD = ['python3', 'src/tile_WSI.py', '-s', '512', '-e', '0', '-j', '16', '-B', '50', '-M', '20', '-o', PATCHES_DIR, INPUT_PATH] + subprocess.call(CMD) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='PyTorch Classification') + parser.add_argument('--slide_path', type=str, help='path to the WSI slide') + args = parser.parse_args() + + + filename = os.path.basename(args.slide_path) + FILEID = filename.rsplit('.', maxsplit=1)[0] + PATCHES_DIR = os.environ['PATCHES_DIR'] + SLIDES_DIR = os.environ['SLIDES_DIR'] + os.makedirs(PATCHES_DIR, exist_ok=True) + os.makedirs(SLIDES_DIR, exist_ok=True) + shutil.move(args.slide_path, SLIDES_DIR) + + INPUT_PATH = os.path.join(SLIDES_DIR, filename) + + + CMD = ['python3', 'src/tile_WSI.py', '-s', '512', '-e', '0', '-j', '16', '-B', '50', '-M', '20', '-o', PATCHES_DIR, INPUT_PATH] + + subprocess.call(CMD) + diff --git a/src/__pycache__/vis_graphcam.cpython-38.pyc b/src/__pycache__/vis_graphcam.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1c8c10ff56515f819d2744de9d24a3d11a0ce64 Binary files /dev/null and b/src/__pycache__/vis_graphcam.cpython-38.pyc differ diff --git a/src/tile_WSI.py b/src/tile_WSI.py new file mode 100644 index 0000000000000000000000000000000000000000..47433c954a4252b4749f89cec7e5284963596a8b --- /dev/null +++ b/src/tile_WSI.py @@ -0,0 +1,980 @@ +''' + File name: tile_WSI.py + Date created: March/2021 + Source: + Tiling code inspired from + https://github.com/openslide/openslide-python/blob/master/examples/deepzoom/deepzoom_tile.py + + The code has been extensively modified + Objective: + Tile svs, jpg or dcm images with the possibility of rejecting some tiles based based on xml or jpg masks + Be careful: + Overload of the node - may have memory issue if node is shared with other jobs. +''' + +from __future__ import print_function +import json +import openslide +from openslide import open_slide, ImageSlide +from openslide.deepzoom import DeepZoomGenerator +from optparse import OptionParser +import re +import shutil +from unicodedata import normalize +import numpy as np +import scipy.misc +import subprocess +from glob import glob +from multiprocessing import Process, JoinableQueue +import time +import os +import sys +try: + import pydicom as dicom +except ImportError: + import dicom +# from scipy.misc import imsave +from imageio import imwrite as imsave +# from scipy.misc import imread +from imageio import imread +# from scipy.misc import imresize + +from xml.dom import minidom +from PIL import Image, ImageDraw, ImageCms +from skimage import color, io +from tqdm import tqdm +Image.MAX_IMAGE_PIXELS = None + + +VIEWER_SLIDE_NAME = 'slide' + + +class TileWorker(Process): + """A child process that generates and writes tiles.""" + + def __init__(self, queue, slidepath, tile_size, overlap, limit_bounds,quality, _Bkg, _ROIpc): + Process.__init__(self, name='TileWorker') + self.daemon = True + self._queue = queue + self._slidepath = slidepath + self._tile_size = tile_size + self._overlap = overlap + self._limit_bounds = limit_bounds + self._quality = quality + self._slide = None + self._Bkg = _Bkg + self._ROIpc = _ROIpc + + def RGB_to_lab(self, tile): + # srgb_p = ImageCms.createProfile("sRGB") + # lab_p = ImageCms.createProfile("LAB") + # rgb2lab = ImageCms.buildTransformFromOpenProfiles(srgb_p, lab_p, "RGB", "LAB") + # Lab = ImageCms.applyTransform(tile, rgb2lab) + # Lab = np.array(Lab) + # Lab = Lab.astype('float') + # Lab[:,:,0] = Lab[:,:,0] / 2.55 + # Lab[:,:,1] = Lab[:,:,1] - 128 + # Lab[:,:,2] = Lab[:,:,2] - 128 + print("RGB to Lab") + Lab = color.rgb2lab(tile) + return Lab + + def Lab_to_RGB(self,Lab): + # srgb_p = ImageCms.createProfile("sRGB") + # lab_p = ImageCms.createProfile("LAB") + # lab2rgb = ImageCms.buildTransformFromOpenProfiles(srgb_p, lab_p, "LAB", "RGB") + # Lab[:,:,0] = Lab[:,:,0] * 2.55 + # Lab[:,:,1] = Lab[:,:,1] + 128 + # Lab[:,:,2] = Lab[:,:,2] + 128 + # newtile = ImageCms.applyTransform(Lab, lab2rgb) + print("Lab to RGB") + newtile = (color.lab2rgb(Lab) * 255).astype(np.uint8) + return newtile + + + def normalize_tile(self, tile, NormVec): + Lab = self.RGB_to_lab(tile) + TileMean = [0,0,0] + TileStd = [1,1,1] + newMean = NormVec[0:3] + newStd = NormVec[3:6] + for i in range(3): + TileMean[i] = np.mean(Lab[:,:,i]) + TileStd[i] = np.std(Lab[:,:,i]) + # print("mean/std chanel " + str(i) + ": " + str(TileMean[i]) + " / " + str(TileStd[i])) + tmp = ((Lab[:,:,i] - TileMean[i]) * (newStd[i] / TileStd[i])) + newMean[i] + if i == 0: + tmp[tmp<0] = 0 + tmp[tmp>100] = 100 + Lab[:,:,i] = tmp + else: + tmp[tmp<-128] = 128 + tmp[tmp>127] = 127 + Lab[:,:,i] = tmp + tile = self.Lab_to_RGB(Lab) + return tile + + def run(self): + self._slide = open_slide(self._slidepath) + last_associated = None + dz = self._get_dz() + while True: + data = self._queue.get() + if data is None: + self._queue.task_done() + break + #associated, level, address, outfile = data + associated, level, address, outfile, format, outfile_bw, PercentMasked, SaveMasks, TileMask, Normalize = data + if last_associated != associated: + dz = self._get_dz(associated) + last_associated = associated + #try: + if True: + try: + tile = dz.get_tile(level, address) + # A single tile is being read + #check the percentage of the image with "information". Should be above 50% + gray = tile.convert('L') + bw = gray.point(lambda x: 0 if x<220 else 1, 'F') + arr = np.array(np.asarray(bw)) + avgBkg = np.average(bw) + bw = gray.point(lambda x: 0 if x<220 else 1, '1') + # check if the image is mostly background + + #print("res: " + outfile + " is " + str(avgBkg)) + if avgBkg <= (self._Bkg / 100.0): + # print("PercentMasked: %.6f, %.6f" % (PercentMasked, self._ROIpc / 100.0) ) + # if an Aperio selection was made, check if is within the selected region + if PercentMasked >= (self._ROIpc / 100.0): + + if Normalize != '': + print("normalize " + str(outfile)) + # arrtile = np.array(tile) + tile = Image.fromarray(self.normalize_tile(tile, Normalize).astype('uint8'),'RGB') + + tile.save(outfile, quality=self._quality) + if bool(SaveMasks)==True: + height = TileMask.shape[0] + width = TileMask.shape[1] + TileMaskO = np.zeros((height,width,3), 'uint8') + maxVal = float(TileMask.max()) + TileMaskO[...,0] = (TileMask[:,:].astype(float) / maxVal * 255.0).astype(int) + TileMaskO[...,1] = (TileMask[:,:].astype(float) / maxVal * 255.0).astype(int) + TileMaskO[...,2] = (TileMask[:,:].astype(float) / maxVal * 255.0).astype(int) + TileMaskO = numpy.array(Image.fromarray(TileMaskO).resize(arr.shape[0], arr.shape[1],3)) + # TileMaskO = imresize(TileMaskO, (arr.shape[0], arr.shape[1],3)) + TileMaskO[TileMaskO<10] = 0 + TileMaskO[TileMaskO>=10] = 255 + imsave(outfile_bw,TileMaskO) #(outfile_bw, quality=self._quality) + + #print("%s good: %f" %(outfile, avgBkg)) + #elif level>5: + # tile.save(outfile, quality=self._quality) + #print("%s empty: %f" %(outfile, avgBkg)) + self._queue.task_done() + except Exception as e: + # print(level, address) + print("image %s failed at dz.get_tile for level %f" % (self._slidepath, level)) + # e = sys.exc_info()[0] + print(e) + self._queue.task_done() + + def _get_dz(self, associated=None): + if associated is not None: + image = ImageSlide(self._slide.associated_images[associated]) + else: + image = self._slide + return DeepZoomGenerator(image, self._tile_size, self._overlap, limit_bounds=self._limit_bounds) + + +class DeepZoomImageTiler(object): + """Handles generation of tiles and metadata for a single image.""" + + def __init__(self, dz, basename, format, associated, queue, slide, basenameJPG, xmlfile, mask_type, xmlLabel, ROIpc, ImgExtension, SaveMasks, Mag, normalize, Fieldxml): + self._dz = dz + self._basename = basename + self._basenameJPG = basenameJPG + self._format = format + self._associated = associated + self._queue = queue + self._processed = 0 + self._slide = slide + self._xmlfile = xmlfile + self._mask_type = mask_type + self._xmlLabel = xmlLabel + self._ROIpc = ROIpc + self._ImgExtension = ImgExtension + self._SaveMasks = SaveMasks + self._Mag = Mag + self._normalize = normalize + self._Fieldxml = Fieldxml + + def run(self): + self._write_tiles() + self._write_dzi() + + def _write_tiles(self): + ########################################3 + # nc_added + #level = self._dz.level_count-1 + Magnification = 20 + tol = 2 + #get slide dimensions, zoom levels, and objective information + Factors = self._slide.level_downsamples + try: + Objective = float(self._slide.properties[openslide.PROPERTY_NAME_OBJECTIVE_POWER]) + # print(self._basename + " - Obj information found") + except: + print(self._basename + " - No Obj information found") + print(self._ImgExtension) + if ("jpg" in self._ImgExtension) | ("dcm" in self._ImgExtension) | ("tif" in self._ImgExtension): + #Objective = self._ROIpc + Objective = 1. + Magnification = Objective + print("input is jpg - will be tiled as such with %f" % Objective) + else: + return + #calculate magnifications + Available = tuple(Objective / x for x in Factors) + #find highest magnification greater than or equal to 'Desired' + Mismatch = tuple(x-Magnification for x in Available) + AbsMismatch = tuple(abs(x) for x in Mismatch) + if len(AbsMismatch) < 1: + print(self._basename + " - Objective field empty!") + return + ''' + if(min(AbsMismatch) <= tol): + Level = int(AbsMismatch.index(min(AbsMismatch))) + Factor = 1 + else: #pick next highest level, downsample + Level = int(max([i for (i, val) in enumerate(Mismatch) if val > 0])) + Factor = Magnification / Available[Level] + # end added + ''' + xml_valid = False + # a dir was provided for xml files + + ''' + ImgID = os.path.basename(self._basename) + Nbr_of_masks = 0 + if self._xmlfile != '': + xmldir = os.path.join(self._xmlfile, ImgID + '.xml') + print("xml:") + print(xmldir) + if os.path.isfile(xmldir): + xml_labels, xml_valid = self.xml_read_labels(xmldir) + Nbr_of_masks = len(xml_labels) + else: + print("No xml file found for slide %s.svs (expected: %s). Directory or xml file does not exist" % (ImgID, xmldir) ) + return + else: + Nbr_of_masks = 1 + ''' + + if True: + #if self._xmlfile != '' && : + # print(self._xmlfile, self._ImgExtension) + ImgID = os.path.basename(self._basename) + xmldir = os.path.join(self._xmlfile, ImgID + '.xml') + # print("xml:") + # print(xmldir) + if (self._xmlfile != '') & (self._ImgExtension != 'jpg') & (self._ImgExtension != 'dcm'): + # print("read xml file...") + mask, xml_valid, Img_Fact = self.xml_read(xmldir, self._xmlLabel, self._Fieldxml) + if xml_valid == False: + print("Error: xml %s file cannot be read properly - please check format" % xmldir) + return + elif (self._xmlfile != '') & (self._ImgExtension == 'dcm'): + # print("check mask for dcm") + mask, xml_valid, Img_Fact = self.jpg_mask_read(xmldir) + # mask <-- read mask + # Img_Fact <-- 1 + # xml_valid <-- True if mask file exists. + if xml_valid == False: + print("Error: xml %s file cannot be read properly - please check format" % xmldir) + return + + # print("current directory: %s" % self._basename) + + #return + #print(self._dz.level_count) + + for level in range(self._dz.level_count-1,-1,-1): + ThisMag = Available[0]/pow(2,self._dz.level_count-(level+1)) + if self._Mag > 0: + if ThisMag != self._Mag: + continue + ######################################## + #tiledir = os.path.join("%s_files" % self._basename, str(level)) + tiledir = os.path.join("%s_files" % self._basename, str(ThisMag)) + if not os.path.exists(tiledir): + os.makedirs(tiledir) + cols, rows = self._dz.level_tiles[level] + if xml_valid: + # print("xml valid") + '''# If xml file is used, check for each tile what are their corresponding coordinate in the base image + IndX_orig, IndY_orig = self._dz.level_tiles[-1] + CurrentLevel_ReductionFactor = (Img_Fact * float(self._dz.level_dimensions[-1][0]) / float(self._dz.level_dimensions[level][0])) + startIndX_current_level_conv = [int(i * CurrentLevel_ReductionFactor) for i in range(cols)] + print("***********") + endIndX_current_level_conv = [int(i * CurrentLevel_ReductionFactor) for i in range(cols)] + endIndX_current_level_conv.append(self._dz.level_dimensions[level][0]) + endIndX_current_level_conv.pop(0) + + startIndY_current_level_conv = [int(i * CurrentLevel_ReductionFactor) for i in range(rows)] + #endIndX_current_level_conv = [i * CurrentLevel_ReductionFactor - 1 for i in range(rows)] + endIndY_current_level_conv = [int(i * CurrentLevel_ReductionFactor) for i in range(rows)] + endIndY_current_level_conv.append(self._dz.level_dimensions[level][1]) + endIndY_current_level_conv.pop(0) + ''' + #startIndY_current_level_conv = [] + #endIndY_current_level_conv = [] + #startIndX_current_level_conv = [] + #endIndX_current_level_conv = [] + + #for row in range(rows): + # for col in range(cols): + # Dlocation, Dlevel, Dsize = self._dz.get_tile_coordinates(level,(col, row)) + # Ddimension = self._dz.get_tile_dimensions(level,(col, row)) + # startIndY_current_level_conv.append(int((Dlocation[1]) / Img_Fact)) + # endIndY_current_level_conv.append(int((Dlocation[1] + Ddimension[1]) / Img_Fact)) + # startIndX_current_level_conv.append(int((Dlocation[0]) / Img_Fact)) + # endIndX_current_level_conv.append(int((Dlocation[0] + Ddimension[0]) / Img_Fact)) + # print(Dlocation, Ddimension, int((Dlocation[1]) / Img_Fact), int((Dlocation[1] + Ddimension[1]) / Img_Fact), int((Dlocation[0]) / Img_Fact), int((Dlocation[0] + Ddimension[0]) / Img_Fact)) + for row in range(rows): + for col in range(cols): + InsertBaseName = False + if InsertBaseName: + tilename = os.path.join(tiledir, '%s_%d_%d.%s' % ( + self._basenameJPG, col, row, self._format)) + tilename_bw = os.path.join(tiledir, '%s_%d_%d_mask.%s' % ( + self._basenameJPG, col, row, self._format)) + else: + tilename = os.path.join(tiledir, '%d_%d.%s' % ( + col, row, self._format)) + tilename_bw = os.path.join(tiledir, '%d_%d_mask.%s' % ( + col, row, self._format)) + if xml_valid: + # compute percentage of tile in mask + # print(row, col) + # print(startIndX_current_level_conv[col]) + # print(endIndX_current_level_conv[col]) + # print(startIndY_current_level_conv[row]) + # print(endIndY_current_level_conv[row]) + # print(mask.shape) + # print(mask[startIndX_current_level_conv[col]:endIndX_current_level_conv[col], startIndY_current_level_conv[row]:endIndY_current_level_conv[row]]) + # TileMask = mask[startIndY_current_level_conv[row]:endIndY_current_level_conv[row], startIndX_current_level_conv[col]:endIndX_current_level_conv[col]] + # PercentMasked = mask[startIndY_current_level_conv[row]:endIndY_current_level_conv[row], startIndX_current_level_conv[col]:endIndX_current_level_conv[col]].mean() + # print(startIndY_current_level_conv[row], endIndY_current_level_conv[row], startIndX_current_level_conv[col], endIndX_current_level_conv[col]) + + Dlocation, Dlevel, Dsize = self._dz.get_tile_coordinates(level,(col, row)) + Ddimension = tuple([pow(2,(self._dz.level_count - 1 - level)) * x for x in self._dz.get_tile_dimensions(level,(col, row))]) + startIndY_current_level_conv = (int((Dlocation[1]) / Img_Fact)) + endIndY_current_level_conv = (int((Dlocation[1] + Ddimension[1]) / Img_Fact)) + startIndX_current_level_conv = (int((Dlocation[0]) / Img_Fact)) + endIndX_current_level_conv = (int((Dlocation[0] + Ddimension[0]) / Img_Fact)) + # print(Ddimension, Dlocation, Dlevel, Dsize, self._dz.level_count , level, col, row) + + #startIndY_current_level_conv = (int((Dlocation[1]) / Img_Fact)) + #endIndY_current_level_conv = (int((Dlocation[1] + Ddimension[1]) / Img_Fact)) + #startIndX_current_level_conv = (int((Dlocation[0]) / Img_Fact)) + #endIndX_current_level_conv = (int((Dlocation[0] + Ddimension[0]) / Img_Fact)) + TileMask = mask[startIndY_current_level_conv:endIndY_current_level_conv, startIndX_current_level_conv:endIndX_current_level_conv] + PercentMasked = mask[startIndY_current_level_conv:endIndY_current_level_conv, startIndX_current_level_conv:endIndX_current_level_conv].mean() + + # print(Ddimension, startIndY_current_level_conv, endIndY_current_level_conv, startIndX_current_level_conv, endIndX_current_level_conv) + + + if self._mask_type == 0: + # keep ROI outside of the mask + PercentMasked = 1.0 - PercentMasked + # print("Invert Mask percentage") + + # if PercentMasked > 0: + # print("PercentMasked_p %.3f" % (PercentMasked)) + # else: + # print("PercentMasked_0 %.3f" % (PercentMasked)) + + + else: + PercentMasked = 1.0 + TileMask = [] + + if not os.path.exists(tilename): + self._queue.put((self._associated, level, (col, row), + tilename, self._format, tilename_bw, PercentMasked, self._SaveMasks, TileMask, self._normalize)) + self._tile_done() + + def _tile_done(self): + self._processed += 1 + count, total = self._processed, self._dz.tile_count + if count % 100 == 0 or count == total: + #print("Tiling %s: wrote %d/%d tiles" % ( + # self._associated or 'slide', count, total), + # end='\r', file=sys.stderr) + if count == total: + print(file=sys.stderr) + + def _write_dzi(self): + with open('%s.dzi' % self._basename, 'w') as fh: + fh.write(self.get_dzi()) + + def get_dzi(self): + return self._dz.get_dzi(self._format) + + + def jpg_mask_read(self, xmldir): + # Original size of the image + ImgMaxSizeX_orig = float(self._dz.level_dimensions[-1][0]) + ImgMaxSizeY_orig = float(self._dz.level_dimensions[-1][1]) + # Number of centers at the highest resolution + cols, rows = self._dz.level_tiles[-1] + # Img_Fact = int(ImgMaxSizeX_orig / 1.0 / cols) + Img_Fact = 1 + try: + # xmldir: change extension from xml to *jpg + xmldir = xmldir[:-4] + "mask.jpg" + # xmlcontent = read xmldir image + xmlcontent = imread(xmldir) + xmlcontent = xmlcontent - np.min(xmlcontent) + mask = xmlcontent / np.max(xmlcontent) + # we want image between 0 and 1 + xml_valid = True + except: + xml_valid = False + print("error with minidom.parse(xmldir)") + return [], xml_valid, 1.0 + + return mask, xml_valid, Img_Fact + + + def xml_read(self, xmldir, Attribute_Name, Fieldxml): + + # Original size of the image + ImgMaxSizeX_orig = float(self._dz.level_dimensions[-1][0]) + ImgMaxSizeY_orig = float(self._dz.level_dimensions[-1][1]) + # Number of centers at the highest resolution + cols, rows = self._dz.level_tiles[-1] + + NewFact = max(ImgMaxSizeX_orig, ImgMaxSizeY_orig) / min(max(ImgMaxSizeX_orig, ImgMaxSizeY_orig),15000.0) + # Img_Fact = + # read_region(location, level, size) + # dz.get_tile_coordinates(14,(0,2)) + # ((0, 1792), 1, (320, 384)) + + Img_Fact = float(ImgMaxSizeX_orig) / 5.0 / float(cols) + + # print("image info:") + # print(ImgMaxSizeX_orig, ImgMaxSizeY_orig, cols, rows) + try: + xmlcontent = minidom.parse(xmldir) + xml_valid = True + except: + xml_valid = False + print("error with minidom.parse(xmldir)") + return [], xml_valid, 1.0 + + xy = {} + xy_neg = {} + NbRg = 0 + labelIDs = xmlcontent.getElementsByTagName('Annotation') + # print("%d labels" % len(labelIDs) ) + for labelID in labelIDs: + if (Attribute_Name==[]) | (Attribute_Name==''): + isLabelOK = True + else: + try: + labeltag = labelID.getElementsByTagName('Attribute')[0] + if (Attribute_Name==labeltag.attributes[Fieldxml].value): + # if (Attribute_Name==labeltag.attributes['Value'].value): + # if (Attribute_Name==labeltag.attributes['Name'].value): + isLabelOK = True + else: + isLabelOK = False + except: + isLabelOK = False + if Attribute_Name == "non_selected_regions": + isLabelOK = True + + #print("label ID, tag:") + #print(labelID, Attribute_Name, labeltag.attributes['Name'].value) + #if Attribute_Name==labeltag.attributes['Name'].value: + if isLabelOK: + regionlist = labelID.getElementsByTagName('Region') + for region in regionlist: + vertices = region.getElementsByTagName('Vertex') + NbRg += 1 + regionID = region.attributes['Id'].value + str(NbRg) + NegativeROA = region.attributes['NegativeROA'].value + # print("%d vertices" % len(vertices)) + if len(vertices) > 0: + #print( len(vertices) ) + if NegativeROA=="0": + xy[regionID] = [] + for vertex in vertices: + # get the x value of the vertex / convert them into index in the tiled matrix of the base image + # x = int(round(float(vertex.attributes['X'].value) / ImgMaxSizeX_orig * (cols*Img_Fact))) + # y = int(round(float(vertex.attributes['Y'].value) / ImgMaxSizeY_orig * (rows*Img_Fact))) + x = int(round(float(vertex.attributes['X'].value) / NewFact)) + y = int(round(float(vertex.attributes['Y'].value) / NewFact)) + xy[regionID].append((x,y)) + #print(vertex.attributes['X'].value, vertex.attributes['Y'].value, x, y ) + + elif NegativeROA=="1": + xy_neg[regionID] = [] + for vertex in vertices: + # get the x value of the vertex / convert them into index in the tiled matrix of the base image + # x = int(round(float(vertex.attributes['X'].value) / ImgMaxSizeX_orig * (cols*Img_Fact))) + # y = int(round(float(vertex.attributes['Y'].value) / ImgMaxSizeY_orig * (rows*Img_Fact))) + x = int(round(float(vertex.attributes['X'].value) / NewFact)) + y = int(round(float(vertex.attributes['Y'].value) / NewFact)) + xy_neg[regionID].append((x,y)) + + + #xy_a = np.array(xy[regionID]) + + # print("%d xy" % len(xy)) + #print(xy) + # print("%d xy_neg" % len(xy_neg)) + #print(xy_neg) + # print("Img_Fact:") + # print(NewFact) + # img = Image.new('L', (int(cols*Img_Fact), int(rows*Img_Fact)), 0) + img = Image.new('L', (int(ImgMaxSizeX_orig/NewFact), int(ImgMaxSizeY_orig/NewFact)), 0) + for regionID in xy.keys(): + xy_a = xy[regionID] + ImageDraw.Draw(img,'L').polygon(xy_a, outline=255, fill=255) + for regionID in xy_neg.keys(): + xy_a = xy_neg[regionID] + ImageDraw.Draw(img,'L').polygon(xy_a, outline=255, fill=0) + #img = img.resize((cols,rows), Image.ANTIALIAS) + mask = np.array(img) + #print(mask.shape) + if Attribute_Name == "non_selected_regions": + # scipy.misc.toimage(255-mask).save(os.path.join(os.path.split(self._basename[:-1])[0], "mask_" + os.path.basename(self._basename) + "_" + Attribute_Name + ".jpeg")) + Image.fromarray(255-mask).save(os.path.join(os.path.split(self._basename[:-1])[0], "mask_" + os.path.basename(self._basename) + "_" + Attribute_Name + ".jpeg")) + else: + if self._mask_type==0: + # scipy.misc.toimage(255-mask).save(os.path.join(os.path.split(self._basename[:-1])[0], "mask_" + os.path.basename(self._basename) + "_" + Attribute_Name + "_inv.jpeg")) + Image.fromarray(255-mask).save(os.path.join(os.path.split(self._basename[:-1])[0], "mask_" + os.path.basename(self._basename) + "_" + Attribute_Name + "_inv.jpeg")) + else: + # scipy.misc.toimage(mask).save(os.path.join(os.path.split(self._basename[:-1])[0], "mask_" + os.path.basename(self._basename) + "_" + Attribute_Name + ".jpeg")) + Image.fromarray(mask).save(os.path.join(os.path.split(self._basename[:-1])[0], "mask_" + os.path.basename(self._basename) + "_" + Attribute_Name + ".jpeg")) + #print(mask) + return mask / 255.0, xml_valid, NewFact + # Img_Fact + + +class DeepZoomStaticTiler(object): + """Handles generation of tiles and metadata for all images in a slide.""" + + def __init__(self, slidepath, basename, format, tile_size, overlap, + limit_bounds, quality, workers, with_viewer, Bkg, basenameJPG, xmlfile, mask_type, ROIpc, oLabel, ImgExtension, SaveMasks, Mag, normalize, Fieldxml): + if with_viewer: + # Check extra dependency before doing a bunch of work + import jinja2 + #print("line226 - %s " % (slidepath) ) + self._slide = open_slide(slidepath) + self._basename = basename + self._basenameJPG = basenameJPG + self._xmlfile = xmlfile + self._mask_type = mask_type + self._format = format + self._tile_size = tile_size + self._overlap = overlap + self._limit_bounds = limit_bounds + self._queue = JoinableQueue(2 * workers) + self._workers = workers + self._with_viewer = with_viewer + self._Bkg = Bkg + self._ROIpc = ROIpc + self._dzi_data = {} + self._xmlLabel = oLabel + self._ImgExtension = ImgExtension + self._SaveMasks = SaveMasks + self._Mag = Mag + self._normalize = normalize + self._Fieldxml = Fieldxml + + for _i in range(workers): + TileWorker(self._queue, slidepath, tile_size, overlap, + limit_bounds, quality, self._Bkg, self._ROIpc).start() + + def run(self): + self._run_image() + if self._with_viewer: + for name in self._slide.associated_images: + self._run_image(name) + self._write_html() + self._write_static() + self._shutdown() + + def _run_image(self, associated=None): + """Run a single image from self._slide.""" + if associated is None: + image = self._slide + if self._with_viewer: + basename = os.path.join(self._basename, VIEWER_SLIDE_NAME) + else: + basename = self._basename + else: + image = ImageSlide(self._slide.associated_images[associated]) + basename = os.path.join(self._basename, self._slugify(associated)) + # print("enter DeepZoomGenerator") + dz = DeepZoomGenerator(image, self._tile_size, self._overlap,limit_bounds=self._limit_bounds) + # print("enter DeepZoomImageTiler") + tiler = DeepZoomImageTiler(dz, basename, self._format, associated,self._queue, self._slide, self._basenameJPG, self._xmlfile, self._mask_type, self._xmlLabel, self._ROIpc, self._ImgExtension, self._SaveMasks, self._Mag, self._normalize, self._Fieldxml) + tiler.run() + self._dzi_data[self._url_for(associated)] = tiler.get_dzi() + + + + def _url_for(self, associated): + if associated is None: + base = VIEWER_SLIDE_NAME + else: + base = self._slugify(associated) + return '%s.dzi' % base + + def _write_html(self): + import jinja2 + env = jinja2.Environment(loader=jinja2.PackageLoader(__name__),autoescape=True) + template = env.get_template('slide-multipane.html') + associated_urls = dict((n, self._url_for(n)) + for n in self._slide.associated_images) + try: + mpp_x = self._slide.properties[openslide.PROPERTY_NAME_MPP_X] + mpp_y = self._slide.properties[openslide.PROPERTY_NAME_MPP_Y] + mpp = (float(mpp_x) + float(mpp_y)) / 2 + except (KeyError, ValueError): + mpp = 0 + # Embed the dzi metadata in the HTML to work around Chrome's + # refusal to allow XmlHttpRequest from file:///, even when + # the originating page is also a file:/// + data = template.render(slide_url=self._url_for(None),slide_mpp=mpp,associated=associated_urls, properties=self._slide.properties, dzi_data=json.dumps(self._dzi_data)) + with open(os.path.join(self._basename, 'index.html'), 'w') as fh: + fh.write(data) + + def _write_static(self): + basesrc = os.path.join(os.path.dirname(os.path.abspath(__file__)), + 'static') + basedst = os.path.join(self._basename, 'static') + self._copydir(basesrc, basedst) + self._copydir(os.path.join(basesrc, 'images'), + os.path.join(basedst, 'images')) + + def _copydir(self, src, dest): + if not os.path.exists(dest): + os.makedirs(dest) + for name in os.listdir(src): + srcpath = os.path.join(src, name) + if os.path.isfile(srcpath): + shutil.copy(srcpath, os.path.join(dest, name)) + + @classmethod + def _slugify(cls, text): + text = normalize('NFKD', text.lower()).encode('ascii', 'ignore').decode() + return re.sub('[^a-z0-9]+', '_', text) + + def _shutdown(self): + for _i in range(self._workers): + self._queue.put(None) + self._queue.join() + + + +def ImgWorker(queue): + # print("ImgWorker started") + while True: + cmd = queue.get() + if cmd is None: + queue.task_done() + break + # print("Execute: %s" % (cmd)) + subprocess.Popen(cmd, shell=True).wait() + queue.task_done() + +def xml_read_labels(xmldir, Fieldxml): + try: + xmlcontent = minidom.parse(xmldir) + xml_valid = True + except: + xml_valid = False + print("error with minidom.parse(xmldir)") + return [], xml_valid + labeltag = xmlcontent.getElementsByTagName('Attribute') + xml_labels = [] + for xmllabel in labeltag: + xml_labels.append(xmllabel.attributes[Fieldxml].value) + #xml_labels.append(xmllabel.attributes['Name'].value) + # xml_labels.append(xmllabel.attributes['Value'].value) + if xml_labels==[]: + xml_labels = [''] + # print(xml_labels) + return xml_labels, xml_valid + + +if __name__ == '__main__': + parser = OptionParser(usage='Usage: %prog [options] ') + + parser.add_option('-L', '--ignore-bounds', dest='limit_bounds', + default=True, action='store_false', + help='display entire scan area') + parser.add_option('-e', '--overlap', metavar='PIXELS', dest='overlap', + type='int', default=1, + help='overlap of adjacent tiles [1]') + parser.add_option('-f', '--format', metavar='{jpeg|png}', dest='format', + default='jpeg', + help='image format for tiles [jpeg]') + parser.add_option('-j', '--jobs', metavar='COUNT', dest='workers', + type='int', default=4, + help='number of worker processes to start [4]') + parser.add_option('-o', '--output', metavar='NAME', dest='basename', + help='base name of output file') + parser.add_option('-Q', '--quality', metavar='QUALITY', dest='quality', + type='int', default=90, + help='JPEG compression quality [90]') + parser.add_option('-r', '--viewer', dest='with_viewer', + action='store_true', + help='generate directory tree with HTML viewer') + parser.add_option('-s', '--size', metavar='PIXELS', dest='tile_size', + type='int', default=254, + help='tile size [254]') + parser.add_option('-B', '--Background', metavar='PIXELS', dest='Bkg', + type='float', default=50, + help='Max background threshold [50]; percentager of background allowed') + parser.add_option('-x', '--xmlfile', metavar='NAME', dest='xmlfile', + help='xml file if needed') + parser.add_option('-F', '--Fieldxml', metavar='{Name|Value}', dest='Fieldxml', + default='Value', + help='which field of the xml file is the label saved') + parser.add_option('-m', '--mask_type', metavar='COUNT', dest='mask_type', + type='int', default=1, + help='if xml file is used, keep tile within the ROI (1) or outside of it (0)') + parser.add_option('-R', '--ROIpc', metavar='PIXELS', dest='ROIpc', + type='float', default=50, + help='To be used with xml file - minimum percentage of tile covered by ROI (white)') + parser.add_option('-l', '--oLabelref', metavar='NAME', dest='oLabelref', + help='To be used with xml file - Only tile for label which contains the characters in oLabel') + parser.add_option('-S', '--SaveMasks', metavar='NAME', dest='SaveMasks', + default=False, + help='set to yes if you want to save ALL masks for ALL tiles (will be saved in same directory with suffix)') + parser.add_option('-t', '--tmp_dcm', metavar='NAME', dest='tmp_dcm', + help='base name of output folder to save intermediate dcm images converted to jpg (we assume the patient ID is the folder name in which the dcm images are originally saved)') + parser.add_option('-M', '--Mag', metavar='PIXELS', dest='Mag', + type='float', default=-1, + help='Magnification at which tiling should be done (-1 of all)') + parser.add_option('-N', '--normalize', metavar='NAME', dest='normalize', + help='if normalization is needed, N list the mean and std for each channel. For example \'57,22,-8,20,10,5\' with the first 3 numbers being the targeted means, and then the targeted stds') + + + + + (opts, args) = parser.parse_args() + + + try: + slidepath = args[0] + except IndexError: + parser.error('Missing slide argument') + if opts.basename is None: + opts.basename = os.path.splitext(os.path.basename(slidepath))[0] + if opts.xmlfile is None: + opts.xmlfile = '' + + try: + if opts.normalize is not None: + opts.normalize = [float(x) for x in opts.normalize.split(',')] + if len(opts.normalize) != 6: + opts.normalize = '' + parser.error("ERROR: NO NORMALIZATION APPLIED: input vector does not have the right length - 6 values expected") + else: + opts.normalize = '' + + except: + opts.normalize = '' + parser.error("ERROR: NO NORMALIZATION APPLIED: input vector does not have the right format") + #if ss != '': + # if os.path.isdir(opts.xmlfile): + + + # Initialization + # imgExample = "/ifs/home/coudrn01/NN/Lung/RawImages/*/*svs" + # tile_size = 512 + # max_number_processes = 10 + # NbrCPU = 4 + + # get images from the data/ file. + + files = glob(slidepath) + #ImgExtension = os.path.splitext(slidepath)[1] + ImgExtension = slidepath.split('*')[-1] + #files + #len(files) + # print(args) + # print(args[0]) + # print(slidepath) + # print(files) + # print("***********************") + + ''' + dz_queue = JoinableQueue() + procs = [] + print("Nb of processes:") + print(opts.max_number_processes) + for i in range(opts.max_number_processes): + p = Process(target = ImgWorker, args = (dz_queue,)) + #p.deamon = True + p.setDaemon = True + p.start() + procs.append(p) + ''' + files = sorted(files) + print(len(files), ' to process') + import time + time.sleep(5) + for imgNb in tqdm(range(len(files))): + filename = files[imgNb] + #print(filename) + opts.basenameJPG = os.path.splitext(os.path.basename(filename))[0] + #print("processing: " + opts.basenameJPG + " with extension: " + ImgExtension) + #opts.basenameJPG = os.path.splitext(os.path.basename(slidepath))[0] + #if os.path.isdir("%s_files" % (basename)): + # print("EXISTS") + #else: + # print("Not Found") + + if ("dcm" in ImgExtension) : + print("convert %s dcm to jpg" % filename) + if opts.tmp_dcm is None: + parser.error('Missing output folder for dcm>jpg intermediate files') + elif not os.path.isdir(opts.tmp_dcm): + parser.error('Missing output folder for dcm>jpg intermediate files') + + if filename[-3:] == 'jpg': + continue + ImageFile=dicom.read_file(filename) + im1 = ImageFile.pixel_array + maxVal = float(im1.max()) + minVal = float(im1.min()) + height = im1.shape[0] + width = im1.shape[1] + image = np.zeros((height,width,3), 'uint8') + image[...,0] = ((im1[:,:].astype(float) - minVal) / (maxVal - minVal) * 255.0).astype(int) + image[...,1] = ((im1[:,:].astype(float) - minVal) / (maxVal - minVal) * 255.0).astype(int) + image[...,2] = ((im1[:,:].astype(float) - minVal) / (maxVal - minVal) * 255.0).astype(int) + # dcm_ID = os.path.basename(os.path.dirname(filename)) + # opts.basenameJPG = dcm_ID + "_" + opts.basenameJPG + filename = os.path.join(opts.tmp_dcm, opts.basenameJPG + ".jpg") + # print(filename) + imsave(filename,image) + + output = os.path.join(opts.basename, opts.basenameJPG) + + try: + DeepZoomStaticTiler(filename, output, opts.format, opts.tile_size, opts.overlap, opts.limit_bounds, opts.quality, opts.workers, opts.with_viewer, opts.Bkg, opts.basenameJPG, opts.xmlfile, opts.mask_type, opts.ROIpc, '', ImgExtension, opts.SaveMasks, opts.Mag, opts.normalize, opts.Fieldxml).run() + except Exception as e: + print("Failed to process file %s, error: %s" % (filename, sys.exc_info()[0])) + print(e) + + #elif ("jpg" in ImgExtension) : + # output = os.path.join(opts.basename, opts.basenameJPG) + # if os.path.exists(output + "_files"): + # print("Image %s already tiled" % opts.basenameJPG) + # continue + + # DeepZoomStaticTiler(filename, output, opts.format, opts.tile_size, opts.overlap, opts.limit_bounds, opts.quality, opts.workers, opts.with_viewer, opts.Bkg, opts.basenameJPG, opts.xmlfile, opts.mask_type, opts.ROIpc, '', ImgExtension, opts.SaveMasks, opts.Mag, opts.normalize, opts.Fieldxml).run() + + elif opts.xmlfile != '': + xmldir = os.path.join(opts.xmlfile, opts.basenameJPG + '.xml') + # print("xml:") + # print(xmldir) + if os.path.isfile(xmldir): + if (opts.mask_type==1) or (opts.oLabelref!=''): + # either mask inside ROI, or mask outside but a reference label exist + xml_labels, xml_valid = xml_read_labels(xmldir, opts.Fieldxml) + if (opts.mask_type==1): + # No inverse mask + Nbr_ROIs_ForNegLabel = 1 + elif (opts.oLabelref!=''): + # Inverse mask and a label reference exist + Nbr_ROIs_ForNegLabel = 0 + + for oLabel in xml_labels: + # print("label is %s and ref is %s" % (oLabel, opts.oLabelref)) + if (opts.oLabelref in oLabel) or (opts.oLabelref==''): + # is a label is identified + if (opts.mask_type==0): + # Inverse mask and label exist in the image + Nbr_ROIs_ForNegLabel += 1 + # there is a label, and map is to be inverted + output = os.path.join(opts.basename, oLabel+'_inv', opts.basenameJPG) + if not os.path.exists(os.path.join(opts.basename, oLabel+'_inv')): + os.makedirs(os.path.join(opts.basename, oLabel+'_inv')) + else: + Nbr_ROIs_ForNegLabel += 1 + output = os.path.join(opts.basename, oLabel, opts.basenameJPG) + if not os.path.exists(os.path.join(opts.basename, oLabel)): + os.makedirs(os.path.join(opts.basename, oLabel)) + if 1: + #try: + DeepZoomStaticTiler(filename, output, opts.format, opts.tile_size, opts.overlap, opts.limit_bounds, opts.quality, opts.workers, opts.with_viewer, opts.Bkg, opts.basenameJPG, opts.xmlfile, opts.mask_type, opts.ROIpc, oLabel, ImgExtension, opts.SaveMasks, opts.Mag, opts.normalize, opts.Fieldxml).run() + #except: + # print("Failed to process file %s, error: %s" % (filename, sys.exc_info()[0])) + if Nbr_ROIs_ForNegLabel==0: + print("label %s is not in that image; invert everything" % (opts.oLabelref)) + # a label ref was given, and inverse mask is required but no ROI with this label in that map --> take everything + oLabel = opts.oLabelref + output = os.path.join(opts.basename, opts.oLabelref+'_inv', opts.basenameJPG) + if not os.path.exists(os.path.join(opts.basename, oLabel+'_inv')): + os.makedirs(os.path.join(opts.basename, oLabel+'_inv')) + if 1: + #try: + DeepZoomStaticTiler(filename, output, opts.format, opts.tile_size, opts.overlap, opts.limit_bounds, opts.quality, opts.workers, opts.with_viewer, opts.Bkg, opts.basenameJPG, opts.xmlfile, opts.mask_type, opts.ROIpc, oLabel, ImgExtension, opts.SaveMasks, opts.Mag, opts.normalize, opts.Fieldxml).run() + #except: + # print("Failed to process file %s, error: %s" % (filename, sys.exc_info()[0])) + + else: + # Background + oLabel = "non_selected_regions" + output = os.path.join(opts.basename, oLabel, opts.basenameJPG) + if not os.path.exists(os.path.join(opts.basename, oLabel)): + os.makedirs(os.path.join(opts.basename, oLabel)) + try: + DeepZoomStaticTiler(filename, output, opts.format, opts.tile_size, opts.overlap, opts.limit_bounds, opts.quality, opts.workers, opts.with_viewer, opts.Bkg, opts.basenameJPG, opts.xmlfile, opts.mask_type, opts.ROIpc, oLabel, ImgExtension, opts.SaveMasks, opts.Mag, opts.normalize, opts.Fieldxml).run() + except Exception as e: + print("Failed to process file %s, error: %s" % (filename, sys.exc_info()[0])) + print(e) + + else: + if (ImgExtension == ".jpg") | (ImgExtension == ".dcm") : + print("Input image to be tiled is jpg or dcm and not svs - will be treated as such") + output = os.path.join(opts.basename, opts.basenameJPG) + try: + DeepZoomStaticTiler(filename, output, opts.format, opts.tile_size, opts.overlap, opts.limit_bounds, opts.quality, opts.workers, opts.with_viewer, opts.Bkg, opts.basenameJPG, opts.xmlfile, opts.mask_type, opts.ROIpc, '', ImgExtension, opts.SaveMasks, opts.Mag, opts.normalize, opts.Fieldxml).run() + except Exception as e: + print("Failed to process file %s, error: %s" % (filename, sys.exc_info()[0])) + print(e) + + + else: + print("No xml file found for slide %s.svs (expected: %s). Directory or xml file does not exist" % (opts.basenameJPG, xmldir) ) + continue + else: + output = os.path.join(opts.basename, opts.basenameJPG) + if os.path.exists(output + "_files"): + print("Image %s already tiled" % opts.basenameJPG) + continue + try: + #if True: + DeepZoomStaticTiler(filename, output, opts.format, opts.tile_size, opts.overlap, opts.limit_bounds, opts.quality, opts.workers, opts.with_viewer, opts.Bkg, opts.basenameJPG, opts.xmlfile, opts.mask_type, opts.ROIpc, '', ImgExtension, opts.SaveMasks, opts.Mag, opts.normalize, opts.Fieldxml).run() + except Exception as e: + print("Failed to process file %s, error: %s" % (filename, sys.exc_info()[0])) + print(e) + ''' + dz_queue.join() + for i in range(opts.max_number_processes): + dz_queue.put( None ) + ''' + + print("End") diff --git a/src/vis_graphcam.py b/src/vis_graphcam.py new file mode 100644 index 0000000000000000000000000000000000000000..dec1c01aff96ddb7c1f03761b4b31cb1a02cec4c --- /dev/null +++ b/src/vis_graphcam.py @@ -0,0 +1,210 @@ +from PIL import Image +from matplotlib.pyplot import imshow, show +import matplotlib.pyplot as plt +from torchvision import models, transforms +from torch.autograd import Variable +from torch.nn import functional as F +import torch +import torch.nn as nn +from torch import topk +import numpy as np +import os +import skimage.transform +import cv2 +import math +import openslide +import argparse +import pickle + +def show_cam_on_image(img, mask): + heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) + heatmap = np.float32(heatmap) / 255 + cam = heatmap + np.float32(img) + cam = cam / np.max(cam) + return cam + +def cam_to_mask(gray, patches, cam_matrix, w, h, w_s, h_s): + mask = np.full_like(gray, 0.).astype(np.float32) + for ind1, patch in enumerate(patches): + x, y = patch.split('.')[0].split('_') + x, y = int(x), int(y) + #if y <5 or x>w-5 or y>h-5: + # continue + mask[int(y*h_s):int((y+1)*h_s), int(x*w_s):int((x+1)*w_s)].fill(cam_matrix[ind1][0]) + + return mask + +def main(args): + label_map = pickle.load(open(os.path.join(args.dataset_metadata_path, 'label_map.pkl'), 'rb')) + + label_name_from_id = dict() + for label_name, label_id in label_map.items(): + label_name_from_id[label_id] = label_name + + n_class = len(label_map)#args.n_class + file_name, label = open(args.path_file, 'r').readlines()[-1].split('\t') + label = label.rstrip().strip() + #site, file_name = file_name.split('/') + file_path = os.path.join(args.path_patches, '{}_files/20.0/'.format(file_name)) + print(file_name) + print(label) + + p = torch.load('graphcam/prob.pt').cpu().detach().numpy()[0] + file_path = os.path.join(args.path_patches, '{}_files/20.0/'.format(file_name)) + #ori = openslide.OpenSlide(os.path.join(args.path_WSI, '{}.svs').format(file_name)) + ORIGINAL_FILEPATH = os.path.join(args.path_WSI,'TCGA',label, '{}.svs'.format(file_name)) + print('L', ORIGINAL_FILEPATH) + ori = openslide.OpenSlide(ORIGINAL_FILEPATH) + patch_info = open(os.path.join(args.path_graph, file_name, 'c_idx.txt'), 'r') + + width, height = ori.dimensions + + REDUCTION_FACTOR = 10 + w, h = int(width/512), int(height/512) + w_r, h_r = int(width/20), int(height/20) + resized_img = ori.get_thumbnail((width,height))#ori.get_thumbnail((w_r,h_r)) + resized_img = resized_img.resize((w_r,h_r)) + ratio_w, ratio_h = width/resized_img.width, height/resized_img.height + print('ratios ', ratio_w, ratio_h) + w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR) + print(w_s, h_s) + + patch_info = patch_info.readlines() + patches = [] + xmax, ymax = 0, 0 + for patch in patch_info: + x, y = patch.strip('\n').split('\t') + if xmax < int(x): xmax = int(x) + if ymax < int(y): ymax = int(y) + patches.append('{}_{}.jpeg'.format(x,y)) + + output_img = np.asarray(resized_img)[:,:,::-1].copy() + #-----------------------------------------------------------------------------------------------------# + # GraphCAM + print('visulize GraphCAM') + assign_matrix = torch.load('graphcam/s_matrix_ori.pt') + m = nn.Softmax(dim=1) + assign_matrix = m(assign_matrix) + + # Thresholding for better visualization + p = np.clip(p, 0.4, 1) + + + + output_img_copy =np.copy(output_img) + gray = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + image_transformer_attribution = (output_img_copy - output_img_copy.min()) / (output_img_copy.max() - output_img_copy.min()) + cam_matrices = [] + masks = [] + visualizations = [] + print(len(patches)) + os.makedirs('graphcam_vis', exist_ok=True) + for class_i in range(n_class): + + # Load graphcam for each class + cam_matrix = torch.load(f'graphcam/cam_{class_i}.pt') + print(cam_matrix.shape) + cam_matrix = torch.mm(assign_matrix, cam_matrix.transpose(1,0)) + cam_matrix = cam_matrix.cpu() + print(assign_matrix.shape) + print(cam_matrix.shape) + # Normalize the graphcam + cam_matrix = (cam_matrix - cam_matrix.min()) / (cam_matrix.max() - cam_matrix.min()) + cam_matrix = cam_matrix.detach().numpy() + cam_matrix = p[class_i] * cam_matrix + cam_matrix = np.clip(cam_matrix, 0, 1) + print(cam_matrix.shape) + #print() + + + mask = cam_to_mask(gray, patches, cam_matrix, w, h, w_s, h_s) + print('mask shape ', mask.shape) + print('imgtf attr ', image_transformer_attribution.shape) + vis = show_cam_on_image(image_transformer_attribution, mask) + vis = np.uint8(255 * vis) + + cam_matrices.append(cam_matrix) + masks.append(mask) + visualizations.append(vis) + print() + cv2.imwrite('graphcam_vis/{}_all_types_cam_{}.png'.format(file_name, label_name_from_id[class_i] ), vis) + h, w, _ = output_img.shape + if h > w: + vis_merge = cv2.hconcat([output_img] + visualizations) + else: + vis_merge = cv2.vconcat([output_img] + visualizations) + + + cv2.imwrite('graphcam_vis/{}_all_types_cam_all.png'.format(file_name), vis_merge) + cv2.imwrite('graphcam_vis/{}_all_types_ori.png'.format(file_name ), output_img) + + ''' + # Load graphcam for differnet class + cam_matrix_0 = torch.load('graphcam/cam_0.pt') + cam_matrix_0 = torch.mm(assign_matrix, cam_matrix_0.transpose(1,0)) + cam_matrix_0 = cam_matrix_0.cpu() + cam_matrix_1 = torch.load('graphcam/cam_1.pt') + cam_matrix_1 = torch.mm(assign_matrix, cam_matrix_1.transpose(1,0)) + cam_matrix_1 = cam_matrix_1.cpu() + cam_matrix_2 = torch.load('graphcam/cam_2.pt') + cam_matrix_2 = torch.mm(assign_matrix, cam_matrix_2.transpose(1,0)) + cam_matrix_2 = cam_matrix_2.cpu() + + # Normalize the graphcam + cam_matrix_0 = (cam_matrix_0 - cam_matrix_0.min()) / (cam_matrix_0.max() - cam_matrix_0.min()) + cam_matrix_0 = cam_matrix_0.detach().numpy() + cam_matrix_0 = p[0] * cam_matrix_0 + cam_matrix_0 = np.clip(cam_matrix_0, 0, 1) + cam_matrix_1 = (cam_matrix_1 - cam_matrix_1.min()) / (cam_matrix_1.max() - cam_matrix_1.min()) + cam_matrix_1 = cam_matrix_1.detach().numpy() + cam_matrix_1 = p[1] * cam_matrix_1 + cam_matrix_1 = np.clip(cam_matrix_1, 0, 1) + cam_matrix_2 = (cam_matrix_2 - cam_matrix_2.min()) / (cam_matrix_2.max() - cam_matrix_2.min()) + cam_matrix_2 = cam_matrix_2.detach().numpy() + cam_matrix_2 = p[2] * cam_matrix_2 + cam_matrix_2 = np.clip(cam_matrix_2, 0, 1) + + output_img_copy =np.copy(output_img) + + gray = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + image_transformer_attribution = (output_img_copy - output_img_copy.min()) / (output_img_copy.max() - output_img_copy.min()) + + mask0 = cam_to_mask(gray, patches, cam_matrix_0, w, h, w_s, h_s) + vis0 = show_cam_on_image(image_transformer_attribution, mask0) + vis0 = np.uint8(255 * vis0) + mask1 = cam_to_mask(gray, patches, cam_matrix_1, w, h, w_s, h_s) + vis1 = show_cam_on_image(image_transformer_attribution, mask1) + vis1 = np.uint8(255 * vis1) + mask2 = cam_to_mask(gray, patches, cam_matrix_2, w, h, w_s, h_s) + vis2 = show_cam_on_image(image_transformer_attribution, mask2) + vis2 = np.uint8(255 * vis2) + + ########################################## + h, w, _ = output_img.shape + if h > w: + vis_merge = cv2.hconcat([output_img, vis0, vis1, vis2]) + else: + vis_merge = cv2.vconcat([output_img, vis0, vis1, vis2]) + + #cv2.imwrite('graphcam_vis/{}_{}_all_types_cam_all.png'.format(file_name, site), vis_merge) + + #cv2.imwrite('graphcam_vis/{}_{}_all_types_ori.png'.format(file_name, site), output_img) + #cv2.imwrite('graphcam_vis/{}_{}_all_types_cam_luad.png'.format(file_name, site), vis1) + #cv2.imwrite('graphcam_vis/{}_{}_all_types_cam_lscc.png'.format(file_name, site), vis2) + cv2.imwrite('graphcam_vis/{}_all_types_cam_all.png'.format(file_name, ), vis_merge) + + cv2.imwrite('graphcam_vis/{}_all_types_ori.png'.format(file_name ), output_img) + cv2.imwrite('graphcam_vis/{}_all_types_cam_luad.png'.format(file_name ), vis1) + cv2.imwrite('graphcam_vis/{}_all_types_cam_lscc.png'.format(file_name ), vis2) + + ''' + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='GraphCAM') + parser.add_argument('--path_file', type=str, default='test.txt', help='txt file contains test sample') + parser.add_argument('--path_patches', type=str, default='', help='') + parser.add_argument('--path_WSI', type=str, default='', help='') + parser.add_argument('--path_graph', type=str, default='', help='') + parser.add_argument('--dataset_metadata_path', type=str, help='Location of the metadata associated with the created dataset: label mapping, splits and so on') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/utils/__pycache__/dataset.cpython-38.pyc b/utils/__pycache__/dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b210956f1abfd54032c697a9462b73ddd18ea8f9 Binary files /dev/null and b/utils/__pycache__/dataset.cpython-38.pyc differ diff --git a/utils/__pycache__/lr_scheduler.cpython-38.pyc b/utils/__pycache__/lr_scheduler.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82a76349d767f6eddb063baf8e248265b4c8fc75 Binary files /dev/null and b/utils/__pycache__/lr_scheduler.cpython-38.pyc differ diff --git a/utils/__pycache__/metrics.cpython-38.pyc b/utils/__pycache__/metrics.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c91051c4f87985c42a6bf7b4e15eca2d30f50a44 Binary files /dev/null and b/utils/__pycache__/metrics.cpython-38.pyc differ diff --git a/utils/dataset.py b/utils/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..86766344717e223362c30caba1f70e5d8f99cef0 --- /dev/null +++ b/utils/dataset.py @@ -0,0 +1,147 @@ +import os +import torch +import torch.utils.data as data +import numpy as np +from PIL import Image, ImageFile +import random +from torchvision.transforms import ToTensor +from torchvision import transforms +import cv2 +import pickle +import torch.nn.functional as F + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +def collate_features(batch): + img = torch.cat([item[0] for item in batch], dim = 0) + coords = np.vstack([item[1] for item in batch]) + return [img, coords] + +def eval_transforms(pretrained=False): + if pretrained: + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + + else: + mean = (0.5,0.5,0.5) + std = (0.5,0.5,0.5) + + trnsfrms_val = transforms.Compose( + [ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize(mean = mean, std = std) + ] + ) + + return trnsfrms_val + +class GraphDataset(data.Dataset): + """input and label image dataset""" + + def __init__(self, root, ids, metadata_path, target_patch_size=-1): + super(GraphDataset, self).__init__() + """ + Args: + + fileDir(string): directory with all the input images. + transform(callable, optional): Optional transform to be applied on a sample + """ + self.root = root + self.ids = ids + #self.target_patch_size = target_patch_size + self.classdict = pickle.load(open(os.path.join(metadata_path, 'label_map.pkl'), 'rb' )) # {'normal': 0, 'luad': 1, 'lscc': 2} # + #self.classdict = {'normal': 0, 'tumor': 1} # + #self.classdict = {'Normal': 0, 'TCGA-LUAD': 1, 'TCGA-LUSC': 2} + self._up_kwargs = {'mode': 'bilinear'} + + def __getitem__(self, index): + sample = {} + info = self.ids[index].replace('\n', '') + #file_name, label = info.split('\t')[0].rsplit('.', 1)[0], info.split('\t')[1] + file_name, label = info.split('\t')[0], info.split('\t')[1] + + + sample['label'] = self.classdict[label] + sample['id'] = file_name + + + file_path = os.path.join(self.root, 'simclr_files') + #feature_path = os.path.join(self.root, file_name, 'features.pt') + feature_path = os.path.join(file_path, file_name, 'features.pt') + + if os.path.exists(feature_path): + features = torch.load(feature_path, map_location=lambda storage, loc: storage) + else: + print(feature_path + ' not exists') + features = torch.zeros(1, 512) + + #adj_s_path = os.path.join(self.root, file_name, 'adj_s.pt') + adj_s_path = os.path.join(file_path, file_name, 'adj_s.pt') + if os.path.exists(adj_s_path): + adj_s = torch.load(adj_s_path, map_location=lambda storage, loc: storage) + else: + print(adj_s_path + ' not exists') + adj_s = torch.ones(features.shape[0], features.shape[0]) + + #features = features.unsqueeze(0) + sample['image'] = features + sample['adj_s'] = adj_s #adj_s.to(torch.double) + # return {'image': image.astype(np.float32), 'label': label.astype(np.int64)} + + return sample + + + def __len__(self): + return len(self.ids) + + +''' def __getitem__(self, index): + sample = {} + info = self.ids[index].replace('\n', '') + file_name, label = info.split('\t')[0].rsplit('.', 1)[0], info.split('\t')[1] + site, file_name = file_name.split('/') + + # if site =='CCRCC': + # file_path = self.root + 'CPTAC_CCRCC_features/simclr_files' + if site =='LUAD' or site =='LSCC': + site = 'LUNG' + file_path = self.root + 'CPTAC_{}_features/simclr_files'.format(site) #_pre# with # rushin + + # For NLST only + if site =='NLST': + file_path = self.root + 'NLST_Lung_features/simclr_files' + + # For TCGA only + if site =='TCGA': + file_name = info.split('\t')[0] + _, file_name = file_name.split('/') + file_path = self.root + 'TCGA_LUNG_features/simclr_files' #_resnet_with + + sample['label'] = self.classdict[label] + sample['id'] = file_name + + #feature_path = os.path.join(self.root, file_name, 'features.pt') + feature_path = os.path.join(file_path, file_name, 'features.pt') + + if os.path.exists(feature_path): + features = torch.load(feature_path, map_location=lambda storage, loc: storage) + else: + print(feature_path + ' not exists') + features = torch.zeros(1, 512) + + #adj_s_path = os.path.join(self.root, file_name, 'adj_s.pt') + adj_s_path = os.path.join(file_path, file_name, 'adj_s.pt') + if os.path.exists(adj_s_path): + adj_s = torch.load(adj_s_path, map_location=lambda storage, loc: storage) + else: + print(adj_s_path + ' not exists') + adj_s = torch.ones(features.shape[0], features.shape[0]) + + #features = features.unsqueeze(0) + sample['image'] = features + sample['adj_s'] = adj_s #adj_s.to(torch.double) + # return {'image': image.astype(np.float32), 'label': label.astype(np.int64)} + + return sample +''' \ No newline at end of file diff --git a/utils/lr_scheduler.py b/utils/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..5913949bdddf15602e55f82985c0a6b7b6656e23 --- /dev/null +++ b/utils/lr_scheduler.py @@ -0,0 +1,71 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: Hang Zhang +## ECE Department, Rutgers University +## Email: zhang.hang@rutgers.edu +## Copyright (c) 2017 +## +## This source code is licensed under the MIT-style license found in the +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import math + +class LR_Scheduler(object): + """Learning Rate Scheduler + + Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` + + Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` + + Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` + + Args: + args: :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), + :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, + :attr:`args.lr_step` + + iters_per_epoch: number of iterations per epoch + """ + def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, + lr_step=0, warmup_epochs=0): + self.mode = mode + print('Using {} LR Scheduler!'.format(self.mode)) + self.lr = base_lr + if mode == 'step': + assert lr_step + self.lr_step = lr_step + self.iters_per_epoch = iters_per_epoch + self.N = num_epochs * iters_per_epoch + self.epoch = -1 + self.warmup_iters = warmup_epochs * iters_per_epoch + + def __call__(self, optimizer, i, epoch, best_pred): + T = epoch * self.iters_per_epoch + i + if self.mode == 'cos': + lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) + elif self.mode == 'poly': + lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) + elif self.mode == 'step': + lr = self.lr * (0.1 ** (epoch // self.lr_step)) + else: + raise NotImplemented + # warm up lr schedule + if self.warmup_iters > 0 and T < self.warmup_iters: + lr = lr * 1.0 * T / self.warmup_iters + if epoch > self.epoch: + print('\n=>Epoches %i, learning rate = %.7f, \ + previous best = %.4f' % (epoch+1, lr, best_pred)) + self.epoch = epoch + assert lr >= 0 + self._adjust_learning_rate(optimizer, lr) + + def _adjust_learning_rate(self, optimizer, lr): + if len(optimizer.param_groups) == 1: + optimizer.param_groups[0]['lr'] = lr + else: + # enlarge the lr at the head + for i in range(len(optimizer.param_groups)): + if optimizer.param_groups[i]['lr'] > 0: optimizer.param_groups[i]['lr'] = lr + # optimizer.param_groups[0]['lr'] = lr + # for i in range(1, len(optimizer.param_groups)): + # optimizer.param_groups[i]['lr'] = lr * 10 diff --git a/utils/metrics.py b/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..6b7799650cc535052f0f9753cd8317f2e740830e --- /dev/null +++ b/utils/metrics.py @@ -0,0 +1,47 @@ +# Adapted from score written by wkentaro +# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py + +import numpy as np + +class ConfusionMatrix(object): + + def __init__(self, n_classes): + self.n_classes = n_classes + # axis = 0: prediction + # axis = 1: target + self.confusion_matrix = np.zeros((n_classes, n_classes)) + + def _fast_hist(self, label_true, label_pred, n_class): + hist = np.zeros((n_class, n_class)) + hist[label_pred, label_true] += 1 + + return hist + + def update(self, label_trues, label_preds): + for lt, lp in zip(label_trues, label_preds): + tmp = self._fast_hist(lt.item(), lp.item(), self.n_classes) #lt.item(), lp.item() + self.confusion_matrix += tmp + + def get_scores(self): + """Returns accuracy score evaluation result. + - overall accuracy + - mean accuracy + - mean IU + - fwavacc + """ + hist = self.confusion_matrix + # accuracy is recall/sensitivity for each class, predicted TP / all real positives + # axis in sum: perform summation along + + if sum(hist.sum(axis=1)) != 0: + acc = sum(np.diag(hist)) / sum(hist.sum(axis=1)) + else: + acc = 0.0 + + return acc + + def plotcm(self): + print(self.confusion_matrix) + + def reset(self): + self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) \ No newline at end of file diff --git a/weights/feature_extractor/config.yaml b/weights/feature_extractor/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4c8f4309e6cbefa7270b1beb7c639d9551b325a8 --- /dev/null +++ b/weights/feature_extractor/config.yaml @@ -0,0 +1,23 @@ +batch_size: 256 +epochs: 20 +eval_every_n_epochs: 1 +fine_tune_from: '' +log_every_n_steps: 25 +weight_decay: 10e-6 +fp16_precision: False +n_gpu: 2 +gpu_ids: (0,1) + +model: + out_dim: 512 + base_model: "resnet18" + +dataset: + s: 1 + input_shape: (224,224,3) + num_workers: 10 + valid_size: 0.1 + +loss: + temperature: 0.5 + use_cosine_similarity: True diff --git a/weights/feature_extractor/model.pth b/weights/feature_extractor/model.pth new file mode 100644 index 0000000000000000000000000000000000000000..86030da63ff1604b29d80d2bbed137d96526439d --- /dev/null +++ b/weights/feature_extractor/model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c156857743ee3bd7b353fe8d34d2250e4c153e930c457b929ec461cafcd15fe4 +size 46779101 diff --git a/weights/graph_transformer/GraphCAM.pth b/weights/graph_transformer/GraphCAM.pth new file mode 100644 index 0000000000000000000000000000000000000000..7593168e87fc9e8b7adc62c297c80e2f190b768f --- /dev/null +++ b/weights/graph_transformer/GraphCAM.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ab59c0a07d8a566f22ccece1d0ba4f05271be0c3927c5362bb4b7e220c432cb +size 577432