Spaces:
Running
Running
geekyrakshit
commited on
Commit
•
5e33295
1
Parent(s):
177344c
update: LlamaGuardFineTuner
Browse files
.gitignore
CHANGED
@@ -168,4 +168,5 @@ temp.txt
|
|
168 |
binary-classifier/
|
169 |
wandb/
|
170 |
artifacts/
|
171 |
-
evaluation_results/
|
|
|
|
168 |
binary-classifier/
|
169 |
wandb/
|
170 |
artifacts/
|
171 |
+
evaluation_results/
|
172 |
+
checkpoints/
|
application_pages/llama_guard_fine_tuning.py
CHANGED
@@ -1,10 +1,16 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
from guardrails_genie.train.llama_guard import DatasetArgs, LlamaGuardFineTuner
|
4 |
|
5 |
|
6 |
def initialize_session_state():
|
7 |
-
st.session_state.llama_guard_fine_tuner = LlamaGuardFineTuner(
|
|
|
|
|
|
|
|
|
8 |
if "dataset_address" not in st.session_state:
|
9 |
st.session_state.dataset_address = ""
|
10 |
if "train_dataset_range" not in st.session_state:
|
@@ -25,6 +31,14 @@ def initialize_session_state():
|
|
25 |
st.session_state.evaluation_batch_size = None
|
26 |
if "evaluation_temperature" not in st.session_state:
|
27 |
st.session_state.evaluation_temperature = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
|
30 |
initialize_session_state()
|
@@ -43,18 +57,34 @@ if st.session_state.dataset_address != "":
|
|
43 |
st.session_state.train_dataset_range = train_dataset_range
|
44 |
st.session_state.test_dataset_range = test_dataset_range
|
45 |
|
46 |
-
model_name = st.sidebar.
|
47 |
-
"Model Name",
|
48 |
-
["meta-llama/Prompt-Guard-86M"],
|
49 |
)
|
50 |
st.session_state.model_name = model_name
|
51 |
|
|
|
|
|
|
|
52 |
preview_dataset = st.sidebar.toggle("Preview Dataset")
|
53 |
st.session_state.preview_dataset = preview_dataset
|
54 |
|
55 |
evaluate_model = st.sidebar.toggle("Evaluate Model")
|
56 |
st.session_state.evaluate_model = evaluate_model
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
load_fine_tuner_button = st.sidebar.button("Load Fine-Tuner")
|
59 |
st.session_state.load_fine_tuner_button = load_fine_tuner_button
|
60 |
|
@@ -68,13 +98,19 @@ if st.session_state.dataset_address != "":
|
|
68 |
)
|
69 |
)
|
70 |
st.session_state.llama_guard_fine_tuner.load_model(
|
71 |
-
model_name=st.session_state.model_name
|
|
|
|
|
|
|
|
|
|
|
72 |
)
|
73 |
if st.session_state.preview_dataset:
|
74 |
st.session_state.llama_guard_fine_tuner.show_dataset_sample()
|
75 |
if st.session_state.evaluate_model:
|
76 |
st.session_state.llama_guard_fine_tuner.evaluate_model(
|
77 |
-
batch_size=
|
78 |
-
|
|
|
79 |
)
|
80 |
st.session_state.is_fine_tuner_loaded = True
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
import streamlit as st
|
4 |
|
5 |
from guardrails_genie.train.llama_guard import DatasetArgs, LlamaGuardFineTuner
|
6 |
|
7 |
|
8 |
def initialize_session_state():
|
9 |
+
st.session_state.llama_guard_fine_tuner = LlamaGuardFineTuner(
|
10 |
+
wandb_project=os.getenv("WANDB_PROJECT_NAME"),
|
11 |
+
wandb_entity=os.getenv("WANDB_ENTITY_NAME"),
|
12 |
+
streamlit_mode=True,
|
13 |
+
)
|
14 |
if "dataset_address" not in st.session_state:
|
15 |
st.session_state.dataset_address = ""
|
16 |
if "train_dataset_range" not in st.session_state:
|
|
|
31 |
st.session_state.evaluation_batch_size = None
|
32 |
if "evaluation_temperature" not in st.session_state:
|
33 |
st.session_state.evaluation_temperature = None
|
34 |
+
if "checkpoint" not in st.session_state:
|
35 |
+
st.session_state.checkpoint = None
|
36 |
+
if "eval_batch_size" not in st.session_state:
|
37 |
+
st.session_state.eval_batch_size = 32
|
38 |
+
if "eval_positive_label" not in st.session_state:
|
39 |
+
st.session_state.eval_positive_label = 2
|
40 |
+
if "eval_temperature" not in st.session_state:
|
41 |
+
st.session_state.eval_temperature = 1.0
|
42 |
|
43 |
|
44 |
initialize_session_state()
|
|
|
57 |
st.session_state.train_dataset_range = train_dataset_range
|
58 |
st.session_state.test_dataset_range = test_dataset_range
|
59 |
|
60 |
+
model_name = st.sidebar.text_input(
|
61 |
+
label="Model Name", value="meta-llama/Prompt-Guard-86M"
|
|
|
62 |
)
|
63 |
st.session_state.model_name = model_name
|
64 |
|
65 |
+
checkpoint = st.sidebar.text_input(label="Fine-tuned Model Checkpoint", value="")
|
66 |
+
st.session_state.checkpoint = checkpoint
|
67 |
+
|
68 |
preview_dataset = st.sidebar.toggle("Preview Dataset")
|
69 |
st.session_state.preview_dataset = preview_dataset
|
70 |
|
71 |
evaluate_model = st.sidebar.toggle("Evaluate Model")
|
72 |
st.session_state.evaluate_model = evaluate_model
|
73 |
|
74 |
+
if st.session_state.evaluate_model:
|
75 |
+
eval_batch_size = st.sidebar.slider(
|
76 |
+
label="Eval Batch Size", min_value=16, max_value=1024, value=32
|
77 |
+
)
|
78 |
+
st.session_state.eval_batch_size = eval_batch_size
|
79 |
+
|
80 |
+
eval_positive_label = st.sidebar.number_input("EVal Positive Label", value=2)
|
81 |
+
st.session_state.eval_positive_label = eval_positive_label
|
82 |
+
|
83 |
+
eval_temperature = st.sidebar.slider(
|
84 |
+
label="Eval Temperature", min_value=0.0, max_value=5.0, value=1.0
|
85 |
+
)
|
86 |
+
st.session_state.eval_temperature = eval_temperature
|
87 |
+
|
88 |
load_fine_tuner_button = st.sidebar.button("Load Fine-Tuner")
|
89 |
st.session_state.load_fine_tuner_button = load_fine_tuner_button
|
90 |
|
|
|
98 |
)
|
99 |
)
|
100 |
st.session_state.llama_guard_fine_tuner.load_model(
|
101 |
+
model_name=st.session_state.model_name,
|
102 |
+
checkpoint=(
|
103 |
+
None
|
104 |
+
if st.session_state.checkpoint == ""
|
105 |
+
else st.session_state.checkpoint
|
106 |
+
),
|
107 |
)
|
108 |
if st.session_state.preview_dataset:
|
109 |
st.session_state.llama_guard_fine_tuner.show_dataset_sample()
|
110 |
if st.session_state.evaluate_model:
|
111 |
st.session_state.llama_guard_fine_tuner.evaluate_model(
|
112 |
+
batch_size=st.session_state.eval_batch_size,
|
113 |
+
positive_label=st.session_state.eval_positive_label,
|
114 |
+
temperature=st.session_state.eval_temperature,
|
115 |
)
|
116 |
st.session_state.is_fine_tuner_loaded = True
|
guardrails_genie/train/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
|
|
1 |
from .train_classifier import train_binary_classifier
|
2 |
-
from .llama_guard import LlamaGuardFineTuner, DatasetArgs
|
3 |
|
4 |
-
__all__ = ["train_binary_classifier", "LlamaGuardFineTuner", "DatasetArgs"]
|
|
|
1 |
+
from .llama_guard import DatasetArgs, LlamaGuardFineTuner
|
2 |
from .train_classifier import train_binary_classifier
|
|
|
3 |
|
4 |
+
__all__ = ["train_binary_classifier", "LlamaGuardFineTuner", "DatasetArgs"]
|
guardrails_genie/train/llama_guard.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import os
|
2 |
import shutil
|
|
|
|
|
3 |
|
4 |
import plotly.graph_objects as go
|
5 |
import streamlit as st
|
@@ -7,15 +9,16 @@ import torch
|
|
7 |
import torch.nn as nn
|
8 |
import torch.nn.functional as F
|
9 |
import torch.optim as optim
|
10 |
-
import wandb
|
11 |
from datasets import load_dataset
|
12 |
from pydantic import BaseModel
|
13 |
from rich.progress import track
|
14 |
-
from safetensors.torch import save_model
|
15 |
from sklearn.metrics import roc_auc_score, roc_curve
|
16 |
from torch.utils.data import DataLoader
|
17 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
18 |
|
|
|
|
|
19 |
|
20 |
class DatasetArgs(BaseModel):
|
21 |
dataset_address: str
|
@@ -30,7 +33,7 @@ class LlamaGuardFineTuner:
|
|
30 |
classification tasks, specifically for detecting prompt injection attacks. It
|
31 |
integrates with Weights & Biases for experiment tracking and optionally
|
32 |
displays progress in a Streamlit app.
|
33 |
-
|
34 |
!!! example "Sample Usage"
|
35 |
```python
|
36 |
from guardrails_genie.train.llama_guard import LlamaGuardFineTuner, DatasetArgs
|
@@ -98,7 +101,11 @@ class LlamaGuardFineTuner:
|
|
98 |
else dataset["test"].select(range(dataset_args.test_dataset_range))
|
99 |
)
|
100 |
|
101 |
-
def load_model(
|
|
|
|
|
|
|
|
|
102 |
"""
|
103 |
Loads the specified pre-trained model and tokenizer for sequence classification tasks.
|
104 |
|
@@ -118,9 +125,20 @@ class LlamaGuardFineTuner:
|
|
118 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
119 |
self.model_name = model_name
|
120 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
121 |
-
|
122 |
-
self.
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
def show_dataset_sample(self):
|
126 |
"""
|
|
|
1 |
import os
|
2 |
import shutil
|
3 |
+
from glob import glob
|
4 |
+
from typing import Optional
|
5 |
|
6 |
import plotly.graph_objects as go
|
7 |
import streamlit as st
|
|
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
11 |
import torch.optim as optim
|
|
|
12 |
from datasets import load_dataset
|
13 |
from pydantic import BaseModel
|
14 |
from rich.progress import track
|
15 |
+
from safetensors.torch import load_model, save_model
|
16 |
from sklearn.metrics import roc_auc_score, roc_curve
|
17 |
from torch.utils.data import DataLoader
|
18 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
19 |
|
20 |
+
import wandb
|
21 |
+
|
22 |
|
23 |
class DatasetArgs(BaseModel):
|
24 |
dataset_address: str
|
|
|
33 |
classification tasks, specifically for detecting prompt injection attacks. It
|
34 |
integrates with Weights & Biases for experiment tracking and optionally
|
35 |
displays progress in a Streamlit app.
|
36 |
+
|
37 |
!!! example "Sample Usage"
|
38 |
```python
|
39 |
from guardrails_genie.train.llama_guard import LlamaGuardFineTuner, DatasetArgs
|
|
|
101 |
else dataset["test"].select(range(dataset_args.test_dataset_range))
|
102 |
)
|
103 |
|
104 |
+
def load_model(
|
105 |
+
self,
|
106 |
+
model_name: str = "meta-llama/Prompt-Guard-86M",
|
107 |
+
checkpoint: Optional[str] = None,
|
108 |
+
):
|
109 |
"""
|
110 |
Loads the specified pre-trained model and tokenizer for sequence classification tasks.
|
111 |
|
|
|
125 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
126 |
self.model_name = model_name
|
127 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
128 |
+
if checkpoint is None:
|
129 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
130 |
+
model_name
|
131 |
+
).to(self.device)
|
132 |
+
else:
|
133 |
+
api = wandb.Api()
|
134 |
+
artifact = api.artifact(checkpoint.removeprefix("wandb://"))
|
135 |
+
artifact_dir = artifact.download()
|
136 |
+
model_file_path = glob(os.path.join(artifact_dir, "model-*.safetensors"))[0]
|
137 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
138 |
+
self.model.classifier = nn.Linear(self.model.classifier.in_features, 2)
|
139 |
+
self.model.num_labels = 2
|
140 |
+
load_model(self.model, model_file_path)
|
141 |
+
self.model = self.model.to(self.device)
|
142 |
|
143 |
def show_dataset_sample(self):
|
144 |
"""
|
guardrails_genie/train/train_classifier.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import evaluate
|
2 |
import numpy as np
|
3 |
import streamlit as st
|
4 |
-
import wandb
|
5 |
from datasets import load_dataset
|
6 |
from transformers import (
|
7 |
AutoModelForSequenceClassification,
|
@@ -11,6 +10,7 @@ from transformers import (
|
|
11 |
TrainingArguments,
|
12 |
)
|
13 |
|
|
|
14 |
from guardrails_genie.utils import StreamlitProgressbarCallback
|
15 |
|
16 |
|
|
|
1 |
import evaluate
|
2 |
import numpy as np
|
3 |
import streamlit as st
|
|
|
4 |
from datasets import load_dataset
|
5 |
from transformers import (
|
6 |
AutoModelForSequenceClassification,
|
|
|
10 |
TrainingArguments,
|
11 |
)
|
12 |
|
13 |
+
import wandb
|
14 |
from guardrails_genie.utils import StreamlitProgressbarCallback
|
15 |
|
16 |
|