import streamlit as st import torch from transformers import AutoModelForCausalLM, LlamaTokenizer from peft import PeftModel import gc @st.cache_resource def load_model(): model_name = "peterxyz/detect-llama-34b" # Load the specific tokenizer type used by the model checkpoint tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) # Check if CUDA is available device = "cuda" if torch.cuda.is_available() else "cpu" st.info(f"Using device: {device}") # Clear memory if device == "cuda": torch.cuda.empty_cache() gc.collect() # Load model with appropriate settings based on device if device == "cuda": from transformers import BitsAndBytesConfig import bitsandbytes as bnb nf4_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16 ) model_nf4 = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=nf4_config, device_map="auto", trust_remote_code=True ) model = PeftModel.from_pretrained(model_nf4, model_name) else: # For CPU, load with reduced precision but without 4-bit quantization model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, # Use float32 for CPU device_map={"": device}, low_cpu_mem_usage=True, trust_remote_code=True ) return model, tokenizer, device def analyze_contract(contract_code, model, tokenizer, device): prompt = f"{contract_code}\n\nidentify vulnerability of this code given above" # Add padding token if needed if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token inputs = tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048 ).to(device) outputs = model.generate( **inputs, max_length=1024, temperature=0.7, num_return_sequences=1, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) return tokenizer.decode(outputs[0], skip_special_tokens=True) # Set page config st.set_page_config( page_title="Smart Contract Vulnerability Detector", page_icon="🔍", layout="wide" ) # Main app st.title("🔍 Smart Contract Vulnerability Detector") st.markdown(""" This app analyzes Solidity smart contracts for potential vulnerabilities using a fine-tuned LLaMA model. Simply paste your smart contract code below and click 'Analyze'. """) # Add a loading message while initializing if 'model_loaded' not in st.session_state: st.session_state.model_loaded = False # Initialize session state for the model if not st.session_state.model_loaded: try: with st.spinner('Loading model... This might take a few minutes...'): st.session_state.model, st.session_state.tokenizer, st.session_state.device = load_model() st.session_state.model_loaded = True st.success('Model loaded successfully!') except Exception as e: st.error(f"Error loading model: {str(e)}") st.stop() # Create the main interface contract_code = st.text_area( "Paste your Solidity contract code here:", height=300, placeholder="pragma solidity ^0.5.0;\n\ncontract YourContract {\n // Your code here\n}" ) col1, col2 = st.columns([1, 4]) with col1: analyze_button = st.button("Analyze Contract", type="primary") with col2: load_sample = st.button("Load Sample Contract") # Sample contract button if load_sample: contract_code = """pragma solidity ^0.5.0; contract ModifierEntrancy { mapping (address => uint) public tokenBalance; string constant name = "Nu Token"; Bank bank; constructor() public{ bank = new Bank(); } function airDrop() hasNoBalance supportsToken public{ tokenBalance[msg.sender] += 20; } modifier supportsToken() { require(keccak256(abi.encodePacked("Nu Token")) == bank.supportsToken()); _; } modifier hasNoBalance { require(tokenBalance[msg.sender] == 0); _; } } contract Bank{ function supportsToken() external returns(bytes32) { return keccak256(abi.encodePacked("Nu Token")); } }""" st.session_state.contract_code = contract_code st.experimental_rerun() # Analysis section if analyze_button and contract_code: try: with st.spinner('Analyzing contract...'): analysis = analyze_contract( contract_code, st.session_state.model, st.session_state.tokenizer, st.session_state.device ) st.subheader("Analysis Results") # Create an expandable section for the analysis with st.expander("View Full Analysis", expanded=True): st.markdown(analysis) except Exception as e: st.error(f"An error occurred during analysis: {str(e)}") st.markdown("**Debug Information:**") st.code(str(e)) elif analyze_button: st.warning("Please enter some contract code to analyze.") # Add footer with information st.markdown("---") st.markdown("""

Built with Streamlit and Hugging Face Transformers

Model: peterxyz/detect-llama-34b

""", unsafe_allow_html=True)