|
import os |
|
import torch |
|
from unsloth import FastLanguageModel, is_bfloat16_supported |
|
from trl import SFTTrainer |
|
from transformers import TrainingArguments |
|
from datasets import load_dataset |
|
import gradio as gr |
|
import json |
|
from huggingface_hub import HfApi |
|
|
|
max_seq_length = 4096 |
|
dtype = None |
|
load_in_4bit = True |
|
hf_token = os.getenv("HF_TOKEN") |
|
current_num = os.getenv("NUM") |
|
|
|
print(f"stage ${current_num}") |
|
|
|
api = HfApi(token=hf_token) |
|
|
|
model_base = "unsloth/gemma-2-27b-bnb-4bit" |
|
|
|
print("Starting model and tokenizer loading...") |
|
|
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=model_base, |
|
max_seq_length=max_seq_length, |
|
dtype=dtype, |
|
load_in_4bit=load_in_4bit, |
|
token=hf_token |
|
) |
|
|
|
print("Model and tokenizer loaded successfully.") |
|
|
|
print("Configuring PEFT model...") |
|
model = FastLanguageModel.get_peft_model( |
|
model, |
|
r=16, |
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
|
lora_alpha=16, |
|
lora_dropout=0, |
|
bias="none", |
|
use_gradient_checkpointing="unsloth", |
|
random_state=3407, |
|
use_rslora=False, |
|
loftq_config=None, |
|
) |
|
print("PEFT model configured.") |
|
|
|
|
|
alpaca_prompt = { |
|
"learning_from": """Below is a CVE definition. |
|
### CVE definition: |
|
{} |
|
### detail CVE: |
|
{}""", |
|
"definition": """Below is a definition about software vulnerability. Explain it. |
|
### Definition: |
|
{} |
|
### Explanation: |
|
{}""", |
|
"code_vulnerability": """Below is a code snippet. Identify the line of code that is vulnerable and describe the type of software vulnerability. |
|
### Code Snippet: |
|
{} |
|
### Vulnerability solution: |
|
{}""" |
|
} |
|
|
|
EOS_TOKEN = tokenizer.eos_token |
|
|
|
def detect_prompt_type(instruction): |
|
if instruction.startswith("what is code vulnerable of this code:"): |
|
return "code_vulnerability" |
|
elif instruction.startswith("Learning from"): |
|
return "learning_from" |
|
elif instruction.startswith("what is"): |
|
return "definition" |
|
else: |
|
return "unknown" |
|
|
|
def formatting_prompts_func(examples): |
|
instructions = examples["instruction"] |
|
outputs = examples["output"] |
|
texts = [] |
|
|
|
for instruction, output in zip(instructions, outputs): |
|
prompt_type = detect_prompt_type(instruction) |
|
if prompt_type in alpaca_prompt: |
|
prompt = alpaca_prompt[prompt_type].format(instruction, output) |
|
else: |
|
prompt = instruction + "\n\n" + output |
|
text = prompt + EOS_TOKEN |
|
texts.append(text) |
|
|
|
return {"text": texts} |
|
|
|
print("Loading dataset...") |
|
dataset = load_dataset("dad1909/DCSV", split="train") |
|
print("Dataset loaded successfully.") |
|
|
|
print("Applying formatting function to the dataset...") |
|
dataset = dataset.map(formatting_prompts_func, batched=True) |
|
print("Formatting function applied.") |
|
|
|
print("Initializing trainer...") |
|
trainer = SFTTrainer( |
|
model=model, |
|
tokenizer=tokenizer, |
|
train_dataset=dataset, |
|
dataset_text_field="text", |
|
max_seq_length=max_seq_length, |
|
dataset_num_proc=2, |
|
packing=False, |
|
args=TrainingArguments( |
|
per_device_train_batch_size=1, |
|
gradient_accumulation_steps=1, |
|
learning_rate=2e-4, |
|
fp16=not is_bfloat16_supported(), |
|
bf16=is_bfloat16_supported(), |
|
warmup_steps=5, |
|
logging_steps=10, |
|
max_steps=100, |
|
optim="adamw_8bit", |
|
weight_decay=0.01, |
|
lr_scheduler_type="linear", |
|
seed=3407, |
|
output_dir="outputs" |
|
), |
|
) |
|
print("Trainer initialized.") |
|
|
|
print("Starting training...") |
|
trainer_stats = trainer.train() |
|
print("Training completed.") |
|
|
|
num = int(current_num) |
|
num += 1 |
|
|
|
uploads_models = f"cybersentinal-2.0-{str(num)}" |
|
|
|
up = "sentinal-3.1-70B" |
|
|
|
print("Saving the trained model...") |
|
model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit") |
|
print("Model saved successfully.") |
|
|
|
print("Pushing the model to the hub...") |
|
model.push_to_hub_merged( |
|
up, |
|
tokenizer, |
|
save_method="merged_16bit", |
|
token=hf_token |
|
) |
|
print("Model pushed to hub successfully.") |
|
|
|
api.delete_space_variable(repo_id="dad1909/CyberCode", key="NUM") |
|
api.add_space_variable(repo_id="dad1909/CyberCode", key="NUM", value=str(num)) |