Spaces:
Running
Running
Download weights first on Spaces
Browse files
app.py
CHANGED
@@ -41,7 +41,7 @@ from app_scribble import create_demo as create_demo_scribble
|
|
41 |
from app_scribble_interactive import \
|
42 |
create_demo as create_demo_scribble_interactive
|
43 |
from app_seg import create_demo as create_demo_seg
|
44 |
-
from model import Model
|
45 |
|
46 |
DESCRIPTION = '# [ControlNet](https://github.com/lllyasviel/ControlNet)'
|
47 |
|
@@ -55,6 +55,9 @@ if SPACE_ID is not None:
|
|
55 |
MAX_IMAGES = int(os.getenv('MAX_IMAGES', '3'))
|
56 |
DEFAULT_NUM_IMAGES = min(MAX_IMAGES, int(os.getenv('DEFAULT_NUM_IMAGES', '1')))
|
57 |
|
|
|
|
|
|
|
58 |
DEFAULT_MODEL_ID = os.getenv('DEFAULT_MODEL_ID',
|
59 |
'runwayml/stable-diffusion-v1-5')
|
60 |
model = Model(base_model_id=DEFAULT_MODEL_ID, task_name='canny')
|
|
|
41 |
from app_scribble_interactive import \
|
42 |
create_demo as create_demo_scribble_interactive
|
43 |
from app_seg import create_demo as create_demo_seg
|
44 |
+
from model import Model, download_all_controlnet_weights
|
45 |
|
46 |
DESCRIPTION = '# [ControlNet](https://github.com/lllyasviel/ControlNet)'
|
47 |
|
|
|
55 |
MAX_IMAGES = int(os.getenv('MAX_IMAGES', '3'))
|
56 |
DEFAULT_NUM_IMAGES = min(MAX_IMAGES, int(os.getenv('DEFAULT_NUM_IMAGES', '1')))
|
57 |
|
58 |
+
if os.getenv('SYSTEM') == 'spaces':
|
59 |
+
download_all_controlnet_weights()
|
60 |
+
|
61 |
DEFAULT_MODEL_ID = os.getenv('DEFAULT_MODEL_ID',
|
62 |
'runwayml/stable-diffusion-v1-5')
|
63 |
model = Model(base_model_id=DEFAULT_MODEL_ID, task_name='canny')
|
model.py
CHANGED
@@ -38,6 +38,11 @@ CONTROLNET_MODEL_IDS = {
|
|
38 |
}
|
39 |
|
40 |
|
|
|
|
|
|
|
|
|
|
|
41 |
class Model:
|
42 |
def __init__(self,
|
43 |
base_model_id: str = 'runwayml/stable-diffusion-v1-5',
|
|
|
38 |
}
|
39 |
|
40 |
|
41 |
+
def download_all_controlnet_weights():
|
42 |
+
for model_id in CONTROLNET_MODEL_IDS.values():
|
43 |
+
ControlNetModel.from_pretrained(model_id)
|
44 |
+
|
45 |
+
|
46 |
class Model:
|
47 |
def __init__(self,
|
48 |
base_model_id: str = 'runwayml/stable-diffusion-v1-5',
|