sagar007 commited on
Commit
85585aa
·
verified ·
1 Parent(s): 5ae24e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -19
app.py CHANGED
@@ -1,12 +1,10 @@
1
- import os
2
- import math
3
  import torch
4
  import torch.nn as nn
5
  from torch.nn import functional as F
6
- import gradio as gr
7
  import tiktoken
 
8
 
9
- # GPT model code
10
  class GPTConfig:
11
  def __init__(self):
12
  self.block_size = 1024
@@ -100,7 +98,7 @@ class GPT(nn.Module):
100
 
101
  return logits, loss
102
 
103
- # Updated load_model function
104
  def load_model(model_path):
105
  config = GPTConfig()
106
  model = GPT(config)
@@ -110,46 +108,63 @@ def load_model(model_path):
110
  print("Checkpoint keys:", checkpoint.keys()) # Debug print
111
 
112
  if 'model_state_dict' in checkpoint:
113
- # If the checkpoint contains a 'model_state_dict' key, use that
114
  model.load_state_dict(checkpoint['model_state_dict'])
115
  else:
116
- # Otherwise, try to load the state dict directly
117
  model.load_state_dict(checkpoint)
118
 
119
  model.eval()
120
  return model
121
 
122
- # Load the trained model
123
  model = load_model('gpt_5000.pt') # Replace with the actual path to your .pt file
124
  enc = tiktoken.get_encoding('gpt2')
125
 
126
- def generate_text(prompt, max_length=100, temperature=0.7):
 
127
  input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
 
128
 
129
  with torch.no_grad():
130
  for _ in range(max_length):
131
  outputs, _ = model(input_ids)
132
- next_token_logits = outputs[:, -1, :] / temperature
133
- next_token = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1)
 
 
 
 
 
 
 
 
 
 
 
134
  input_ids = torch.cat([input_ids, next_token], dim=-1)
 
135
 
136
- if next_token.item() == enc.encode('\n')[0]:
 
137
  break
138
 
139
- generated_text = enc.decode(input_ids[0].tolist())
140
- return generated_text
141
 
142
  # Gradio interface
 
 
 
143
  iface = gr.Interface(
144
- fn=generate_text,
145
  inputs=[
146
  gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
147
- gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
148
- gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
 
149
  ],
150
  outputs=gr.Textbox(label="Generated Text"),
151
- title="GPT-2 Text Generator",
152
- description="Enter a prompt and generate text using a fine-tuned GPT-2 model."
153
  )
154
 
155
  # Launch the app
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  from torch.nn import functional as F
 
4
  import tiktoken
5
+ import gradio as gr
6
 
7
+ # Define the model architecture
8
  class GPTConfig:
9
  def __init__(self):
10
  self.block_size = 1024
 
98
 
99
  return logits, loss
100
 
101
+ # Load the model
102
  def load_model(model_path):
103
  config = GPTConfig()
104
  model = GPT(config)
 
108
  print("Checkpoint keys:", checkpoint.keys()) # Debug print
109
 
110
  if 'model_state_dict' in checkpoint:
 
111
  model.load_state_dict(checkpoint['model_state_dict'])
112
  else:
 
113
  model.load_state_dict(checkpoint)
114
 
115
  model.eval()
116
  return model
117
 
118
+ # Load the model
119
  model = load_model('gpt_5000.pt') # Replace with the actual path to your .pt file
120
  enc = tiktoken.get_encoding('gpt2')
121
 
122
+ # Improved text generation function
123
+ def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
124
  input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
125
+ generated = []
126
 
127
  with torch.no_grad():
128
  for _ in range(max_length):
129
  outputs, _ = model(input_ids)
130
+ next_token_logits = outputs[:, -1, :]
131
+
132
+ # Apply temperature
133
+ next_token_logits = next_token_logits / temperature
134
+
135
+ # Apply top-k filtering
136
+ top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
137
+ next_token_probs = F.softmax(top_k_logits, dim=-1)
138
+
139
+ # Sample from the filtered distribution
140
+ next_token_index = torch.multinomial(next_token_probs, num_samples=1)
141
+ next_token = top_k_indices.gather(-1, next_token_index)
142
+
143
  input_ids = torch.cat([input_ids, next_token], dim=-1)
144
+ generated.append(next_token.item())
145
 
146
+ # Stop if we generate a newline, but only after generating at least 20 tokens
147
+ if next_token.item() == enc.encode('\n')[0] and len(generated) > 20:
148
  break
149
 
150
+ generated_text = enc.decode(generated)
151
+ return prompt + generated_text
152
 
153
  # Gradio interface
154
+ def gradio_generate(prompt, max_length, temperature, top_k):
155
+ return generate_text(prompt, max_length, temperature, top_k)
156
+
157
  iface = gr.Interface(
158
+ fn=gradio_generate,
159
  inputs=[
160
  gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
161
+ gr.Slider(minimum=20, maximum=500, value=100, step=1, label="Max Length"),
162
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
163
+ gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k")
164
  ],
165
  outputs=gr.Textbox(label="Generated Text"),
166
+ title="GPT Text Generator",
167
+ description="Enter a prompt and adjust parameters to generate text using a fine-tuned GPT model."
168
  )
169
 
170
  # Launch the app