VatsalPatel18's picture
Model files
c238491
raw
history blame
8.47 kB
import numpy as np
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from lifelines.statistics import logrank_test
from itertools import combinations
import matplotlib.pyplot as plt
from yellowbrick.cluster import KElbowVisualizer
import pandas as pd
import seaborn as sns
from lifelines import KaplanMeierFitter
import matplotlib.cm as cm
import itertools
import torch
class GraphAnalysis:
def __init__(self, EXTRACTER):
self.extracter = EXTRACTER
self.process()
def process(self):
latent_features_list = list(self.extracter.latent_feat_dict.values())
patient_list = list(self.extracter.latent_feat_dict.keys())
latentF = torch.stack(latent_features_list, dim=0)
self.latentF = np.squeeze(latentF.numpy())
self.pIDs = patient_list
self.df = pd.DataFrame(columns=['PC1','PC2','tX','tY','groups'], index=self.pIDs)
self.clnc_df = pd.read_csv('./data/survival.hnsc_data.csv').set_index('PatientID')
self.df = self.df.join(self.clnc_df)
def pca_tsne(self):
pca = PCA(n_components=2)
X_pca = pca.fit_transform(self.latentF)
self.df['PC1'] = X_pca[:,0]
self.df['PC2'] = X_pca[:,1]
tsne = TSNE(n_components=2)
X_tsne = tsne.fit_transform(self.latentF)
self.df['tX'] = X_tsne[:,0]
self.df['tY'] = X_tsne[:,1]
def find_optimal_clusters(self, min_clusters=2, max_clusters=11, save_path='./results/kelbow'):
model = KMeans(random_state=42)
visualizer = KElbowVisualizer(model, k=(min_clusters, max_clusters))
visualizer.fit(self.latentF)
visualizer.show()
fig = visualizer.ax.get_figure()
fig.savefig(save_path + ".png", dpi=150)
fig.savefig(save_path + ".jpeg", format="jpeg", dpi=150)
self.optimal_clusters = visualizer.elbow_value_
def cluster_data(self):
if self.optimal_clusters is None:
raise ValueError("Please run 'find_optimal_clusters' method before clustering the data.")
kmeans = KMeans(n_clusters=self.optimal_clusters, random_state=0).fit(self.latentF)
self.labels = kmeans.labels_
self.df['groups'] = self.labels
self.generate_color_list_based_on_median_survival()
def cluster_data2(self, kclust):
kmeans = KMeans(n_clusters=kclust, random_state=0).fit(self.latentF)
self.labels = kmeans.labels_
self.df['groups'] = self.labels
self.generate_color_list_based_on_median_survival()
def visualize_clusters(self):
plt.figure(figsize=(20,8))
plt.subplot(1,2,1)
sns.scatterplot(data=self.df, x='PC1', y='PC2', hue='groups', palette=self.color_list)
plt.subplot(1,2,2)
sns.scatterplot(data=self.df, x='tX', y='tY', hue='groups', palette=self.color_list)
def save_visualize_clusters(self):
plt.figure(figsize=(10,8))
sns.scatterplot(data=self.df, x='PC1', y='PC2', hue='groups', palette=self.color_list)
plt.savefig('./results/temp_pca.jpeg', dpi=300)
plt.savefig('./results/temp_pca.png', dpi=300)
plt.close()
plt.figure(figsize=(10,8))
sns.scatterplot(data=self.df, x='tX', y='tY', hue='groups', palette=self.color_list)
plt.savefig('./results/temp_tsne.jpeg', dpi=300)
plt.savefig('./results/temp_tsne.png', dpi=300)
def map_group_to_color(group):
return self.color_list[group]
def generate_color_list_based_on_median_survival(self):
groups = self.df['groups'].unique()
median_survival_times = {group: self.df[self.df['groups'] == group]['Overall Survival (Months)'].median() for group in groups}
sorted_groups = sorted(groups, key=median_survival_times.get, reverse=True)
vibgyor_colors = cm.rainbow(np.linspace(0, 1, len(groups)))
self.color_list = {group: color for group, color in zip(sorted_groups, vibgyor_colors)}
def perform_log_rank_test(self, alpha=0.05):
if self.df is None:
raise ValueError("Please run 'cluster_data' or 'cluster_data2' method before performing log rank test.")
groups = self.df['groups'].unique()
significant_pairs = []
for pair in itertools.combinations(groups, 2):
group_a = self.df[self.df['groups'] == pair[0]]
group_b = self.df[self.df['groups'] == pair[1]]
results = logrank_test(group_a['Overall Survival (Months)'], group_b['Overall Survival (Months)'], group_a['Overall Survival Status'], group_b['Overall Survival Status'])
if results.p_value < alpha:
significant_pairs.append(pair)
self.significant_pairs = significant_pairs
return self.significant_pairs
def generate_summary_table(self):
groups = self.df['groups'].unique()
summary_table = pd.DataFrame(columns=['Total number of patients', 'Alive', 'Deceased', 'Median survival time'], index=groups)
for group in groups:
group_data = self.df[self.df['groups'] == group]
total_patients = len(group_data)
alive = len(group_data[group_data['Overall Survival Status'] == 0])
deceased = len(group_data[group_data['Overall Survival Status'] == 1])
kmf = KaplanMeierFitter()
kmf.fit(group_data['Overall Survival (Months)'], group_data['Overall Survival Status'])
median_survival_time = kmf.median_survival_time_
summary_table.loc[group] = [total_patients, alive, deceased, median_survival_time]
return summary_table
def plot_kaplan_meier(self, plot_for_groups=True, name='temp_k5'):
kmf = KaplanMeierFitter()
plt.figure(figsize=(8, 6))
plt.grid(False)
if plot_for_groups:
groups = sorted(self.df['groups'].unique())
for i, group in enumerate(groups):
group_data = self.df[self.df['groups'] == group]
kmf.fit(group_data['Overall Survival (Months)'], group_data['Overall Survival Status'], label=f'Group {group}')
kmf.plot(ci_show=False, linewidth=2, color=self.color_list[group])
plt.title("Kaplan-Meier Curves for Each Group")
else:
kmf.fit(self.df['Overall Survival (Months)'], self.df['Overall Survival Status'], label='All Data')
kmf.plot(ci_show=False, linewidth=2, color='black')
plt.title("Kaplan-Meier Curve for All Data")
plt.gca().set_facecolor('#f5f5f5')
plt.grid(color='lightgrey', linestyle='-', linewidth=0.5)
plt.xlabel("Overall Survival (Months)", fontweight='bold')
plt.ylabel("Survival Probability", fontweight='bold')
plt.legend()
plt.savefig('./results/{}_plan_meir.jpeg'.format(name), dpi=300)
plt.savefig('./results/{}_plan_meir.png'.format(name), dpi=300)
plt.show()
def club_two_groups(self, primary_group, secondary_group):
self.df.loc[self.df['groups'] == secondary_group, 'groups'] = primary_group
unique_groups = sorted(self.df['groups'].unique())
mapping = {old: new for new, old in enumerate(unique_groups)}
self.df['groups'] = self.df['groups'].map(mapping)
self.generate_color_list_based_on_median_survival()
self.summary_table = self.generate_summary_table()
def plot_median_survival_bar(self, name='temp_k5'):
summary_df = self.generate_summary_table()
summary_df['group'] = summary_df.index
max_val = summary_df["Median survival time"].replace(np.inf, np.nan).max()
summary_df["Display Median"] = summary_df["Median survival time"].replace(np.inf, max_val * 1.1)
summary_df = summary_df.sort_index()
colors = [self.color_list[group] for group in summary_df.index]
num_groups = len(summary_df)
plt.figure(figsize=(6, num_groups * 0.8))
plt.grid(False)
sns.barplot(data=summary_df, y='group', x="Display Median", palette=colors, orient="h", order=summary_df.index)
plt.xlabel("Median Survival Time (Months)")
plt.ylabel("Groups")
plt.title("Median Survival Time by Group")
plt.tight_layout()
plt.savefig('./results/{}_median_survival.jpeg'.format(name), dpi=300)
plt.savefig('./results/{}_median_survival.png'.format(name), dpi=300)
plt.show()