Update app.py
Browse files
app.py
CHANGED
@@ -100,15 +100,27 @@ class GPT(nn.Module):
|
|
100 |
|
101 |
return logits, loss
|
102 |
|
103 |
-
#
|
104 |
def load_model(model_path):
|
105 |
config = GPTConfig()
|
106 |
model = GPT(config)
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
model.eval()
|
109 |
return model
|
110 |
|
111 |
-
|
|
|
112 |
enc = tiktoken.get_encoding('gpt2')
|
113 |
|
114 |
def generate_text(prompt, max_length=100, temperature=0.7):
|
|
|
100 |
|
101 |
return logits, loss
|
102 |
|
103 |
+
# Updated load_model function
|
104 |
def load_model(model_path):
|
105 |
config = GPTConfig()
|
106 |
model = GPT(config)
|
107 |
+
|
108 |
+
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
|
109 |
+
|
110 |
+
print("Checkpoint keys:", checkpoint.keys()) # Debug print
|
111 |
+
|
112 |
+
if 'model_state_dict' in checkpoint:
|
113 |
+
# If the checkpoint contains a 'model_state_dict' key, use that
|
114 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
115 |
+
else:
|
116 |
+
# Otherwise, try to load the state dict directly
|
117 |
+
model.load_state_dict(checkpoint)
|
118 |
+
|
119 |
model.eval()
|
120 |
return model
|
121 |
|
122 |
+
# Load the trained model
|
123 |
+
model = load_model('gpt_5000.pt') # Replace with the actual path to your .pt file
|
124 |
enc = tiktoken.get_encoding('gpt2')
|
125 |
|
126 |
def generate_text(prompt, max_length=100, temperature=0.7):
|