Spaces:
Running
on
L4
Running
on
L4
Update app.py
Browse files
app.py
CHANGED
@@ -83,10 +83,9 @@ def predict(content, policy):
|
|
83 |
input_text = PROMPT.format(policy=policy, content=content)
|
84 |
input_ids = tokenizer.encode(input_text, return_tensors="pt")
|
85 |
|
86 |
-
with torch.
|
87 |
outputs = model(input_ids)
|
88 |
logits = outputs.logits[:, -1, :] # Get logits for the last token
|
89 |
-
model.gradient_checkpointing_enable()
|
90 |
predicted_token_id = torch.argmax(logits, dim=-1).item()
|
91 |
decoded_output = tokenizer.decode([predicted_token_id])
|
92 |
if decoded_output == '0':
|
|
|
83 |
input_text = PROMPT.format(policy=policy, content=content)
|
84 |
input_ids = tokenizer.encode(input_text, return_tensors="pt")
|
85 |
|
86 |
+
with torch.inference_mode():
|
87 |
outputs = model(input_ids)
|
88 |
logits = outputs.logits[:, -1, :] # Get logits for the last token
|
|
|
89 |
predicted_token_id = torch.argmax(logits, dim=-1).item()
|
90 |
decoded_output = tokenizer.decode([predicted_token_id])
|
91 |
if decoded_output == '0':
|