chansung commited on
Commit
e66e083
1 Parent(s): b9d9ab2

Update gen.py

Browse files
Files changed (1) hide show
  1. gen.py +1 -1
gen.py CHANGED
@@ -69,7 +69,7 @@ def get_pretrained_models(
69
  tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model")
70
  model_args.vocab_size = tokenizer.n_words
71
  torch.set_default_tensor_type(torch.cuda.HalfTensor)
72
- model = Transformer(model_args)
73
  torch.set_default_tensor_type(torch.FloatTensor)
74
  model.load_state_dict(checkpoint, strict=False)
75
 
 
69
  tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model")
70
  model_args.vocab_size = tokenizer.n_words
71
  torch.set_default_tensor_type(torch.cuda.HalfTensor)
72
+ model = Transformer(model_args).cuda().half()
73
  torch.set_default_tensor_type(torch.FloatTensor)
74
  model.load_state_dict(checkpoint, strict=False)
75