rudr4sarkar commited on
Commit
836789f
·
verified ·
1 Parent(s): 42322cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -26
app.py CHANGED
@@ -1,38 +1,55 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig
4
  from peft import PeftModel
5
- import bitsandbytes as bnb
6
  import gc
7
 
8
  @st.cache_resource
9
  def load_model():
10
  model_name = "peterxyz/detect-llama-34b"
11
- # Use LlamaTokenizer instead of AutoTokenizer
12
  tokenizer = LlamaTokenizer.from_pretrained(model_name)
13
 
14
- nf4_config = BitsAndBytesConfig(
15
- load_in_4bit=True,
16
- bnb_4bit_quant_type="nf4",
17
- bnb_4bit_use_double_quant=True,
18
- bnb_4bit_compute_dtype=torch.bfloat16
19
- )
20
 
21
- # Clear CUDA cache and garbage collect
22
- torch.cuda.empty_cache()
 
23
  gc.collect()
24
 
25
- model_nf4 = AutoModelForCausalLM.from_pretrained(
26
- model_name,
27
- quantization_config=nf4_config,
28
- device_map="auto",
29
- trust_remote_code=True # Added this parameter for safety
30
- )
31
- model = PeftModel.from_pretrained(model_nf4, model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- return model, tokenizer
34
 
35
- def analyze_contract(contract_code, model, tokenizer):
36
  prompt = f"{contract_code}\n\nidentify vulnerability of this code given above"
37
 
38
  # Add padding token if needed
@@ -44,16 +61,16 @@ def analyze_contract(contract_code, model, tokenizer):
44
  return_tensors="pt",
45
  padding=True,
46
  truncation=True,
47
- max_length=2048 # Added max length for safety
48
- ).to("cuda")
49
 
50
  outputs = model.generate(
51
  **inputs,
52
  max_length=1024,
53
  temperature=0.7,
54
  num_return_sequences=1,
55
- pad_token_id=tokenizer.pad_token_id, # Explicitly set pad token ID
56
- eos_token_id=tokenizer.eos_token_id # Explicitly set EOS token ID
57
  )
58
 
59
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -80,7 +97,7 @@ if 'model_loaded' not in st.session_state:
80
  if not st.session_state.model_loaded:
81
  try:
82
  with st.spinner('Loading model... This might take a few minutes...'):
83
- st.session_state.model, st.session_state.tokenizer = load_model()
84
  st.session_state.model_loaded = True
85
  st.success('Model loaded successfully!')
86
  except Exception as e:
@@ -143,7 +160,8 @@ if analyze_button and contract_code:
143
  analysis = analyze_contract(
144
  contract_code,
145
  st.session_state.model,
146
- st.session_state.tokenizer
 
147
  )
148
 
149
  st.subheader("Analysis Results")
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import AutoModelForCausalLM, LlamaTokenizer
4
  from peft import PeftModel
 
5
  import gc
6
 
7
  @st.cache_resource
8
  def load_model():
9
  model_name = "peterxyz/detect-llama-34b"
 
10
  tokenizer = LlamaTokenizer.from_pretrained(model_name)
11
 
12
+ # Check if CUDA is available
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ st.info(f"Using device: {device}")
 
 
 
15
 
16
+ # Clear memory
17
+ if device == "cuda":
18
+ torch.cuda.empty_cache()
19
  gc.collect()
20
 
21
+ # Load model with appropriate settings based on device
22
+ if device == "cuda":
23
+ from transformers import BitsAndBytesConfig
24
+ import bitsandbytes as bnb
25
+
26
+ nf4_config = BitsAndBytesConfig(
27
+ load_in_4bit=True,
28
+ bnb_4bit_quant_type="nf4",
29
+ bnb_4bit_use_double_quant=True,
30
+ bnb_4bit_compute_dtype=torch.bfloat16
31
+ )
32
+
33
+ model_nf4 = AutoModelForCausalLM.from_pretrained(
34
+ model_name,
35
+ quantization_config=nf4_config,
36
+ device_map="auto",
37
+ trust_remote_code=True
38
+ )
39
+ model = PeftModel.from_pretrained(model_nf4, model_name)
40
+ else:
41
+ # For CPU, load with reduced precision but without 4-bit quantization
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ model_name,
44
+ torch_dtype=torch.float32, # Use float32 for CPU
45
+ device_map={"": device},
46
+ low_cpu_mem_usage=True,
47
+ trust_remote_code=True
48
+ )
49
 
50
+ return model, tokenizer, device
51
 
52
+ def analyze_contract(contract_code, model, tokenizer, device):
53
  prompt = f"{contract_code}\n\nidentify vulnerability of this code given above"
54
 
55
  # Add padding token if needed
 
61
  return_tensors="pt",
62
  padding=True,
63
  truncation=True,
64
+ max_length=2048
65
+ ).to(device)
66
 
67
  outputs = model.generate(
68
  **inputs,
69
  max_length=1024,
70
  temperature=0.7,
71
  num_return_sequences=1,
72
+ pad_token_id=tokenizer.pad_token_id,
73
+ eos_token_id=tokenizer.eos_token_id
74
  )
75
 
76
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
97
  if not st.session_state.model_loaded:
98
  try:
99
  with st.spinner('Loading model... This might take a few minutes...'):
100
+ st.session_state.model, st.session_state.tokenizer, st.session_state.device = load_model()
101
  st.session_state.model_loaded = True
102
  st.success('Model loaded successfully!')
103
  except Exception as e:
 
160
  analysis = analyze_contract(
161
  contract_code,
162
  st.session_state.model,
163
+ st.session_state.tokenizer,
164
+ st.session_state.device
165
  )
166
 
167
  st.subheader("Analysis Results")