chansung commited on
Commit
32d38c1
1 Parent(s): 3c01434

Update gen.py

Browse files
Files changed (1) hide show
  1. gen.py +1 -1
gen.py CHANGED
@@ -67,7 +67,7 @@ def get_pretrained_models(
67
  with open(Path(llama_weight_path) / "params.json", "r") as f:
68
  params = json.loads(f.read())
69
 
70
- model_args: ModelArgs = ModelArgs(max_seq_len=1024, max_batch_size=1, **params)
71
  tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model")
72
  model_args.vocab_size = tokenizer.n_words
73
  torch.set_default_tensor_type(torch.cuda.HalfTensor)
 
67
  with open(Path(llama_weight_path) / "params.json", "r") as f:
68
  params = json.loads(f.read())
69
 
70
+ model_args: ModelArgs = ModelArgs(max_seq_len=512, max_batch_size=1, **params)
71
  tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model")
72
  model_args.vocab_size = tokenizer.n_words
73
  torch.set_default_tensor_type(torch.cuda.HalfTensor)