import os os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu") os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html") import gradio as gr from glycowork.ml.processing import dataset_to_dataloader import numpy as np import torch import torch.nn as nn from glycowork.motif.graph import glycan_to_nxGraph import networkx as nx import pydot # import pygraphviz as pgv class EnsembleModel(nn.Module): def __init__(self, models): super().__init__() self.models = models def forward(self, data): # Check if GPU available device = "cpu" if torch.cuda.is_available(): device = "cuda:0" # Prepare data x = data.labels.to(device) edge_index = data.edge_index.to(device) batch = data.batch.to(device) y_pred = [model(x,edge_index, batch).cpu().detach().numpy() for model in self.models] y_pred = np.mean(y_pred,axis=0)[0] return y_pred class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae', 'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria'] model1 = torch.load("model1.pt", map_location=torch.device('cpu')) model2 = torch.load("model2.pt", map_location=torch.device('cpu')) model3 = torch.load("model3.pt", map_location=torch.device('cpu')) model4 = torch.load("model4.pt", map_location=torch.device('cpu')) model5 = torch.load("model5.pt", map_location=torch.device('cpu')) model6 = torch.load("model6.pt", map_location=torch.device('cpu')) model7 = torch.load("model7.pt", map_location=torch.device('cpu')) def fn(glycan, model): # Draw graph #graph = glycan_to_nxGraph(glycan) #node_labels = nx.get_node_attributes(graph, 'string_labels') #labels = {i:node_labels[i] for i in range(len(graph.nodes))} #graph = nx.relabel_nodes(graph, labels) #graph = nx.drawing.nx_pydot.to_pydot(graph) #graph.set_prog("dot") #graph.write_png("graph.png") # write_dot(graph, "graph.dot") # graph=pgv.AGraph("graph.dot") # graph.layout(prog='dot') # graph.draw("graph.png") # Perform inference if model == "No data augmentation": model_pred = model1 model_pred.eval() elif model == "Classical Ensemble": model_pred = model3 model_pred.eval() elif model == "Bagging ensemble": model_pred = model4 model_pred.eval() elif model == "Random edge deletion": model_pred = model5 model_pred.eval() elif model == "Hierarchy substitution": model_pred = model6 model_pred.eval() elif model == "Adjusted class weights": model_pred = model7 model_pred.eval() else: model_pred = model2 model_pred.eval() glycan = [glycan] label = [0] data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1))) if model in ["Ensemble", "Bootstrap ensemble"]: pred = model_pred(data) else: device = "cpu" x = data.labels edge_index = data.edge_index batch = data.batch x = x.to(device) edge_index = edge_index.to(device) batch = batch.to(device) pred = model_pred(x,edge_index, batch).cpu().detach().numpy()[0] pred = np.exp(pred)/sum(np.exp(pred)) # Softmax pred = [float(x) for x in pred] pred = {class_list[i]:pred[i] for i in range(15)} return pred demo = gr.Interface( fn=fn, inputs=[gr.Textbox(label="Glycan sequence", value="Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc"), gr.Radio(label="Model",choices=["No data augmentation", "Random node deletion", "Random edge deletion", "Ensemble", "Bootstrap ensemble", "Hierarchy substitution", "Adjusted class weights"])], outputs=[gr.Label(num_top_classes=15, label="Prediction")], allow_flagging="never", title="SweetNet demo", examples=[ ["D-Rha(b1-2)D-Rha(b1-2)Gal(b1-4)[Glc(b1-2)]GlcAOMe", "Random node deletion"], ["Neu5Ac(a2-3)Gal(b1-4)GlcNAc(b1-3)GalNAc", "No data augmentation"], ["Kdo(a2-4)[Kdo(a2-8)]Kdo(a2-4)Kdo", "Classical ensemble"], ["Galf(b1-6)Galf(b1-5)Galf(b1-6)Galf", "Bagging Ensemble"], ["GlcNAc(b1-2)Rha(a1-2)Rha(b1-3)Rha(a1-3)GlcNAc", "Random edge deletion"], ["Pse(b2-6)Glc(b1-6)Gal(b1-3)GalNAc(b1-3)[Glc(b1-6)]Gal(b1-3)GalNAc", "Adjusted class weights"], ] ) demo.launch(debug=True)