Tech-Meld's picture
Create app.py
d1d1942 verified
raw
history blame
2.53 kB
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from tqdm.auto import tqdm
from huggingface_hub import cached_download, hf_hub_url
import os
def display_image(image):
"""
Replace this with your actual image display logic.
"""
image.show()
def load_and_merge_lora(base_model_id, lora_id, lora_weight_name, lora_adapter_name):
try:
pipe = DiffusionPipeline.from_pretrained(
base_model_id,
torch_dtype=torch.float16,
scheduler=DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config),
variant="fp16",
use_safetensors=True,
).to("cuda")
lora_url = hf_hub_url(lora_id, revision="main", filename=lora_weight_name)
lora_path = cached_download(lora_url)
with tqdm(desc="Loading LoRA weights", unit="step") as pbar:
pipe.load_lora_weights(
lora_path,
weight_name=lora_weight_name,
adapter_name=lora_adapter_name,
progress_callback=lambda step, max_steps: pbar.update(1)
)
print("LoRA merged successfully!")
return pipe
except Exception as e:
print(f"Error merging LoRA: {e}")
return None
def save_merged_model(pipe, save_path):
"""Saves the merged model to the specified path."""
try:
pipe.save_pretrained(save_path)
print(f"Merged model saved successfully to: {save_path}")
except Exception as e:
print(f"Error saving the merged model: {e}")
if __name__ == "__main__":
base_model_id = input("Enter the base model ID: ")
lora_id = input("Enter the LoRA Hugging Face Hub ID: ")
lora_weight_name = input("Enter the LoRA weight file name: ")
lora_adapter_name = input("Enter the LoRA adapter name: ")
pipe = load_and_merge_lora(base_model_id, lora_id, lora_weight_name, lora_adapter_name)
if pipe:
prompt = input("Enter your prompt: ")
lora_scale = float(input("Enter the LoRA scale (e.g., 0.9): "))
image = pipe(
prompt,
num_inference_steps=30,
cross_attention_kwargs={"scale": lora_scale},
generator=torch.manual_seed(0)
).images[0]
display_image(image)
# Ask the user for a directory to save the model
save_path = input(
"Enter the directory where you want to save the merged model: "
)
save_merged_model(pipe, save_path)