newaimodel / app.py
saintyboy's picture
Update app.py
48efaf5 verified
raw
history blame
No virus
5.76 kB
import os
import torch
import tiktoken
from model import GPT, GPTConfig
import pickle
import string
import gradio as gr
from contextlib import nullcontext
# Model and Tokenizer setup
device = 'cpu'
dtype = 'bfloat16' if device != 'cpu' and torch.cuda.is_bf16_supported() else 'float16'
device_type = 'cuda' if 'cuda' in device else 'cpu'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
cl100k_base = tiktoken.get_encoding("cl100k_base")
enc = tiktoken.Encoding(
name="cl100k_im",
pat_str=cl100k_base._pat_str,
mergeable_ranks=cl100k_base._mergeable_ranks,
special_tokens={
**cl100k_base._special_tokens,
"<EOR>": 100264,
"<SPECIAL2>": 100265,
}
)
# Load model from checkpoint
model_save_path = 'tuned_ckpt.pt'
if os.path.exists(model_save_path):
checkpoint = torch.load(model_save_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
model.load_state_dict(checkpoint['model'])
else:
raise FileNotFoundError(f"Model file {model_save_path} not found")
model.eval()
model.to(device)
model = torch.compile(model)
# Function to encode and decode using the tokenizer
def encode(text):
return enc.encode(text, allowed_special={"<EOR>"})
def decode(tokens):
return enc.decode(tokens)
# Function to truncate output to token limit
def truncate_output(text, token_limit):
tokens = text.split()
if len(tokens) > token_limit:
return ' '.join(tokens[:token_limit]) + '...'
return text
def ensure_complete_output(output, context, max_length, temperature, top_k, top_p, repetition_penalty, eor_token_id):
while len(output.split()) < max_length and not output.endswith('.'):
continuation = model.generate(
torch.tensor(encode(output), dtype=torch.long, device=device)[None, ...],
max_new_tokens=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
eor_token_id=eor_token_id
)
continuation_text = decode(continuation[0].tolist())
if eor_token_id in continuation[0].tolist():
continuation_text = continuation_text.split("<EOR>")[0]
output += continuation_text
break
else:
output += continuation_text
if len(output.split()) >= max_length:
break
return output
# Text generation function for Gradio interface
def generate_text(prompt, num_samples, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id):
# Add input validation
if num_samples is None:
num_samples = 1
elif not isinstance(num_samples, int):
raise ValueError("Number of samples must be an integer")
with torch.no_grad():
with ctx:
start_ids = encode(prompt)
initial_prompt = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
outputs = []
for _ in range(num_samples):
y = model.generate(
initial_prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
eor_token_id=eor_token_id
)
# Filter out tokens after the end-of-response token or similar markers
output_ids = y[0].tolist()
if eor_token_id in output_ids:
output_ids = output_ids[:output_ids.index(eor_token_id) + 1] # Include EOR token
else:
# Check for similar markers like '<E' and handle them
try:
eor_index = next(i for i, token in enumerate(output_ids) if decode([token]).startswith('<E'))
output_ids = output_ids[:eor_index]
except StopIteration:
pass
# Ensure the prompt is not included in the final output
output = decode(output_ids).replace(prompt, '').strip()
output = ensure_complete_output(output, prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id)
truncated_output = truncate_output(output, max_new_tokens)
outputs.append(truncated_output)
return '\n\n'.join(outputs)
# Create a Gradio interface
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your prompt here...", value="Write a short story about a boy:"),
gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of Samples"),
gr.Slider(minimum=10, maximum=200, step=1, value=75, label="Max New Tokens"),
gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.8, label="Temperature"),
gr.Slider(minimum=1, maximum=100, step=1, value=50, label="Top-k"),
gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.85, label="Top-p"),
gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.1, label="Repetition Penalty"),
gr.Number(value=100264, label="End-of-Response Token ID")
],
outputs="text",
title="GPT Text Generator",
description="Generate text based on a prompt using a trained GPT model.",
examples=[
["Write a short story about a boy:"],
["Explain the theory of relativity:"],
["What is the meaning of life?"]
]
)
# Launch the Gradio app
demo.launch()