Nef Caballero commited on
Commit
7371635
·
1 Parent(s): c95cb7c

updating app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -1
app.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
  import gradio as gr
7
  from huggingface_hub import hf_hub_download
8
  import spaces
 
9
 
10
  hf_hub_download(repo_id="Comfy-Org/stable-diffusion-v1-5-archive", filename="v1-5-pruned-emaonly-fp16.safetensors", local_dir="models/checkpoints")
11
 
@@ -119,11 +120,30 @@ def import_custom_nodes() -> None:
119
 
120
  from nodes import NODE_CLASS_MAPPINGS
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  @spaces.GPU(duration=60) #modify the duration for the average it takes for your worflow to run, in seconds
123
  def generate_image(prompt):
124
  import_custom_nodes()
125
  with torch.inference_mode():
126
- checkpointloadersimple = NODE_CLASS_MAPPINGS["CheckpointLoaderSimple"]()
127
  checkpointloadersimple_4 = checkpointloadersimple.load_checkpoint(
128
  ckpt_name="v1-5-pruned.safetensors"
129
  )
 
6
  import gradio as gr
7
  from huggingface_hub import hf_hub_download
8
  import spaces
9
+ from comfy import model_management
10
 
11
  hf_hub_download(repo_id="Comfy-Org/stable-diffusion-v1-5-archive", filename="v1-5-pruned-emaonly-fp16.safetensors", local_dir="models/checkpoints")
12
 
 
120
 
121
  from nodes import NODE_CLASS_MAPPINGS
122
 
123
+ checkpointloadersimple = NODE_CLASS_MAPPINGS["CheckpointLoaderSimple"]()
124
+ checkpointloadersimple_4 = checkpointloadersimple.load_checkpoint(
125
+ ckpt_name="v1-5-pruned.safetensors"
126
+ )
127
+
128
+
129
+ #Add all the models that load a safetensors file
130
+ model_loaders = [checkpointloadersimple_4]
131
+
132
+ # Check which models are valid and how to best load them
133
+ valid_models = [
134
+ getattr(loader[0], 'patcher', loader[0])
135
+ for loader in model_loaders
136
+ if not isinstance(loader[0], dict) and not isinstance(getattr(loader[0], 'patcher', None), dict)
137
+ ]
138
+
139
+ #Finally loads the models
140
+ model_management.load_models_gpu(valid_models)
141
+
142
  @spaces.GPU(duration=60) #modify the duration for the average it takes for your worflow to run, in seconds
143
  def generate_image(prompt):
144
  import_custom_nodes()
145
  with torch.inference_mode():
146
+
147
  checkpointloadersimple_4 = checkpointloadersimple.load_checkpoint(
148
  ckpt_name="v1-5-pruned.safetensors"
149
  )