Spaces:
Running
on
L4
Running
on
L4
Update app.py
Browse files
app.py
CHANGED
@@ -12,7 +12,7 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
12 |
base_model_name = "google/gemma-7b"
|
13 |
adapter_model_name = "samidh/cope-g7bq-2c-hs.s1.5fpc.9-sx.s1.5.9o-VL.s1.5.9-HR.s5-SH.s5-l5e5-e3-d25-r8"
|
14 |
|
15 |
-
model = AutoModelForCausalLM.from_pretrained(base_model_name, token=os.environ['HF_TOKEN'])
|
16 |
model = PeftModel.from_pretrained(model, adapter_model_name, token=os.environ['HF_TOKEN'])
|
17 |
model.merge_and_unload()
|
18 |
|
@@ -86,6 +86,7 @@ def predict(content, policy):
|
|
86 |
with torch.no_grad():
|
87 |
outputs = model(input_ids)
|
88 |
logits = outputs.logits[:, -1, :] # Get logits for the last token
|
|
|
89 |
predicted_token_id = torch.argmax(logits, dim=-1).item()
|
90 |
decoded_output = tokenizer.decode([predicted_token_id])
|
91 |
if decoded_output == '0':
|
|
|
12 |
base_model_name = "google/gemma-7b"
|
13 |
adapter_model_name = "samidh/cope-g7bq-2c-hs.s1.5fpc.9-sx.s1.5.9o-VL.s1.5.9-HR.s5-SH.s5-l5e5-e3-d25-r8"
|
14 |
|
15 |
+
model = AutoModelForCausalLM.from_pretrained(base_model_name, token=os.environ['HF_TOKEN'], device_map="auto")
|
16 |
model = PeftModel.from_pretrained(model, adapter_model_name, token=os.environ['HF_TOKEN'])
|
17 |
model.merge_and_unload()
|
18 |
|
|
|
86 |
with torch.no_grad():
|
87 |
outputs = model(input_ids)
|
88 |
logits = outputs.logits[:, -1, :] # Get logits for the last token
|
89 |
+
model.gradient_checkpointing_enable()
|
90 |
predicted_token_id = torch.argmax(logits, dim=-1).item()
|
91 |
decoded_output = tokenizer.decode([predicted_token_id])
|
92 |
if decoded_output == '0':
|