newaimodel / app.py
saintyboy's picture
Update app.py
48efaf5 verified
import os
import torch
import tiktoken
from model import GPT, GPTConfig
import pickle
import string
import gradio as gr
from contextlib import nullcontext
# Model and Tokenizer setup
device = 'cpu'
dtype = 'bfloat16' if device != 'cpu' and torch.cuda.is_bf16_supported() else 'float16'
device_type = 'cuda' if 'cuda' in device else 'cpu'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
cl100k_base = tiktoken.get_encoding("cl100k_base")
enc = tiktoken.Encoding(
name="cl100k_im",
pat_str=cl100k_base._pat_str,
mergeable_ranks=cl100k_base._mergeable_ranks,
special_tokens={
**cl100k_base._special_tokens,
"<EOR>": 100264,
"<SPECIAL2>": 100265,
}
)
# Load model from checkpoint
model_save_path = 'tuned_ckpt.pt'
if os.path.exists(model_save_path):
checkpoint = torch.load(model_save_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
model.load_state_dict(checkpoint['model'])
else:
raise FileNotFoundError(f"Model file {model_save_path} not found")
model.eval()
model.to(device)
model = torch.compile(model)
# Function to encode and decode using the tokenizer
def encode(text):
return enc.encode(text, allowed_special={"<EOR>"})
def decode(tokens):
return enc.decode(tokens)
# Function to truncate output to token limit
def truncate_output(text, token_limit):
tokens = text.split()
if len(tokens) > token_limit:
return ' '.join(tokens[:token_limit]) + '...'
return text
def ensure_complete_output(output, context, max_length, temperature, top_k, top_p, repetition_penalty, eor_token_id):
while len(output.split()) < max_length and not output.endswith('.'):
continuation = model.generate(
torch.tensor(encode(output), dtype=torch.long, device=device)[None, ...],
max_new_tokens=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
eor_token_id=eor_token_id
)
continuation_text = decode(continuation[0].tolist())
if eor_token_id in continuation[0].tolist():
continuation_text = continuation_text.split("<EOR>")[0]
output += continuation_text
break
else:
output += continuation_text
if len(output.split()) >= max_length:
break
return output
# Text generation function for Gradio interface
def generate_text(prompt, num_samples, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id):
# Add input validation
if num_samples is None:
num_samples = 1
elif not isinstance(num_samples, int):
raise ValueError("Number of samples must be an integer")
with torch.no_grad():
with ctx:
start_ids = encode(prompt)
initial_prompt = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
outputs = []
for _ in range(num_samples):
y = model.generate(
initial_prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
eor_token_id=eor_token_id
)
# Filter out tokens after the end-of-response token or similar markers
output_ids = y[0].tolist()
if eor_token_id in output_ids:
output_ids = output_ids[:output_ids.index(eor_token_id) + 1] # Include EOR token
else:
# Check for similar markers like '<E' and handle them
try:
eor_index = next(i for i, token in enumerate(output_ids) if decode([token]).startswith('<E'))
output_ids = output_ids[:eor_index]
except StopIteration:
pass
# Ensure the prompt is not included in the final output
output = decode(output_ids).replace(prompt, '').strip()
output = ensure_complete_output(output, prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id)
truncated_output = truncate_output(output, max_new_tokens)
outputs.append(truncated_output)
return '\n\n'.join(outputs)
# Create a Gradio interface
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your prompt here...", value="Write a short story about a boy:"),
gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of Samples"),
gr.Slider(minimum=10, maximum=200, step=1, value=75, label="Max New Tokens"),
gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.8, label="Temperature"),
gr.Slider(minimum=1, maximum=100, step=1, value=50, label="Top-k"),
gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.85, label="Top-p"),
gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.1, label="Repetition Penalty"),
gr.Number(value=100264, label="End-of-Response Token ID")
],
outputs="text",
title="GPT Text Generator",
description="Generate text based on a prompt using a trained GPT model.",
examples=[
["Write a short story about a boy:"],
["Explain the theory of relativity:"],
["What is the meaning of life?"]
]
)
# Launch the Gradio app
demo.launch()