sagar007 commited on
Commit
324c98e
·
verified ·
1 Parent(s): 0be31e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -12
app.py CHANGED
@@ -9,6 +9,22 @@ from torch.nn import functional as F
9
  import tiktoken
10
  import gradio as gr
11
  import asyncio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Define the model architecture
13
  class GPTConfig:
14
  def __init__(self):
@@ -134,7 +150,7 @@ import gradio as gr
134
  # [Your existing model code remains unchanged]
135
 
136
  # Modify the generate_text function to be asynchronous
137
- async def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
138
  input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
139
  generated = []
140
 
@@ -151,13 +167,16 @@ async def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
151
  input_ids = torch.cat([input_ids, next_token], dim=-1)
152
  generated.append(next_token.item())
153
 
154
- yield enc.decode([next_token.item()])
 
155
 
156
- if next_token.item() == enc.encode('\n')[0] and len(generated) > 20:
157
  break
158
 
159
- await asyncio.sleep(0.05) # Use asyncio.sleep instead of time.sleep
160
 
 
 
161
  # Modify the gradio_generate function to be asynchronous
162
  async def gradio_generate(prompt, max_length, temperature, top_k):
163
  output = ""
@@ -178,20 +197,23 @@ css = """
178
  </style>
179
  """
180
 
181
- # 6. Gradio App Definition
182
  with gr.Blocks(css=css) as demo:
183
- gr.HTML("<div class='header'><h1>🌟 GPT-2 Text Generator</h1></div>")
184
 
185
  with gr.Row():
186
  with gr.Column(scale=3):
187
- prompt = gr.Textbox(placeholder="Enter your prompt here...", label="Prompt", elem_classes="user-input")
 
 
 
 
188
  with gr.Column(scale=1):
189
- generate_btn = gr.Button("Generate", elem_classes="generate-btn")
190
 
191
  with gr.Row():
192
- max_length = gr.Slider(minimum=20, maximum=500, value=100, step=1, label="Max Length")
193
- temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
194
- top_k = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k")
195
 
196
  output = gr.Markdown(elem_classes="output-box")
197
 
@@ -201,6 +223,6 @@ with gr.Blocks(css=css) as demo:
201
  outputs=output
202
  )
203
 
204
- # 7. Launch the app
205
  if __name__ == "__main__":
206
  demo.launch()
 
9
  import tiktoken
10
  import gradio as gr
11
  import asyncio
12
+
13
+ # Add the post-processing function here
14
+ def post_process_text(text):
15
+ # Ensure the text starts with a capital letter
16
+ text = text.capitalize()
17
+
18
+ # Remove any incomplete sentences at the end
19
+ sentences = text.split('.')
20
+ complete_sentences = sentences[:-1] if len(sentences) > 1 else sentences
21
+
22
+ # Rejoin sentences and add a period if missing
23
+ processed_text = '. '.join(complete_sentences)
24
+ if not processed_text.endswith('.'):
25
+ processed_text += '.'
26
+
27
+ return processed_text
28
  # Define the model architecture
29
  class GPTConfig:
30
  def __init__(self):
 
150
  # [Your existing model code remains unchanged]
151
 
152
  # Modify the generate_text function to be asynchronous
153
+ async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
154
  input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
155
  generated = []
156
 
 
167
  input_ids = torch.cat([input_ids, next_token], dim=-1)
168
  generated.append(next_token.item())
169
 
170
+ next_token_str = enc.decode([next_token.item()])
171
+ yield next_token_str
172
 
173
+ if next_token.item() == enc.encode('\n')[0] and len(generated) > 100:
174
  break
175
 
176
+ await asyncio.sleep(0.02) # Slightly faster typing effect
177
 
178
+ if len(generated) == max_length:
179
+ yield "... (output truncated due to length)"
180
  # Modify the gradio_generate function to be asynchronous
181
  async def gradio_generate(prompt, max_length, temperature, top_k):
182
  output = ""
 
197
  </style>
198
  """
199
 
 
200
  with gr.Blocks(css=css) as demo:
201
+ gr.HTML("<div class='header'><h1>🌟 GPT-2 Storyteller</h1></div>")
202
 
203
  with gr.Row():
204
  with gr.Column(scale=3):
205
+ prompt = gr.Textbox(
206
+ placeholder="Start your story here (e.g., 'Once upon a time in a magical forest...')",
207
+ label="Story Prompt",
208
+ elem_classes="user-input"
209
+ )
210
  with gr.Column(scale=1):
211
+ generate_btn = gr.Button("Generate Story", elem_classes="generate-btn")
212
 
213
  with gr.Row():
214
+ max_length = gr.Slider(minimum=50, maximum=500, value=432, step=1, label="Max Length")
215
+ temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature")
216
+ top_k = gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-k")
217
 
218
  output = gr.Markdown(elem_classes="output-box")
219
 
 
223
  outputs=output
224
  )
225
 
226
+
227
  if __name__ == "__main__":
228
  demo.launch()