rudr4sarkar commited on
Commit
75599b5
β€’
1 Parent(s): c08ad66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -50
app.py CHANGED
@@ -1,70 +1,139 @@
 
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
  from peft import PeftModel
4
  import bitsandbytes as bnb
 
5
 
6
- model_name = "peterxyz/detect-llama-34b"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- nf4_config = BitsAndBytesConfig(
10
- load_in_4bit=True,
11
- bnb_4bit_quant_type="nf4",
12
- bnb_4bit_use_double_quant=True,
13
- bnb_4bit_compute_dtype=torch.bfloat16
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  )
15
 
16
- import gc
17
- torch.cuda.empty_cache()
18
- gc.collect()
19
 
20
- model_nf4 = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=nf4_config)
21
- model = PeftModel.from_pretrained(model_nf4, model_name)
 
22
 
23
- from datasets import load_dataset
 
 
 
24
 
25
- dataset = load_dataset("peterxyz/smart-contract-vuln-detection")
 
 
26
 
27
- input_text = """
28
- pragma solidity ^0.5.0;
 
29
 
30
- contract ModifierEntrancy {
 
 
 
31
 
32
- mapping (address => uint) public tokenBalance;
33
- string constant name = "Nu Token";
34
- Bank bank;
35
-
36
- constructor() public{
37
- bank = new Bank();
38
- }
39
-
40
- //If a contract has a zero balance and supports the token give them some token
41
- function airDrop() hasNoBalance supportsToken public{
42
- tokenBalance[msg.sender] += 20;
43
- }
44
-
45
- //Checks that the contract responds the way we want
46
- modifier supportsToken() {
47
- require(keccak256(abi.encodePacked("Nu Token")) == bank.supportsToken());
48
- _;
49
- }
50
-
51
- //Checks that the caller has a zero balance
52
- modifier hasNoBalance {
53
- require(tokenBalance[msg.sender] == 0);
54
- _;
55
- }
56
  }
57
 
58
  contract Bank{
59
-
60
  function supportsToken() external returns(bytes32) {
61
  return keccak256(abi.encodePacked("Nu Token"));
62
  }
63
-
64
- }
65
-
66
- identify vulnerability of this code given above
67
- """
68
- inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
69
- outputs = model.generate(**inputs)
70
- print(tokenizer.decode(outputs[0], skip_special_tokens=True))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  from peft import PeftModel
5
  import bitsandbytes as bnb
6
+ import gc
7
 
8
+ @st.cache_resource
9
+ def load_model():
10
+ model_name = "peterxyz/detect-llama-34b"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+
13
+ nf4_config = BitsAndBytesConfig(
14
+ load_in_4bit=True,
15
+ bnb_4bit_quant_type="nf4",
16
+ bnb_4bit_use_double_quant=True,
17
+ bnb_4bit_compute_dtype=torch.bfloat16
18
+ )
19
+
20
+ # Clear CUDA cache and garbage collect
21
+ torch.cuda.empty_cache()
22
+ gc.collect()
23
+
24
+ model_nf4 = AutoModelForCausalLM.from_pretrained(
25
+ model_name,
26
+ quantization_config=nf4_config,
27
+ device_map="auto"
28
+ )
29
+ model = PeftModel.from_pretrained(model_nf4, model_name)
30
+
31
+ return model, tokenizer
32
+
33
+ def analyze_contract(contract_code, model, tokenizer):
34
+ inputs = tokenizer(
35
+ f"{contract_code}\n\nidentify vulnerability of this code given above",
36
+ return_tensors="pt"
37
+ ).to("cuda")
38
+
39
+ outputs = model.generate(
40
+ **inputs,
41
+ max_length=1024,
42
+ temperature=0.7,
43
+ num_return_sequences=1
44
+ )
45
+
46
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
47
+
48
+ # Set page config
49
+ st.set_page_config(
50
+ page_title="Smart Contract Vulnerability Detector",
51
+ page_icon="πŸ”",
52
+ layout="wide"
53
+ )
54
 
55
+ # Main app
56
+ st.title("πŸ” Smart Contract Vulnerability Detector")
57
+ st.markdown("""
58
+ This app analyzes Solidity smart contracts for potential vulnerabilities using a fine-tuned LLaMA model.
59
+ Simply paste your smart contract code below and click 'Analyze'.
60
+ """)
61
+
62
+ # Initialize session state for the model
63
+ if 'model' not in st.session_state:
64
+ with st.spinner('Loading model... This might take a few minutes...'):
65
+ st.session_state.model, st.session_state.tokenizer = load_model()
66
+ st.success('Model loaded successfully!')
67
+
68
+ # Create the main interface
69
+ contract_code = st.text_area(
70
+ "Paste your Solidity contract code here:",
71
+ height=300,
72
+ placeholder="pragma solidity ^0.5.0;\n\ncontract YourContract {\n // Your code here\n}"
73
  )
74
 
75
+ analyze_button = st.button("Analyze Contract", type="primary")
 
 
76
 
77
+ # Sample contract button
78
+ if st.button("Load Sample Contract"):
79
+ contract_code = """pragma solidity ^0.5.0;
80
 
81
+ contract ModifierEntrancy {
82
+ mapping (address => uint) public tokenBalance;
83
+ string constant name = "Nu Token";
84
+ Bank bank;
85
 
86
+ constructor() public{
87
+ bank = new Bank();
88
+ }
89
 
90
+ function airDrop() hasNoBalance supportsToken public{
91
+ tokenBalance[msg.sender] += 20;
92
+ }
93
 
94
+ modifier supportsToken() {
95
+ require(keccak256(abi.encodePacked("Nu Token")) == bank.supportsToken());
96
+ _;
97
+ }
98
 
99
+ modifier hasNoBalance {
100
+ require(tokenBalance[msg.sender] == 0);
101
+ _;
102
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  }
104
 
105
  contract Bank{
 
106
  function supportsToken() external returns(bytes32) {
107
  return keccak256(abi.encodePacked("Nu Token"));
108
  }
109
+ }"""
110
+ st.session_state.contract_code = contract_code
111
+ st.experimental_rerun()
112
+
113
+ # Analysis section
114
+ if analyze_button and contract_code:
115
+ with st.spinner('Analyzing contract...'):
116
+ try:
117
+ analysis = analyze_contract(
118
+ contract_code,
119
+ st.session_state.model,
120
+ st.session_state.tokenizer
121
+ )
122
+
123
+ st.subheader("Analysis Results")
124
+ st.markdown(analysis)
125
+
126
+ except Exception as e:
127
+ st.error(f"An error occurred during analysis: {str(e)}")
128
+
129
+ elif analyze_button:
130
+ st.warning("Please enter some contract code to analyze.")
131
+
132
+ # Add footer with information
133
+ st.markdown("---")
134
+ st.markdown("""
135
+ <div style='text-align: center'>
136
+ <p>Built with Streamlit and Hugging Face Transformers</p>
137
+ <p>Model: peterxyz/detect-llama-34b</p>
138
+ </div>
139
+ """, unsafe_allow_html=True)