import os import torch import tiktoken from model import GPT, GPTConfig import pickle import string import gradio as gr from contextlib import nullcontext # Model and Tokenizer setup device = 'cpu' dtype = 'bfloat16' if device != 'cpu' and torch.cuda.is_bf16_supported() else 'float16' device_type = 'cuda' if 'cuda' in device else 'cpu' ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) cl100k_base = tiktoken.get_encoding("cl100k_base") enc = tiktoken.Encoding( name="cl100k_im", pat_str=cl100k_base._pat_str, mergeable_ranks=cl100k_base._mergeable_ranks, special_tokens={ **cl100k_base._special_tokens, "": 100264, "": 100265, } ) # Load model from checkpoint model_save_path = 'latest_model.pt' if os.path.exists(model_save_path): model = torch.load(model_save_path, map_location=device) else: raise FileNotFoundError(f"Model file {model_save_path} not found") model.eval() model.to(device) model = torch.compile(model) # Function to encode and decode using the tokenizer def encode(text): return enc.encode(text, allowed_special={""}) def decode(tokens): return enc.decode(tokens) # Function to truncate output to token limit def truncate_output(text, token_limit): tokens = text.split() if len(tokens) > token_limit: return ' '.join(tokens[:token_limit]) + '...' return text def ensure_complete_output(output, context, max_length, temperature, top_k, top_p, repetition_penalty, eor_token_id): while len(output.split()) < max_length and not output.endswith('.'): continuation = model.generate( torch.tensor(encode(output), dtype=torch.long, device=device)[None, ...], max_new_tokens=max_length, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, eor_token_id=eor_token_id ) continuation_text = decode(continuation[0].tolist()) if eor_token_id in continuation[0].tolist(): continuation_text = continuation_text.split("")[0] output += continuation_text break else: output += continuation_text if len(output.split()) >= max_length: break return output # Text generation function for Gradio interface def generate_text(prompt, num_samples, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id): with torch.no_grad(): with ctx: start_ids = encode(prompt) initial_prompt = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...] outputs = [] for _ in range(num_samples): y = model.generate( initial_prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, eor_token_id=eor_token_id ) # Filter out tokens after the end-of-response token or similar markers output_ids = y[0].tolist() if eor_token_id in output_ids: output_ids = output_ids[:output_ids.index(eor_token_id) + 1] # Include EOR token else: # Check for similar markers like '