Spaces:
Running
on
L40S
Running
on
L40S
from peft.tuners.tuners_utils import BaseTunerLayer | |
from typing import List, Any, Optional, Type | |
class enable_lora: | |
def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None: | |
self.activated: bool = activated | |
if activated: | |
return | |
self.lora_modules: List[BaseTunerLayer] = [ | |
each for each in lora_modules if isinstance(each, BaseTunerLayer) | |
] | |
self.scales = [ | |
{ | |
active_adapter: lora_module.scaling[active_adapter] | |
for active_adapter in lora_module.active_adapters | |
} | |
for lora_module in self.lora_modules | |
] | |
def __enter__(self) -> None: | |
if self.activated: | |
return | |
for lora_module in self.lora_modules: | |
if not isinstance(lora_module, BaseTunerLayer): | |
continue | |
lora_module.scale_layer(0) | |
def __exit__( | |
self, | |
exc_type: Optional[Type[BaseException]], | |
exc_val: Optional[BaseException], | |
exc_tb: Optional[Any], | |
) -> None: | |
if self.activated: | |
return | |
for i, lora_module in enumerate(self.lora_modules): | |
if not isinstance(lora_module, BaseTunerLayer): | |
continue | |
for active_adapter in lora_module.active_adapters: | |
lora_module.scaling[active_adapter] = self.scales[i][active_adapter] | |
class set_lora_scale: | |
def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None: | |
self.lora_modules: List[BaseTunerLayer] = [ | |
each for each in lora_modules if isinstance(each, BaseTunerLayer) | |
] | |
self.scales = [ | |
{ | |
active_adapter: lora_module.scaling[active_adapter] | |
for active_adapter in lora_module.active_adapters | |
} | |
for lora_module in self.lora_modules | |
] | |
self.scale = scale | |
def __enter__(self) -> None: | |
for lora_module in self.lora_modules: | |
if not isinstance(lora_module, BaseTunerLayer): | |
continue | |
lora_module.scale_layer(self.scale) | |
def __exit__( | |
self, | |
exc_type: Optional[Type[BaseException]], | |
exc_val: Optional[BaseException], | |
exc_tb: Optional[Any], | |
) -> None: | |
for i, lora_module in enumerate(self.lora_modules): | |
if not isinstance(lora_module, BaseTunerLayer): | |
continue | |
for active_adapter in lora_module.active_adapters: | |
lora_module.scaling[active_adapter] = self.scales[i][active_adapter] | |