SweetNet / app.py
dalexanderch's picture
Update app.py
e2b405a
raw
history blame
4.47 kB
import os
os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu")
os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html")
import gradio as gr
from glycowork.ml.processing import dataset_to_dataloader
import numpy as np
import torch
import torch.nn as nn
from glycowork.motif.graph import glycan_to_nxGraph
import networkx as nx
import pydot
# import pygraphviz as pgv
class EnsembleModel(nn.Module):
def __init__(self, models):
super().__init__()
self.models = models
def forward(self, data):
# Check if GPU available
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
# Prepare data
x = data.labels.to(device)
edge_index = data.edge_index.to(device)
batch = data.batch.to(device)
y_pred = [model(x,edge_index, batch).cpu().detach().numpy() for model in self.models]
y_pred = np.mean(y_pred,axis=0)[0]
return y_pred
class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae',
'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']
model1 = torch.load("model1.pt", map_location=torch.device('cpu'))
model2 = torch.load("model2.pt", map_location=torch.device('cpu'))
model3 = torch.load("model3.pt", map_location=torch.device('cpu'))
model4 = torch.load("model4.pt", map_location=torch.device('cpu'))
model5 = torch.load("model5.pt", map_location=torch.device('cpu'))
model6 = torch.load("model6.pt", map_location=torch.device('cpu'))
model7 = torch.load("model7.pt", map_location=torch.device('cpu'))
def fn(glycan, model):
# Draw graph
graph = glycan_to_nxGraph(glycan)
node_labels = nx.get_node_attributes(graph, 'string_labels')
labels = {i:node_labels[i] for i in range(len(graph.nodes))}
graph = nx.relabel_nodes(graph, labels)
graph = nx.drawing.nx_pydot.to_pydot(graph)
graph.set_prog("dot")
graph.write_png("graph.png")
# write_dot(graph, "graph.dot")
# graph=pgv.AGraph("graph.dot")
# graph.layout(prog='dot')
# graph.draw("graph.png")
# Perform inference
if model == "No data augmentation":
model_pred = model1
model_pred.eval()
elif model == "Ensemble":
model_pred = model3
model_pred.eval()
elif model == "Bootstrap ensemble":
model_pred = model4
model_pred.eval()
elif model == "Random edge deletion":
model_pred = model5
model_pred.eval()
elif model == "Hierarchy substitution":
model_pred = model6
model_pred.eval()
elif model == "Adjusted class weights":
model_pred = model7
model_pred.eval()
else:
model_pred = model2
model_pred.eval()
glycan = [glycan]
label = [0]
data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1)))
if model in ["Ensemble", "Bootstrap Ensemble"]:
pred = model_pred(data)
else:
device = "cpu"
x = data.labels
edge_index = data.edge_index
batch = data.batch
x = x.to(device)
edge_index = edge_index.to(device)
batch = batch.to(device)
pred = model_pred(x,edge_index, batch).cpu().detach().numpy()[0]
pred = np.exp(pred)/sum(np.exp(pred)) # Softmax
pred = [float(x) for x in pred]
pred = {class_list[i]:pred[i] for i in range(15)}
return pred, "graph.png"
demo = gr.Interface(
fn=fn,
inputs=[gr.Textbox(label="Glycan sequence", value="Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc"), gr.Radio(label="Model",choices=["No data augmentation", "Random node deletion", "Random edge deletion", "Ensemble", "Bootstrap ensemble", "Hierarchy substitution", "Adjusted class weights"])],
outputs=[gr.Label(num_top_classes=15, label="Prediction"), gr.Image(label="Glycan graph")],
allow_flagging="never",
title="SweetNet demo",
examples=[["Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc", "Random node deletion"],
["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN", "No data augmentation"],
["GlcNAc(b1-7)LDManHep(a1-6)Glc(a1-2)Glc(a1-3)[Gal(a1-6)]Glc(a1-3)[LDManHep(a1-7)]LDManHepOP(a1-3)LDManHepOP(a1-5)[Kdo(a2-4)]Kdo", "Ensemble"]]
)
demo.launch(debug=True)