File size: 4,557 Bytes
c20f071
 
 
44c8341
 
 
 
50edbe9
6506504
 
b7d6d94
633bc62
1a7661f
44c8341
50edbe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85f8980
 
 
 
 
 
8a9a147
4e2f6bd
bcf93fd
b1533be
85f8980
 
840fdaa
d54895f
 
 
 
 
 
 
555d33d
 
 
 
840fdaa
0fbae15
fc829e4
 
cb13d0d
fc829e4
 
e2b405a
8a9a147
 
4e2f6bd
6225750
4e2f6bd
86f541d
a376bf7
 
b1533be
 
 
0fbae15
fc829e4
 
 
44c8341
 
8b25912
47aa6b1
7ec6647
fc829e4
 
 
 
 
 
 
 
 
 
47aa6b1
8b25912
 
 
d54895f
44c8341
85f8980
44c8341
2b584be
e2b405a
d54895f
38d614d
44c8341
e8018b1
e9847e4
 
 
 
 
 
e8018b1
44c8341
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu")
os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html")
import gradio as gr
from glycowork.ml.processing import dataset_to_dataloader
import numpy as np
import torch
import torch.nn as nn
from glycowork.motif.graph import glycan_to_nxGraph
import networkx as nx
import pydot
# import pygraphviz as pgv


class EnsembleModel(nn.Module):
    def __init__(self, models):
        super().__init__()
        self.models = models
        
    def forward(self, data):
      # Check if GPU available
      device = "cpu"
      if torch.cuda.is_available():
        device = "cuda:0"
      # Prepare data
      x = data.labels.to(device)
      edge_index = data.edge_index.to(device)
      batch = data.batch.to(device)
      y_pred = [model(x,edge_index, batch).cpu().detach().numpy() for model in self.models]
      y_pred = np.mean(y_pred,axis=0)[0]
      return y_pred
  
class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae', 
            'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']

model1 = torch.load("model1.pt", map_location=torch.device('cpu'))
model2 = torch.load("model2.pt", map_location=torch.device('cpu'))
model3 = torch.load("model3.pt", map_location=torch.device('cpu'))
model4 = torch.load("model4.pt", map_location=torch.device('cpu'))
model5 = torch.load("model5.pt", map_location=torch.device('cpu'))
model6 = torch.load("model6.pt", map_location=torch.device('cpu'))
model7 = torch.load("model7.pt", map_location=torch.device('cpu'))

def fn(glycan, model):
    # Draw graph
    #graph = glycan_to_nxGraph(glycan)
    #node_labels = nx.get_node_attributes(graph, 'string_labels')
    #labels = {i:node_labels[i] for i in range(len(graph.nodes))}
    #graph = nx.relabel_nodes(graph, labels)
    #graph = nx.drawing.nx_pydot.to_pydot(graph)
    #graph.set_prog("dot")
    #graph.write_png("graph.png")
    # write_dot(graph, "graph.dot")
    # graph=pgv.AGraph("graph.dot")  
    # graph.layout(prog='dot')
    # graph.draw("graph.png")
    # Perform inference
    if model == "No data augmentation":
      model_pred = model1
      model_pred.eval()
    elif model == "Ensemble":
      model_pred = model3
      model_pred.eval()
    elif model == "Bootstrap ensemble":
      model_pred = model4
      model_pred.eval()
    elif model == "Random edge deletion":
      model_pred = model5
      model_pred.eval()
    elif model == "Hierarchy substitution":
      model_pred = model6
      model_pred.eval()
    elif model == "Adjusted class weights":
      model_pred = model7
      model_pred.eval()        
    else:
      model_pred = model2
      model_pred.eval()
    
    glycan = [glycan]
    label = [0]
    data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1)))
    
    if model in ["Ensemble", "Bootstrap ensemble"]:
        pred = model_pred(data)
    else:
        device = "cpu"
        x = data.labels
        edge_index = data.edge_index
        batch = data.batch
        x = x.to(device)
        edge_index = edge_index.to(device)
        batch = batch.to(device)
        pred = model_pred(x,edge_index, batch).cpu().detach().numpy()[0]
    
    pred = np.exp(pred)/sum(np.exp(pred)) # Softmax 
    pred = [float(x) for x in pred]
    pred = {class_list[i]:pred[i] for i in range(15)}
    return pred


demo = gr.Interface(
    fn=fn,
    inputs=[gr.Textbox(label="Glycan sequence", value="Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc"), gr.Radio(label="Model",choices=["No data augmentation", "Random node deletion", "Random edge deletion", "Ensemble", "Bootstrap ensemble", "Hierarchy substitution", "Adjusted class weights"])],
    outputs=[gr.Label(num_top_classes=15, label="Prediction")],
    allow_flagging="never",
    title="SweetNet demo",
    examples=[
    ["D-Rha(b1-2)D-Rha(b1-2)Gal(b1-4)[Glc(b1-2)]GlcAOMe", "Random node deletion"],
    ["Neu5Ac(a2-3)Gal(b1-4)GlcNAc(b1-3)GalNAc", "No data augmentation"],
    ["Kdo(a2-4)[Kdo(a2-8)]Kdo(a2-4)Kdo", "Ensemble"],
    ["Galf(b1-6)Galf(b1-5)Galf(b1-6)Galf", "Bootstrap Ensemble"],
    ["GlcNAc(b1-2)Rha(a1-2)Rha(b1-3)Rha(a1-3)GlcNAc", "Random edge deletion"],
    ["Pse(b2-6)Glc(b1-6)Gal(b1-3)GalNAc(b1-3)[Glc(b1-6)]Gal(b1-3)GalNAc", "Adjusted class weights"],
    ]
)
demo.launch(debug=True)