Spaces:
Runtime error
Runtime error
File size: 4,091 Bytes
c20f071 44c8341 50edbe9 6506504 b7d6d94 633bc62 1a7661f 44c8341 50edbe9 85f8980 8a9a147 4e2f6bd 85f8980 840fdaa b7d6d94 9967649 b7d6d94 555d33d 840fdaa 0fbae15 fc829e4 cb13d0d fc829e4 8a9a147 4e2f6bd 6225750 4e2f6bd 0fbae15 fc829e4 44c8341 8b25912 47aa6b1 853a85f fc829e4 47aa6b1 8b25912 6585887 44c8341 85f8980 44c8341 2b584be 4e2f6bd aec7d72 38d614d 44c8341 713d064 6eb2572 44c8341 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import os
os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu")
os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html")
import gradio as gr
from glycowork.ml.processing import dataset_to_dataloader
import numpy as np
import torch
import torch.nn as nn
from glycowork.motif.graph import glycan_to_nxGraph
import networkx as nx
import pydot
# import pygraphviz as pgv
class EnsembleModel(nn.Module):
def __init__(self, models):
super().__init__()
self.models = models
def forward(self, data):
# Check if GPU available
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
# Prepare data
x = data.labels.to(device)
edge_index = data.edge_index.to(device)
batch = data.batch.to(device)
y_pred = [model(x,edge_index, batch).cpu().detach().numpy() for model in self.models]
y_pred = np.mean(y_pred,axis=0)[0]
return y_pred
class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae',
'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']
model1 = torch.load("model1.pt", map_location=torch.device('cpu'))
model2 = torch.load("model2.pt", map_location=torch.device('cpu'))
model3 = torch.load("model3.pt", map_location=torch.device('cpu'))
model4 = torch.load("model4.pt", map_location=torch.device('cpu'))
model5 = torch.load("model5.pt", map_location=torch.device('cpu'))
def fn(glycan, model):
# Draw graph
graph = glycan_to_nxGraph(glycan)
node_labels = nx.get_node_attributes(graph, 'string_labels')
labels = {i:node_labels[i] for i in range(len(graph.nodes))}
graph = nx.relabel_nodes(graph, labels)
graph = nx.drawing.nx_pydot.to_pydot(graph)
graph.set_prog("dot")
graph.write_png("graph.png")
# write_dot(graph, "graph.dot")
# graph=pgv.AGraph("graph.dot")
# graph.layout(prog='dot')
# graph.draw("graph.png")
# Perform inference
if model == "No data augmentation":
model_pred = model1
model_pred.eval()
elif model == "Ensemble":
model_pred = model3
model_pred.eval()
elif model == "Bootstrap Ensemble":
model_pred = model4
model_pred.eval()
elif model == "Random edge deletion":
model_pred = model5
model_pred.eval()
else:
model_pred = model2
model_pred.eval()
glycan = [glycan]
label = [0]
data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1)))
if model in ["Ensemble", "Bootstrap Ensemble"]:
pred = model_pred(data)
else:
device = "cpu"
x = data.labels
edge_index = data.edge_index
batch = data.batch
x = x.to(device)
edge_index = edge_index.to(device)
batch = batch.to(device)
pred = model_pred(x,edge_index, batch).cpu().detach().numpy()[0]
pred = np.exp(pred)/sum(np.exp(pred)) # Softmax
pred = [float(x) for x in pred]
pred = {class_list[i]:pred[i] for i in range(15)}
return pred, "graph.png"
demo = gr.Interface(
fn=fn,
inputs=[gr.Textbox(label="Glycan sequence", value="Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc"), gr.Radio(label="Model",choices=["No data augmentation", "Random node deletion", "Random edge deletion", "Ensemble", "Bootstrap Ensemble"])],
outputs=[gr.Label(num_top_classes=15, label="Prediction"), gr.Image(label="Glycan graph")],
allow_flagging="never",
title="SweetNet demo",
examples=[["Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc", "Random node deletion"],
["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN", "No data augmentation"],
["GlcNAc(b1-7)LDManHep(a1-6)Glc(a1-2)Glc(a1-3)[Gal(a1-6)]Glc(a1-3)[LDManHep(a1-7)]LDManHepOP(a1-3)LDManHepOP(a1-5)[Kdo(a2-4)]Kdo", "Ensemble"]]
)
demo.launch(debug=True) |