|
|
|
|
|
|
|
""" |
|
This script automates the process of updating a Stable Diffusion training |
|
script with settings extracted from a LoRA model's JSON metadata. |
|
|
|
It performs the following main tasks: |
|
1. Reads a JSON file containing LoRA model metadata |
|
2. Parses an existing Stable Diffusion training script |
|
3. Maps metadata keys to corresponding script arguments |
|
4. Updates the script with values from the metadata |
|
5. Handles special cases and complex arguments (e.g., network_args) |
|
6. Writes the updated script to a new file |
|
|
|
Usage: |
|
python steal_sdscripts_metadata <metadata_file> <script_file> <output_file> |
|
|
|
This tool is particularly useful for replicating training conditions or |
|
fine-tuning existing models based on successful previous runs. |
|
""" |
|
|
|
import json |
|
import re |
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
description='Update training script based on metadata.' |
|
) |
|
parser.add_argument( |
|
'metadata_file', type=str, help='Path to the metadata JSON file' |
|
) |
|
parser.add_argument( |
|
'script_file', type=str, help='Path to the training script file' |
|
) |
|
parser.add_argument( |
|
'output_file', type=str, help='Path to save the updated training script' |
|
) |
|
args = parser.parse_args() |
|
|
|
|
|
with open(args.metadata_file, 'r', encoding='utf-8') as f: |
|
metadata = json.load(f) |
|
|
|
|
|
with open(args.script_file, 'r', encoding='utf-8') as f: |
|
script_content = f.read() |
|
|
|
|
|
mappings = { |
|
'ss_network_dim': '--network_dim', |
|
'ss_network_alpha': '--network_alpha', |
|
'ss_learning_rate': '--learning_rate', |
|
'ss_unet_lr': '--unet_lr', |
|
'ss_text_encoder_lr': '--text_encoder_lr', |
|
'ss_max_train_steps': '--max_train_steps', |
|
'ss_train_batch_size': '--train_batch_size', |
|
'ss_gradient_accumulation_steps': '--gradient_accumulation_steps', |
|
'ss_mixed_precision': '--mixed_precision', |
|
'ss_seed': '--seed', |
|
'ss_resolution': '--resolution', |
|
'ss_clip_skip': '--clip_skip', |
|
'ss_lr_scheduler': '--lr_scheduler', |
|
'ss_network_module': '--network_module', |
|
} |
|
|
|
|
|
for json_key, script_arg in mappings.items(): |
|
if json_key in metadata: |
|
value = metadata[json_key] |
|
|
|
|
|
if json_key == 'ss_resolution': |
|
value = f'"{value[1:-1]}"' |
|
elif isinstance(value, str): |
|
value = f'"{value}"' |
|
|
|
|
|
pattern = f'{script_arg}=\\S+' |
|
replacement = f'{script_arg}={value}' |
|
if re.search(pattern, script_content): |
|
script_content = re.sub(pattern, replacement, script_content) |
|
else: |
|
script_content = script_content.replace( |
|
'args=(', f'args=(\n {replacement}' |
|
) |
|
|
|
|
|
if 'ss_network_args' in metadata: |
|
network_args = metadata['ss_network_args'] |
|
NETWORK_ARGS_STR = ' '.join( |
|
[f'"{k}={v}"' for k, v in network_args.items()] |
|
) |
|
PATTERN = r'--network_args(\s+".+")+' |
|
replacement = f'--network_args\n {NETWORK_ARGS_STR}' |
|
script_content = re.sub(PATTERN, replacement, script_content) |
|
|
|
|
|
with open(args.output_file, 'w', encoding='utf-8') as f: |
|
f.write(script_content) |
|
|
|
print(f"Updated training script has been saved as '{args.output_file}'") |
|
|