Spaces:
Paused
Paused
import streamlit as st | |
import torch | |
from transformers import AutoModelForCausalLM, LlamaTokenizer | |
from peft import PeftModel | |
import gc | |
def load_model(): | |
model_name = "peterxyz/detect-llama-34b" | |
# Load the specific tokenizer type used by the model checkpoint | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
# Check if CUDA is available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
st.info(f"Using device: {device}") | |
# Clear memory | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Load model with appropriate settings based on device | |
if device == "cuda": | |
from transformers import BitsAndBytesConfig | |
import bitsandbytes as bnb | |
nf4_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
model_nf4 = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=nf4_config, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
model = PeftModel.from_pretrained(model_nf4, model_name) | |
else: | |
# For CPU, load with reduced precision but without 4-bit quantization | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float32, # Use float32 for CPU | |
device_map={"": device}, | |
low_cpu_mem_usage=True, | |
trust_remote_code=True | |
) | |
return model, tokenizer, device | |
def analyze_contract(contract_code, model, tokenizer, device): | |
prompt = f"{contract_code}\n\nidentify vulnerability of this code given above" | |
# Add padding token if needed | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=2048 | |
).to(device) | |
outputs = model.generate( | |
**inputs, | |
max_length=1024, | |
temperature=0.7, | |
num_return_sequences=1, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Set page config | |
st.set_page_config( | |
page_title="Smart Contract Vulnerability Detector", | |
page_icon="π", | |
layout="wide" | |
) | |
# Main app | |
st.title("π Smart Contract Vulnerability Detector") | |
st.markdown(""" | |
This app analyzes Solidity smart contracts for potential vulnerabilities using a fine-tuned LLaMA model. | |
Simply paste your smart contract code below and click 'Analyze'. | |
""") | |
# Add a loading message while initializing | |
if 'model_loaded' not in st.session_state: | |
st.session_state.model_loaded = False | |
# Initialize session state for the model | |
if not st.session_state.model_loaded: | |
try: | |
with st.spinner('Loading model... This might take a few minutes...'): | |
st.session_state.model, st.session_state.tokenizer, st.session_state.device = load_model() | |
st.session_state.model_loaded = True | |
st.success('Model loaded successfully!') | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
st.stop() | |
# Create the main interface | |
contract_code = st.text_area( | |
"Paste your Solidity contract code here:", | |
height=300, | |
placeholder="pragma solidity ^0.5.0;\n\ncontract YourContract {\n // Your code here\n}" | |
) | |
col1, col2 = st.columns([1, 4]) | |
with col1: | |
analyze_button = st.button("Analyze Contract", type="primary") | |
with col2: | |
load_sample = st.button("Load Sample Contract") | |
# Sample contract button | |
if load_sample: | |
contract_code = """pragma solidity ^0.5.0; | |
contract ModifierEntrancy { | |
mapping (address => uint) public tokenBalance; | |
string constant name = "Nu Token"; | |
Bank bank; | |
constructor() public{ | |
bank = new Bank(); | |
} | |
function airDrop() hasNoBalance supportsToken public{ | |
tokenBalance[msg.sender] += 20; | |
} | |
modifier supportsToken() { | |
require(keccak256(abi.encodePacked("Nu Token")) == bank.supportsToken()); | |
_; | |
} | |
modifier hasNoBalance { | |
require(tokenBalance[msg.sender] == 0); | |
_; | |
} | |
} | |
contract Bank{ | |
function supportsToken() external returns(bytes32) { | |
return keccak256(abi.encodePacked("Nu Token")); | |
} | |
}""" | |
st.session_state.contract_code = contract_code | |
st.experimental_rerun() | |
# Analysis section | |
if analyze_button and contract_code: | |
try: | |
with st.spinner('Analyzing contract...'): | |
analysis = analyze_contract( | |
contract_code, | |
st.session_state.model, | |
st.session_state.tokenizer, | |
st.session_state.device | |
) | |
st.subheader("Analysis Results") | |
# Create an expandable section for the analysis | |
with st.expander("View Full Analysis", expanded=True): | |
st.markdown(analysis) | |
except Exception as e: | |
st.error(f"An error occurred during analysis: {str(e)}") | |
st.markdown("**Debug Information:**") | |
st.code(str(e)) | |
elif analyze_button: | |
st.warning("Please enter some contract code to analyze.") | |
# Add footer with information | |
st.markdown("---") | |
st.markdown(""" | |
<div style='text-align: center'> | |
<p>Built with Streamlit and Hugging Face Transformers</p> | |
<p>Model: peterxyz/detect-llama-34b</p> | |
</div> | |
""", unsafe_allow_html=True) |