import torch from safetensors.torch import safe_open from modules import scripts, sd_models, shared import gradio as gr from modules.processing import process_images class KeyBasedModelMerger(scripts.Script): def title(self): return "Key-based model merging" def ui(self, is_txt2img): model_names = sorted(sd_models.checkpoints_list.keys(), key=str.casefold) model_a_dropdown = gr.Dropdown( label="Model A", choices=model_names, value=model_names[0] if model_names else None ) model_b_dropdown = gr.Dropdown( label="Model B", choices=model_names, value=model_names[0] if model_names else None ) model_c_dropdown = gr.Dropdown( label="Model C (Add difference mode用)", choices=model_names, value=model_names[0] if model_names else None ) keys_and_alphas_textbox = gr.Textbox( label="マージするテンソルのキーとマージ比率 (部分一致, 1行に1つ, カンマ区切り)", lines=5, placeholder="例:\nmodel.diffusion_model.input_blocks.0,0.5\nmodel.diffusion_model.middle_block,0.3" ) merge_checkbox = gr.Checkbox(label="モデルのマージを有効にする", value=True) use_gpu_checkbox = gr.Checkbox(label="GPUを使用", value=True) batch_size_slider = gr.Slider(minimum=1, maximum=500, step=1, value=250, label="KeyMgerge_BatchSize") merge_mode_dropdown = gr.Dropdown( label="Merge Mode", choices=["Normal", "Add difference (B-C to Current)", "Add difference (A + (B-C) to Current)"], value="Normal" ) return [model_a_dropdown, model_b_dropdown, model_c_dropdown, keys_and_alphas_textbox, merge_checkbox, use_gpu_checkbox, batch_size_slider, merge_mode_dropdown] def run(self, p, model_a_name, model_b_name, model_c_name, keys_and_alphas_str, merge_enabled, use_gpu, batch_size, merge_mode): if not model_b_name: print("Error: Model B is not selected.") return p try: # 必要なモデルファイルだけを読み込む if merge_mode == "Normal": model_a_filename = sd_models.checkpoints_list[model_a_name].filename model_b_filename = sd_models.checkpoints_list[model_b_name].filename elif merge_mode == "Add difference (B-C to Current)": model_b_filename = sd_models.checkpoints_list[model_b_name].filename model_c_filename = sd_models.checkpoints_list[model_c_name].filename elif merge_mode == "Add difference (A + (B-C) to Current)": model_a_filename = sd_models.checkpoints_list[model_a_name].filename model_b_filename = sd_models.checkpoints_list[model_b_name].filename model_c_filename = sd_models.checkpoints_list[model_c_name].filename else: raise ValueError(f"Invalid merge mode: ") except KeyError as e: print(f"Error: Selected model is not found in checkpoints list. ") return p # マージ処理 if merge_enabled: input_keys_and_alphas = [] for line in keys_and_alphas_str.split("\n"): if "," in line: key_part, alpha_str = line.split(",", 1) try: alpha = float(alpha_str) input_keys_and_alphas.append((key_part, alpha)) except ValueError: print(f"Invalid alpha value in line '', skipping...") # state_dictからキーのリストを事前に作成 model_keys = list(shared.sd_model.state_dict().keys()) # 部分一致検索を行う final_keys_and_alphas = {} for key_part, alpha in input_keys_and_alphas: for model_key in model_keys: if key_part in model_key: final_keys_and_alphas[model_key] = alpha # デバイスの設定 (GPUかCPUか選べるようにする) device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu" # バッチ処理でキーをまとめて処理 batched_keys = list(final_keys_and_alphas.items()) # モデルファイルを開く if merge_mode == "Normal": with safe_open(model_a_filename, framework="pt", device=device) as f_a, \ safe_open(model_b_filename, framework="pt", device=device) as f_b: self._merge_models(f_a, f_b, None, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device) elif merge_mode == "Add difference (B-C to Current)": with safe_open(model_b_filename, framework="pt", device=device) as f_b, \ safe_open(model_c_filename, framework="pt", device=device) as f_c: self._merge_models(None, f_b, f_c, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device) elif merge_mode == "Add difference (A + (B-C) to Current)": with safe_open(model_a_filename, framework="pt", device=device) as f_a, \ safe_open(model_b_filename, framework="pt", device=device) as f_b, \ safe_open(model_c_filename, framework="pt", device=device) as f_c: self._merge_models(f_a, f_b, f_c, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device) else: raise ValueError(f"Invalid merge mode: ") # 必要に応じて process_images を実行 return process_images(p) def _merge_models(self, f_a, f_b, f_c, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device): # バッチごとに処理 for i in range(0, len(batched_keys), batch_size): batch = batched_keys[i:i + batch_size] # バッチでテンソルを取得 tensors_a = [f_a.get_tensor(key) for key, _ in batch] if f_a is not None else None tensors_b = [f_b.get_tensor(key) for key, _ in batch] if f_b is not None else None tensors_c = [f_c.get_tensor(key) for key, _ in batch] if f_c is not None else None alphas = [final_keys_and_alphas[key] for key, _ in batch] # マージ処理の実行 for j, (key, alpha) in enumerate(batch): tensor_a = tensors_a[j] if tensors_a is not None else None tensor_b = tensors_b[j] if tensors_b is not None else None tensor_c = tensors_c[j] if tensors_c is not None else None if merge_mode == "Normal": merged_tensor = torch.lerp(tensor_a, tensor_b, alpha) print(f"NomalMerged:{alpha}:{key}") elif merge_mode == "Add difference (B-C to Current)": merged_tensor = shared.sd_model.state_dict()[key] + alpha * (tensor_b - tensor_c) print(f"(B-C to Current):{alpha}:{key}") elif merge_mode == "Add difference (A + (B-C) to Current)": merged_tensor = tensor_a + alpha * (tensor_b - tensor_c) print(f"(A + (B-C) to Current):{alpha}:{key}") else: raise ValueError(f"Invalid merge mode: ") shared.sd_model.state_dict()[key].copy_(merged_tensor.to(device))