cocktailpeanut commited on
Commit
0c5448c
·
1 Parent(s): 998a442
Files changed (1) hide show
  1. injection_main_HF.py +3 -1
injection_main_HF.py CHANGED
@@ -36,8 +36,10 @@ from typing import List, Tuple
36
  import omegaconf
37
  import utils.exp_utils
38
  import json
 
39
 
40
- device = "cuda"
 
41
 
42
 
43
  def _get_text_embeddings(prompt: str, tokenizer, text_encoder, device):
 
36
  import omegaconf
37
  import utils.exp_utils
38
  import json
39
+ import devicetorch
40
 
41
+ device = devicetorch.get(torch)
42
+ #device = "cuda"
43
 
44
 
45
  def _get_text_embeddings(prompt: str, tokenizer, text_encoder, device):