sagar007 commited on
Commit
e60730b
·
verified ·
1 Parent(s): 25893d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -16
app.py CHANGED
@@ -121,14 +121,12 @@ class GPT(nn.Module):
121
 
122
  return logits, loss
123
 
124
- # Load the model
125
  def load_model(model_path):
126
  config = GPTConfig()
127
  model = GPT(config)
128
 
129
- checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
130
-
131
- print("Checkpoint keys:", checkpoint.keys()) # Debug print
132
 
133
  if 'model_state_dict' in checkpoint:
134
  model.load_state_dict(checkpoint['model_state_dict'])
@@ -136,24 +134,17 @@ def load_model(model_path):
136
  model.load_state_dict(checkpoint)
137
 
138
  model.eval()
 
139
  return model
140
 
141
  # Load the model
142
  model = load_model('gpt_model.pth') # Replace with the actual path to your .pt file
143
  enc = tiktoken.get_encoding('gpt2')
144
 
145
- # Improved text generation function
146
- import torch
147
- import torch.nn as nn
148
- from torch.nn import functional as F
149
- import tiktoken
150
- import gradio as gr
151
-
152
- # [Your existing model code remains unchanged]
153
-
154
- # Modify the generate_text function to be asynchronous
155
  async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
156
- input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
157
  generated = []
158
 
159
  with torch.no_grad():
@@ -179,7 +170,9 @@ async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
179
 
180
  if len(generated) == max_length:
181
  yield "... (output truncated due to length)"
182
- # Modify the gradio_generate function to be asynchronous
 
 
183
  async def gradio_generate(prompt, max_length, temperature, top_k):
184
  output = ""
185
  async for token in generate_text(prompt, max_length, temperature, top_k):
 
121
 
122
  return logits, loss
123
 
124
+ @spaces.GPU
125
  def load_model(model_path):
126
  config = GPTConfig()
127
  model = GPT(config)
128
 
129
+ checkpoint = torch.load(model_path, map_location=torch.device('cuda'))
 
 
130
 
131
  if 'model_state_dict' in checkpoint:
132
  model.load_state_dict(checkpoint['model_state_dict'])
 
134
  model.load_state_dict(checkpoint)
135
 
136
  model.eval()
137
+ model.to('cuda')
138
  return model
139
 
140
  # Load the model
141
  model = load_model('gpt_model.pth') # Replace with the actual path to your .pt file
142
  enc = tiktoken.get_encoding('gpt2')
143
 
144
+ # Update the generate_text function
145
+ @spaces.GPU(duration=60) # Adjust duration as needed
 
 
 
 
 
 
 
 
146
  async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
147
+ input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).cuda()
148
  generated = []
149
 
150
  with torch.no_grad():
 
170
 
171
  if len(generated) == max_length:
172
  yield "... (output truncated due to length)"
173
+
174
+ # Update the gradio_generate function
175
+ @spaces.GPU(duration=60) # Adjust duration as needed
176
  async def gradio_generate(prompt, max_length, temperature, top_k):
177
  output = ""
178
  async for token in generate_text(prompt, max_length, temperature, top_k):