BK-Lee commited on
Commit
a3e7589
·
1 Parent(s): 8ab9b1a
Files changed (1) hide show
  1. trol/load_trol.py +8 -2
trol/load_trol.py CHANGED
@@ -64,8 +64,14 @@ def load_trol(link):
64
  # Loading tokenizer & Loading backbone model (error -> then delete flash attention)
65
  tok_trol = TroLTokenizer.from_pretrained(path, padding_side='left')
66
  try:
67
- trol = TroLForCausalLM.from_pretrained(path, **huggingface_config).cuda()
68
  except:
69
  del huggingface_config["attn_implementation"]
70
- trol = TroLForCausalLM.from_pretrained(path, **huggingface_config).cuda()
 
 
 
 
 
 
71
  return trol, tok_trol
 
64
  # Loading tokenizer & Loading backbone model (error -> then delete flash attention)
65
  tok_trol = TroLTokenizer.from_pretrained(path, padding_side='left')
66
  try:
67
+ trol = TroLForCausalLM.from_pretrained(path, **huggingface_config)
68
  except:
69
  del huggingface_config["attn_implementation"]
70
+ trol = TroLForCausalLM.from_pretrained(path, **huggingface_config)
71
+
72
+ # wrapping
73
+ try:
74
+ trol = trol.cuda()
75
+ except:
76
+ pass
77
  return trol, tok_trol