Spaces:
Runtime error
Runtime error
import torch | |
from transformers import LlamaForCausalLM, LlamaTokenizer | |
import gradio as gr | |
# Load the model and tokenizer | |
MODEL_NAME = "meta-llama/Llama-2-8b-hf" # Update this if using a custom LLaMA model | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
print("Loading model...") | |
tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME) | |
model = LlamaForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=torch.float16, # Use float16 for better performance | |
device_map="auto" # Automatically load onto available GPU | |
) | |
# Define a function for generating responses | |
def generate_response(prompt): | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(DEVICE) | |
with torch.no_grad(): | |
outputs = model.generate( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_length=512, | |
temperature=0.7, # Adjust creativity level | |
top_p=0.95, # Top-p sampling | |
num_return_sequences=1 | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
# Gradio UI | |
iface = gr.Interface( | |
fn=generate_response, | |
inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here..."), | |
outputs=gr.Textbox(label="LLaMA Response"), | |
title="LLaMA 3.1 8B Chatbot", | |
description="An interactive demo of the LLaMA 3.1 8B model using Hugging Face Spaces." | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
iface.launch() | |