macrdel commited on
Commit
8fe6de7
·
1 Parent(s): 91ba022

update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import numpy as np
3
  import random
4
 
5
- from diffusers import DiffusionPipeline
6
  from peft import PeftModel, PeftConfig
7
  import torch
8
 
@@ -25,7 +25,7 @@ else:
25
  # Cache to avoid re-initializing pipelines repeatedly
26
  model_cache = {}
27
 
28
- def load_pipeline(model_id: str):
29
  """
30
  Loads or retrieves a cached DiffusionPipeline.
31
 
@@ -38,7 +38,7 @@ def load_pipeline(model_id: str):
38
  if model_id == "macrdel/unico_proj":
39
  # Use the specified base model for your LoRA adapter.
40
  base_model = "CompVis/stable-diffusion-v1-4"
41
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch_dtype)
42
  # Load the LoRA weights
43
  pipe.unet = PeftModel.from_pretrained(
44
  pipe.unet,
@@ -53,7 +53,9 @@ def load_pipeline(model_id: str):
53
  torch_dtype=torch_dtype
54
  )
55
  else:
56
- pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
 
 
57
 
58
  pipe.to(device)
59
  model_cache[model_id] = pipe
@@ -76,7 +78,7 @@ def infer(
76
  progress=gr.Progress(track_tqdm=True),
77
  ):
78
  # Load the pipeline for the chosen model
79
- pipe = load_pipeline(model_id)
80
 
81
  if randomize_seed:
82
  seed = random.randint(0, MAX_SEED)
@@ -84,12 +86,12 @@ def infer(
84
  generator = torch.Generator(device=device).manual_seed(seed)
85
 
86
  # If using the LoRA model, update the LoRA scale if supported.
87
- if model_id == "macrdel/unico_proj":
88
  # This assumes your pipeline's unet has a method to update the LoRA scale.
89
- if hasattr(pipe.unet, "set_lora_scale"):
90
- pipe.unet.set_lora_scale(lora_scale)
91
- else:
92
- print("Warning: LoRA scale adjustment method not found on UNet.")
93
 
94
  image = pipe(
95
  prompt=prompt,
 
2
  import numpy as np
3
  import random
4
 
5
+ from diffusers import StableDiffusionPipeline # DiffusionPipeline
6
  from peft import PeftModel, PeftConfig
7
  import torch
8
 
 
25
  # Cache to avoid re-initializing pipelines repeatedly
26
  model_cache = {}
27
 
28
+ def load_pipeline(model_id: str, lora_scale):
29
  """
30
  Loads or retrieves a cached DiffusionPipeline.
31
 
 
38
  if model_id == "macrdel/unico_proj":
39
  # Use the specified base model for your LoRA adapter.
40
  base_model = "CompVis/stable-diffusion-v1-4"
41
+ pipe = StableDiffusionPipeline.from_pretrained(base_model, torch_dtype=torch_dtype)
42
  # Load the LoRA weights
43
  pipe.unet = PeftModel.from_pretrained(
44
  pipe.unet,
 
53
  torch_dtype=torch_dtype
54
  )
55
  else:
56
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, safety_checker=None).to(device)
57
+ pipe.unet.load_state_dict({k: lora_scale * v for k, v in pipe.unet.state_dict().items()})
58
+ pipe.text_encoder.load_state_dict({k: lora_scale * v for k, v in pipe.text_encoder.state_dict().items()})
59
 
60
  pipe.to(device)
61
  model_cache[model_id] = pipe
 
78
  progress=gr.Progress(track_tqdm=True),
79
  ):
80
  # Load the pipeline for the chosen model
81
+ pipe = load_pipeline(model_id, lora_scale)
82
 
83
  if randomize_seed:
84
  seed = random.randint(0, MAX_SEED)
 
86
  generator = torch.Generator(device=device).manual_seed(seed)
87
 
88
  # If using the LoRA model, update the LoRA scale if supported.
89
+ # if model_id == "macrdel/unico_proj":
90
  # This assumes your pipeline's unet has a method to update the LoRA scale.
91
+ # if hasattr(pipe.unet, "set_lora_scale"):
92
+ # pipe.unet.set_lora_scale(lora_scale)
93
+ # else:
94
+ # print("Warning: LoRA scale adjustment method not found on UNet.")
95
 
96
  image = pipe(
97
  prompt=prompt,