Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import tiktoken | |
from model import GPT, GPTConfig | |
import pickle | |
import string | |
import gradio as gr | |
from contextlib import nullcontext | |
# Model and Tokenizer setup | |
device = 'cpu' | |
dtype = 'bfloat16' if device != 'cpu' and torch.cuda.is_bf16_supported() else 'float16' | |
device_type = 'cuda' if 'cuda' in device else 'cpu' | |
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] | |
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) | |
cl100k_base = tiktoken.get_encoding("cl100k_base") | |
enc = tiktoken.Encoding( | |
name="cl100k_im", | |
pat_str=cl100k_base._pat_str, | |
mergeable_ranks=cl100k_base._mergeable_ranks, | |
special_tokens={ | |
**cl100k_base._special_tokens, | |
"<EOR>": 100264, | |
"<SPECIAL2>": 100265, | |
} | |
) | |
# Load model from checkpoint | |
model_save_path = 'tuned_ckpt.pt' | |
if os.path.exists(model_save_path): | |
model = torch.load(model_save_path, map_location=device) | |
else: | |
raise FileNotFoundError(f"Model file {model_save_path} not found") | |
model.eval() | |
model.to(device) | |
model = torch.compile(model) | |
# Function to encode and decode using the tokenizer | |
def encode(text): | |
return enc.encode(text, allowed_special={"<EOR>"}) | |
def decode(tokens): | |
return enc.decode(tokens) | |
# Function to truncate output to token limit | |
def truncate_output(text, token_limit): | |
tokens = text.split() | |
if len(tokens) > token_limit: | |
return ' '.join(tokens[:token_limit]) + '...' | |
return text | |
def ensure_complete_output(output, context, max_length, temperature, top_k, top_p, repetition_penalty, eor_token_id): | |
while len(output.split()) < max_length and not output.endswith('.'): | |
continuation = model.generate( | |
torch.tensor(encode(output), dtype=torch.long, device=device)[None, ...], | |
max_new_tokens=max_length, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
eor_token_id=eor_token_id | |
) | |
continuation_text = decode(continuation[0].tolist()) | |
if eor_token_id in continuation[0].tolist(): | |
continuation_text = continuation_text.split("<EOR>")[0] | |
output += continuation_text | |
break | |
else: | |
output += continuation_text | |
if len(output.split()) >= max_length: | |
break | |
return output | |
# Text generation function for Gradio interface | |
def generate_text(prompt, num_samples, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id): | |
with torch.no_grad(): | |
with ctx: | |
start_ids = encode(prompt) | |
initial_prompt = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...] | |
outputs = [] | |
for _ in range(num_samples): | |
y = model.generate( | |
initial_prompt, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
eor_token_id=eor_token_id | |
) | |
# Filter out tokens after the end-of-response token or similar markers | |
output_ids = y[0].tolist() | |
if eor_token_id in output_ids: | |
output_ids = output_ids[:output_ids.index(eor_token_id) + 1] # Include EOR token | |
else: | |
# Check for similar markers like '<E' and handle them | |
try: | |
eor_index = next(i for i, token in enumerate(output_ids) if decode([token]).startswith('<E')) | |
output_ids = output_ids[:eor_index] | |
except StopIteration: | |
pass | |
# Ensure the prompt is not included in the final output | |
output = decode(output_ids).replace(prompt, '').strip() | |
output = ensure_complete_output(output, prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id) | |
truncated_output = truncate_output(output, max_new_tokens) | |
outputs.append(truncated_output) | |
return '\n\n'.join(outputs) | |
# Create a Gradio interface | |
demo = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.inputs.Textbox(lines=2, placeholder="Enter your prompt here...", default="Write a short story about a boy:"), | |
gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Samples"), | |
gr.inputs.Slider(minimum=10, maximum=200, step=1, default=75, label="Max New Tokens"), | |
gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.8, label="Temperature"), | |
gr.inputs.Slider(minimum=1, maximum=100, step=1, default=50, label="Top-k"), | |
gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.85, label="Top-p"), | |
gr.inputs.Slider(minimum=1.0, maximum=2.0, step=0.1, default=1.1, label="Repetition Penalty"), | |
gr.inputs.Number(default=100264, label="End-of-Response Token ID") | |
], | |
outputs="text", | |
title="GPT Text Generator", | |
description="Generate text based on a prompt using a trained GPT model.", | |
examples=[ | |
["Write a short story about a boy:"], | |
["Explain the theory of relativity:"], | |
["What is the meaning of life?"] | |
] | |
) | |
# Launch the Gradio app | |
demo.launch() | |