Spaces:
Runtime error
Runtime error
File size: 5,517 Bytes
4028f2b 7805bd4 4028f2b 7805bd4 4028f2b 7805bd4 4028f2b 7805bd4 4028f2b 7805bd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
import torch
import tiktoken
from model import GPT, GPTConfig
import pickle
import string
import gradio as gr
from contextlib import nullcontext
# Model and Tokenizer setup
device = 'cpu'
dtype = 'bfloat16' if device != 'cpu' and torch.cuda.is_bf16_supported() else 'float16'
device_type = 'cuda' if 'cuda' in device else 'cpu'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
cl100k_base = tiktoken.get_encoding("cl100k_base")
enc = tiktoken.Encoding(
name="cl100k_im",
pat_str=cl100k_base._pat_str,
mergeable_ranks=cl100k_base._mergeable_ranks,
special_tokens={
**cl100k_base._special_tokens,
"<EOR>": 100264,
"<SPECIAL2>": 100265,
}
)
# Load model from checkpoint
model_save_path = 'latest_model.pt'
if os.path.exists(model_save_path):
model = torch.load(model_save_path, map_location=device)
else:
raise FileNotFoundError(f"Model file {model_save_path} not found")
model.eval()
model.to(device)
model = torch.compile(model)
# Function to encode and decode using the tokenizer
def encode(text):
return enc.encode(text, allowed_special={"<EOR>"})
def decode(tokens):
return enc.decode(tokens)
# Function to truncate output to token limit
def truncate_output(text, token_limit):
tokens = text.split()
if len(tokens) > token_limit:
return ' '.join(tokens[:token_limit]) + '...'
return text
def ensure_complete_output(output, context, max_length, temperature, top_k, top_p, repetition_penalty, eor_token_id):
while len(output.split()) < max_length and not output.endswith('.'):
continuation = model.generate(
torch.tensor(encode(output), dtype=torch.long, device=device)[None, ...],
max_new_tokens=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
eor_token_id=eor_token_id
)
continuation_text = decode(continuation[0].tolist())
if eor_token_id in continuation[0].tolist():
continuation_text = continuation_text.split("<EOR>")[0]
output += continuation_text
break
else:
output += continuation_text
if len(output.split()) >= max_length:
break
return output
# Text generation function for Gradio interface
def generate_text(prompt, num_samples, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id):
with torch.no_grad():
with ctx:
start_ids = encode(prompt)
initial_prompt = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
outputs = []
for _ in range(num_samples):
y = model.generate(
initial_prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
eor_token_id=eor_token_id
)
# Filter out tokens after the end-of-response token or similar markers
output_ids = y[0].tolist()
if eor_token_id in output_ids:
output_ids = output_ids[:output_ids.index(eor_token_id) + 1] # Include EOR token
else:
# Check for similar markers like '<E' and handle them
try:
eor_index = next(i for i, token in enumerate(output_ids) if decode([token]).startswith('<E'))
output_ids = output_ids[:eor_index]
except StopIteration:
pass
# Ensure the prompt is not included in the final output
output = decode(output_ids).replace(prompt, '').strip()
output = ensure_complete_output(output, prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id)
truncated_output = truncate_output(output, max_new_tokens)
outputs.append(truncated_output)
return '\n\n'.join(outputs)
# Create a Gradio interface
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.inputs.Textbox(lines=2, placeholder="Enter your prompt here...", default="Write a short story about a boy:"),
gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Samples"),
gr.inputs.Slider(minimum=10, maximum=200, step=1, default=75, label="Max New Tokens"),
gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.8, label="Temperature"),
gr.inputs.Slider(minimum=1, maximum=100, step=1, default=50, label="Top-k"),
gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.85, label="Top-p"),
gr.inputs.Slider(minimum=1.0, maximum=2.0, step=0.1, default=1.1, label="Repetition Penalty"),
gr.inputs.Number(default=100264, label="End-of-Response Token ID")
],
outputs="text",
title="GPT Text Generator",
description="Generate text based on a prompt using a trained GPT model.",
examples=[
["Write a short story about a boy:"],
["Explain the theory of relativity:"],
["What is the meaning of life?"]
]
)
# Launch the Gradio app
demo.launch()
|