File size: 8,471 Bytes
c238491
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import numpy as np
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from lifelines.statistics import logrank_test
from itertools import combinations
import matplotlib.pyplot as plt
from yellowbrick.cluster import KElbowVisualizer
import pandas as pd
import seaborn as sns
from lifelines import KaplanMeierFitter
import matplotlib.cm as cm
import itertools
import torch

class GraphAnalysis:
    def __init__(self, EXTRACTER):
        self.extracter = EXTRACTER
        self.process()
        
    def process(self):
        latent_features_list = list(self.extracter.latent_feat_dict.values())
        patient_list = list(self.extracter.latent_feat_dict.keys())
        latentF = torch.stack(latent_features_list, dim=0)
        self.latentF = np.squeeze(latentF.numpy())
        self.pIDs = patient_list
        self.df = pd.DataFrame(columns=['PC1','PC2','tX','tY','groups'], index=self.pIDs)
        self.clnc_df = pd.read_csv('./data/survival.hnsc_data.csv').set_index('PatientID')
        self.df = self.df.join(self.clnc_df)

    def pca_tsne(self):
        pca = PCA(n_components=2)
        X_pca = pca.fit_transform(self.latentF)
        self.df['PC1'] = X_pca[:,0]
        self.df['PC2'] = X_pca[:,1]
        tsne = TSNE(n_components=2)
        X_tsne = tsne.fit_transform(self.latentF)
        self.df['tX'] = X_tsne[:,0]
        self.df['tY'] = X_tsne[:,1]
        
    def find_optimal_clusters(self, min_clusters=2, max_clusters=11, save_path='./results/kelbow'):
        model = KMeans(random_state=42)
        visualizer = KElbowVisualizer(model, k=(min_clusters, max_clusters))
        visualizer.fit(self.latentF)
        visualizer.show()
        fig = visualizer.ax.get_figure()
        fig.savefig(save_path + ".png", dpi=150)
        fig.savefig(save_path + ".jpeg", format="jpeg", dpi=150)
        self.optimal_clusters = visualizer.elbow_value_

    def cluster_data(self):
        if self.optimal_clusters is None:
            raise ValueError("Please run 'find_optimal_clusters' method before clustering the data.")
        kmeans = KMeans(n_clusters=self.optimal_clusters, random_state=0).fit(self.latentF)
        self.labels = kmeans.labels_
        self.df['groups'] = self.labels 
        self.generate_color_list_based_on_median_survival()
        
    def cluster_data2(self, kclust):
        kmeans = KMeans(n_clusters=kclust, random_state=0).fit(self.latentF)
        self.labels = kmeans.labels_
        self.df['groups'] = self.labels 
        self.generate_color_list_based_on_median_survival()

    def visualize_clusters(self):
        plt.figure(figsize=(20,8))
        plt.subplot(1,2,1)
        sns.scatterplot(data=self.df, x='PC1', y='PC2', hue='groups', palette=self.color_list)
        plt.subplot(1,2,2)
        sns.scatterplot(data=self.df, x='tX', y='tY', hue='groups', palette=self.color_list)
        
    def save_visualize_clusters(self):
        plt.figure(figsize=(10,8))
        sns.scatterplot(data=self.df, x='PC1', y='PC2', hue='groups', palette=self.color_list)
        plt.savefig('./results/temp_pca.jpeg', dpi=300)
        plt.savefig('./results/temp_pca.png', dpi=300)
        plt.close()
        plt.figure(figsize=(10,8))
        sns.scatterplot(data=self.df, x='tX', y='tY', hue='groups', palette=self.color_list)
        plt.savefig('./results/temp_tsne.jpeg', dpi=300)
        plt.savefig('./results/temp_tsne.png', dpi=300)

    def map_group_to_color(group):
        return self.color_list[group]
        
    def generate_color_list_based_on_median_survival(self):
        groups = self.df['groups'].unique()
        median_survival_times = {group: self.df[self.df['groups'] == group]['Overall Survival (Months)'].median() for group in groups}
        sorted_groups = sorted(groups, key=median_survival_times.get, reverse=True)
        vibgyor_colors = cm.rainbow(np.linspace(0, 1, len(groups)))
        self.color_list = {group: color for group, color in zip(sorted_groups, vibgyor_colors)}
        
    def perform_log_rank_test(self, alpha=0.05):
        if self.df is None:
            raise ValueError("Please run 'cluster_data' or 'cluster_data2' method before performing log rank test.")
        groups = self.df['groups'].unique()
        significant_pairs = []
        for pair in itertools.combinations(groups, 2):
            group_a = self.df[self.df['groups'] == pair[0]]
            group_b = self.df[self.df['groups'] == pair[1]]
            results = logrank_test(group_a['Overall Survival (Months)'], group_b['Overall Survival (Months)'], group_a['Overall Survival Status'], group_b['Overall Survival Status'])
            if results.p_value < alpha:
                significant_pairs.append(pair)
        self.significant_pairs = significant_pairs
        return self.significant_pairs
    
    def generate_summary_table(self):
        groups = self.df['groups'].unique()
        summary_table = pd.DataFrame(columns=['Total number of patients', 'Alive', 'Deceased', 'Median survival time'], index=groups)
        for group in groups:
            group_data = self.df[self.df['groups'] == group]
            total_patients = len(group_data)
            alive = len(group_data[group_data['Overall Survival Status'] == 0])
            deceased = len(group_data[group_data['Overall Survival Status'] == 1])
            kmf = KaplanMeierFitter()
            kmf.fit(group_data['Overall Survival (Months)'], group_data['Overall Survival Status'])
            median_survival_time = kmf.median_survival_time_
            summary_table.loc[group] = [total_patients, alive, deceased, median_survival_time]
        return summary_table
    
    def plot_kaplan_meier(self, plot_for_groups=True, name='temp_k5'):
        kmf = KaplanMeierFitter()
        plt.figure(figsize=(8, 6))
        plt.grid(False)
        if plot_for_groups:
            groups = sorted(self.df['groups'].unique())
            for i, group in enumerate(groups):
                group_data = self.df[self.df['groups'] == group]
                kmf.fit(group_data['Overall Survival (Months)'], group_data['Overall Survival Status'], label=f'Group {group}')
                kmf.plot(ci_show=False, linewidth=2, color=self.color_list[group])
            plt.title("Kaplan-Meier Curves for Each Group")
        else:
            kmf.fit(self.df['Overall Survival (Months)'], self.df['Overall Survival Status'], label='All Data')
            kmf.plot(ci_show=False, linewidth=2, color='black')
            plt.title("Kaplan-Meier Curve for All Data")
        plt.gca().set_facecolor('#f5f5f5')
        plt.grid(color='lightgrey', linestyle='-', linewidth=0.5)
        plt.xlabel("Overall Survival (Months)", fontweight='bold')
        plt.ylabel("Survival Probability", fontweight='bold')
        plt.legend()
        plt.savefig('./results/{}_plan_meir.jpeg'.format(name), dpi=300)
        plt.savefig('./results/{}_plan_meir.png'.format(name), dpi=300)
        plt.show()
        
    def club_two_groups(self, primary_group, secondary_group):
        self.df.loc[self.df['groups'] == secondary_group, 'groups'] = primary_group
        unique_groups = sorted(self.df['groups'].unique())
        mapping = {old: new for new, old in enumerate(unique_groups)}
        self.df['groups'] = self.df['groups'].map(mapping)
        self.generate_color_list_based_on_median_survival()
        self.summary_table = self.generate_summary_table()

    def plot_median_survival_bar(self, name='temp_k5'):
        summary_df = self.generate_summary_table()
        summary_df['group'] = summary_df.index
        max_val = summary_df["Median survival time"].replace(np.inf, np.nan).max()
        summary_df["Display Median"] = summary_df["Median survival time"].replace(np.inf, max_val * 1.1)
        summary_df = summary_df.sort_index()
        colors = [self.color_list[group] for group in summary_df.index]
        num_groups = len(summary_df)
        plt.figure(figsize=(6, num_groups * 0.8))
        plt.grid(False)
        sns.barplot(data=summary_df, y='group', x="Display Median", palette=colors, orient="h", order=summary_df.index)
        plt.xlabel("Median Survival Time (Months)")
        plt.ylabel("Groups")
        plt.title("Median Survival Time by Group")
        plt.tight_layout()
        plt.savefig('./results/{}_median_survival.jpeg'.format(name), dpi=300)
        plt.savefig('./results/{}_median_survival.png'.format(name), dpi=300)
        plt.show()