Spaces:
Sleeping
Sleeping
import torch | |
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler | |
from tqdm.auto import tqdm | |
from huggingface_hub import cached_download, hf_hub_url | |
import os | |
def display_image(image): | |
""" | |
Replace this with your actual image display logic. | |
""" | |
image.show() | |
def load_and_merge_lora(base_model_id, lora_id, lora_weight_name, lora_adapter_name): | |
try: | |
pipe = DiffusionPipeline.from_pretrained( | |
base_model_id, | |
torch_dtype=torch.float16, | |
scheduler=DPMSolverMultistepScheduler.from_config( | |
pipe.scheduler.config), | |
variant="fp16", | |
use_safetensors=True, | |
).to("cuda") | |
lora_url = hf_hub_url(lora_id, revision="main", filename=lora_weight_name) | |
lora_path = cached_download(lora_url) | |
with tqdm(desc="Loading LoRA weights", unit="step") as pbar: | |
pipe.load_lora_weights( | |
lora_path, | |
weight_name=lora_weight_name, | |
adapter_name=lora_adapter_name, | |
progress_callback=lambda step, max_steps: pbar.update(1) | |
) | |
print("LoRA merged successfully!") | |
return pipe | |
except Exception as e: | |
print(f"Error merging LoRA: {e}") | |
return None | |
def save_merged_model(pipe, save_path): | |
"""Saves the merged model to the specified path.""" | |
try: | |
pipe.save_pretrained(save_path) | |
print(f"Merged model saved successfully to: {save_path}") | |
except Exception as e: | |
print(f"Error saving the merged model: {e}") | |
if __name__ == "__main__": | |
base_model_id = input("Enter the base model ID: ") | |
lora_id = input("Enter the LoRA Hugging Face Hub ID: ") | |
lora_weight_name = input("Enter the LoRA weight file name: ") | |
lora_adapter_name = input("Enter the LoRA adapter name: ") | |
pipe = load_and_merge_lora(base_model_id, lora_id, lora_weight_name, lora_adapter_name) | |
if pipe: | |
prompt = input("Enter your prompt: ") | |
lora_scale = float(input("Enter the LoRA scale (e.g., 0.9): ")) | |
image = pipe( | |
prompt, | |
num_inference_steps=30, | |
cross_attention_kwargs={"scale": lora_scale}, | |
generator=torch.manual_seed(0) | |
).images[0] | |
display_image(image) | |
# Ask the user for a directory to save the model | |
save_path = input( | |
"Enter the directory where you want to save the merged model: " | |
) | |
save_merged_model(pipe, save_path) |