dalexanderch commited on
Commit
1a7661f
1 Parent(s): 9ab9535

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -1,9 +1,6 @@
1
  import os
2
  os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu")
3
  os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html")
4
- # os.system("apt-get install -y graphviz-dev")
5
- # os.system("pip install pygraphviz")
6
-
7
  import gradio as gr
8
  from glycowork.ml.processing import dataset_to_dataloader
9
  import numpy as np
@@ -12,6 +9,9 @@ import torch.nn as nn
12
  from glycowork.motif.graph import glycan_to_nxGraph
13
  import networkx as nx
14
  import matplotlib.pyplot as plt
 
 
 
15
 
16
  class EnsembleModel(nn.Module):
17
  def __init__(self, models):
@@ -44,8 +44,10 @@ def fn(glycan, model):
44
  node_labels = nx.get_node_attributes(graph, 'string_labels')
45
  labels = {i:node_labels[i] for i in range(len(graph.nodes))}
46
  graph = nx.relabel_nodes(graph, labels)
47
- nx.draw(graph, with_labels=True)
48
- plt.savefig("graph.png")
 
 
49
  # Perform inference
50
  if model == "No data augmentation":
51
  model_pred = model1
 
1
  import os
2
  os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu")
3
  os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html")
 
 
 
4
  import gradio as gr
5
  from glycowork.ml.processing import dataset_to_dataloader
6
  import numpy as np
 
9
  from glycowork.motif.graph import glycan_to_nxGraph
10
  import networkx as nx
11
  import matplotlib.pyplot as plt
12
+ from networkx.drawing.nx_agraph import write_dot
13
+ import pygraphviz as pgv
14
+
15
 
16
  class EnsembleModel(nn.Module):
17
  def __init__(self, models):
 
44
  node_labels = nx.get_node_attributes(graph, 'string_labels')
45
  labels = {i:node_labels[i] for i in range(len(graph.nodes))}
46
  graph = nx.relabel_nodes(graph, labels)
47
+ write_dot(graph, "graph.dot")
48
+ graph=pgv.AGraph("graph.dot")
49
+ graph.layout(prog='dot')
50
+ graph.draw("graph.png")
51
  # Perform inference
52
  if model == "No data augmentation":
53
  model_pred = model1