File size: 3,587 Bytes
ce0e9f1
 
648ca95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce0e9f1
 
 
648ca95
f0e9ca6
648ca95
 
 
f0e9ca6
648ca95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce0e9f1
 
 
 
 
 
 
 
 
648ca95
 
ce0e9f1
648ca95
ce0e9f1
648ca95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from lifelines import KaplanMeierFitter
from yellowbrick.cluster import KElbowVisualizer
from itertools import combinations
from lifelines.statistics import logrank_test
import os
import subprocess

from MultiOmicsGraphAttentionAutoencoderModel import MultiOmicsGraphAttentionAutoencoderModel
from OmicsConfig import OmicsConfig
from Attention_Extracter import Attention_Extracter
from GraphAnalysis import GraphAnalysis

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the autoencoder model
autoencoder_config = OmicsConfig.from_pretrained("./lc_models/MultiOmicsAutoencoder/trained_autoencoder")
autoencoder_model = MultiOmicsGraphAttentionAutoencoderModel(autoencoder_config).to(device)

# Initialize Attention Extracter
graph_data_dict_path = './data/hnscc.patient.chg.network.pth'
extracter = Attention_Extracter(graph_data_dict_path, autoencoder_model.encoder, gpu=(device == "cuda"))

def extract_features():
    ga = GraphAnalysis(extracter)
    return ga

def find_optimal_clusters(ga, min_clusters, max_clusters):
    ga.find_optimal_clusters(min_clusters=min_clusters, max_clusters=max_clusters, save_path='/workspace/MultiOmics-Graph-Attention-Autoencoder/temp')
    return ga.optimal_clusters

def perform_clustering(ga, num_clusters):
    ga.cluster_data2(num_clusters)
    return "Clustering completed."

def plot_kaplan_meier(ga):
    ga.plot_kaplan_meier()
    return "Kaplan-Meier plot saved."

def plot_median_survival_bar(ga):
    ga.plot_median_survival_bar(name='temp')
    return "Median survival bar plot saved."

def perform_log_rank_test(ga):
    significant_pairs = ga.perform_log_rank_test()
    return f"Significant pairs from log-rank test: {significant_pairs}"

css = """
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""
        # Graph to Features and Analysis
        Currently running on {device}.
        """)

        with gr.Row():
            extract_button = gr.Button("Extract Features")
            clustering_button = gr.Button("Find Optimal Clusters")
            cluster_data_button = gr.Button("Perform Clustering")
            kaplan_meier_button = gr.Button("Plot Kaplan-Meier")
            survival_bar_button = gr.Button("Plot Median Survival Bar")
            log_rank_button = gr.Button("Perform Log-Rank Test")

        num_clusters = gr.Slider(label="Number of Clusters", minimum=2, maximum=10, step=1, value=5)
        min_clusters = gr.Slider(label="Min Clusters for Elbow Method", minimum=2, maximum=10, step=1, value=2)
        max_clusters = gr.Slider(label="Max Clusters for Elbow Method", minimum=3, maximum=20, step=1, value=10)

        result = gr.Textbox(label="Result")

    ga = gr.State()

    extract_button.click(fn=extract_features, inputs=[], outputs=[ga])
    clustering_button.click(fn=find_optimal_clusters, inputs=[ga, min_clusters, max_clusters], outputs=[result])
    cluster_data_button.click(fn=perform_clustering, inputs=[ga, num_clusters], outputs=[result])
    kaplan_meier_button.click(fn=plot_kaplan_meier, inputs=[ga], outputs=[result])
    survival_bar_button.click(fn=plot_median_survival_bar, inputs=[ga], outputs=[result])
    log_rank_button.click(fn=perform_log_rank_test, inputs=[ga], outputs=[result])

demo.queue().launch()