BlockDetail commited on
Commit
3214d99
1 Parent(s): 162c342
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  import numpy as np
5
  from PIL import Image, ImageFilter
6
  from extension import CustomStableDiffusionControlNetPipeline
 
7
 
8
  negative_prompt = ""
9
  device = torch.device('cuda')
@@ -11,6 +12,7 @@ pipe = None
11
 
12
  print(gr.__version__)
13
 
 
14
  def load():
15
  global pipe
16
  controlnet = ControlNetModel.from_pretrained("BlockDetail/PartialSketchControlNet", torch_dtype=torch.float16).to(device)
@@ -54,6 +56,7 @@ with gr.Blocks() as demo:
54
  sketch_states = gr.State(start_state)
55
  checkbox_state = gr.State(True)
56
 
 
57
  def sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps, dilation):
58
  global curr_num_samples
59
  global pipe
@@ -159,5 +162,5 @@ with gr.Blocks() as demo:
159
  stroke_type[0].change(change_color, [stroke_type[0]], canvas)
160
  num_samples[0].change(change_num_samples, [num_samples[0]], None)
161
 
162
-
163
  demo.launch(share = True, debug = True)
 
4
  import numpy as np
5
  from PIL import Image, ImageFilter
6
  from extension import CustomStableDiffusionControlNetPipeline
7
+ import spaces
8
 
9
  negative_prompt = ""
10
  device = torch.device('cuda')
 
12
 
13
  print(gr.__version__)
14
 
15
+ @spaces.GPU
16
  def load():
17
  global pipe
18
  controlnet = ControlNetModel.from_pretrained("BlockDetail/PartialSketchControlNet", torch_dtype=torch.float16).to(device)
 
56
  sketch_states = gr.State(start_state)
57
  checkbox_state = gr.State(True)
58
 
59
+ @spaces.GPU
60
  def sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps, dilation):
61
  global curr_num_samples
62
  global pipe
 
162
  stroke_type[0].change(change_color, [stroke_type[0]], canvas)
163
  num_samples[0].change(change_num_samples, [num_samples[0]], None)
164
 
165
+ load()
166
  demo.launch(share = True, debug = True)