geekyrakshit commited on
Commit
177344c
1 Parent(s): 883a576

update: LlamaGuardFineTuner + corresponding docs

Browse files
docs/train/train_llama_guard.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Train Llama Guard
2
+
3
+ ::: guardrails_genie.train.llama_guard
guardrails_genie/train/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .train_classifier import train_binary_classifier
2
+ from .llama_guard import LlamaGuardFineTuner, DatasetArgs
3
+
4
+ __all__ = ["train_binary_classifier", "LlamaGuardFineTuner", "DatasetArgs"]
guardrails_genie/train/llama_guard.py CHANGED
@@ -30,6 +30,26 @@ class LlamaGuardFineTuner:
30
  classification tasks, specifically for detecting prompt injection attacks. It
31
  integrates with Weights & Biases for experiment tracking and optionally
32
  displays progress in a Streamlit app.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  Args:
35
  wandb_project (str): The name of the Weights & Biases project.
@@ -63,6 +83,7 @@ class LlamaGuardFineTuner:
63
  train_dataset: The selected training dataset.
64
  test_dataset: The selected testing dataset.
65
  """
 
66
  dataset = load_dataset(dataset_args.dataset_address)
67
  self.train_dataset = (
68
  dataset["train"]
@@ -299,8 +320,8 @@ class LlamaGuardFineTuner:
299
  batch_size: int = 32,
300
  lr: float = 5e-6,
301
  num_classes: int = 2,
302
- log_interval: int = 20,
303
- save_interval: int = 1000,
304
  ):
305
  """
306
  Fine-tunes the pre-trained LlamaGuard model on the training dataset for a single epoch.
@@ -332,13 +353,21 @@ class LlamaGuardFineTuner:
332
  wandb.init(
333
  project=self.wandb_project,
334
  entity=self.wandb_entity,
335
- name=f"{self.model_name}-{self.dataset_name}",
336
  job_type="fine-tune-llama-guard",
337
  )
 
 
 
 
 
 
 
338
  self.model.classifier = nn.Linear(
339
  self.model.classifier.in_features, num_classes
340
  )
341
  self.model.num_labels = num_classes
 
342
  self.model.train()
343
  optimizer = optim.AdamW(self.model.parameters(), lr=lr)
344
  data_loader = DataLoader(
@@ -367,8 +396,12 @@ class LlamaGuardFineTuner:
367
  progress_percentage,
368
  text=f"Training batch {i + 1}/{len(data_loader)}, Loss: {loss.item()}",
369
  )
370
- if (i + 1) % save_interval == 0:
371
  save_model(self.model, f"checkpoints/model-{i + 1}.safetensors")
372
- wandb.log_model(f"checkpoints/model-{i + 1}.safetensors")
 
 
 
 
373
  wandb.finish()
374
  shutil.rmtree("checkpoints")
 
30
  classification tasks, specifically for detecting prompt injection attacks. It
31
  integrates with Weights & Biases for experiment tracking and optionally
32
  displays progress in a Streamlit app.
33
+
34
+ !!! example "Sample Usage"
35
+ ```python
36
+ from guardrails_genie.train.llama_guard import LlamaGuardFineTuner, DatasetArgs
37
+
38
+ fine_tuner = LlamaGuardFineTuner(
39
+ wandb_project="guardrails-genie",
40
+ wandb_entity="geekyrakshit",
41
+ streamlit_mode=False,
42
+ )
43
+ fine_tuner.load_dataset(
44
+ DatasetArgs(
45
+ dataset_address="wandb/synthetic-prompt-injections",
46
+ train_dataset_range=-1,
47
+ test_dataset_range=-1,
48
+ )
49
+ )
50
+ fine_tuner.load_model()
51
+ fine_tuner.train(save_interval=100)
52
+ ```
53
 
54
  Args:
55
  wandb_project (str): The name of the Weights & Biases project.
 
83
  train_dataset: The selected training dataset.
84
  test_dataset: The selected testing dataset.
85
  """
86
+ self.dataset_args = dataset_args
87
  dataset = load_dataset(dataset_args.dataset_address)
88
  self.train_dataset = (
89
  dataset["train"]
 
320
  batch_size: int = 32,
321
  lr: float = 5e-6,
322
  num_classes: int = 2,
323
+ log_interval: int = 1,
324
+ save_interval: int = 50,
325
  ):
326
  """
327
  Fine-tunes the pre-trained LlamaGuard model on the training dataset for a single epoch.
 
353
  wandb.init(
354
  project=self.wandb_project,
355
  entity=self.wandb_entity,
356
+ name=f"{self.model_name}-{self.dataset_args.dataset_address.split('/')[-1]}",
357
  job_type="fine-tune-llama-guard",
358
  )
359
+ wandb.config.dataset_args = self.dataset_args.model_dump()
360
+ wandb.config.model_name = self.model_name
361
+ wandb.config.batch_size = batch_size
362
+ wandb.config.lr = lr
363
+ wandb.config.num_classes = num_classes
364
+ wandb.config.log_interval = log_interval
365
+ wandb.config.save_interval = save_interval
366
  self.model.classifier = nn.Linear(
367
  self.model.classifier.in_features, num_classes
368
  )
369
  self.model.num_labels = num_classes
370
+ self.model = self.model.to(self.device)
371
  self.model.train()
372
  optimizer = optim.AdamW(self.model.parameters(), lr=lr)
373
  data_loader = DataLoader(
 
396
  progress_percentage,
397
  text=f"Training batch {i + 1}/{len(data_loader)}, Loss: {loss.item()}",
398
  )
399
+ if (i + 1) % save_interval == 0 or i + 1 == len(data_loader):
400
  save_model(self.model, f"checkpoints/model-{i + 1}.safetensors")
401
+ wandb.log_model(
402
+ f"checkpoints/model-{i + 1}.safetensors",
403
+ name=f"{wandb.run.id}-model",
404
+ aliases=f"step-{i + 1}",
405
+ )
406
  wandb.finish()
407
  shutil.rmtree("checkpoints")
mkdocs.yml CHANGED
@@ -80,6 +80,7 @@ nav:
80
  - RegexModel: 'regex_model.md'
81
  - Training:
82
  - Train Classifier: 'train/train_classifier.md'
 
83
  - Utils: 'utils.md'
84
 
85
  repo_url: https://github.com/soumik12345/guardrails-genie
 
80
  - RegexModel: 'regex_model.md'
81
  - Training:
82
  - Train Classifier: 'train/train_classifier.md'
83
+ - Train Llama Guard: 'train/train_llama_guard.md'
84
  - Utils: 'utils.md'
85
 
86
  repo_url: https://github.com/soumik12345/guardrails-genie