dalexanderch commited on
Commit
85f8980
1 Parent(s): 7a972f8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -6,19 +6,24 @@ import gradio as gr
6
  from glycowork.ml.processing import dataset_to_dataloader
7
  import numpy as np
8
  import torch
9
- from glycowork.glycan_data.loader import lib
10
 
11
 
12
- def fn(class_list):
13
- def f(glycan, model):
 
 
 
 
 
 
14
  if model == "No data augmentation":
15
- model = torch.load("model1.pt", map_location=torch.device('cpu'))
16
  model.eval()
17
  elif model == "Ensemble":
18
- model = torch.load("model3.pt", map_location=torch.device('cpu'))
19
  model.eval()
20
  else:
21
- model = torch.load("model2.pt", map_location=torch.device('cpu'))
22
  model.eval()
23
  glycan = [glycan]
24
  label = [0]
@@ -37,10 +42,9 @@ def fn(class_list):
37
  pred = [float(x) for x in pred]
38
  pred = {class_list[i]:pred[i] for i in range(15)}
39
  return pred
40
- return f
41
 
42
- class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae',
43
- 'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']
44
 
45
  f = fn(class_list)
46
 
 
6
  from glycowork.ml.processing import dataset_to_dataloader
7
  import numpy as np
8
  import torch
 
9
 
10
 
11
+ class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae',
12
+ 'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']
13
+
14
+ model1 = torch.load("model1.pt", map_location=torch.device('cpu'))
15
+ model2 = torch.load("model2.pt", map_location=torch.device('cpu'))
16
+ model3 = torch.load("model3.pt", map_location=torch.device('cpu'))
17
+
18
+ def fn(glycan, model):
19
  if model == "No data augmentation":
20
+ model = model1
21
  model.eval()
22
  elif model == "Ensemble":
23
+ model = model3
24
  model.eval()
25
  else:
26
+ model = model2
27
  model.eval()
28
  glycan = [glycan]
29
  label = [0]
 
42
  pred = [float(x) for x in pred]
43
  pred = {class_list[i]:pred[i] for i in range(15)}
44
  return pred
 
45
 
46
+
47
+
48
 
49
  f = fn(class_list)
50