ironjr commited on
Commit
cada94d
1 Parent(s): ad24923

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -14
app.py CHANGED
@@ -320,7 +320,7 @@ def import_state(state, json_text):
320
  ### Main worker
321
 
322
 
323
- def register(state, drawpad):
324
  seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
325
  print('Generate!')
326
 
@@ -362,15 +362,15 @@ def register(state, drawpad):
362
  # prompts, negative_prompts = preprocess_prompts(
363
  # prompts, negative_prompts, style_name=state.style_name, quality_name=state.quality_name)
364
 
365
- state.model.update_background(
366
  background.convert('RGB'),
367
  prompt=None,
368
  negative_prompt=None,
369
  )
370
- state.prompts[0] = state.model.background.prompt
371
- state.neg_prompts[0] = state.model.background.negative_prompt
372
 
373
- state.model.update_layers(
374
  prompts=prompts,
375
  negative_prompts=negative_prompts,
376
  masks=masks.to(device),
@@ -384,23 +384,23 @@ def register(state, drawpad):
384
 
385
  @spaces.GPU(duration=120)
386
  def run(state, drawpad):
387
- state.model = model
388
- state.model.device = torch.device('cuda')
389
- state.model.reset_seed(state.model.generator, opt.seed)
390
- state.model.reset_latent()
391
- state.model.prepare()
392
 
393
- state = register(state, drawpad)
394
  state.is_running = True
395
 
396
  tic = time.time()
397
  while True:
398
- yield [state, state.model()]
399
  toc = time.time()
400
  tdelta = toc - tic
401
  if tdelta > opt.run_time:
402
  state.is_running = False
403
- return [state, state.model()]
 
404
 
405
 
406
  def hide_element():
@@ -412,7 +412,11 @@ def show_element():
412
 
413
 
414
  def draw(state, drawpad):
 
 
 
415
  if not state.is_running:
 
416
  return
417
 
418
  user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
@@ -601,7 +605,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, head=head) as demo:
601
  state.model_id = opt.model
602
  state.style_name = '(None)'
603
  state.quality_name = 'Standard v3.1'
604
- state.model = model
605
 
606
  # State variables (one-hot).
607
  state.active_palettes = 5
 
320
  ### Main worker
321
 
322
 
323
+ def register(state, drawpad, model):
324
  seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
325
  print('Generate!')
326
 
 
362
  # prompts, negative_prompts = preprocess_prompts(
363
  # prompts, negative_prompts, style_name=state.style_name, quality_name=state.quality_name)
364
 
365
+ model.update_background(
366
  background.convert('RGB'),
367
  prompt=None,
368
  negative_prompt=None,
369
  )
370
+ state.prompts[0] = model.background.prompt
371
+ state.neg_prompts[0] = model.background.negative_prompt
372
 
373
+ model.update_layers(
374
  prompts=prompts,
375
  negative_prompts=negative_prompts,
376
  masks=masks.to(device),
 
384
 
385
  @spaces.GPU(duration=120)
386
  def run(state, drawpad):
387
+ model.device = torch.device('cuda')
388
+ model.reset_seed(model.generator, opt.seed)
389
+ model.reset_latent()
390
+ model.prepare()
 
391
 
392
+ state = register(state, drawpad, model)
393
  state.is_running = True
394
 
395
  tic = time.time()
396
  while True:
397
+ yield [state, model()]
398
  toc = time.time()
399
  tdelta = toc - tic
400
  if tdelta > opt.run_time:
401
  state.is_running = False
402
+ state.model = None
403
+ return [state, model()]
404
 
405
 
406
  def hide_element():
 
412
 
413
 
414
  def draw(state, drawpad):
415
+ if not hasattr(state, 'model') or state.model is None:
416
+ print('[WARNING] Model is not registered, update ignored.')
417
+ return
418
  if not state.is_running:
419
+ print('[WARNING] Streaming is currently off, update ignored.')
420
  return
421
 
422
  user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
 
605
  state.model_id = opt.model
606
  state.style_name = '(None)'
607
  state.quality_name = 'Standard v3.1'
608
+ state.model = None
609
 
610
  # State variables (one-hot).
611
  state.active_palettes = 5