rudr4sarkar commited on
Commit
7f94d85
1 Parent(s): 7213d0c

Update app.py

Browse files

we need to explicitly use LlamaTokenizer since the model is Llama-based.

Files changed (1) hide show
  1. app.py +48 -17
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  from peft import PeftModel
5
  import bitsandbytes as bnb
6
  import gc
@@ -8,7 +8,8 @@ import gc
8
  @st.cache_resource
9
  def load_model():
10
  model_name = "peterxyz/detect-llama-34b"
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
12
 
13
  nf4_config = BitsAndBytesConfig(
14
  load_in_4bit=True,
@@ -24,23 +25,35 @@ def load_model():
24
  model_nf4 = AutoModelForCausalLM.from_pretrained(
25
  model_name,
26
  quantization_config=nf4_config,
27
- device_map="auto"
 
28
  )
29
  model = PeftModel.from_pretrained(model_nf4, model_name)
30
 
31
  return model, tokenizer
32
 
33
  def analyze_contract(contract_code, model, tokenizer):
 
 
 
 
 
 
34
  inputs = tokenizer(
35
- f"{contract_code}\n\nidentify vulnerability of this code given above",
36
- return_tensors="pt"
 
 
 
37
  ).to("cuda")
38
 
39
  outputs = model.generate(
40
  **inputs,
41
  max_length=1024,
42
  temperature=0.7,
43
- num_return_sequences=1
 
 
44
  )
45
 
46
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -59,11 +72,20 @@ This app analyzes Solidity smart contracts for potential vulnerabilities using a
59
  Simply paste your smart contract code below and click 'Analyze'.
60
  """)
61
 
 
 
 
 
62
  # Initialize session state for the model
63
- if 'model' not in st.session_state:
64
- with st.spinner('Loading model... This might take a few minutes...'):
65
- st.session_state.model, st.session_state.tokenizer = load_model()
66
- st.success('Model loaded successfully!')
 
 
 
 
 
67
 
68
  # Create the main interface
69
  contract_code = st.text_area(
@@ -72,10 +94,14 @@ contract_code = st.text_area(
72
  placeholder="pragma solidity ^0.5.0;\n\ncontract YourContract {\n // Your code here\n}"
73
  )
74
 
75
- analyze_button = st.button("Analyze Contract", type="primary")
 
 
 
 
76
 
77
  # Sample contract button
78
- if st.button("Load Sample Contract"):
79
  contract_code = """pragma solidity ^0.5.0;
80
 
81
  contract ModifierEntrancy {
@@ -112,8 +138,8 @@ contract Bank{
112
 
113
  # Analysis section
114
  if analyze_button and contract_code:
115
- with st.spinner('Analyzing contract...'):
116
- try:
117
  analysis = analyze_contract(
118
  contract_code,
119
  st.session_state.model,
@@ -121,10 +147,15 @@ if analyze_button and contract_code:
121
  )
122
 
123
  st.subheader("Analysis Results")
124
- st.markdown(analysis)
125
 
126
- except Exception as e:
127
- st.error(f"An error occurred during analysis: {str(e)}")
 
 
 
 
 
 
128
 
129
  elif analyze_button:
130
  st.warning("Please enter some contract code to analyze.")
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig
4
  from peft import PeftModel
5
  import bitsandbytes as bnb
6
  import gc
 
8
  @st.cache_resource
9
  def load_model():
10
  model_name = "peterxyz/detect-llama-34b"
11
+ # Use LlamaTokenizer instead of AutoTokenizer
12
+ tokenizer = LlamaTokenizer.from_pretrained(model_name)
13
 
14
  nf4_config = BitsAndBytesConfig(
15
  load_in_4bit=True,
 
25
  model_nf4 = AutoModelForCausalLM.from_pretrained(
26
  model_name,
27
  quantization_config=nf4_config,
28
+ device_map="auto",
29
+ trust_remote_code=True # Added this parameter for safety
30
  )
31
  model = PeftModel.from_pretrained(model_nf4, model_name)
32
 
33
  return model, tokenizer
34
 
35
  def analyze_contract(contract_code, model, tokenizer):
36
+ prompt = f"{contract_code}\n\nidentify vulnerability of this code given above"
37
+
38
+ # Add padding token if needed
39
+ if tokenizer.pad_token is None:
40
+ tokenizer.pad_token = tokenizer.eos_token
41
+
42
  inputs = tokenizer(
43
+ prompt,
44
+ return_tensors="pt",
45
+ padding=True,
46
+ truncation=True,
47
+ max_length=2048 # Added max length for safety
48
  ).to("cuda")
49
 
50
  outputs = model.generate(
51
  **inputs,
52
  max_length=1024,
53
  temperature=0.7,
54
+ num_return_sequences=1,
55
+ pad_token_id=tokenizer.pad_token_id, # Explicitly set pad token ID
56
+ eos_token_id=tokenizer.eos_token_id # Explicitly set EOS token ID
57
  )
58
 
59
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
72
  Simply paste your smart contract code below and click 'Analyze'.
73
  """)
74
 
75
+ # Add a loading message while initializing
76
+ if 'model_loaded' not in st.session_state:
77
+ st.session_state.model_loaded = False
78
+
79
  # Initialize session state for the model
80
+ if not st.session_state.model_loaded:
81
+ try:
82
+ with st.spinner('Loading model... This might take a few minutes...'):
83
+ st.session_state.model, st.session_state.tokenizer = load_model()
84
+ st.session_state.model_loaded = True
85
+ st.success('Model loaded successfully!')
86
+ except Exception as e:
87
+ st.error(f"Error loading model: {str(e)}")
88
+ st.stop()
89
 
90
  # Create the main interface
91
  contract_code = st.text_area(
 
94
  placeholder="pragma solidity ^0.5.0;\n\ncontract YourContract {\n // Your code here\n}"
95
  )
96
 
97
+ col1, col2 = st.columns([1, 4])
98
+ with col1:
99
+ analyze_button = st.button("Analyze Contract", type="primary")
100
+ with col2:
101
+ load_sample = st.button("Load Sample Contract")
102
 
103
  # Sample contract button
104
+ if load_sample:
105
  contract_code = """pragma solidity ^0.5.0;
106
 
107
  contract ModifierEntrancy {
 
138
 
139
  # Analysis section
140
  if analyze_button and contract_code:
141
+ try:
142
+ with st.spinner('Analyzing contract...'):
143
  analysis = analyze_contract(
144
  contract_code,
145
  st.session_state.model,
 
147
  )
148
 
149
  st.subheader("Analysis Results")
 
150
 
151
+ # Create an expandable section for the analysis
152
+ with st.expander("View Full Analysis", expanded=True):
153
+ st.markdown(analysis)
154
+
155
+ except Exception as e:
156
+ st.error(f"An error occurred during analysis: {str(e)}")
157
+ st.markdown("**Debug Information:**")
158
+ st.code(str(e))
159
 
160
  elif analyze_button:
161
  st.warning("Please enter some contract code to analyze.")