|
import torch |
|
import sys |
|
from safetensors.torch import save_file |
|
|
|
|
|
model = torch.load('pytorch_model.bin', map_location="cpu") |
|
for key, value in model.items(): |
|
print(key) |
|
|
|
|
|
|
|
stop_key = 'h.12.ln_1.weight' |
|
final_keys = ['ln_f.weight', 'ln_f.bias'] |
|
|
|
|
|
stripped = {} |
|
stopped = False |
|
for key, value in model.items(): |
|
if key == stop_key: |
|
stopped = True |
|
continue |
|
|
|
if key in final_keys: |
|
stripped[key] = value |
|
|
|
if stopped is False: |
|
stripped[key] = value |
|
|
|
|
|
save_file(stripped, 'pytorch_model_stripped.safetensors') |
|
torch.save(stripped, 'pytorch_model_stripped.bin') |
|
|
|
|