yr / app.py
rudr4sarkar's picture
Update app.py
58367cb verified
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, LlamaTokenizer
from peft import PeftModel
import gc
@st.cache_resource
def load_model():
model_name = "peterxyz/detect-llama-34b"
# Load the specific tokenizer type used by the model checkpoint
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
st.info(f"Using device: {device}")
# Clear memory
if device == "cuda":
torch.cuda.empty_cache()
gc.collect()
# Load model with appropriate settings based on device
if device == "cuda":
from transformers import BitsAndBytesConfig
import bitsandbytes as bnb
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
model_nf4 = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=nf4_config,
device_map="auto",
trust_remote_code=True
)
model = PeftModel.from_pretrained(model_nf4, model_name)
else:
# For CPU, load with reduced precision but without 4-bit quantization
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32, # Use float32 for CPU
device_map={"": device},
low_cpu_mem_usage=True,
trust_remote_code=True
)
return model, tokenizer, device
def analyze_contract(contract_code, model, tokenizer, device):
prompt = f"{contract_code}\n\nidentify vulnerability of this code given above"
# Add padding token if needed
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
).to(device)
outputs = model.generate(
**inputs,
max_length=1024,
temperature=0.7,
num_return_sequences=1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Set page config
st.set_page_config(
page_title="Smart Contract Vulnerability Detector",
page_icon="πŸ”",
layout="wide"
)
# Main app
st.title("πŸ” Smart Contract Vulnerability Detector")
st.markdown("""
This app analyzes Solidity smart contracts for potential vulnerabilities using a fine-tuned LLaMA model.
Simply paste your smart contract code below and click 'Analyze'.
""")
# Add a loading message while initializing
if 'model_loaded' not in st.session_state:
st.session_state.model_loaded = False
# Initialize session state for the model
if not st.session_state.model_loaded:
try:
with st.spinner('Loading model... This might take a few minutes...'):
st.session_state.model, st.session_state.tokenizer, st.session_state.device = load_model()
st.session_state.model_loaded = True
st.success('Model loaded successfully!')
except Exception as e:
st.error(f"Error loading model: {str(e)}")
st.stop()
# Create the main interface
contract_code = st.text_area(
"Paste your Solidity contract code here:",
height=300,
placeholder="pragma solidity ^0.5.0;\n\ncontract YourContract {\n // Your code here\n}"
)
col1, col2 = st.columns([1, 4])
with col1:
analyze_button = st.button("Analyze Contract", type="primary")
with col2:
load_sample = st.button("Load Sample Contract")
# Sample contract button
if load_sample:
contract_code = """pragma solidity ^0.5.0;
contract ModifierEntrancy {
mapping (address => uint) public tokenBalance;
string constant name = "Nu Token";
Bank bank;
constructor() public{
bank = new Bank();
}
function airDrop() hasNoBalance supportsToken public{
tokenBalance[msg.sender] += 20;
}
modifier supportsToken() {
require(keccak256(abi.encodePacked("Nu Token")) == bank.supportsToken());
_;
}
modifier hasNoBalance {
require(tokenBalance[msg.sender] == 0);
_;
}
}
contract Bank{
function supportsToken() external returns(bytes32) {
return keccak256(abi.encodePacked("Nu Token"));
}
}"""
st.session_state.contract_code = contract_code
st.experimental_rerun()
# Analysis section
if analyze_button and contract_code:
try:
with st.spinner('Analyzing contract...'):
analysis = analyze_contract(
contract_code,
st.session_state.model,
st.session_state.tokenizer,
st.session_state.device
)
st.subheader("Analysis Results")
# Create an expandable section for the analysis
with st.expander("View Full Analysis", expanded=True):
st.markdown(analysis)
except Exception as e:
st.error(f"An error occurred during analysis: {str(e)}")
st.markdown("**Debug Information:**")
st.code(str(e))
elif analyze_button:
st.warning("Please enter some contract code to analyze.")
# Add footer with information
st.markdown("---")
st.markdown("""
<div style='text-align: center'>
<p>Built with Streamlit and Hugging Face Transformers</p>
<p>Model: peterxyz/detect-llama-34b</p>
</div>
""", unsafe_allow_html=True)