sagar007 commited on
Commit
0be31e9
·
verified ·
1 Parent(s): bdc217e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -3,7 +3,12 @@ import torch.nn as nn
3
  from torch.nn import functional as F
4
  import tiktoken
5
  import gradio as gr
6
-
 
 
 
 
 
7
  # Define the model architecture
8
  class GPTConfig:
9
  def __init__(self):
@@ -128,8 +133,8 @@ import gradio as gr
128
 
129
  # [Your existing model code remains unchanged]
130
 
131
- # Modified text generation function to yield tokens
132
- def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
133
  input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
134
  generated = []
135
 
@@ -150,13 +155,14 @@ def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
150
 
151
  if next_token.item() == enc.encode('\n')[0] and len(generated) > 20:
152
  break
 
 
153
 
154
- # Gradio interface
155
- def gradio_generate(prompt, max_length, temperature, top_k):
156
  output = ""
157
- for token in generate_text(prompt, max_length, temperature, top_k):
158
  output += token
159
- time.sleep(0.05) # Simulate typing effect
160
  yield output
161
 
162
  # Custom CSS for the animation effect
 
3
  from torch.nn import functional as F
4
  import tiktoken
5
  import gradio as gr
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn import functional as F
9
+ import tiktoken
10
+ import gradio as gr
11
+ import asyncio
12
  # Define the model architecture
13
  class GPTConfig:
14
  def __init__(self):
 
133
 
134
  # [Your existing model code remains unchanged]
135
 
136
+ # Modify the generate_text function to be asynchronous
137
+ async def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
138
  input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
139
  generated = []
140
 
 
155
 
156
  if next_token.item() == enc.encode('\n')[0] and len(generated) > 20:
157
  break
158
+
159
+ await asyncio.sleep(0.05) # Use asyncio.sleep instead of time.sleep
160
 
161
+ # Modify the gradio_generate function to be asynchronous
162
+ async def gradio_generate(prompt, max_length, temperature, top_k):
163
  output = ""
164
+ async for token in generate_text(prompt, max_length, temperature, top_k):
165
  output += token
 
166
  yield output
167
 
168
  # Custom CSS for the animation effect