1bit_llama3_instruct_xmad_qa_batch / backups /app_unquantized_backup.py
Aston-xMAD's picture
init commit
9382e3f verified
import json
import os
import time
import torch
import gradio as gr
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
# Environment variables
os.environ["TOKENIZERS_PARALLELISM"] = "0"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Load model and tokenizer
def load_model_and_tokenizer(model_name, dtype):
tokenizer = AutoTokenizer.from_pretrained(model_name)
special_tokens = {"pad_token": "<PAD>"}
tokenizer.add_special_tokens(special_tokens)
config = AutoConfig.from_pretrained(model_name)
if dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp16":
dtype = torch.float16
elif dtype == "fp32":
dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=dtype, device_map="auto")
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
model.resize_token_embeddings(len(tokenizer))
tokenizer.padding_side = "left"
model.config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
# Format response
def format_response(dialog, response):
formatted_dialog = dialog.copy()
formatted_dialog.append({"role": "assistant", "content": response})
return formatted_dialog
# Load questions
def load_questions(prompts_path, num_questions, custom_question):
with open(prompts_path, "r") as file:
dialogs = json.load(file)
if custom_question and custom_question.strip():
custom_dialog = [{"role": "user", "content": custom_question}]
dialogs.insert(0, custom_dialog)
dialogs = dialogs[:num_questions]
return dialogs
# Inference
def infer(model_name, dialogs, num_new_tokens, temperature, dtype):
model, tokenizer = load_model_and_tokenizer(model_name, dtype)
batch_inputs = [
tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=True)
for dialog in dialogs
]
responses = []
for i in range(len(dialogs)):
batch = batch_inputs[i:i+1]
encoded_inputs = tokenizer(batch, padding=True, truncation=False, return_tensors="pt")
input_ids = encoded_inputs["input_ids"].to(model.device)
attention_mask = encoded_inputs["attention_mask"].to(model.device)
with torch.no_grad():
output_tokens = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=num_new_tokens,
do_sample=True,
temperature=temperature,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
decoded_outputs = tokenizer.batch_decode(output_tokens, skip_special_tokens=True)
for j, response in enumerate(decoded_outputs):
original_dialog = dialogs[i + j]
formatted_response = format_response(original_dialog, response)
responses.append(formatted_response)
torch.cuda.empty_cache()
results = {
"Responses": responses
}
return results
# Demo function
def demo(num_new_tokens, temperature, num_questions, custom_question):
dialogs = load_questions("chats_sys_none.json", num_questions, custom_question)
results = infer("NousResearch/Meta-Llama-3-8B-Instruct", dialogs, num_new_tokens, temperature, "fp16")
return results
# Load JSON data
with open("chats_sys_none.json", "r") as file:
json_data = json.load(file)
json_data_str = json.dumps(json_data, indent=2)
# Show JSON function
def show_json():
return json_data_str
# Gradio interface
interface = gr.Interface(
fn=demo,
inputs=[
gr.Slider(label="Number of New Tokens", minimum=1, maximum=1024, step=1, value=512),
gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.4),
gr.Slider(minimum=20, maximum=100, step=1, label="Number of Questions", value=20),
gr.Textbox(label="Custom Question", placeholder="Type your custom question here..."),
],
outputs=[
gr.JSON(label="Responses")
],
title="LLM Inference Demo",
description="A demo for running LLM inference using Gradio and Hugging Face.",
live=False
)
json_interface = gr.Interface(
fn=show_json,
inputs=[],
outputs=[
gr.HTML("<pre>{}</pre>".format(json_data_str))
],
live=False
)
app = gr.Blocks()
with app:
with gr.Tab("LLM Inference Demo"):
interface.render()
with gr.Tab("Show JSON"):
json_interface.render()
if __name__ == "__main__":
app.launch()