{ "cells": [ { "cell_type": "code", "execution_count": 3, "id": "83d8d249-affe-45dd-915e-992b4b35b31a", "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import pandas as pd\n", "import deepsort\n", "from sklearn.metrics import accuracy_score, f1_score\n", "from tqdm.notebook import tqdm\n", "import pickle" ] }, { "cell_type": "code", "execution_count": 4, "id": "25de46ec-8a41-484d-8e14-d2b19768fc2c", "metadata": {}, "outputs": [], "source": [ "def compute_metrics(labels, preds):\n", "\n", " # calculate accuracy and macro f1 using sklearn's function\n", " acc = accuracy_score(labels, preds)\n", " macro_f1 = f1_score(labels, preds, average='macro')\n", " return {\n", " 'accuracy': acc,\n", " 'macro_f1': macro_f1\n", " }" ] }, { "cell_type": "code", "execution_count": 5, "id": "a4029b2b-afca-4300-82a2-082fec59f191", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['pancreas',\n", " 'liver',\n", " 'blood',\n", " 'lung',\n", " 'spleen',\n", " 'placenta',\n", " 'colorectum',\n", " 'kidney',\n", " 'brain']" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rootdir = \"/path/to/data/\"\n", "\n", "dir_list = []\n", "for dir_i in os.listdir(rootdir):\n", " if (\"results\" not in dir_i) & (os.path.isdir(os.path.join(rootdir, dir_i))):\n", " dir_list += [dir_i]\n", "dir_list" ] }, { "cell_type": "code", "execution_count": null, "id": "ddcdc5cd-871e-4fd2-8457-18d3049fa76c", "metadata": { "tags": [] }, "outputs": [], "source": [ "output_dir = \"results_EDefault_filtered\"\n", "n_epochs = \"Default\" # scDeepsort default epochs = 300\n", "\n", "results_dict = dict()\n", "for dir_name in tqdm(dir_list):\n", " print(f\"TRAINING: {dir_name}\")\n", " subrootdir = f\"{rootdir}{dir_name}/\"\n", " train_files = [(f\"{subrootdir}{dir_name}_filtered_data_train.csv\",f\"{subrootdir}{dir_name}_filtered_celltype_train.csv\")]\n", " test_file = f\"{subrootdir}{dir_name}_filtered_data_test.csv\"\n", " label_file = f\"{subrootdir}{dir_name}_filtered_celltype_test.csv\"\n", " \n", " # define the model\n", " model = deepsort.DeepSortClassifier(species='human',\n", " tissue=dir_name,\n", " gpu_id=0,\n", " random_seed=1,\n", " validation_fraction=0) # use all training data (already held out 20% in test data file)\n", "\n", " # fit the model\n", " model.fit(train_files, save_path=f\"{subrootdir}{output_dir}\")\n", " \n", " # use the saved model to predict cell types in test data\n", " model.predict(input_file=test_file,\n", " model_path=f\"{subrootdir}{output_dir}\",\n", " save_path=f\"{subrootdir}{output_dir}\",\n", " unsure_rate=0,\n", " file_type='csv')\n", " labels_df = pd.read_csv(label_file)\n", " preds_df = pd.read_csv(f\"{subrootdir}{output_dir}/human_{dir_name}_{dir_name}_filtered_data_test.csv\")\n", " label_cell_ids = labels_df[\"Cell\"]\n", " pred_cell_ids = preds_df[\"index\"]\n", " assert list(label_cell_ids) == list(pred_cell_ids)\n", " labels = list(labels_df[\"Cell_type\"])\n", " if isinstance(preds_df[\"cell_subtype\"][0],float):\n", " if np.isnan(preds_df[\"cell_subtype\"][0]):\n", " preds = list(preds_df[\"cell_type\"])\n", " results = compute_metrics(labels, preds)\n", " else:\n", " preds1 = list(preds_df[\"cell_type\"])\n", " preds2 = list(preds_df[\"cell_subtype\"])\n", " results1 = compute_metrics(labels, preds1)\n", " results2 = compute_metrics(labels, preds2)\n", " if results2[\"accuracy\"] > results1[\"accuracy\"]:\n", " results = results2\n", " else:\n", " results = results1\n", " \n", " print(f\"{dir_name}: {results}\")\n", " results_dict[dir_name] = results\n", " with open(f\"{subrootdir}deepsort_E{n_epochs}_filtered_pred_{dir_name}.pickle\", \"wb\") as output_file:\n", " pickle.dump(results, output_file)\n", "\n", "# save results\n", "with open(f\"{rootdir}deepsort_E{n_epochs}_filtered_pred_dict.pickle\", \"wb\") as output_file:\n", " pickle.dump(results_dict, output_file)\n", " " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.6 64-bit ('3.8.6')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" }, "vscode": { "interpreter": { "hash": "eba1599a1f7e611c14c87ccff6793920aa63510b01fc0e229d6dd014149b8829" } } }, "nbformat": 4, "nbformat_minor": 5 }