Spaces:
Runtime error
Runtime error
Update gen.py
Browse files
gen.py
CHANGED
@@ -67,7 +67,7 @@ def get_pretrained_models(
|
|
67 |
with open(Path(llama_weight_path) / "params.json", "r") as f:
|
68 |
params = json.loads(f.read())
|
69 |
|
70 |
-
model_args: ModelArgs = ModelArgs(max_seq_len=
|
71 |
tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model")
|
72 |
model_args.vocab_size = tokenizer.n_words
|
73 |
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
|
|
67 |
with open(Path(llama_weight_path) / "params.json", "r") as f:
|
68 |
params = json.loads(f.read())
|
69 |
|
70 |
+
model_args: ModelArgs = ModelArgs(max_seq_len=512, max_batch_size=1, **params)
|
71 |
tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model")
|
72 |
model_args.vocab_size = tokenizer.n_words
|
73 |
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|