saintyboy commited on
Commit
78c09a1
1 Parent(s): 1875730

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -29,7 +29,10 @@ enc = tiktoken.Encoding(
29
  # Load model from checkpoint
30
  model_save_path = 'tuned_ckpt.pt'
31
  if os.path.exists(model_save_path):
32
- model = torch.load(model_save_path, map_location=device)
 
 
 
33
  else:
34
  raise FileNotFoundError(f"Model file {model_save_path} not found")
35
 
 
29
  # Load model from checkpoint
30
  model_save_path = 'tuned_ckpt.pt'
31
  if os.path.exists(model_save_path):
32
+ checkpoint = torch.load(model_save_path, map_location=device)
33
+ gptconf = GPTConfig(**checkpoint['model_args'])
34
+ model = GPT(gptconf)
35
+ model.load_state_dict(checkpoint['model'])
36
  else:
37
  raise FileNotFoundError(f"Model file {model_save_path} not found")
38