samidh commited on
Commit
8f9126b
·
verified ·
1 Parent(s): cad7e4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -14
app.py CHANGED
@@ -1,5 +1,4 @@
1
  # This is a project of Chakra Lab LLC. All rights reserved.
2
- import spaces
3
 
4
  import gradio as gr
5
  import os
@@ -22,19 +21,14 @@ bnb_config = BitsAndBytesConfig(
22
  #bnb_4bit_use_double_quant=True
23
  )
24
 
25
- @spaces.GPU
26
- def load_model():
27
- model = AutoModelForCausalLM.from_pretrained(base_model_name,
28
- token=os.environ['HF_TOKEN'],
29
- quantization_config=bnb_config,
30
- device_map="auto")
31
- model = PeftModel.from_pretrained(model, adapter_model_name, token=os.environ['HF_TOKEN'])
32
- model.merge_and_unload()
33
-
34
- model = model.to(device)
35
- return model
36
 
37
- model = load_model()
38
 
39
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
40
 
@@ -97,7 +91,6 @@ This policy is designed to determine whether or not content is hate speech.
97
  DEFAULT_CONTENT = "LLMs steal our jobs."
98
 
99
  # Function to make predictions
100
- @spaces.GPU
101
  def predict(content, policy):
102
  input_text = PROMPT.format(policy=policy, content=content)
103
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
 
1
  # This is a project of Chakra Lab LLC. All rights reserved.
 
2
 
3
  import gradio as gr
4
  import os
 
21
  #bnb_4bit_use_double_quant=True
22
  )
23
 
24
+ model = AutoModelForCausalLM.from_pretrained(base_model_name,
25
+ token=os.environ['HF_TOKEN'],
26
+ quantization_config=bnb_config,
27
+ device_map="auto")
28
+ model = PeftModel.from_pretrained(model, adapter_model_name, token=os.environ['HF_TOKEN'])
29
+ model.merge_and_unload()
 
 
 
 
 
30
 
31
+ model = model.to(device)
32
 
33
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
34
 
 
91
  DEFAULT_CONTENT = "LLMs steal our jobs."
92
 
93
  # Function to make predictions
 
94
  def predict(content, policy):
95
  input_text = PROMPT.format(policy=policy, content=content)
96
  input_ids = tokenizer.encode(input_text, return_tensors="pt")