AIModels24's picture
Update app.py
e2a5d47 verified
import torch
import streamlit as st
from peft import PeftModel
# from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load the model and tokenizer
# def load_model_and_tokenizer():
# model_name = "AIModels24/Indian_Constitution" # Replace with your model name
# # Define quantization configuration for 4-bit quantization
# # quant_config = BitsAndBytesConfig(load_in_4bit=True) # 4-bit quantization
# # Load the tokenizer
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# # Load the model with 4-bit quantization
# model = AutoModelForCausalLM.from_pretrained(
# model_name,
# # quantization_config=quant_config,
# device_map=None,
# low_cpu_mem_usage=True
# )
# return model, tokenizer
@st.cache_resource
def load_model_and_tokenizer():
# Base model
base_model_name = "unsloth/llama-3-8b-bnb-4bit"
adapter_name = "AIModels24/Indian_Constitution"
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# Load the base model
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map=None,
low_cpu_mem_usage=True,
use_cache=True
)
# Load the LoRA adapter
model = PeftModel.from_pretrained(model, adapter_name)
return model, tokenizer
# Load model and tokenizer using the function
model, tokenizer = load_model_and_tokenizer()
## prompt function
alpaca_prompt = "### Instruction:\n{}\n\n### Response:\n"
# Streamlit User Interface
st.title("भारतीय कानून व्यवस्था")
st.subheader("AI-powered responses for legal questions in Indian law")
# Input text box for user question
instruction = st.text_area("Enter your question:", placeholder="Ask a question about Indian law...")
# Generate response button
if st.button("Generate Response"):
if instruction.strip():
with st.spinner("Generating response..."):
# Prepare the prompt for the model
inputs = tokenizer(
[alpaca_prompt.format(instruction)],
return_tensors="pt"
).to("cuda")
# Generate the response
outputs = model.generate(**inputs, max_new_tokens=150, use_cache=True)
response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# Extract the clean response
response_cleaned = response.split("### Response:\n")[-1].strip()
# Display the response
st.success("Response:")
st.write(response_cleaned)
else:
st.error("Please enter a question to generate a response.")