import json import os import time import torch import gradio as gr from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer import random # Environment variables os.environ["TOKENIZERS_PARALLELISM"] = "0" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Global variables to store the model and tokenizer model = None tokenizer = None # Load model and tokenizer def load_model_and_tokenizer(model_name, dtype, kv_bits): global model, tokenizer if model is None or tokenizer is None: print("Loading model and tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_name) special_tokens = {"pad_token": ""} tokenizer.add_special_tokens(special_tokens) config = AutoConfig.from_pretrained(model_name) if kv_bits != "unquantized": quantizer_path = f"codebooks/{model_name.split('/')[-1]}_{kv_bits}bit.xmad" setattr(config, "quantizer_path", quantizer_path) if dtype == "bf16": dtype = torch.bfloat16 elif dtype == "fp16": dtype = torch.float16 elif dtype == "fp32": dtype = torch.float32 model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=dtype, device_map="auto") if len(tokenizer) > model.get_input_embeddings().weight.shape[0]: model.resize_token_embeddings(len(tokenizer)) tokenizer.padding_side = "left" model.config.pad_token_id = tokenizer.pad_token_id return model, tokenizer # Format response def format_response(dialog, response): question = next((turn['content'] for turn in dialog if turn['role'] == 'user'), 'No question found') answer = response.split("assistant")[-1].strip() return {"question": question, "answer": answer} # Load questions def load_questions(prompts_path, custom_questions): with open(prompts_path, "r") as file: dialogs = json.load(file) selected_dialogs = [] if custom_questions: for question in custom_questions: if question.strip(): custom_dialog = [{"role": "user", "content": question}] selected_dialogs.append(custom_dialog) num_questions = 60 - len(selected_dialogs) random.shuffle(dialogs) selected_dialogs.extend(dialogs[:num_questions]) return selected_dialogs[:60] # Inference def infer(model_name, dialogs, num_new_tokens, temperature, dtype, kv_bits, progress=gr.Progress()): print("Starting inference...") model, tokenizer = load_model_and_tokenizer(model_name, dtype, kv_bits) batch_inputs = [ tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=True) for dialog in dialogs ] responses = [] start_time = time.time() batch_size = 30 # Set batch size for processing, this can be adjusted num_dialogs = len(dialogs) total_time = 0 total_tokens = 0 num_batches = (num_dialogs + batch_size - 1) // batch_size for batch_idx in range(num_batches): start_idx = batch_idx * batch_size end_idx = min(start_idx + batch_size, num_dialogs) batch = batch_inputs[start_idx:end_idx] encoded_inputs = tokenizer(batch, padding=True, truncation=False, return_tensors="pt") input_ids = encoded_inputs["input_ids"].to(model.device) attention_mask = encoded_inputs["attention_mask"].to(model.device) with torch.no_grad(): torch.cuda.synchronize() batch_start_time = time.perf_counter() # Generate responses and measure time to first token output_tokens = model.generate( input_ids, attention_mask=attention_mask, max_new_tokens=num_new_tokens, do_sample=True, temperature=temperature, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) torch.cuda.synchronize() batch_end_time = time.perf_counter() batch_time = batch_end_time - batch_start_time total_time += batch_time total_tokens += output_tokens.numel() # Calculate TTFT if batch_idx == 0: ttft = batch_time / input_ids.size(0) # Time to first token for the first batch decoded_outputs = tokenizer.batch_decode(output_tokens, skip_special_tokens=True) for i, response in enumerate(decoded_outputs): original_dialog = dialogs[start_idx + i] formatted_response = format_response(original_dialog, response) responses.append(formatted_response) formatted_responses = "\n\n---\n\n".join([f"**Question**: {res['question']}\n\n**Answer**: {res['answer']}" for res in responses]) yield formatted_responses progress((batch_idx + 1) / num_batches, desc="Processing batches") elapsed_time = time.time() - start_time tokens_per_second = total_tokens / total_time if total_time > 0 else 0 print(f"Inference completed in {elapsed_time:.2f} seconds.") yield { "Time Taken (seconds)": elapsed_time, "Tokens per Second": tokens_per_second, "Time to First Token (TTFT, seconds)": ttft, "Formatted Responses": formatted_responses } # Demo function def demo(num_new_tokens, temperature, custom_questions_text, kv_bits, progress=gr.Progress()): custom_questions = custom_questions_text.split("\n") print("Loading questions...") dialogs = load_questions("chats_sys_none.json", custom_questions) print(f"{len(dialogs)} questions loaded. Starting inference...") result_gen = infer("NousResearch/Meta-Llama-3-8B-Instruct", dialogs, num_new_tokens, temperature, "fp16", kv_bits, progress=progress) formatted_responses = "" for result in result_gen: if isinstance(result, str): formatted_responses = result yield None, None, None, formatted_responses else: time_taken = result["Time Taken (seconds)"] tokens_per_second = result["Tokens per Second"] ttft = result["Time to First Token (TTFT, seconds)"] formatted_responses = result["Formatted Responses"] yield time_taken, tokens_per_second, ttft, formatted_responses # Load JSON data with open("chats_sys_none.json", "r") as file: json_data = json.load(file) json_data_str = json.dumps(json_data, indent=2) # Show JSON function def show_json(): return json_data_str # Gradio interface app = gr.Blocks(css=".scrollable {height: 400px; overflow-y: auto; padding: 10px; border: 1px solid #ccc;}") with app: with gr.Tab("LLM Inference Demo"): with gr.Row(): with gr.Column(): num_new_tokens = gr.Slider(label="Number of New Tokens", minimum=128, maximum=1024, step=128, value=512) temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.4) kv_bits = gr.Dropdown(label="KV Bits", choices=["1", "2", "4", "unquantized"], value="1") with gr.Column(): time_taken = gr.Number(label="Time Taken (seconds)") tokens_per_second = gr.Number(label="Tokens per Second") ttft = gr.Number(label="Time to First Token (TTFT, seconds)") with gr.Row(): custom_questions_text = gr.Textbox(label="Custom Questions", placeholder="Type your custom questions here, one per line...", lines=5) with gr.Row(): demo_btn = gr.Button("Run Inference") with gr.Row(): formatted_responses = gr.Markdown(label="Formatted Responses") demo_btn.click(demo, inputs=[num_new_tokens, temperature, custom_questions_text, kv_bits], outputs=[time_taken, tokens_per_second, ttft, formatted_responses]) with gr.Tab("Show JSON"): json_output = gr.HTML("
{}
".format(json_data_str)) json_interface = gr.Interface(fn=show_json, inputs=[], outputs=[json_output], live=False) json_interface.render() if __name__: print("Loading model and tokenizer on startup...") load_model_and_tokenizer("NousResearch/Meta-Llama-3-8B-Instruct", "fp16", "1") print("Model and tokenizer loaded. Starting Gradio interface...") app.queue(default_concurrency_limit=5).launch()