OminiControl / src /lora_controller.py
Yuanshi's picture
add all
6ed1db6
raw
history blame
2.62 kB
from peft.tuners.tuners_utils import BaseTunerLayer
from typing import List, Any, Optional, Type
class enable_lora:
def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
self.activated: bool = activated
if activated:
return
self.lora_modules: List[BaseTunerLayer] = [
each for each in lora_modules if isinstance(each, BaseTunerLayer)
]
self.scales = [
{
active_adapter: lora_module.scaling[active_adapter]
for active_adapter in lora_module.active_adapters
}
for lora_module in self.lora_modules
]
def __enter__(self) -> None:
if self.activated:
return
for lora_module in self.lora_modules:
if not isinstance(lora_module, BaseTunerLayer):
continue
lora_module.scale_layer(0)
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[Any],
) -> None:
if self.activated:
return
for i, lora_module in enumerate(self.lora_modules):
if not isinstance(lora_module, BaseTunerLayer):
continue
for active_adapter in lora_module.active_adapters:
lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
class set_lora_scale:
def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
self.lora_modules: List[BaseTunerLayer] = [
each for each in lora_modules if isinstance(each, BaseTunerLayer)
]
self.scales = [
{
active_adapter: lora_module.scaling[active_adapter]
for active_adapter in lora_module.active_adapters
}
for lora_module in self.lora_modules
]
self.scale = scale
def __enter__(self) -> None:
for lora_module in self.lora_modules:
if not isinstance(lora_module, BaseTunerLayer):
continue
lora_module.scale_layer(self.scale)
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[Any],
) -> None:
for i, lora_module in enumerate(self.lora_modules):
if not isinstance(lora_module, BaseTunerLayer):
continue
for active_adapter in lora_module.active_adapters:
lora_module.scaling[active_adapter] = self.scales[i][active_adapter]