Pouriarouzrokh commited on
Commit
a803714
·
1 Parent(s): a371dbe

removed osail_utils

Browse files
Files changed (1) hide show
  1. app.py +17 -2
app.py CHANGED
@@ -3,7 +3,6 @@ import os
3
  import gradio as gr
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
- import osail_utils
7
  import pandas as pd
8
  import skimage
9
  from mediffusion import DiffusionModule
@@ -29,6 +28,22 @@ BASELINE_NOISE = torch.randn(1, 1, 256, 256).cuda().half()
29
 
30
  # Model helper functions
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def create_ds(img_paths):
33
  if type(img_paths) == str:
34
  img_paths = [img_paths]
@@ -36,7 +51,7 @@ def create_ds(img_paths):
36
 
37
  # Get the transforms
38
  Ts_list = [
39
- osail_utils.io.LoadImageD(keys=["img"], transpose=True, normalize=True),
40
  mn.transforms.EnsureChannelFirstD(
41
  keys=["img"], channel_dim="no_channel"
42
  ),
 
3
  import gradio as gr
4
  import matplotlib.pyplot as plt
5
  import numpy as np
 
6
  import pandas as pd
7
  import skimage
8
  from mediffusion import DiffusionModule
 
28
 
29
  # Model helper functions
30
 
31
+ class LoadImageD(mn.transforms.Transform):
32
+ def __init__(self, keys, transpose=False, normalize=False):
33
+ self.keys = keys
34
+ self.transpose = transpose
35
+ self.normalize = normalize
36
+ def __call__(self, data):
37
+ for key in self.keys:
38
+ img = skimage.io.imread(data[key])
39
+ if self.transpose:
40
+ img = img.transpose(0, 1)
41
+ if self.normalize:
42
+ img -= img.min()
43
+ img /= (img.max()+1e-6)
44
+ data[key] = img
45
+ return data
46
+
47
  def create_ds(img_paths):
48
  if type(img_paths) == str:
49
  img_paths = [img_paths]
 
51
 
52
  # Get the transforms
53
  Ts_list = [
54
+ LoadImageD(keys=["img"], transpose=True, normalize=True),
55
  mn.transforms.EnsureChannelFirstD(
56
  keys=["img"], channel_dim="no_channel"
57
  ),