yr / app.py
rudr4sarkar's picture
Update app.py
7f94d85 verified
raw
history blame
5.02 kB
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig
from peft import PeftModel
import bitsandbytes as bnb
import gc
@st.cache_resource
def load_model():
model_name = "peterxyz/detect-llama-34b"
# Use LlamaTokenizer instead of AutoTokenizer
tokenizer = LlamaTokenizer.from_pretrained(model_name)
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
# Clear CUDA cache and garbage collect
torch.cuda.empty_cache()
gc.collect()
model_nf4 = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=nf4_config,
device_map="auto",
trust_remote_code=True # Added this parameter for safety
)
model = PeftModel.from_pretrained(model_nf4, model_name)
return model, tokenizer
def analyze_contract(contract_code, model, tokenizer):
prompt = f"{contract_code}\n\nidentify vulnerability of this code given above"
# Add padding token if needed
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048 # Added max length for safety
).to("cuda")
outputs = model.generate(
**inputs,
max_length=1024,
temperature=0.7,
num_return_sequences=1,
pad_token_id=tokenizer.pad_token_id, # Explicitly set pad token ID
eos_token_id=tokenizer.eos_token_id # Explicitly set EOS token ID
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Set page config
st.set_page_config(
page_title="Smart Contract Vulnerability Detector",
page_icon="πŸ”",
layout="wide"
)
# Main app
st.title("πŸ” Smart Contract Vulnerability Detector")
st.markdown("""
This app analyzes Solidity smart contracts for potential vulnerabilities using a fine-tuned LLaMA model.
Simply paste your smart contract code below and click 'Analyze'.
""")
# Add a loading message while initializing
if 'model_loaded' not in st.session_state:
st.session_state.model_loaded = False
# Initialize session state for the model
if not st.session_state.model_loaded:
try:
with st.spinner('Loading model... This might take a few minutes...'):
st.session_state.model, st.session_state.tokenizer = load_model()
st.session_state.model_loaded = True
st.success('Model loaded successfully!')
except Exception as e:
st.error(f"Error loading model: {str(e)}")
st.stop()
# Create the main interface
contract_code = st.text_area(
"Paste your Solidity contract code here:",
height=300,
placeholder="pragma solidity ^0.5.0;\n\ncontract YourContract {\n // Your code here\n}"
)
col1, col2 = st.columns([1, 4])
with col1:
analyze_button = st.button("Analyze Contract", type="primary")
with col2:
load_sample = st.button("Load Sample Contract")
# Sample contract button
if load_sample:
contract_code = """pragma solidity ^0.5.0;
contract ModifierEntrancy {
mapping (address => uint) public tokenBalance;
string constant name = "Nu Token";
Bank bank;
constructor() public{
bank = new Bank();
}
function airDrop() hasNoBalance supportsToken public{
tokenBalance[msg.sender] += 20;
}
modifier supportsToken() {
require(keccak256(abi.encodePacked("Nu Token")) == bank.supportsToken());
_;
}
modifier hasNoBalance {
require(tokenBalance[msg.sender] == 0);
_;
}
}
contract Bank{
function supportsToken() external returns(bytes32) {
return keccak256(abi.encodePacked("Nu Token"));
}
}"""
st.session_state.contract_code = contract_code
st.experimental_rerun()
# Analysis section
if analyze_button and contract_code:
try:
with st.spinner('Analyzing contract...'):
analysis = analyze_contract(
contract_code,
st.session_state.model,
st.session_state.tokenizer
)
st.subheader("Analysis Results")
# Create an expandable section for the analysis
with st.expander("View Full Analysis", expanded=True):
st.markdown(analysis)
except Exception as e:
st.error(f"An error occurred during analysis: {str(e)}")
st.markdown("**Debug Information:**")
st.code(str(e))
elif analyze_button:
st.warning("Please enter some contract code to analyze.")
# Add footer with information
st.markdown("---")
st.markdown("""
<div style='text-align: center'>
<p>Built with Streamlit and Hugging Face Transformers</p>
<p>Model: peterxyz/detect-llama-34b</p>
</div>
""", unsafe_allow_html=True)