ndhieunguyen commited on
Commit
23a7a4b
1 Parent(s): 22c5f0f

feat: add device

Browse files
Files changed (1) hide show
  1. app.py +12 -17
app.py CHANGED
@@ -1,26 +1,18 @@
1
  import torch
2
- import argparse
3
  import selfies as sf
4
- from tqdm import tqdm
5
  from transformers import T5EncoderModel
6
- from transformers import set_seed
7
  from src.scripts.mytokenizers import Tokenizer
8
  from src.improved_diffusion import gaussian_diffusion as gd
9
- from src.improved_diffusion import dist_util, logger
10
  from src.improved_diffusion.respace import SpacedDiffusion
11
  from src.improved_diffusion.transformer_model import TransformerNetModel
12
- from src.improved_diffusion.script_util import (
13
- model_and_diffusion_defaults,
14
- add_dict_to_argparser,
15
- )
16
- from src.scripts.mydatasets import Lang2molDataset_submission
17
  import streamlit as st
18
  import os
19
 
20
 
21
  @st.cache_resource
22
- def get_encoder():
23
  model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
 
24
  model.eval()
25
  return model
26
 
@@ -31,7 +23,7 @@ def get_tokenizer():
31
 
32
 
33
  @st.cache_resource
34
- def get_model():
35
  model = TransformerNetModel(
36
  in_channels=32,
37
  model_channels=128,
@@ -44,9 +36,10 @@ def get_model():
44
  model.load_state_dict(
45
  torch.load(
46
  os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"),
47
- map_location=torch.device("cpu"),
48
  )
49
  )
 
50
  model.eval()
51
  return model
52
 
@@ -65,9 +58,11 @@ def get_diffusion():
65
  )
66
 
67
 
 
 
68
  tokenizer = get_tokenizer()
69
- encoder = get_encoder()
70
- model = get_model()
71
  diffusion = get_diffusion()
72
 
73
  st.title("Lang2mol-Diff")
@@ -85,8 +80,8 @@ if button:
85
  return_attention_mask=True,
86
  )
87
  caption_state = encoder(
88
- input_ids=output["input_ids"],
89
- attention_mask=output["attention_mask"],
90
  ).last_hidden_state
91
  caption_mask = output["attention_mask"]
92
 
@@ -98,7 +93,7 @@ if button:
98
  model_kwargs={},
99
  top_p=1.0,
100
  progress=True,
101
- caption=(caption_state, caption_mask),
102
  )
103
  logits = model.get_logits(torch.tensor(outputs))
104
  cands = torch.topk(logits, k=1, dim=-1)
 
1
  import torch
 
2
  import selfies as sf
 
3
  from transformers import T5EncoderModel
 
4
  from src.scripts.mytokenizers import Tokenizer
5
  from src.improved_diffusion import gaussian_diffusion as gd
 
6
  from src.improved_diffusion.respace import SpacedDiffusion
7
  from src.improved_diffusion.transformer_model import TransformerNetModel
 
 
 
 
 
8
  import streamlit as st
9
  import os
10
 
11
 
12
  @st.cache_resource
13
+ def get_encoder(device):
14
  model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
15
+ model.to(device)
16
  model.eval()
17
  return model
18
 
 
23
 
24
 
25
  @st.cache_resource
26
+ def get_model(device):
27
  model = TransformerNetModel(
28
  in_channels=32,
29
  model_channels=128,
 
36
  model.load_state_dict(
37
  torch.load(
38
  os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"),
39
+ map_location=torch.device(device),
40
  )
41
  )
42
+ model.to(device)
43
  model.eval()
44
  return model
45
 
 
58
  )
59
 
60
 
61
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
62
+
63
  tokenizer = get_tokenizer()
64
+ encoder = get_encoder(device)
65
+ model = get_model(device)
66
  diffusion = get_diffusion()
67
 
68
  st.title("Lang2mol-Diff")
 
80
  return_attention_mask=True,
81
  )
82
  caption_state = encoder(
83
+ input_ids=output["input_ids"].to(device),
84
+ attention_mask=output["attention_mask"].to(device),
85
  ).last_hidden_state
86
  caption_mask = output["attention_mask"]
87
 
 
93
  model_kwargs={},
94
  top_p=1.0,
95
  progress=True,
96
+ caption=(caption_state.to(device), caption_mask.to(device)),
97
  )
98
  logits = model.get_logits(torch.tensor(outputs))
99
  cands = torch.topk(logits, k=1, dim=-1)