Spaces:
Runtime error
Runtime error
File size: 4,573 Bytes
c20f071 44c8341 50edbe9 6506504 b7d6d94 633bc62 1a7661f 44c8341 50edbe9 85f8980 8a9a147 4e2f6bd bcf93fd b1533be 85f8980 840fdaa d54895f 555d33d 840fdaa 0fbae15 fc829e4 af8a4b3 fc829e4 af8a4b3 8a9a147 4e2f6bd 6225750 4e2f6bd 86f541d a376bf7 b1533be 0fbae15 fc829e4 44c8341 8b25912 47aa6b1 7ec6647 fc829e4 47aa6b1 8b25912 d54895f 44c8341 85f8980 44c8341 2b584be e2b405a d54895f 38d614d 44c8341 e8018b1 e9847e4 af8a4b3 e9847e4 e8018b1 44c8341 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import os
os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu")
os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html")
import gradio as gr
from glycowork.ml.processing import dataset_to_dataloader
import numpy as np
import torch
import torch.nn as nn
from glycowork.motif.graph import glycan_to_nxGraph
import networkx as nx
import pydot
# import pygraphviz as pgv
class EnsembleModel(nn.Module):
def __init__(self, models):
super().__init__()
self.models = models
def forward(self, data):
# Check if GPU available
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
# Prepare data
x = data.labels.to(device)
edge_index = data.edge_index.to(device)
batch = data.batch.to(device)
y_pred = [model(x,edge_index, batch).cpu().detach().numpy() for model in self.models]
y_pred = np.mean(y_pred,axis=0)[0]
return y_pred
class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae',
'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']
model1 = torch.load("model1.pt", map_location=torch.device('cpu'))
model2 = torch.load("model2.pt", map_location=torch.device('cpu'))
model3 = torch.load("model3.pt", map_location=torch.device('cpu'))
model4 = torch.load("model4.pt", map_location=torch.device('cpu'))
model5 = torch.load("model5.pt", map_location=torch.device('cpu'))
model6 = torch.load("model6.pt", map_location=torch.device('cpu'))
model7 = torch.load("model7.pt", map_location=torch.device('cpu'))
def fn(glycan, model):
# Draw graph
#graph = glycan_to_nxGraph(glycan)
#node_labels = nx.get_node_attributes(graph, 'string_labels')
#labels = {i:node_labels[i] for i in range(len(graph.nodes))}
#graph = nx.relabel_nodes(graph, labels)
#graph = nx.drawing.nx_pydot.to_pydot(graph)
#graph.set_prog("dot")
#graph.write_png("graph.png")
# write_dot(graph, "graph.dot")
# graph=pgv.AGraph("graph.dot")
# graph.layout(prog='dot')
# graph.draw("graph.png")
# Perform inference
if model == "No data augmentation":
model_pred = model1
model_pred.eval()
elif model == "Classical Ensemble":
model_pred = model3
model_pred.eval()
elif model == "Bagging ensemble":
model_pred = model4
model_pred.eval()
elif model == "Random edge deletion":
model_pred = model5
model_pred.eval()
elif model == "Hierarchy substitution":
model_pred = model6
model_pred.eval()
elif model == "Adjusted class weights":
model_pred = model7
model_pred.eval()
else:
model_pred = model2
model_pred.eval()
glycan = [glycan]
label = [0]
data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1)))
if model in ["Ensemble", "Bootstrap ensemble"]:
pred = model_pred(data)
else:
device = "cpu"
x = data.labels
edge_index = data.edge_index
batch = data.batch
x = x.to(device)
edge_index = edge_index.to(device)
batch = batch.to(device)
pred = model_pred(x,edge_index, batch).cpu().detach().numpy()[0]
pred = np.exp(pred)/sum(np.exp(pred)) # Softmax
pred = [float(x) for x in pred]
pred = {class_list[i]:pred[i] for i in range(15)}
return pred
demo = gr.Interface(
fn=fn,
inputs=[gr.Textbox(label="Glycan sequence", value="Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc"), gr.Radio(label="Model",choices=["No data augmentation", "Random node deletion", "Random edge deletion", "Ensemble", "Bootstrap ensemble", "Hierarchy substitution", "Adjusted class weights"])],
outputs=[gr.Label(num_top_classes=15, label="Prediction")],
allow_flagging="never",
title="SweetNet demo",
examples=[
["D-Rha(b1-2)D-Rha(b1-2)Gal(b1-4)[Glc(b1-2)]GlcAOMe", "Random node deletion"],
["Neu5Ac(a2-3)Gal(b1-4)GlcNAc(b1-3)GalNAc", "No data augmentation"],
["Kdo(a2-4)[Kdo(a2-8)]Kdo(a2-4)Kdo", "Classical ensemble"],
["Galf(b1-6)Galf(b1-5)Galf(b1-6)Galf", "Bagging Ensemble"],
["GlcNAc(b1-2)Rha(a1-2)Rha(b1-3)Rha(a1-3)GlcNAc", "Random edge deletion"],
["Pse(b2-6)Glc(b1-6)Gal(b1-3)GalNAc(b1-3)[Glc(b1-6)]Gal(b1-3)GalNAc", "Adjusted class weights"],
]
)
demo.launch(debug=True) |