Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -121,14 +121,12 @@ class GPT(nn.Module):
|
|
121 |
|
122 |
return logits, loss
|
123 |
|
124 |
-
|
125 |
def load_model(model_path):
|
126 |
config = GPTConfig()
|
127 |
model = GPT(config)
|
128 |
|
129 |
-
checkpoint = torch.load(model_path, map_location=torch.device('
|
130 |
-
|
131 |
-
print("Checkpoint keys:", checkpoint.keys()) # Debug print
|
132 |
|
133 |
if 'model_state_dict' in checkpoint:
|
134 |
model.load_state_dict(checkpoint['model_state_dict'])
|
@@ -136,24 +134,17 @@ def load_model(model_path):
|
|
136 |
model.load_state_dict(checkpoint)
|
137 |
|
138 |
model.eval()
|
|
|
139 |
return model
|
140 |
|
141 |
# Load the model
|
142 |
model = load_model('gpt_model.pth') # Replace with the actual path to your .pt file
|
143 |
enc = tiktoken.get_encoding('gpt2')
|
144 |
|
145 |
-
#
|
146 |
-
|
147 |
-
import torch.nn as nn
|
148 |
-
from torch.nn import functional as F
|
149 |
-
import tiktoken
|
150 |
-
import gradio as gr
|
151 |
-
|
152 |
-
# [Your existing model code remains unchanged]
|
153 |
-
|
154 |
-
# Modify the generate_text function to be asynchronous
|
155 |
async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
|
156 |
-
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
|
157 |
generated = []
|
158 |
|
159 |
with torch.no_grad():
|
@@ -179,7 +170,9 @@ async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
|
|
179 |
|
180 |
if len(generated) == max_length:
|
181 |
yield "... (output truncated due to length)"
|
182 |
-
|
|
|
|
|
183 |
async def gradio_generate(prompt, max_length, temperature, top_k):
|
184 |
output = ""
|
185 |
async for token in generate_text(prompt, max_length, temperature, top_k):
|
|
|
121 |
|
122 |
return logits, loss
|
123 |
|
124 |
+
@spaces.GPU
|
125 |
def load_model(model_path):
|
126 |
config = GPTConfig()
|
127 |
model = GPT(config)
|
128 |
|
129 |
+
checkpoint = torch.load(model_path, map_location=torch.device('cuda'))
|
|
|
|
|
130 |
|
131 |
if 'model_state_dict' in checkpoint:
|
132 |
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
134 |
model.load_state_dict(checkpoint)
|
135 |
|
136 |
model.eval()
|
137 |
+
model.to('cuda')
|
138 |
return model
|
139 |
|
140 |
# Load the model
|
141 |
model = load_model('gpt_model.pth') # Replace with the actual path to your .pt file
|
142 |
enc = tiktoken.get_encoding('gpt2')
|
143 |
|
144 |
+
# Update the generate_text function
|
145 |
+
@spaces.GPU(duration=60) # Adjust duration as needed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
|
147 |
+
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).cuda()
|
148 |
generated = []
|
149 |
|
150 |
with torch.no_grad():
|
|
|
170 |
|
171 |
if len(generated) == max_length:
|
172 |
yield "... (output truncated due to length)"
|
173 |
+
|
174 |
+
# Update the gradio_generate function
|
175 |
+
@spaces.GPU(duration=60) # Adjust duration as needed
|
176 |
async def gradio_generate(prompt, max_length, temperature, top_k):
|
177 |
output = ""
|
178 |
async for token in generate_text(prompt, max_length, temperature, top_k):
|