sagar007 commited on
Commit
5ae24e1
·
verified ·
1 Parent(s): 559a174

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -100,15 +100,27 @@ class GPT(nn.Module):
100
 
101
  return logits, loss
102
 
103
- # Load the trained model
104
  def load_model(model_path):
105
  config = GPTConfig()
106
  model = GPT(config)
107
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
 
 
 
 
 
 
 
 
 
 
 
108
  model.eval()
109
  return model
110
 
111
- model = load_model('gpt_5000.pt') # Replace with the actual path to your .pth file
 
112
  enc = tiktoken.get_encoding('gpt2')
113
 
114
  def generate_text(prompt, max_length=100, temperature=0.7):
 
100
 
101
  return logits, loss
102
 
103
+ # Updated load_model function
104
  def load_model(model_path):
105
  config = GPTConfig()
106
  model = GPT(config)
107
+
108
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
109
+
110
+ print("Checkpoint keys:", checkpoint.keys()) # Debug print
111
+
112
+ if 'model_state_dict' in checkpoint:
113
+ # If the checkpoint contains a 'model_state_dict' key, use that
114
+ model.load_state_dict(checkpoint['model_state_dict'])
115
+ else:
116
+ # Otherwise, try to load the state dict directly
117
+ model.load_state_dict(checkpoint)
118
+
119
  model.eval()
120
  return model
121
 
122
+ # Load the trained model
123
+ model = load_model('gpt_5000.pt') # Replace with the actual path to your .pt file
124
  enc = tiktoken.get_encoding('gpt2')
125
 
126
  def generate_text(prompt, max_length=100, temperature=0.7):