import numpy as np from sklearn.cluster import KMeans from sklearn.decomposition import PCA from sklearn.manifold import TSNE from lifelines.statistics import logrank_test from itertools import combinations import matplotlib.pyplot as plt from yellowbrick.cluster import KElbowVisualizer import pandas as pd import seaborn as sns from lifelines import KaplanMeierFitter import matplotlib.cm as cm import itertools import torch class GraphAnalysis: def __init__(self, EXTRACTER): self.extracter = EXTRACTER self.process() def process(self): latent_features_list = list(self.extracter.latent_feat_dict.values()) patient_list = list(self.extracter.latent_feat_dict.keys()) latentF = torch.stack(latent_features_list, dim=0) self.latentF = np.squeeze(latentF.numpy()) self.pIDs = patient_list self.df = pd.DataFrame(columns=['PC1','PC2','tX','tY','groups'], index=self.pIDs) self.clnc_df = pd.read_csv('./data/survival.hnsc_data.csv').set_index('PatientID') self.df = self.df.join(self.clnc_df) def pca_tsne(self): pca = PCA(n_components=2) X_pca = pca.fit_transform(self.latentF) self.df['PC1'] = X_pca[:,0] self.df['PC2'] = X_pca[:,1] tsne = TSNE(n_components=2) X_tsne = tsne.fit_transform(self.latentF) self.df['tX'] = X_tsne[:,0] self.df['tY'] = X_tsne[:,1] def find_optimal_clusters(self, min_clusters=2, max_clusters=11, save_path='./results/kelbow'): model = KMeans(random_state=42) visualizer = KElbowVisualizer(model, k=(min_clusters, max_clusters)) visualizer.fit(self.latentF) visualizer.show() fig = visualizer.ax.get_figure() fig.savefig(save_path + ".png", dpi=150) fig.savefig(save_path + ".jpeg", format="jpeg", dpi=150) self.optimal_clusters = visualizer.elbow_value_ def cluster_data(self): if self.optimal_clusters is None: raise ValueError("Please run 'find_optimal_clusters' method before clustering the data.") kmeans = KMeans(n_clusters=self.optimal_clusters, random_state=0).fit(self.latentF) self.labels = kmeans.labels_ self.df['groups'] = self.labels self.generate_color_list_based_on_median_survival() def cluster_data2(self, kclust): kmeans = KMeans(n_clusters=kclust, random_state=0).fit(self.latentF) self.labels = kmeans.labels_ self.df['groups'] = self.labels self.generate_color_list_based_on_median_survival() def visualize_clusters(self): plt.figure(figsize=(20,8)) plt.subplot(1,2,1) sns.scatterplot(data=self.df, x='PC1', y='PC2', hue='groups', palette=self.color_list) plt.subplot(1,2,2) sns.scatterplot(data=self.df, x='tX', y='tY', hue='groups', palette=self.color_list) def save_visualize_clusters(self): plt.figure(figsize=(10,8)) sns.scatterplot(data=self.df, x='PC1', y='PC2', hue='groups', palette=self.color_list) plt.savefig('./results/temp_pca.jpeg', dpi=300) plt.savefig('./results/temp_pca.png', dpi=300) plt.close() plt.figure(figsize=(10,8)) sns.scatterplot(data=self.df, x='tX', y='tY', hue='groups', palette=self.color_list) plt.savefig('./results/temp_tsne.jpeg', dpi=300) plt.savefig('./results/temp_tsne.png', dpi=300) def map_group_to_color(group): return self.color_list[group] def generate_color_list_based_on_median_survival(self): groups = self.df['groups'].unique() median_survival_times = {group: self.df[self.df['groups'] == group]['Overall Survival (Months)'].median() for group in groups} sorted_groups = sorted(groups, key=median_survival_times.get, reverse=True) vibgyor_colors = cm.rainbow(np.linspace(0, 1, len(groups))) self.color_list = {group: color for group, color in zip(sorted_groups, vibgyor_colors)} def perform_log_rank_test(self, alpha=0.05): if self.df is None: raise ValueError("Please run 'cluster_data' or 'cluster_data2' method before performing log rank test.") groups = self.df['groups'].unique() significant_pairs = [] for pair in itertools.combinations(groups, 2): group_a = self.df[self.df['groups'] == pair[0]] group_b = self.df[self.df['groups'] == pair[1]] results = logrank_test(group_a['Overall Survival (Months)'], group_b['Overall Survival (Months)'], group_a['Overall Survival Status'], group_b['Overall Survival Status']) if results.p_value < alpha: significant_pairs.append(pair) self.significant_pairs = significant_pairs return self.significant_pairs def generate_summary_table(self): groups = self.df['groups'].unique() summary_table = pd.DataFrame(columns=['Total number of patients', 'Alive', 'Deceased', 'Median survival time'], index=groups) for group in groups: group_data = self.df[self.df['groups'] == group] total_patients = len(group_data) alive = len(group_data[group_data['Overall Survival Status'] == 0]) deceased = len(group_data[group_data['Overall Survival Status'] == 1]) kmf = KaplanMeierFitter() kmf.fit(group_data['Overall Survival (Months)'], group_data['Overall Survival Status']) median_survival_time = kmf.median_survival_time_ summary_table.loc[group] = [total_patients, alive, deceased, median_survival_time] return summary_table def plot_kaplan_meier(self, plot_for_groups=True, name='temp_k5'): kmf = KaplanMeierFitter() plt.figure(figsize=(8, 6)) plt.grid(False) if plot_for_groups: groups = sorted(self.df['groups'].unique()) for i, group in enumerate(groups): group_data = self.df[self.df['groups'] == group] kmf.fit(group_data['Overall Survival (Months)'], group_data['Overall Survival Status'], label=f'Group {group}') kmf.plot(ci_show=False, linewidth=2, color=self.color_list[group]) plt.title("Kaplan-Meier Curves for Each Group") else: kmf.fit(self.df['Overall Survival (Months)'], self.df['Overall Survival Status'], label='All Data') kmf.plot(ci_show=False, linewidth=2, color='black') plt.title("Kaplan-Meier Curve for All Data") plt.gca().set_facecolor('#f5f5f5') plt.grid(color='lightgrey', linestyle='-', linewidth=0.5) plt.xlabel("Overall Survival (Months)", fontweight='bold') plt.ylabel("Survival Probability", fontweight='bold') plt.legend() plt.savefig('./results/{}_plan_meir.jpeg'.format(name), dpi=300) plt.savefig('./results/{}_plan_meir.png'.format(name), dpi=300) plt.show() def club_two_groups(self, primary_group, secondary_group): self.df.loc[self.df['groups'] == secondary_group, 'groups'] = primary_group unique_groups = sorted(self.df['groups'].unique()) mapping = {old: new for new, old in enumerate(unique_groups)} self.df['groups'] = self.df['groups'].map(mapping) self.generate_color_list_based_on_median_survival() self.summary_table = self.generate_summary_table() def plot_median_survival_bar(self, name='temp_k5'): summary_df = self.generate_summary_table() summary_df['group'] = summary_df.index max_val = summary_df["Median survival time"].replace(np.inf, np.nan).max() summary_df["Display Median"] = summary_df["Median survival time"].replace(np.inf, max_val * 1.1) summary_df = summary_df.sort_index() colors = [self.color_list[group] for group in summary_df.index] num_groups = len(summary_df) plt.figure(figsize=(6, num_groups * 0.8)) plt.grid(False) sns.barplot(data=summary_df, y='group', x="Display Median", palette=colors, orient="h", order=summary_df.index) plt.xlabel("Median Survival Time (Months)") plt.ylabel("Groups") plt.title("Median Survival Time by Group") plt.tight_layout() plt.savefig('./results/{}_median_survival.jpeg'.format(name), dpi=300) plt.savefig('./results/{}_median_survival.png'.format(name), dpi=300) plt.show()