Spaces:
Running
Running
geekyrakshit
commited on
Commit
•
177344c
1
Parent(s):
883a576
update: LlamaGuardFineTuner + corresponding docs
Browse files- docs/train/train_llama_guard.md +3 -0
- guardrails_genie/train/__init__.py +4 -0
- guardrails_genie/train/llama_guard.py +38 -5
- mkdocs.yml +1 -0
docs/train/train_llama_guard.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Train Llama Guard
|
2 |
+
|
3 |
+
::: guardrails_genie.train.llama_guard
|
guardrails_genie/train/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .train_classifier import train_binary_classifier
|
2 |
+
from .llama_guard import LlamaGuardFineTuner, DatasetArgs
|
3 |
+
|
4 |
+
__all__ = ["train_binary_classifier", "LlamaGuardFineTuner", "DatasetArgs"]
|
guardrails_genie/train/llama_guard.py
CHANGED
@@ -30,6 +30,26 @@ class LlamaGuardFineTuner:
|
|
30 |
classification tasks, specifically for detecting prompt injection attacks. It
|
31 |
integrates with Weights & Biases for experiment tracking and optionally
|
32 |
displays progress in a Streamlit app.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
Args:
|
35 |
wandb_project (str): The name of the Weights & Biases project.
|
@@ -63,6 +83,7 @@ class LlamaGuardFineTuner:
|
|
63 |
train_dataset: The selected training dataset.
|
64 |
test_dataset: The selected testing dataset.
|
65 |
"""
|
|
|
66 |
dataset = load_dataset(dataset_args.dataset_address)
|
67 |
self.train_dataset = (
|
68 |
dataset["train"]
|
@@ -299,8 +320,8 @@ class LlamaGuardFineTuner:
|
|
299 |
batch_size: int = 32,
|
300 |
lr: float = 5e-6,
|
301 |
num_classes: int = 2,
|
302 |
-
log_interval: int =
|
303 |
-
save_interval: int =
|
304 |
):
|
305 |
"""
|
306 |
Fine-tunes the pre-trained LlamaGuard model on the training dataset for a single epoch.
|
@@ -332,13 +353,21 @@ class LlamaGuardFineTuner:
|
|
332 |
wandb.init(
|
333 |
project=self.wandb_project,
|
334 |
entity=self.wandb_entity,
|
335 |
-
name=f"{self.model_name}-{self.
|
336 |
job_type="fine-tune-llama-guard",
|
337 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
self.model.classifier = nn.Linear(
|
339 |
self.model.classifier.in_features, num_classes
|
340 |
)
|
341 |
self.model.num_labels = num_classes
|
|
|
342 |
self.model.train()
|
343 |
optimizer = optim.AdamW(self.model.parameters(), lr=lr)
|
344 |
data_loader = DataLoader(
|
@@ -367,8 +396,12 @@ class LlamaGuardFineTuner:
|
|
367 |
progress_percentage,
|
368 |
text=f"Training batch {i + 1}/{len(data_loader)}, Loss: {loss.item()}",
|
369 |
)
|
370 |
-
if (i + 1) % save_interval == 0:
|
371 |
save_model(self.model, f"checkpoints/model-{i + 1}.safetensors")
|
372 |
-
wandb.log_model(
|
|
|
|
|
|
|
|
|
373 |
wandb.finish()
|
374 |
shutil.rmtree("checkpoints")
|
|
|
30 |
classification tasks, specifically for detecting prompt injection attacks. It
|
31 |
integrates with Weights & Biases for experiment tracking and optionally
|
32 |
displays progress in a Streamlit app.
|
33 |
+
|
34 |
+
!!! example "Sample Usage"
|
35 |
+
```python
|
36 |
+
from guardrails_genie.train.llama_guard import LlamaGuardFineTuner, DatasetArgs
|
37 |
+
|
38 |
+
fine_tuner = LlamaGuardFineTuner(
|
39 |
+
wandb_project="guardrails-genie",
|
40 |
+
wandb_entity="geekyrakshit",
|
41 |
+
streamlit_mode=False,
|
42 |
+
)
|
43 |
+
fine_tuner.load_dataset(
|
44 |
+
DatasetArgs(
|
45 |
+
dataset_address="wandb/synthetic-prompt-injections",
|
46 |
+
train_dataset_range=-1,
|
47 |
+
test_dataset_range=-1,
|
48 |
+
)
|
49 |
+
)
|
50 |
+
fine_tuner.load_model()
|
51 |
+
fine_tuner.train(save_interval=100)
|
52 |
+
```
|
53 |
|
54 |
Args:
|
55 |
wandb_project (str): The name of the Weights & Biases project.
|
|
|
83 |
train_dataset: The selected training dataset.
|
84 |
test_dataset: The selected testing dataset.
|
85 |
"""
|
86 |
+
self.dataset_args = dataset_args
|
87 |
dataset = load_dataset(dataset_args.dataset_address)
|
88 |
self.train_dataset = (
|
89 |
dataset["train"]
|
|
|
320 |
batch_size: int = 32,
|
321 |
lr: float = 5e-6,
|
322 |
num_classes: int = 2,
|
323 |
+
log_interval: int = 1,
|
324 |
+
save_interval: int = 50,
|
325 |
):
|
326 |
"""
|
327 |
Fine-tunes the pre-trained LlamaGuard model on the training dataset for a single epoch.
|
|
|
353 |
wandb.init(
|
354 |
project=self.wandb_project,
|
355 |
entity=self.wandb_entity,
|
356 |
+
name=f"{self.model_name}-{self.dataset_args.dataset_address.split('/')[-1]}",
|
357 |
job_type="fine-tune-llama-guard",
|
358 |
)
|
359 |
+
wandb.config.dataset_args = self.dataset_args.model_dump()
|
360 |
+
wandb.config.model_name = self.model_name
|
361 |
+
wandb.config.batch_size = batch_size
|
362 |
+
wandb.config.lr = lr
|
363 |
+
wandb.config.num_classes = num_classes
|
364 |
+
wandb.config.log_interval = log_interval
|
365 |
+
wandb.config.save_interval = save_interval
|
366 |
self.model.classifier = nn.Linear(
|
367 |
self.model.classifier.in_features, num_classes
|
368 |
)
|
369 |
self.model.num_labels = num_classes
|
370 |
+
self.model = self.model.to(self.device)
|
371 |
self.model.train()
|
372 |
optimizer = optim.AdamW(self.model.parameters(), lr=lr)
|
373 |
data_loader = DataLoader(
|
|
|
396 |
progress_percentage,
|
397 |
text=f"Training batch {i + 1}/{len(data_loader)}, Loss: {loss.item()}",
|
398 |
)
|
399 |
+
if (i + 1) % save_interval == 0 or i + 1 == len(data_loader):
|
400 |
save_model(self.model, f"checkpoints/model-{i + 1}.safetensors")
|
401 |
+
wandb.log_model(
|
402 |
+
f"checkpoints/model-{i + 1}.safetensors",
|
403 |
+
name=f"{wandb.run.id}-model",
|
404 |
+
aliases=f"step-{i + 1}",
|
405 |
+
)
|
406 |
wandb.finish()
|
407 |
shutil.rmtree("checkpoints")
|
mkdocs.yml
CHANGED
@@ -80,6 +80,7 @@ nav:
|
|
80 |
- RegexModel: 'regex_model.md'
|
81 |
- Training:
|
82 |
- Train Classifier: 'train/train_classifier.md'
|
|
|
83 |
- Utils: 'utils.md'
|
84 |
|
85 |
repo_url: https://github.com/soumik12345/guardrails-genie
|
|
|
80 |
- RegexModel: 'regex_model.md'
|
81 |
- Training:
|
82 |
- Train Classifier: 'train/train_classifier.md'
|
83 |
+
- Train Llama Guard: 'train/train_llama_guard.md'
|
84 |
- Utils: 'utils.md'
|
85 |
|
86 |
repo_url: https://github.com/soumik12345/guardrails-genie
|