newaimodel / app.py
saintyboy's picture
Update app.py
b6e3bcc verified
raw
history blame
5.52 kB
import os
import torch
import tiktoken
from model import GPT, GPTConfig
import pickle
import string
import gradio as gr
from contextlib import nullcontext
# Model and Tokenizer setup
device = 'cpu'
dtype = 'bfloat16' if device != 'cpu' and torch.cuda.is_bf16_supported() else 'float16'
device_type = 'cuda' if 'cuda' in device else 'cpu'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
cl100k_base = tiktoken.get_encoding("cl100k_base")
enc = tiktoken.Encoding(
name="cl100k_im",
pat_str=cl100k_base._pat_str,
mergeable_ranks=cl100k_base._mergeable_ranks,
special_tokens={
**cl100k_base._special_tokens,
"<EOR>": 100264,
"<SPECIAL2>": 100265,
}
)
# Load model from checkpoint
model_save_path = 'tuned_ckpt.pt'
if os.path.exists(model_save_path):
model = torch.load(model_save_path, map_location=device)
else:
raise FileNotFoundError(f"Model file {model_save_path} not found")
model.eval()
model.to(device)
model = torch.compile(model)
# Function to encode and decode using the tokenizer
def encode(text):
return enc.encode(text, allowed_special={"<EOR>"})
def decode(tokens):
return enc.decode(tokens)
# Function to truncate output to token limit
def truncate_output(text, token_limit):
tokens = text.split()
if len(tokens) > token_limit:
return ' '.join(tokens[:token_limit]) + '...'
return text
def ensure_complete_output(output, context, max_length, temperature, top_k, top_p, repetition_penalty, eor_token_id):
while len(output.split()) < max_length and not output.endswith('.'):
continuation = model.generate(
torch.tensor(encode(output), dtype=torch.long, device=device)[None, ...],
max_new_tokens=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
eor_token_id=eor_token_id
)
continuation_text = decode(continuation[0].tolist())
if eor_token_id in continuation[0].tolist():
continuation_text = continuation_text.split("<EOR>")[0]
output += continuation_text
break
else:
output += continuation_text
if len(output.split()) >= max_length:
break
return output
# Text generation function for Gradio interface
def generate_text(prompt, num_samples, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id):
with torch.no_grad():
with ctx:
start_ids = encode(prompt)
initial_prompt = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
outputs = []
for _ in range(num_samples):
y = model.generate(
initial_prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
eor_token_id=eor_token_id
)
# Filter out tokens after the end-of-response token or similar markers
output_ids = y[0].tolist()
if eor_token_id in output_ids:
output_ids = output_ids[:output_ids.index(eor_token_id) + 1] # Include EOR token
else:
# Check for similar markers like '<E' and handle them
try:
eor_index = next(i for i, token in enumerate(output_ids) if decode([token]).startswith('<E'))
output_ids = output_ids[:eor_index]
except StopIteration:
pass
# Ensure the prompt is not included in the final output
output = decode(output_ids).replace(prompt, '').strip()
output = ensure_complete_output(output, prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id)
truncated_output = truncate_output(output, max_new_tokens)
outputs.append(truncated_output)
return '\n\n'.join(outputs)
# Create a Gradio interface
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.inputs.Textbox(lines=2, placeholder="Enter your prompt here...", default="Write a short story about a boy:"),
gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Samples"),
gr.inputs.Slider(minimum=10, maximum=200, step=1, default=75, label="Max New Tokens"),
gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.8, label="Temperature"),
gr.inputs.Slider(minimum=1, maximum=100, step=1, default=50, label="Top-k"),
gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.85, label="Top-p"),
gr.inputs.Slider(minimum=1.0, maximum=2.0, step=0.1, default=1.1, label="Repetition Penalty"),
gr.inputs.Number(default=100264, label="End-of-Response Token ID")
],
outputs="text",
title="GPT Text Generator",
description="Generate text based on a prompt using a trained GPT model.",
examples=[
["Write a short story about a boy:"],
["Explain the theory of relativity:"],
["What is the meaning of life?"]
]
)
# Launch the Gradio app
demo.launch()