Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,15 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
from
|
3 |
|
4 |
-
#
|
5 |
-
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# Create a Gradio interface
|
12 |
-
demo = gr.Interface(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# Launch the Gradio app
|
15 |
demo.launch()
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import tiktoken
|
4 |
+
from model import GPT, GPTConfig
|
5 |
+
import pickle
|
6 |
+
import string
|
7 |
import gradio as gr
|
8 |
+
from contextlib import nullcontext
|
9 |
|
10 |
+
# Model and Tokenizer setup
|
11 |
+
device = 'cpu'
|
12 |
+
dtype = 'bfloat16' if device != 'cpu' and torch.cuda.is_bf16_supported() else 'float16'
|
13 |
+
device_type = 'cuda' if 'cuda' in device else 'cpu'
|
14 |
+
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
|
15 |
+
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
16 |
|
17 |
+
cl100k_base = tiktoken.get_encoding("cl100k_base")
|
18 |
+
enc = tiktoken.Encoding(
|
19 |
+
name="cl100k_im",
|
20 |
+
pat_str=cl100k_base._pat_str,
|
21 |
+
mergeable_ranks=cl100k_base._mergeable_ranks,
|
22 |
+
special_tokens={
|
23 |
+
**cl100k_base._special_tokens,
|
24 |
+
"<EOR>": 100264,
|
25 |
+
"<SPECIAL2>": 100265,
|
26 |
+
}
|
27 |
+
)
|
28 |
+
|
29 |
+
# Load model from checkpoint
|
30 |
+
model_save_path = 'latest_model.pt'
|
31 |
+
if os.path.exists(model_save_path):
|
32 |
+
model = torch.load(model_save_path, map_location=device)
|
33 |
+
else:
|
34 |
+
raise FileNotFoundError(f"Model file {model_save_path} not found")
|
35 |
+
|
36 |
+
model.eval()
|
37 |
+
model.to(device)
|
38 |
+
model = torch.compile(model)
|
39 |
+
|
40 |
+
# Function to encode and decode using the tokenizer
|
41 |
+
def encode(text):
|
42 |
+
return enc.encode(text, allowed_special={"<EOR>"})
|
43 |
+
|
44 |
+
def decode(tokens):
|
45 |
+
return enc.decode(tokens)
|
46 |
+
|
47 |
+
# Function to truncate output to token limit
|
48 |
+
def truncate_output(text, token_limit):
|
49 |
+
tokens = text.split()
|
50 |
+
if len(tokens) > token_limit:
|
51 |
+
return ' '.join(tokens[:token_limit]) + '...'
|
52 |
+
return text
|
53 |
+
|
54 |
+
def ensure_complete_output(output, context, max_length, temperature, top_k, top_p, repetition_penalty, eor_token_id):
|
55 |
+
while len(output.split()) < max_length and not output.endswith('.'):
|
56 |
+
continuation = model.generate(
|
57 |
+
torch.tensor(encode(output), dtype=torch.long, device=device)[None, ...],
|
58 |
+
max_new_tokens=max_length,
|
59 |
+
temperature=temperature,
|
60 |
+
top_k=top_k,
|
61 |
+
top_p=top_p,
|
62 |
+
repetition_penalty=repetition_penalty,
|
63 |
+
eor_token_id=eor_token_id
|
64 |
+
)
|
65 |
+
continuation_text = decode(continuation[0].tolist())
|
66 |
+
if eor_token_id in continuation[0].tolist():
|
67 |
+
continuation_text = continuation_text.split("<EOR>")[0]
|
68 |
+
output += continuation_text
|
69 |
+
break
|
70 |
+
else:
|
71 |
+
output += continuation_text
|
72 |
+
if len(output.split()) >= max_length:
|
73 |
+
break
|
74 |
+
return output
|
75 |
+
|
76 |
+
# Text generation function for Gradio interface
|
77 |
+
def generate_text(prompt, num_samples, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id):
|
78 |
+
with torch.no_grad():
|
79 |
+
with ctx:
|
80 |
+
start_ids = encode(prompt)
|
81 |
+
initial_prompt = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
|
82 |
+
outputs = []
|
83 |
+
for _ in range(num_samples):
|
84 |
+
y = model.generate(
|
85 |
+
initial_prompt,
|
86 |
+
max_new_tokens=max_new_tokens,
|
87 |
+
temperature=temperature,
|
88 |
+
top_k=top_k,
|
89 |
+
top_p=top_p,
|
90 |
+
repetition_penalty=repetition_penalty,
|
91 |
+
eor_token_id=eor_token_id
|
92 |
+
)
|
93 |
+
# Filter out tokens after the end-of-response token or similar markers
|
94 |
+
output_ids = y[0].tolist()
|
95 |
+
if eor_token_id in output_ids:
|
96 |
+
output_ids = output_ids[:output_ids.index(eor_token_id) + 1] # Include EOR token
|
97 |
+
else:
|
98 |
+
# Check for similar markers like '<E' and handle them
|
99 |
+
try:
|
100 |
+
eor_index = next(i for i, token in enumerate(output_ids) if decode([token]).startswith('<E'))
|
101 |
+
output_ids = output_ids[:eor_index]
|
102 |
+
except StopIteration:
|
103 |
+
pass
|
104 |
+
|
105 |
+
# Ensure the prompt is not included in the final output
|
106 |
+
output = decode(output_ids).replace(prompt, '').strip()
|
107 |
+
output = ensure_complete_output(output, prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id)
|
108 |
+
truncated_output = truncate_output(output, max_new_tokens)
|
109 |
+
outputs.append(truncated_output)
|
110 |
+
return '\n\n'.join(outputs)
|
111 |
|
112 |
# Create a Gradio interface
|
113 |
+
demo = gr.Interface(
|
114 |
+
fn=generate_text,
|
115 |
+
inputs=[
|
116 |
+
gr.inputs.Textbox(lines=2, placeholder="Enter your prompt here...", default="Write a short story about a boy:"),
|
117 |
+
gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Samples"),
|
118 |
+
gr.inputs.Slider(minimum=10, maximum=200, step=1, default=75, label="Max New Tokens"),
|
119 |
+
gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.8, label="Temperature"),
|
120 |
+
gr.inputs.Slider(minimum=1, maximum=100, step=1, default=50, label="Top-k"),
|
121 |
+
gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.85, label="Top-p"),
|
122 |
+
gr.inputs.Slider(minimum=1.0, maximum=2.0, step=0.1, default=1.1, label="Repetition Penalty"),
|
123 |
+
gr.inputs.Number(default=100264, label="End-of-Response Token ID")
|
124 |
+
],
|
125 |
+
outputs="text",
|
126 |
+
title="GPT Text Generator",
|
127 |
+
description="Generate text based on a prompt using a trained GPT model.",
|
128 |
+
examples=[
|
129 |
+
["Write a short story about a boy:"],
|
130 |
+
["Explain the theory of relativity:"],
|
131 |
+
["What is the meaning of life?"]
|
132 |
+
]
|
133 |
+
)
|
134 |
|
135 |
# Launch the Gradio app
|
136 |
demo.launch()
|