import torch from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler from tqdm.auto import tqdm from huggingface_hub import cached_download, hf_hub_url import os def display_image(image): """ Replace this with your actual image display logic. """ image.show() def load_and_merge_lora(base_model_id, lora_id, lora_weight_name, lora_adapter_name): try: pipe = DiffusionPipeline.from_pretrained( base_model_id, torch_dtype=torch.float16, scheduler=DPMSolverMultistepScheduler.from_config( pipe.scheduler.config), variant="fp16", use_safetensors=True, ).to("cuda") lora_url = hf_hub_url(lora_id, revision="main", filename=lora_weight_name) lora_path = cached_download(lora_url) with tqdm(desc="Loading LoRA weights", unit="step") as pbar: pipe.load_lora_weights( lora_path, weight_name=lora_weight_name, adapter_name=lora_adapter_name, progress_callback=lambda step, max_steps: pbar.update(1) ) print("LoRA merged successfully!") return pipe except Exception as e: print(f"Error merging LoRA: {e}") return None def save_merged_model(pipe, save_path): """Saves the merged model to the specified path.""" try: pipe.save_pretrained(save_path) print(f"Merged model saved successfully to: {save_path}") except Exception as e: print(f"Error saving the merged model: {e}") if __name__ == "__main__": base_model_id = input("Enter the base model ID: ") lora_id = input("Enter the LoRA Hugging Face Hub ID: ") lora_weight_name = input("Enter the LoRA weight file name: ") lora_adapter_name = input("Enter the LoRA adapter name: ") pipe = load_and_merge_lora(base_model_id, lora_id, lora_weight_name, lora_adapter_name) if pipe: prompt = input("Enter your prompt: ") lora_scale = float(input("Enter the LoRA scale (e.g., 0.9): ")) image = pipe( prompt, num_inference_steps=30, cross_attention_kwargs={"scale": lora_scale}, generator=torch.manual_seed(0) ).images[0] display_image(image) # Ask the user for a directory to save the model save_path = input( "Enter the directory where you want to save the merged model: " ) save_merged_model(pipe, save_path)