adamelliotfields commited on
Commit
a8ad716
·
verified ·
1 Parent(s): c95eb38

Fix progress bars

Browse files
Files changed (3) hide show
  1. app.py +7 -1
  2. lib/inference.py +7 -2
  3. lib/loader.py +6 -3
app.py CHANGED
@@ -87,7 +87,13 @@ async def generate_fn(*args):
87
  if prompt is None or prompt.strip() == "":
88
  raise gr.Error("You must enter a prompt")
89
  try:
90
- images = await async_call(generate, *args, Info=gr.Info, Error=gr.Error)
 
 
 
 
 
 
91
  except RuntimeError:
92
  raise gr.Error("RuntimeError: Please try again")
93
  return images
 
87
  if prompt is None or prompt.strip() == "":
88
  raise gr.Error("You must enter a prompt")
89
  try:
90
+ images = await async_call(
91
+ generate,
92
+ *args,
93
+ Info=gr.Info,
94
+ Error=gr.Error,
95
+ progress=gr.Progress(),
96
+ )
97
  except RuntimeError:
98
  raise gr.Error("RuntimeError: Please try again")
99
  return images
lib/inference.py CHANGED
@@ -9,7 +9,6 @@ from itertools import product
9
  from typing import Callable, TypeVar
10
 
11
  import anyio
12
- import gradio as gr
13
  import numpy as np
14
  import spaces
15
  import torch
@@ -123,7 +122,7 @@ def generate(
123
  clip_skip=False,
124
  Info: Callable[[str], None] = None,
125
  Error=Exception,
126
- progress=gr.Progress(),
127
  ):
128
  if not torch.cuda.is_available():
129
  raise Error("CUDA not available")
@@ -150,10 +149,15 @@ def generate(
150
  IP_ADAPTER = ""
151
 
152
  if progress is not None:
 
153
  progress((0, inference_steps), desc=f"Generating image {CURRENT_IMAGE}/{num_images}")
 
 
154
 
155
  def callback_on_step_end(pipeline, step, timestep, latents):
156
  nonlocal CURRENT_IMAGE
 
 
157
  strength = denoising_strength if KIND == "img2img" else 1
158
  total_steps = min(int(inference_steps * strength), inference_steps)
159
  current_step = step + 1
@@ -177,6 +181,7 @@ def generate(
177
  freeu,
178
  deepcache,
179
  scale,
 
180
  DEVICE,
181
  )
182
 
 
9
  from typing import Callable, TypeVar
10
 
11
  import anyio
 
12
  import numpy as np
13
  import spaces
14
  import torch
 
122
  clip_skip=False,
123
  Info: Callable[[str], None] = None,
124
  Error=Exception,
125
+ progress=None,
126
  ):
127
  if not torch.cuda.is_available():
128
  raise Error("CUDA not available")
 
149
  IP_ADAPTER = ""
150
 
151
  if progress is not None:
152
+ TQDM = False
153
  progress((0, inference_steps), desc=f"Generating image {CURRENT_IMAGE}/{num_images}")
154
+ else:
155
+ TQDM = True
156
 
157
  def callback_on_step_end(pipeline, step, timestep, latents):
158
  nonlocal CURRENT_IMAGE
159
+ if progress is None:
160
+ return latents
161
  strength = denoising_strength if KIND == "img2img" else 1
162
  total_steps = min(int(inference_steps * strength), inference_steps)
163
  current_step = step + 1
 
181
  freeu,
182
  deepcache,
183
  scale,
184
+ TQDM,
185
  DEVICE,
186
  )
187
 
lib/loader.py CHANGED
@@ -110,7 +110,7 @@ class Loader:
110
  self.upscaler = RealESRGAN(device=device, scale=scale)
111
  self.upscaler.load_weights()
112
 
113
- def _load_pipeline(self, kind, model, device, **kwargs):
114
  pipeline = Config.PIPELINES[kind]
115
  if self.pipe is None:
116
  print(f"Loading {model}...")
@@ -131,7 +131,9 @@ class Loader:
131
 
132
  if not isinstance(self.pipe, pipeline):
133
  self.pipe = pipeline.from_pipe(self.pipe).to(device)
134
- self.pipe.set_progress_bar_config(disable=True)
 
 
135
 
136
  def _load_vae(self, taesd=False, model=""):
137
  vae_type = type(self.pipe.vae)
@@ -202,6 +204,7 @@ class Loader:
202
  freeu,
203
  deepcache,
204
  scale,
 
205
  device,
206
  ):
207
  scheduler_kwargs = {
@@ -242,7 +245,7 @@ class Loader:
242
  pipe_kwargs["torch_dtype"] = torch.float16
243
 
244
  self._unload(kind, model, ip_adapter, scale)
245
- self._load_pipeline(kind, model, device, **pipe_kwargs)
246
 
247
  # error loading model
248
  if self.pipe is None:
 
110
  self.upscaler = RealESRGAN(device=device, scale=scale)
111
  self.upscaler.load_weights()
112
 
113
+ def _load_pipeline(self, kind, model, tqdm, device, **kwargs):
114
  pipeline = Config.PIPELINES[kind]
115
  if self.pipe is None:
116
  print(f"Loading {model}...")
 
131
 
132
  if not isinstance(self.pipe, pipeline):
133
  self.pipe = pipeline.from_pipe(self.pipe).to(device)
134
+
135
+ if not tqdm:
136
+ self.pipe.set_progress_bar_config(disable=True)
137
 
138
  def _load_vae(self, taesd=False, model=""):
139
  vae_type = type(self.pipe.vae)
 
204
  freeu,
205
  deepcache,
206
  scale,
207
+ tqdm,
208
  device,
209
  ):
210
  scheduler_kwargs = {
 
245
  pipe_kwargs["torch_dtype"] = torch.float16
246
 
247
  self._unload(kind, model, ip_adapter, scale)
248
+ self._load_pipeline(kind, model, tqdm, device, **pipe_kwargs)
249
 
250
  # error loading model
251
  if self.pipe is None: