ohayonguy commited on
Commit
1fef40b
·
1 Parent(s): 20ac05d

trying to fix interface

Browse files
Files changed (1) hide show
  1. app.py +19 -12
app.py CHANGED
@@ -5,6 +5,7 @@ if os.getenv('SPACES_ZERO_GPU') == "true":
5
  os.environ['K_DIFFUSION_USE_COMPILE'] = "0"
6
  import spaces
7
  import cv2
 
8
  import gradio as gr
9
  import random
10
  import torch
@@ -51,10 +52,12 @@ def generate_reconstructions(pmrf_model, x, y, non_noisy_z0, num_flow_steps, dev
51
  dt = (1.0 / num_flow_steps) * (1.0 - pmrf_model.hparams.eps)
52
  x_t_next = source_dist_samples.clone()
53
  t_one = torch.ones(x.shape[0], device=device)
54
- for i in range(num_flow_steps):
 
55
  num_t = (i / num_flow_steps) * (1.0 - pmrf_model.hparams.eps) + pmrf_model.hparams.eps
56
  v_t_next = pmrf_model(x_t=x_t_next, t=t_one * num_t, y=y).to(x_t_next.dtype)
57
  x_t_next = x_t_next.clone() + v_t_next * dt
 
58
 
59
  return x_t_next.clip(0, 1).to(torch.float32)
60
 
@@ -78,7 +81,7 @@ def enhance_face(img, face_helper, has_aligned, num_flow_steps, only_center_face
78
  # prepare data
79
  h, w = cropped_face.shape[0], cropped_face.shape[1]
80
  cropped_face = cv2.resize(cropped_face, (512, 512), interpolation=cv2.INTER_LINEAR)
81
- face_helper.cropped_faces[i] = cropped_face
82
  cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
83
  cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
84
 
@@ -108,7 +111,11 @@ def enhance_face(img, face_helper, has_aligned, num_flow_steps, only_center_face
108
 
109
  @torch.inference_mode()
110
  @spaces.GPU()
111
- def inference(seed, randomize_seed, img, aligned, scale, num_flow_steps):
 
 
 
 
112
  if randomize_seed:
113
  seed = random.randint(0, MAX_SEED)
114
  torch.manual_seed(seed)
@@ -139,16 +146,16 @@ def inference(seed, randomize_seed, img, aligned, scale, num_flow_steps):
139
  scale=scale)
140
  if has_aligned:
141
  output = restored_aligned[0]
142
- input = cropped_face[0].astype('uint8')
143
  else:
144
  output = restored_img
145
- input = img
146
 
147
  output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
148
- h, w = output.shape[0:2]
149
- input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
150
- input = cv2.resize(input, (h, w), interpolation=cv2.INTER_LINEAR)
151
- return [input, output, seed]
152
 
153
 
154
  intro = """
@@ -215,7 +222,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
215
 
216
  with gr.Row():
217
  with gr.Column(scale=2):
218
- input_im = gr.Image(label="Input Image", type="filepath")
219
  with gr.Column(scale=1):
220
  num_inference_steps = gr.Slider(
221
  label="Number of Inference Steps",
@@ -246,7 +253,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
246
  run_button = gr.Button(value="Submit", variant="primary")
247
 
248
  with gr.Row():
249
- result = ImageSlider(label="Input / Output", type="numpy", interactive=True, show_label=True)
250
 
251
  gr.Markdown(article)
252
  gr.on(
@@ -266,4 +273,4 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
266
  )
267
 
268
  demo.queue()
269
- demo.launch(state_session_capacity=15, show_api=False)
 
5
  os.environ['K_DIFFUSION_USE_COMPILE'] = "0"
6
  import spaces
7
  import cv2
8
+ from tqdm import tqdm
9
  import gradio as gr
10
  import random
11
  import torch
 
52
  dt = (1.0 / num_flow_steps) * (1.0 - pmrf_model.hparams.eps)
53
  x_t_next = source_dist_samples.clone()
54
  t_one = torch.ones(x.shape[0], device=device)
55
+ pbar = tqdm(range(num_flow_steps))
56
+ for i in pbar:
57
  num_t = (i / num_flow_steps) * (1.0 - pmrf_model.hparams.eps) + pmrf_model.hparams.eps
58
  v_t_next = pmrf_model(x_t=x_t_next, t=t_one * num_t, y=y).to(x_t_next.dtype)
59
  x_t_next = x_t_next.clone() + v_t_next * dt
60
+ pbar.set_description(f'Flow step {i}')
61
 
62
  return x_t_next.clip(0, 1).to(torch.float32)
63
 
 
81
  # prepare data
82
  h, w = cropped_face.shape[0], cropped_face.shape[1]
83
  cropped_face = cv2.resize(cropped_face, (512, 512), interpolation=cv2.INTER_LINEAR)
84
+ # face_helper.cropped_faces[i] = cropped_face
85
  cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
86
  cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
87
 
 
111
 
112
  @torch.inference_mode()
113
  @spaces.GPU()
114
+ def inference(seed, randomize_seed, img, aligned, scale, num_flow_steps,
115
+ progress=gr.Progress(track_tqdm=True)):
116
+ if img is None:
117
+ gr.Info("Please upload an image before submitting")
118
+ return [None, None, None]
119
  if randomize_seed:
120
  seed = random.randint(0, MAX_SEED)
121
  torch.manual_seed(seed)
 
146
  scale=scale)
147
  if has_aligned:
148
  output = restored_aligned[0]
149
+ # input = cropped_face[0].astype('uint8')
150
  else:
151
  output = restored_img
152
+ # input = img
153
 
154
  output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
155
+ # h, w = output.shape[0:2]
156
+ # input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
157
+ # input = cv2.resize(input, (h, w), interpolation=cv2.INTER_LINEAR)
158
+ return output
159
 
160
 
161
  intro = """
 
222
 
223
  with gr.Row():
224
  with gr.Column(scale=2):
225
+ input_im = gr.Image(label="Input", type="filepath", show_label=True)
226
  with gr.Column(scale=1):
227
  num_inference_steps = gr.Slider(
228
  label="Number of Inference Steps",
 
253
  run_button = gr.Button(value="Submit", variant="primary")
254
 
255
  with gr.Row():
256
+ result = gr.Image(label="Output", type="numpy", show_label=True)
257
 
258
  gr.Markdown(article)
259
  gr.on(
 
273
  )
274
 
275
  demo.queue()
276
+ demo.launch(state_session_capacity=15)