Spaces:
Runtime error
Runtime error
Update gen.py
Browse files
gen.py
CHANGED
@@ -69,7 +69,7 @@ def get_pretrained_models(
|
|
69 |
tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model")
|
70 |
model_args.vocab_size = tokenizer.n_words
|
71 |
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
72 |
-
model = Transformer(model_args)
|
73 |
torch.set_default_tensor_type(torch.FloatTensor)
|
74 |
model.load_state_dict(checkpoint, strict=False)
|
75 |
|
|
|
69 |
tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model")
|
70 |
model_args.vocab_size = tokenizer.n_words
|
71 |
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
72 |
+
model = Transformer(model_args).cuda().half()
|
73 |
torch.set_default_tensor_type(torch.FloatTensor)
|
74 |
model.load_state_dict(checkpoint, strict=False)
|
75 |
|