msong97 commited on
Commit
4e6590f
·
1 Parent(s): 9eb8ea4

add resize img & convert to complex and to grayscale when necessary

Browse files
Files changed (1) hide show
  1. app.py +21 -2
app.py CHANGED
@@ -21,6 +21,16 @@ DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu'
21
 
22
 
23
  ### Gradio Utils
 
 
 
 
 
 
 
 
 
 
24
  def generate_imgs_from_user(image,
25
  physics: PhysicsWithGenerator, use_gen: bool,
26
  baseline: BaselineModel, model: EvalModel,
@@ -31,9 +41,17 @@ def generate_imgs_from_user(image,
31
 
32
  # PIL image -> torch.Tensor / (1, C, H, W) / move to DEVICE_STR
33
  x = transforms.ToTensor()(image).unsqueeze(0).to(DEVICE_STR)
 
 
34
 
35
- return generate_imgs(x, physics, use_gen, baseline, model, metrics)
 
 
 
 
 
36
 
 
37
  def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
38
  physics: PhysicsWithGenerator, use_gen: bool,
39
  baseline: BaselineModel, model: EvalModel,
@@ -189,7 +207,8 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
189
  print(f"[Render] CUDA max allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
190
  print(f"[Render] CUDA max reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f} MB")
191
 
192
- @gr.render(inputs=[dataset_placeholder, physics_placeholder, available_physics_placeholder])
 
193
  def dynamic_layout(dataset, physics, available_physics):
194
  ### LAYOUT
195
 
 
21
 
22
 
23
  ### Gradio Utils
24
+ def resize_tensor_within_box(tensor_img: torch.Tensor, max_size: int = 512):
25
+ _, _, h, w = tensor_img.shape
26
+ scale = min(max_size / h, max_size / w)
27
+
28
+ if scale < 1.0:
29
+ new_h, new_w = int(h * scale), int(w * scale)
30
+ tensor_img = transforms.functional.resize(tensor_img, [new_h, new_w], antialias=True)
31
+
32
+ return tensor_img
33
+
34
  def generate_imgs_from_user(image,
35
  physics: PhysicsWithGenerator, use_gen: bool,
36
  baseline: BaselineModel, model: EvalModel,
 
41
 
42
  # PIL image -> torch.Tensor / (1, C, H, W) / move to DEVICE_STR
43
  x = transforms.ToTensor()(image).unsqueeze(0).to(DEVICE_STR)
44
+ # Resize img within a 512x512 box
45
+ x = resize_tensor_within_box(x)
46
 
47
+ C = x.shape[1]
48
+ if C == 3 and physics.name == 'CT':
49
+ x = transforms.Grayscale(num_output_channels=1)(x)
50
+ elif C == 3 and physics.name == 'MRI': # not working because MRI physics has a fixed img size
51
+ x = transforms.Grayscale(num_output_channels=1)(x)
52
+ x = torch.cat((x, torch.zeros_like(x)), dim=1)
53
 
54
+ return generate_imgs(x, physics, use_gen, baseline, model, metrics)
55
  def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
56
  physics: PhysicsWithGenerator, use_gen: bool,
57
  baseline: BaselineModel, model: EvalModel,
 
207
  print(f"[Render] CUDA max allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
208
  print(f"[Render] CUDA max reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f} MB")
209
 
210
+ @gr.render(inputs=[dataset_placeholder, physics_placeholder, available_physics_placeholder],
211
+ triggers=[dataset_placeholder.change, physics_placeholder.change])
212
  def dynamic_layout(dataset, physics, available_physics):
213
  ### LAYOUT
214