yr / app.py
rudr4sarkar's picture
Update app.py
75599b5 verified
raw
history blame
3.91 kB
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import bitsandbytes as bnb
import gc
@st.cache_resource
def load_model():
model_name = "peterxyz/detect-llama-34b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
# Clear CUDA cache and garbage collect
torch.cuda.empty_cache()
gc.collect()
model_nf4 = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=nf4_config,
device_map="auto"
)
model = PeftModel.from_pretrained(model_nf4, model_name)
return model, tokenizer
def analyze_contract(contract_code, model, tokenizer):
inputs = tokenizer(
f"{contract_code}\n\nidentify vulnerability of this code given above",
return_tensors="pt"
).to("cuda")
outputs = model.generate(
**inputs,
max_length=1024,
temperature=0.7,
num_return_sequences=1
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Set page config
st.set_page_config(
page_title="Smart Contract Vulnerability Detector",
page_icon="πŸ”",
layout="wide"
)
# Main app
st.title("πŸ” Smart Contract Vulnerability Detector")
st.markdown("""
This app analyzes Solidity smart contracts for potential vulnerabilities using a fine-tuned LLaMA model.
Simply paste your smart contract code below and click 'Analyze'.
""")
# Initialize session state for the model
if 'model' not in st.session_state:
with st.spinner('Loading model... This might take a few minutes...'):
st.session_state.model, st.session_state.tokenizer = load_model()
st.success('Model loaded successfully!')
# Create the main interface
contract_code = st.text_area(
"Paste your Solidity contract code here:",
height=300,
placeholder="pragma solidity ^0.5.0;\n\ncontract YourContract {\n // Your code here\n}"
)
analyze_button = st.button("Analyze Contract", type="primary")
# Sample contract button
if st.button("Load Sample Contract"):
contract_code = """pragma solidity ^0.5.0;
contract ModifierEntrancy {
mapping (address => uint) public tokenBalance;
string constant name = "Nu Token";
Bank bank;
constructor() public{
bank = new Bank();
}
function airDrop() hasNoBalance supportsToken public{
tokenBalance[msg.sender] += 20;
}
modifier supportsToken() {
require(keccak256(abi.encodePacked("Nu Token")) == bank.supportsToken());
_;
}
modifier hasNoBalance {
require(tokenBalance[msg.sender] == 0);
_;
}
}
contract Bank{
function supportsToken() external returns(bytes32) {
return keccak256(abi.encodePacked("Nu Token"));
}
}"""
st.session_state.contract_code = contract_code
st.experimental_rerun()
# Analysis section
if analyze_button and contract_code:
with st.spinner('Analyzing contract...'):
try:
analysis = analyze_contract(
contract_code,
st.session_state.model,
st.session_state.tokenizer
)
st.subheader("Analysis Results")
st.markdown(analysis)
except Exception as e:
st.error(f"An error occurred during analysis: {str(e)}")
elif analyze_button:
st.warning("Please enter some contract code to analyze.")
# Add footer with information
st.markdown("---")
st.markdown("""
<div style='text-align: center'>
<p>Built with Streamlit and Hugging Face Transformers</p>
<p>Model: peterxyz/detect-llama-34b</p>
</div>
""", unsafe_allow_html=True)