File size: 5,024 Bytes
75599b5
c08ad66
7f94d85
c08ad66
 
75599b5
c08ad66
75599b5
 
 
7f94d85
 
75599b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f94d85
 
75599b5
 
 
 
 
 
7f94d85
 
 
 
 
 
75599b5
7f94d85
 
 
 
 
75599b5
 
 
 
 
 
7f94d85
 
 
75599b5
 
 
 
 
 
 
 
 
 
c08ad66
75599b5
 
 
 
 
 
 
7f94d85
 
 
 
75599b5
7f94d85
 
 
 
 
 
 
 
 
75599b5
 
 
 
 
 
c08ad66
 
7f94d85
 
 
 
 
c08ad66
75599b5
7f94d85
75599b5
c08ad66
75599b5
 
 
 
c08ad66
75599b5
 
 
c08ad66
75599b5
 
 
c08ad66
75599b5
 
 
 
c08ad66
75599b5
 
 
 
c08ad66
 
 
 
 
 
75599b5
 
 
 
 
 
7f94d85
 
75599b5
 
 
 
 
 
 
 
7f94d85
 
 
 
 
 
 
 
75599b5
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig
from peft import PeftModel
import bitsandbytes as bnb
import gc

@st.cache_resource
def load_model():
    model_name = "peterxyz/detect-llama-34b"
    # Use LlamaTokenizer instead of AutoTokenizer
    tokenizer = LlamaTokenizer.from_pretrained(model_name)
    
    nf4_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    
    # Clear CUDA cache and garbage collect
    torch.cuda.empty_cache()
    gc.collect()
    
    model_nf4 = AutoModelForCausalLM.from_pretrained(
        model_name, 
        quantization_config=nf4_config,
        device_map="auto",
        trust_remote_code=True  # Added this parameter for safety
    )
    model = PeftModel.from_pretrained(model_nf4, model_name)
    
    return model, tokenizer

def analyze_contract(contract_code, model, tokenizer):
    prompt = f"{contract_code}\n\nidentify vulnerability of this code given above"
    
    # Add padding token if needed
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=2048  # Added max length for safety
    ).to("cuda")
    
    outputs = model.generate(
        **inputs,
        max_length=1024,
        temperature=0.7,
        num_return_sequences=1,
        pad_token_id=tokenizer.pad_token_id,  # Explicitly set pad token ID
        eos_token_id=tokenizer.eos_token_id   # Explicitly set EOS token ID
    )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Set page config
st.set_page_config(
    page_title="Smart Contract Vulnerability Detector",
    page_icon="πŸ”",
    layout="wide"
)

# Main app
st.title("πŸ” Smart Contract Vulnerability Detector")
st.markdown("""
This app analyzes Solidity smart contracts for potential vulnerabilities using a fine-tuned LLaMA model.
Simply paste your smart contract code below and click 'Analyze'.
""")

# Add a loading message while initializing
if 'model_loaded' not in st.session_state:
    st.session_state.model_loaded = False

# Initialize session state for the model
if not st.session_state.model_loaded:
    try:
        with st.spinner('Loading model... This might take a few minutes...'):
            st.session_state.model, st.session_state.tokenizer = load_model()
            st.session_state.model_loaded = True
            st.success('Model loaded successfully!')
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        st.stop()

# Create the main interface
contract_code = st.text_area(
    "Paste your Solidity contract code here:",
    height=300,
    placeholder="pragma solidity ^0.5.0;\n\ncontract YourContract {\n    // Your code here\n}"
)

col1, col2 = st.columns([1, 4])
with col1:
    analyze_button = st.button("Analyze Contract", type="primary")
with col2:
    load_sample = st.button("Load Sample Contract")

# Sample contract button
if load_sample:
    contract_code = """pragma solidity ^0.5.0;

contract ModifierEntrancy {
    mapping (address => uint) public tokenBalance;
    string constant name = "Nu Token";
    Bank bank;

    constructor() public{
        bank = new Bank();
    }

    function airDrop() hasNoBalance supportsToken public{
        tokenBalance[msg.sender] += 20;
    }

    modifier supportsToken() {
        require(keccak256(abi.encodePacked("Nu Token")) == bank.supportsToken());
        _;
    }

    modifier hasNoBalance {
        require(tokenBalance[msg.sender] == 0);
        _;
    }
}

contract Bank{
    function supportsToken() external returns(bytes32) {
        return keccak256(abi.encodePacked("Nu Token"));
    }
}"""
    st.session_state.contract_code = contract_code
    st.experimental_rerun()

# Analysis section
if analyze_button and contract_code:
    try:
        with st.spinner('Analyzing contract...'):
            analysis = analyze_contract(
                contract_code,
                st.session_state.model,
                st.session_state.tokenizer
            )
            
            st.subheader("Analysis Results")
            
            # Create an expandable section for the analysis
            with st.expander("View Full Analysis", expanded=True):
                st.markdown(analysis)
            
    except Exception as e:
        st.error(f"An error occurred during analysis: {str(e)}")
        st.markdown("**Debug Information:**")
        st.code(str(e))
            
elif analyze_button:
    st.warning("Please enter some contract code to analyze.")

# Add footer with information
st.markdown("---")
st.markdown("""
<div style='text-align: center'>
    <p>Built with Streamlit and Hugging Face Transformers</p>
    <p>Model: peterxyz/detect-llama-34b</p>
</div>
""", unsafe_allow_html=True)