VatsalPatel18's picture
Model files
f0e9ca6
raw
history blame
3.59 kB
import gradio as gr
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from lifelines import KaplanMeierFitter
from yellowbrick.cluster import KElbowVisualizer
from itertools import combinations
from lifelines.statistics import logrank_test
import os
import subprocess
from MultiOmicsGraphAttentionAutoencoderModel import MultiOmicsGraphAttentionAutoencoderModel
from OmicsConfig import OmicsConfig
from Attention_Extracter import Attention_Extracter
from GraphAnalysis import GraphAnalysis
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the autoencoder model
autoencoder_config = OmicsConfig.from_pretrained("./lc_models/MultiOmicsAutoencoder/trained_autoencoder")
autoencoder_model = MultiOmicsGraphAttentionAutoencoderModel(autoencoder_config).to(device)
# Initialize Attention Extracter
graph_data_dict_path = './data/hnscc.patient.chg.network.pth'
extracter = Attention_Extracter(graph_data_dict_path, autoencoder_model.encoder, gpu=(device == "cuda"))
def extract_features():
ga = GraphAnalysis(extracter)
return ga
def find_optimal_clusters(ga, min_clusters, max_clusters):
ga.find_optimal_clusters(min_clusters=min_clusters, max_clusters=max_clusters, save_path='/workspace/MultiOmics-Graph-Attention-Autoencoder/temp')
return ga.optimal_clusters
def perform_clustering(ga, num_clusters):
ga.cluster_data2(num_clusters)
return "Clustering completed."
def plot_kaplan_meier(ga):
ga.plot_kaplan_meier()
return "Kaplan-Meier plot saved."
def plot_median_survival_bar(ga):
ga.plot_median_survival_bar(name='temp')
return "Median survival bar plot saved."
def perform_log_rank_test(ga):
significant_pairs = ga.perform_log_rank_test()
return f"Significant pairs from log-rank test: {significant_pairs}"
css = """
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# Graph to Features and Analysis
Currently running on {device}.
""")
with gr.Row():
extract_button = gr.Button("Extract Features")
clustering_button = gr.Button("Find Optimal Clusters")
cluster_data_button = gr.Button("Perform Clustering")
kaplan_meier_button = gr.Button("Plot Kaplan-Meier")
survival_bar_button = gr.Button("Plot Median Survival Bar")
log_rank_button = gr.Button("Perform Log-Rank Test")
num_clusters = gr.Slider(label="Number of Clusters", minimum=2, maximum=10, step=1, value=5)
min_clusters = gr.Slider(label="Min Clusters for Elbow Method", minimum=2, maximum=10, step=1, value=2)
max_clusters = gr.Slider(label="Max Clusters for Elbow Method", minimum=3, maximum=20, step=1, value=10)
result = gr.Textbox(label="Result")
ga = gr.State()
extract_button.click(fn=extract_features, inputs=[], outputs=[ga])
clustering_button.click(fn=find_optimal_clusters, inputs=[ga, min_clusters, max_clusters], outputs=[result])
cluster_data_button.click(fn=perform_clustering, inputs=[ga, num_clusters], outputs=[result])
kaplan_meier_button.click(fn=plot_kaplan_meier, inputs=[ga], outputs=[result])
survival_bar_button.click(fn=plot_median_survival_bar, inputs=[ga], outputs=[result])
log_rank_button.click(fn=perform_log_rank_test, inputs=[ga], outputs=[result])
demo.queue().launch()