import gradio as gr import torch import numpy as np import pandas as pd import seaborn as sns import matplotlib.pyplot as plt from sklearn.decomposition import PCA from sklearn.manifold import TSNE from sklearn.cluster import KMeans from lifelines import KaplanMeierFitter from yellowbrick.cluster import KElbowVisualizer from itertools import combinations from lifelines.statistics import logrank_test import os import subprocess from MultiOmicsGraphAttentionAutoencoderModel import MultiOmicsGraphAttentionAutoencoderModel from OmicsConfig import OmicsConfig from Attention_Extracter import Attention_Extracter from GraphAnalysis import GraphAnalysis device = "cuda" if torch.cuda.is_available() else "cpu" # Load the autoencoder model autoencoder_config = OmicsConfig.from_pretrained("./lc_models/MultiOmicsAutoencoder/trained_autoencoder") autoencoder_model = MultiOmicsGraphAttentionAutoencoderModel(autoencoder_config).to(device) # Initialize Attention Extracter graph_data_dict_path = './data/hnscc.patient.chg.network.pth' extracter = Attention_Extracter(graph_data_dict_path, autoencoder_model.encoder, gpu=(device == "cuda")) def extract_features(): ga = GraphAnalysis(extracter) return ga def find_optimal_clusters(ga, min_clusters, max_clusters): ga.find_optimal_clusters(min_clusters=min_clusters, max_clusters=max_clusters, save_path='/workspace/MultiOmics-Graph-Attention-Autoencoder/temp') return ga.optimal_clusters def perform_clustering(ga, num_clusters): ga.cluster_data2(num_clusters) return "Clustering completed." def plot_kaplan_meier(ga): ga.plot_kaplan_meier() return "Kaplan-Meier plot saved." def plot_median_survival_bar(ga): ga.plot_median_survival_bar(name='temp') return "Median survival bar plot saved." def perform_log_rank_test(ga): significant_pairs = ga.perform_log_rank_test() return f"Significant pairs from log-rank test: {significant_pairs}" css = """ #col-container { margin: 0 auto; max-width: 520px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(f""" # Graph to Features and Analysis Currently running on {device}. """) with gr.Row(): extract_button = gr.Button("Extract Features") clustering_button = gr.Button("Find Optimal Clusters") cluster_data_button = gr.Button("Perform Clustering") kaplan_meier_button = gr.Button("Plot Kaplan-Meier") survival_bar_button = gr.Button("Plot Median Survival Bar") log_rank_button = gr.Button("Perform Log-Rank Test") num_clusters = gr.Slider(label="Number of Clusters", minimum=2, maximum=10, step=1, value=5) min_clusters = gr.Slider(label="Min Clusters for Elbow Method", minimum=2, maximum=10, step=1, value=2) max_clusters = gr.Slider(label="Max Clusters for Elbow Method", minimum=3, maximum=20, step=1, value=10) result = gr.Textbox(label="Result") ga = gr.State() extract_button.click(fn=extract_features, inputs=[], outputs=[ga]) clustering_button.click(fn=find_optimal_clusters, inputs=[ga, min_clusters, max_clusters], outputs=[result]) cluster_data_button.click(fn=perform_clustering, inputs=[ga, num_clusters], outputs=[result]) kaplan_meier_button.click(fn=plot_kaplan_meier, inputs=[ga], outputs=[result]) survival_bar_button.click(fn=plot_median_survival_bar, inputs=[ga], outputs=[result]) log_rank_button.click(fn=perform_log_rank_test, inputs=[ga], outputs=[result]) demo.queue().launch()