lab2 / app.py
emeses's picture
Update space
9744f58
raw
history blame
2.9 kB
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
from transformers import BitsAndBytesConfig
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = AutoModelForCausalLM.from_pretrained(
"unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
device_map="auto",
torch_dtype=torch.float16
)
# Configure quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
# Load model and tokenizer
base_model = AutoModelForCausalLM.from_pretrained(
"unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
device_map="auto",
torch_dtype=torch.float16,
quantization_config=bnb_config
)
model = PeftModel.from_pretrained(base_model, "emeses/lab2_model")
tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-3B-Instruct-bnb-4bit")
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens=512,
temperature=0.7,
top_p=0.9,
):
try:
# Format the prompt
prompt = f"{system_message}\n\nUser: {message}\nAssistant:"
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
# Generate response
outputs = model.generate(
inputs.input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
)
# Decode response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract assistant's response
response = response.split("Assistant:")[-1].strip()
return response
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
iface = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
label="System Message",
value="You are a helpful AI assistant.",
lines=2 # Better for system prompts
),
gr.Slider(minimum=1, maximum=1024, value=512, label="Max Tokens"),
gr.Slider(minimum=0, maximum=1, value=0.7, label="Temperature", step=0.1),
gr.Slider(minimum=0, maximum=1, value=0.9, label="Top P", step=0.1),
],
title="Chat with Fine-tuned LLaMA Model",
description="A conversational AI powered by fine-tuned LLaMA 3.2B model",
retry_btn="Regenerate", # Add retry button
undo_btn="Delete Last", # Add undo button
clear_btn="Clear Chat" # Add clear button
)
# Add examples to help users (optional)
iface.queue().launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
show_error=True # Better error visibility
)