|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import tempfile |
|
from collections import OrderedDict |
|
|
|
import torch |
|
|
|
from peft import ( |
|
LoraConfig, |
|
PeftModel, |
|
PrefixTuningConfig, |
|
PromptEncoderConfig, |
|
PromptTuningConfig, |
|
get_peft_model, |
|
get_peft_model_state_dict, |
|
prepare_model_for_int8_training, |
|
) |
|
|
|
|
|
CONFIG_CLASSES = ( |
|
LoraConfig, |
|
PrefixTuningConfig, |
|
PromptEncoderConfig, |
|
PromptTuningConfig, |
|
) |
|
CONFIG_TESTING_KWARGS = ( |
|
{ |
|
"r": 8, |
|
"lora_alpha": 32, |
|
"target_modules": None, |
|
"lora_dropout": 0.05, |
|
"bias": "none", |
|
}, |
|
{ |
|
"num_virtual_tokens": 10, |
|
}, |
|
{ |
|
"num_virtual_tokens": 10, |
|
"encoder_hidden_size": 32, |
|
}, |
|
{ |
|
"num_virtual_tokens": 10, |
|
}, |
|
) |
|
|
|
CLASSES_MAPPING = { |
|
"lora": (LoraConfig, CONFIG_TESTING_KWARGS[0]), |
|
"prefix_tuning": (PrefixTuningConfig, CONFIG_TESTING_KWARGS[1]), |
|
"prompt_encoder": (PromptEncoderConfig, CONFIG_TESTING_KWARGS[2]), |
|
"prompt_tuning": (PromptTuningConfig, CONFIG_TESTING_KWARGS[3]), |
|
} |
|
|
|
|
|
|
|
class ClassInstantier(OrderedDict): |
|
def __getitem__(self, key, *args, **kwargs): |
|
|
|
if any(kwarg in self[key][1] for kwarg in kwargs): |
|
new_config_kwargs = self[key][1].copy() |
|
new_config_kwargs.update(kwargs) |
|
return (self[key][0], new_config_kwargs) |
|
|
|
return super().__getitem__(key, *args, **kwargs) |
|
|
|
def get_grid_parameters(self, grid_parameters, filter_params_func=None): |
|
r""" |
|
Returns a list of all possible combinations of the parameters in the config classes. |
|
|
|
Args: |
|
grid_parameters (`dict`): |
|
A dictionary containing the parameters to be tested. There should be at least the key "model_ids" which |
|
contains a list of model ids to be tested. The other keys should be the name of the config class |
|
post-fixed with "_kwargs" and the value should be a dictionary containing the parameters to be tested |
|
for that config class. |
|
filter_params_func (`callable`, `optional`): |
|
A function that takes a list of tuples and returns a list of tuples. This function is used to filter |
|
out the tests that needs for example to be skipped. |
|
|
|
Returns: |
|
generated_tests (`list`): |
|
A list of tuples containing the name of the test, the model id, the config class and the config class |
|
kwargs. |
|
""" |
|
generated_tests = [] |
|
model_list = grid_parameters["model_ids"] |
|
task_type = grid_parameters["task_type"] if "task_type" in grid_parameters else None |
|
|
|
for model_id in model_list: |
|
for key, value in self.items(): |
|
if "{}_kwargs".format(key) in grid_parameters: |
|
peft_configs = [] |
|
current_peft_config = value[1].copy() |
|
for current_key, current_value in grid_parameters[f"{key}_kwargs"].items(): |
|
for kwarg in current_value: |
|
current_peft_config.update({current_key: kwarg}) |
|
|
|
if task_type is not None: |
|
current_peft_config.update({"task_type": task_type}) |
|
|
|
peft_configs.append(current_peft_config.copy()) |
|
else: |
|
current_peft_config = value[1].copy() |
|
if task_type is not None: |
|
current_peft_config.update({"task_type": task_type}) |
|
peft_configs = [current_peft_config] |
|
|
|
for peft_config in peft_configs: |
|
generated_tests.append((f"test_{model_id}_{key}", model_id, value[0], peft_config)) |
|
|
|
if filter_params_func is not None: |
|
generated_tests = filter_params_func(generated_tests) |
|
|
|
return generated_tests |
|
|
|
|
|
PeftTestConfigManager = ClassInstantier(CLASSES_MAPPING) |
|
|
|
|
|
class PeftCommonTester: |
|
r""" |
|
A large testing suite for testing common functionality of the PEFT models. |
|
|
|
Attributes: |
|
torch_device (`torch.device`): |
|
The device on which the tests will be run. |
|
transformers_class (`transformers.PreTrainedModel`): |
|
The transformers class that is being tested. |
|
""" |
|
torch_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
transformers_class = None |
|
|
|
def prepare_inputs_for_common(self): |
|
raise NotImplementedError |
|
|
|
def _test_model_attr(self, model_id, config_cls, config_kwargs): |
|
model = self.transformers_class.from_pretrained(model_id) |
|
config = config_cls( |
|
base_model_name_or_path=model_id, |
|
**config_kwargs, |
|
) |
|
model = get_peft_model(model, config) |
|
|
|
self.assertTrue(hasattr(model, "save_pretrained")) |
|
self.assertTrue(hasattr(model, "from_pretrained")) |
|
self.assertTrue(hasattr(model, "push_to_hub")) |
|
|
|
def _test_prepare_for_training(self, model_id, config_cls, config_kwargs): |
|
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) |
|
config = config_cls( |
|
base_model_name_or_path=model_id, |
|
**config_kwargs, |
|
) |
|
model = get_peft_model(model, config) |
|
|
|
dummy_input = self.prepare_inputs_for_testing() |
|
dummy_output = model.get_input_embeddings()(dummy_input["input_ids"]) |
|
|
|
self.assertTrue(not dummy_output.requires_grad) |
|
|
|
|
|
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) |
|
model = prepare_model_for_int8_training(model) |
|
|
|
for param in model.parameters(): |
|
self.assertTrue(not param.requires_grad) |
|
|
|
config = config_cls( |
|
base_model_name_or_path=model_id, |
|
**config_kwargs, |
|
) |
|
model = get_peft_model(model, config) |
|
|
|
|
|
if hasattr(model, "enable_input_require_grads"): |
|
model.enable_input_require_grads() |
|
else: |
|
|
|
def make_inputs_require_grad(module, input, output): |
|
output.requires_grad_(True) |
|
|
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
|
|
|
dummy_input = self.prepare_inputs_for_testing() |
|
dummy_output = model.get_input_embeddings()(dummy_input["input_ids"]) |
|
|
|
self.assertTrue(dummy_output.requires_grad) |
|
|
|
def _test_save_pretrained(self, model_id, config_cls, config_kwargs): |
|
model = self.transformers_class.from_pretrained(model_id) |
|
config = config_cls( |
|
base_model_name_or_path=model_id, |
|
**config_kwargs, |
|
) |
|
model = get_peft_model(model, config) |
|
model = model.to(self.torch_device) |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dirname: |
|
model.save_pretrained(tmp_dirname) |
|
|
|
model_from_pretrained = self.transformers_class.from_pretrained(model_id) |
|
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname) |
|
|
|
|
|
state_dict = get_peft_model_state_dict(model) |
|
state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained) |
|
|
|
|
|
self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys()) |
|
|
|
|
|
for key in state_dict.keys(): |
|
self.assertTrue( |
|
torch.allclose( |
|
state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) |
|
) |
|
) |
|
|
|
|
|
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin"))) |
|
|
|
|
|
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))) |
|
|
|
|
|
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin"))) |
|
|
|
|
|
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) |
|
|
|
def _test_merge_layers(self, model_id, config_cls, config_kwargs): |
|
model = self.transformers_class.from_pretrained(model_id) |
|
config = config_cls( |
|
base_model_name_or_path=model_id, |
|
**config_kwargs, |
|
) |
|
model = get_peft_model(model, config) |
|
model = model.to(self.torch_device) |
|
|
|
if config.peft_type != "LORA": |
|
with self.assertRaises(AttributeError): |
|
model = model.merge_and_unload() |
|
elif model.config.model_type == "gpt2": |
|
with self.assertRaises(ValueError): |
|
model = model.merge_and_unload() |
|
else: |
|
dummy_input = self.prepare_inputs_for_testing() |
|
model.eval() |
|
logits_lora = model(**dummy_input)[0] |
|
|
|
model = model.merge_and_unload() |
|
|
|
logits_merged = model(**dummy_input)[0] |
|
|
|
transformers_model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) |
|
|
|
logits_transformers = transformers_model(**dummy_input)[0] |
|
|
|
self.assertTrue(torch.allclose(logits_lora, logits_merged, atol=1e-4, rtol=1e-4)) |
|
self.assertFalse(torch.allclose(logits_merged, logits_transformers, atol=1e-10, rtol=1e-10)) |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dirname: |
|
model.save_pretrained(tmp_dirname) |
|
|
|
model_from_pretrained = self.transformers_class.from_pretrained(tmp_dirname).to(self.torch_device) |
|
|
|
logits_merged_from_pretrained = model_from_pretrained(**dummy_input)[0] |
|
|
|
self.assertTrue(torch.allclose(logits_merged, logits_merged_from_pretrained, atol=1e-4, rtol=1e-4)) |
|
|
|
def _test_generate(self, model_id, config_cls, config_kwargs): |
|
model = self.transformers_class.from_pretrained(model_id) |
|
config = config_cls( |
|
base_model_name_or_path=model_id, |
|
**config_kwargs, |
|
) |
|
model = get_peft_model(model, config) |
|
model = model.to(self.torch_device) |
|
|
|
inputs = self.prepare_inputs_for_testing() |
|
|
|
|
|
_ = model.generate(**inputs) |
|
|
|
with self.assertRaises(TypeError): |
|
|
|
_ = model.generate(inputs["input_ids"]) |
|
|
|
def _test_generate_half_prec(self, model_id, config_cls, config_kwargs): |
|
if config_cls not in (LoraConfig, PrefixTuningConfig): |
|
return |
|
|
|
model = self.transformers_class.from_pretrained(model_id, torch_dtype=torch.bfloat16) |
|
config = config_cls( |
|
base_model_name_or_path=model_id, |
|
**config_kwargs, |
|
) |
|
model = get_peft_model(model, config) |
|
model = model.to(self.torch_device) |
|
|
|
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) |
|
attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) |
|
|
|
|
|
_ = model.generate(input_ids=input_ids, attention_mask=attention_mask) |
|
|
|
with self.assertRaises(TypeError): |
|
|
|
_ = model.generate(input_ids, attention_mask=attention_mask) |
|
|
|
def _test_training(self, model_id, config_cls, config_kwargs): |
|
if config_cls not in (LoraConfig,): |
|
return |
|
|
|
model = self.transformers_class.from_pretrained(model_id) |
|
config = config_cls( |
|
base_model_name_or_path=model_id, |
|
**config_kwargs, |
|
) |
|
model = get_peft_model(model, config) |
|
model = model.to(self.torch_device) |
|
|
|
inputs = self.prepare_inputs_for_testing() |
|
|
|
|
|
output = model(**inputs)[0] |
|
loss = output.sum() |
|
loss.backward() |
|
|
|
for n, param in model.named_parameters(): |
|
if "lora" in n: |
|
self.assertIsNotNone(param.grad) |
|
else: |
|
self.assertIsNone(param.grad) |
|
|