Spaces:
Sleeping
Sleeping
from enum import Enum | |
import argparse | |
import dataclasses | |
from dataclasses import dataclass, field | |
from typing import Optional | |
import json | |
from transformers import HfArgumentParser, TrainingArguments | |
from tasks.utils import * | |
class WatermarkTrainingArguments(TrainingArguments): | |
removal: bool = field( | |
default=False, | |
metadata={ | |
"help": "Will do watermark removal" | |
} | |
) | |
max_steps: int = field( | |
default=0, | |
metadata={ | |
"help": "Will do watermark removal" | |
} | |
) | |
trigger_num: int = field( | |
metadata={ | |
"help": "Number of trigger token: " + ", ".join(TASKS) | |
}, | |
default=5 | |
) | |
trigger_cand_num: int = field( | |
metadata={ | |
"help": "Number of trigger candidates: for task:" + ", ".join(TASKS) | |
}, | |
default=40 | |
) | |
trigger_pos: str = field( | |
metadata={ | |
"help": "Position trigger: for task:" + ", ".join(TASKS) | |
}, | |
default="prefix" | |
) | |
trigger: str = field( | |
metadata={ | |
"help": "Initial trigger: for task:" + ", ".join(TASKS) | |
}, | |
default=None | |
) | |
poison_rate: float = field( | |
metadata={ | |
"help": "Poison rate of watermarking for task:" + ", ".join(TASKS) | |
}, | |
default=0.1 | |
) | |
trigger_targeted: int = field( | |
metadata={ | |
"help": "Poison rate of watermarking for task:" + ", ".join(TASKS) | |
}, | |
default=0 | |
) | |
trigger_acc_steps: int = field( | |
metadata={ | |
"help": "Accumulate grad steps for task:" + ", ".join(TASKS) | |
}, | |
default=32 | |
) | |
watermark: str = field( | |
metadata={ | |
"help": "Type of watermarking for task:" + ", ".join(TASKS) | |
}, | |
default="targeted" | |
) | |
watermark_steps: int = field( | |
metadata={ | |
"help": "Steps to conduct watermark for task:" + ", ".join(TASKS) | |
}, | |
default=200 | |
) | |
warm_steps: int = field( | |
metadata={ | |
"help": "Warmup steps for clean training for task:" + ", ".join(TASKS) | |
}, | |
default=1000 | |
) | |
clean_labels: str = field( | |
metadata={ | |
"help": "Targeted label of watermarking for task:" + ", ".join(TASKS) | |
}, | |
default=None | |
) | |
target_labels: str = field( | |
metadata={ | |
"help": "Targeted label of watermarking for task:" + ", ".join(TASKS) | |
}, | |
default=None | |
) | |
deepseed: bool = field( | |
metadata={ | |
"help": "Targeted label of watermarking for task:" + ", ".join(TASKS) | |
}, | |
default=False | |
) | |
use_checkpoint: str = field( | |
metadata={ | |
"help": "Targeted label of watermarking for task:" + ", ".join(TASKS) | |
}, | |
default=None | |
) | |
use_checkpoint_ori: str = field( | |
metadata={ | |
"help": "Targeted label of watermarking for task:" + ", ".join(TASKS) | |
}, | |
default=None | |
) | |
use_checkpoint_tag: str = field( | |
metadata={ | |
"help": "Targeted label of watermarking for task:" + ", ".join(TASKS) | |
}, | |
default=None | |
) | |
class DataTrainingArguments: | |
""" | |
Arguments pertaining to what data we are going to input our model for training and eval. | |
Using `HfArgumentParser` we can turn this class | |
into argparse arguments to be able to specify them on | |
the command line.training_args | |
""" | |
task_name: str = field( | |
metadata={ | |
"help": "The name of the task to train on: " + ", ".join(TASKS), | |
"choices": TASKS | |
} | |
) | |
dataset_name: str = field( | |
metadata={ | |
"help": "The name of the dataset to use: " + ", ".join(DATASETS), | |
"choices": DATASETS | |
} | |
) | |
dataset_config_name: Optional[str] = field( | |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} | |
) | |
max_seq_length: int = field( | |
default=128, | |
metadata={ | |
"help": "The maximum total input sequence length after tokenization. Sequences longer " | |
"than this will be truncated, sequences shorter will be padded." | |
}, | |
) | |
overwrite_cache: bool = field( | |
default=True, metadata={"help": "Overwrite the cached preprocessed datasets or not."} | |
) | |
pad_to_max_length: bool = field( | |
default=True, | |
metadata={ | |
"help": "Whether to pad all samples to `max_seq_length`. " | |
"If False, will pad the samples dynamically when batching to the maximum length in the batch." | |
}, | |
) | |
max_train_samples: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": "For debugging purposes or quicker training, truncate the number of training examples to this " | |
"value if set." | |
}, | |
) | |
max_eval_samples: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " | |
"value if set." | |
}, | |
) | |
max_predict_samples: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " | |
"value if set." | |
}, | |
) | |
train_file: Optional[str] = field( | |
default=None, metadata={"help": "A csv or a json file containing the training data."} | |
) | |
validation_file: Optional[str] = field( | |
default=None, metadata={"help": "A csv or a json file containing the validation data."} | |
) | |
test_file: Optional[str] = field( | |
default=None, | |
metadata={"help": "A csv or a json file containing the test data."} | |
) | |
template_id: Optional[int] = field( | |
default=0, | |
metadata={ | |
"help": "The specific prompt string to use" | |
} | |
) | |
class ModelArguments: | |
""" | |
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. | |
""" | |
model_name_or_path: str = field( | |
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} | |
) | |
model_name_or_path_ori: str = field( | |
default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} | |
) | |
config_name: Optional[str] = field( | |
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} | |
) | |
tokenizer_name: Optional[str] = field( | |
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} | |
) | |
cache_dir: Optional[str] = field( | |
default=None, | |
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, | |
) | |
use_fast_tokenizer: bool = field( | |
default=True, | |
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, | |
) | |
model_revision: str = field( | |
default="main", | |
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, | |
) | |
use_auth_token: bool = field( | |
default=False, | |
metadata={ | |
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " | |
"with private models)." | |
}, | |
) | |
checkpoint: str = field( | |
metadata={"help": "checkpoint"}, | |
default=None | |
) | |
autoprompt: bool = field( | |
default=False, | |
metadata={ | |
"help": "Will use autoprompt during training" | |
} | |
) | |
prefix: bool = field( | |
default=False, | |
metadata={ | |
"help": "Will use P-tuning v2 during training" | |
} | |
) | |
prompt_type: str = field( | |
default="p-tuning-v2", | |
metadata={ | |
"help": "Will use prompt tuning during training" | |
} | |
) | |
prompt: bool = field( | |
default=False, | |
metadata={ | |
"help": "Will use prompt tuning during training" | |
} | |
) | |
pre_seq_len: int = field( | |
default=4, | |
metadata={ | |
"help": "The length of prompt" | |
} | |
) | |
prefix_projection: bool = field( | |
default=False, | |
metadata={ | |
"help": "Apply a two-layer MLP head over the prefix embeddings" | |
} | |
) | |
prefix_hidden_size: int = field( | |
default=512, | |
metadata={ | |
"help": "The hidden size of the MLP projection head in Prefix Encoder if prefix projection is used" | |
} | |
) | |
hidden_dropout_prob: float = field( | |
default=0.1, | |
metadata={ | |
"help": "The dropout probability used in the models" | |
} | |
) | |
class QuestionAnwseringArguments: | |
n_best_size: int = field( | |
default=20, | |
metadata={"help": "The total number of n-best predictions to generate when looking for an answer."}, | |
) | |
max_answer_length: int = field( | |
default=30, | |
metadata={ | |
"help": "The maximum length of an answer that can be generated. This is needed because the start " | |
"and end predictions are not conditioned on one another." | |
}, | |
) | |
version_2_with_negative: bool = field( | |
default=False, metadata={"help": "If true, some of the examples do not have an answer."} | |
) | |
null_score_diff_threshold: float = field( | |
default=0.0, | |
metadata={ | |
"help": "The threshold used to select the null answer: if the best answer has a score that is less than " | |
"the score of the null answer minus this threshold, the null answer is selected for this example. " | |
"Only useful when `version_2_with_negative=True`." | |
}, | |
) | |
def get_args(): | |
"""Parse all the args.""" | |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, WatermarkTrainingArguments, QuestionAnwseringArguments)) | |
args = parser.parse_args_into_dataclasses() | |
if args[2].watermark == "clean": | |
args[2].poison_rate = 0.0 | |
if args[2].trigger is not None: | |
raw_trigger = args[2].trigger.replace(" ", "").split(",") | |
trigger = [int(x) for x in raw_trigger] | |
else: | |
trigger = np.random.choice(20000, args[2].trigger_num, replace=False).tolist() | |
args[0].trigger = list([trigger]) | |
args[2].trigger = list([trigger]) | |
args[2].trigger_num = len(trigger) | |
label2ids = [] | |
for k, v in json.loads(str(args[2].clean_labels)).items(): | |
label2ids.append(v) | |
args[0].clean_labels = label2ids | |
args[2].clean_labels = label2ids | |
args[2].dataset_name = args[1].dataset_name | |
label2ids = [] | |
for k, v in json.loads(str(args[2].target_labels)).items(): | |
label2ids.append(v) | |
args[0].target_labels = label2ids | |
args[2].target_labels = label2ids | |
args[2].label_names = ["labels"] | |
print(f"-> clean label:{args[2].clean_labels}\n-> target label:{args[2].target_labels}") | |
return args |