|
import os |
|
import torch |
|
|
|
def process_file(file_path): |
|
state_dict = torch.load(file_path) |
|
|
|
filtered_state_dict = {name: param for name, param in state_dict.items() if 'lora' in name or 'audio' in name} |
|
|
|
print(f"Parameters in file {file_path}:") |
|
for name in filtered_state_dict.keys(): |
|
print(name) |
|
|
|
torch.save(filtered_state_dict, file_path[:-4] + '_trainable.bin') |
|
print(file_path[:-4] + '_trainable.bin') |
|
print('----------------------------------') |
|
|
|
|
|
count = 0 |
|
for dirpath, dirnames, filenames in os.walk('/fs/nexus-projects/brain_project/acl_sk_24/GAMA//llm/alpaca-lora-main/'): |
|
for file in filenames: |
|
if file == "pytorch_model.bin": |
|
cur_target = os.path.join(dirpath, file) |
|
if os.path.exists(cur_target[:-4] + '_trainable.bin') == False: |
|
print(os.path.join(dirpath, file)) |
|
process_file(os.path.join(dirpath, file)) |
|
count +=1 |
|
print(count) |