hysts HF staff commited on
Commit
253feed
·
1 Parent(s): 10f851e

Support changing base model

Browse files
Files changed (2) hide show
  1. app.py +38 -2
  2. model.py +18 -9
app.py CHANGED
@@ -49,7 +49,14 @@ This is an unofficial demo for [https://github.com/lllyasviel/ControlNet](https:
49
 
50
  If you are interested in trying out other base models, check out [this Space](https://huggingface.co/spaces/hysts/ControlNet-with-other-models) as well.
51
  '''
52
- if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
 
 
 
 
 
 
 
53
  DESCRIPTION += f'''<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.<br/>
54
  <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true">
55
  <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
@@ -59,7 +66,9 @@ if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
59
  MAX_IMAGES = int(os.getenv('MAX_IMAGES', '3'))
60
  DEFAULT_NUM_IMAGES = min(MAX_IMAGES, int(os.getenv('DEFAULT_NUM_IMAGES', '1')))
61
 
62
- model = Model()
 
 
63
 
64
  with gr.Blocks(css='style.css') as demo:
65
  gr.Markdown(DESCRIPTION)
@@ -106,4 +115,31 @@ with gr.Blocks(css='style.css') as demo:
106
  max_images=MAX_IMAGES,
107
  default_num_images=DEFAULT_NUM_IMAGES)
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  demo.queue(api_open=False).launch()
 
49
 
50
  If you are interested in trying out other base models, check out [this Space](https://huggingface.co/spaces/hysts/ControlNet-with-other-models) as well.
51
  '''
52
+
53
+ SPACE_ID = os.getenv('SPACE_ID')
54
+ ALLOW_CHANGING_BASE_MODEL = SPACE_ID != 'hysts/ControlNet'
55
+
56
+ if not ALLOW_CHANGING_BASE_MODEL:
57
+ DESCRIPTION += 'In this Space, the base model is not allowed to be changed so as not to slow down the demo, but it can be changed if you duplicate the Space.'
58
+
59
+ if SPACE_ID is not None:
60
  DESCRIPTION += f'''<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.<br/>
61
  <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true">
62
  <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
 
66
  MAX_IMAGES = int(os.getenv('MAX_IMAGES', '3'))
67
  DEFAULT_NUM_IMAGES = min(MAX_IMAGES, int(os.getenv('DEFAULT_NUM_IMAGES', '1')))
68
 
69
+ DEFAULT_MODEL_ID = os.getenv('DEFAULT_MODEL_ID',
70
+ 'runwayml/stable-diffusion-v1-5')
71
+ model = Model(base_model_id=DEFAULT_MODEL_ID, task_name='canny')
72
 
73
  with gr.Blocks(css='style.css') as demo:
74
  gr.Markdown(DESCRIPTION)
 
115
  max_images=MAX_IMAGES,
116
  default_num_images=DEFAULT_NUM_IMAGES)
117
 
118
+ with gr.Accordion(label='Base model', open=False):
119
+ with gr.Row():
120
+ with gr.Column():
121
+ current_base_model = gr.Text(label='Current base model')
122
+ with gr.Column(scale=0.3):
123
+ check_base_model_button = gr.Button('Check current base model')
124
+ with gr.Row():
125
+ with gr.Column():
126
+ base_model_id = gr.Text(
127
+ label='Base model repo',
128
+ max_lines=1,
129
+ placeholder='runwayml/stable-diffusion-v1-5',
130
+ info=
131
+ 'The base model must be compatible with Stable Diffusion v1.5.',
132
+ interactive=ALLOW_CHANGING_BASE_MODEL)
133
+ with gr.Column(scale=0.3):
134
+ change_base_model_button = gr.Button('Change base model')
135
+
136
+ check_base_model_button.click(fn=lambda: model.base_model_id,
137
+ outputs=current_base_model)
138
+ base_model_id.submit(fn=model.set_base_model,
139
+ inputs=base_model_id,
140
+ outputs=current_base_model)
141
+ change_base_model_button.click(fn=model.set_base_model,
142
+ inputs=base_model_id,
143
+ outputs=current_base_model)
144
+
145
  demo.queue(api_open=False).launch()
model.py CHANGED
@@ -39,18 +39,21 @@ CONTROLNET_MODEL_IDS = {
39
 
40
 
41
  class Model:
42
- def __init__(self):
43
- # FIXME
44
- self.base_model_id = 'andite/anything-v4.0'
45
- self.task_name = 'pose'
46
- self.pipe = self.load_pipe()
47
-
48
- def load_pipe(self) -> DiffusionPipeline:
49
- model_id = CONTROLNET_MODEL_IDS[self.task_name]
 
 
 
50
  controlnet = ControlNetModel.from_pretrained(model_id,
51
  torch_dtype=torch.float16)
52
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
53
- self.base_model_id,
54
  safety_checker=None,
55
  controlnet=controlnet,
56
  torch_dtype=torch.float16)
@@ -58,8 +61,14 @@ class Model:
58
  pipe.scheduler.config)
59
  pipe.enable_xformers_memory_efficient_attention()
60
  pipe.enable_model_cpu_offload()
 
 
61
  return pipe
62
 
 
 
 
 
63
  def load_controlnet_weight(self, task_name: str) -> None:
64
  if task_name == self.task_name:
65
  return
 
39
 
40
 
41
  class Model:
42
+ def __init__(self,
43
+ base_model_id: str = 'runwayml/stable-diffusion-v1-5',
44
+ task_name: str = 'canny'):
45
+ self.base_model_id = ''
46
+ self.task_name = ''
47
+ self.pipe = self.load_pipe(base_model_id, task_name)
48
+
49
+ def load_pipe(self, base_model_id: str, task_name) -> DiffusionPipeline:
50
+ if base_model_id == self.base_model_id and task_name == self.task_name:
51
+ return self.pipe
52
+ model_id = CONTROLNET_MODEL_IDS[task_name]
53
  controlnet = ControlNetModel.from_pretrained(model_id,
54
  torch_dtype=torch.float16)
55
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
56
+ base_model_id,
57
  safety_checker=None,
58
  controlnet=controlnet,
59
  torch_dtype=torch.float16)
 
61
  pipe.scheduler.config)
62
  pipe.enable_xformers_memory_efficient_attention()
63
  pipe.enable_model_cpu_offload()
64
+ self.base_model_id = base_model_id
65
+ self.task_name = task_name
66
  return pipe
67
 
68
+ def set_base_model(self, base_model_id: str) -> str:
69
+ self.pipe = self.load_pipe(base_model_id, self.task_name)
70
+ return self.base_model_id
71
+
72
  def load_controlnet_weight(self, task_name: str) -> None:
73
  if task_name == self.task_name:
74
  return