Spaces:
Running
on
L4
Running
on
L4
from dataclasses import dataclass | |
import loralib as lora | |
class LoraConfig: | |
r: int | |
lora_alpha: float | |
lora_dropout: float = 0.0 | |
def setup_lora(model, lora_config): | |
# Replace the embedding layer with a LoRA layer | |
model.embeddings = lora.Embedding( | |
num_embeddings=model.embeddings.num_embeddings, | |
embedding_dim=model.embeddings.embedding_dim, | |
padding_idx=model.embeddings.padding_idx, | |
r=lora_config.r, | |
lora_alpha=lora_config.lora_alpha, | |
) | |
model.codebook_embeddings = lora.Embedding( | |
num_embeddings=model.codebook_embeddings.num_embeddings, | |
embedding_dim=model.codebook_embeddings.embedding_dim, | |
padding_idx=model.codebook_embeddings.padding_idx, | |
r=lora_config.r, | |
lora_alpha=lora_config.lora_alpha, | |
) | |
# Replace output layer with a LoRA layer | |
linears = [(model, "output")] | |
# Replace all linear layers with LoRA layers | |
for layer in model.layers: | |
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) | |
linears.extend( | |
[ | |
(layer.feed_forward, "w1"), | |
(layer.feed_forward, "w2"), | |
(layer.feed_forward, "w3"), | |
] | |
) | |
if hasattr(model, "fast_layers"): | |
model.fast_embeddings = lora.Embedding( | |
num_embeddings=model.fast_embeddings.num_embeddings, | |
embedding_dim=model.fast_embeddings.embedding_dim, | |
padding_idx=model.fast_embeddings.padding_idx, | |
r=lora_config.r, | |
lora_alpha=lora_config.lora_alpha, | |
) | |
# Dual-AR model | |
linears.append((model, "fast_output")) | |
for layer in model.fast_layers: | |
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) | |
linears.extend( | |
[ | |
(layer.feed_forward, "w1"), | |
(layer.feed_forward, "w2"), | |
(layer.feed_forward, "w3"), | |
] | |
) | |
for module, layer in linears: | |
updated_linear = lora.Linear( | |
in_features=getattr(module, layer).in_features, | |
out_features=getattr(module, layer).out_features, | |
bias=getattr(module, layer).bias, | |
r=lora_config.r, | |
lora_alpha=lora_config.lora_alpha, | |
lora_dropout=lora_config.lora_dropout, | |
) | |
setattr(module, layer, updated_linear) | |
# Mark only the LoRA layers as trainable | |
lora.mark_only_lora_as_trainable(model, bias="none") | |
def get_merged_state_dict(model): | |
# This line will merge the state dict of the model and the LoRA parameters | |
model.eval() | |
# Then we need to remove the LoRA parameters from the state dict | |
state_dict = model.state_dict() | |
for name in list(state_dict.keys()): | |
if "lora" in name: | |
state_dict.pop(name) | |
return state_dict | |