manhkhanhUIT commited on
Commit
7018e07
1 Parent(s): 73db728

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -24
app.py CHANGED
@@ -11,8 +11,8 @@ import torchvision.transforms as transforms
11
  from PIL import Image
12
  import torch
13
 
14
- os.system("pip install dlib")
15
- os.system('bash setup.sh')
16
 
17
  def lab2rgb(L, AB):
18
  """Convert an Lab tensor image to a RGB numpy output
@@ -48,11 +48,29 @@ def get_transform(params=None, grayscale=False, method=Image.BICUBIC):
48
 
49
  return transforms.Compose(transform_list)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def inferColorization(img,model_name):
52
  print(model_name)
53
- if model_name == "Pix2Pix_resnet9b":
54
  model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_resnet9b')
55
- elif model_name == "Pix2Pix_unet256":
56
  model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_unet256')
57
  elif model_name == "Deoldify":
58
  model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'DeOldifyColorization')
@@ -108,30 +126,37 @@ def run_cmd(command):
108
  print("Process interrupted")
109
  sys.exit(1)
110
 
111
- def run(image):
112
- if os.path.isdir("Temp"):
113
- shutil.rmtree("Temp")
114
-
115
- os.makedirs("Temp")
116
- os.makedirs("Temp/input")
117
- print(type(image))
118
- cv2.imwrite("Temp/input/input_img.png", image)
 
119
 
120
- command = ("python run.py --input_folder "
121
- + "Temp/input"
122
- + " --output_folder "
123
- + "Temp"
124
- + " --GPU "
125
- + "-1"
126
- + " --with_scratch")
127
- run_cmd(command)
128
 
129
- result_restoration = Image.open("Temp/final_output/input_img.png")
130
- shutil.rmtree("Temp")
 
 
 
131
 
132
- result_colorization = inferColorization(result_restoration,"Deoldify")
133
 
134
  return result_colorization
135
 
136
  examples = [['example/1.jpeg'],['example/2.jpg'],['example/3.jpg'],['example/4.jpg']]
137
- iface = gr.Interface(fn=run, inputs="image", outputs="image",examples=examples).launch(debug=True,share=False)
 
 
 
 
11
  from PIL import Image
12
  import torch
13
 
14
+ #os.system("pip install dlib")
15
+ #os.system('bash setup.sh')
16
 
17
  def lab2rgb(L, AB):
18
  """Convert an Lab tensor image to a RGB numpy output
 
48
 
49
  return transforms.Compose(transform_list)
50
 
51
+ def inferRestoration(img, model_name):
52
+ #if model_name == "Pix2Pix":
53
+ model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixRestoration_unet256')
54
+ transform_list = [
55
+ transforms.ToTensor(),
56
+ transforms.Resize([256,256], Image.BICUBIC),
57
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
58
+ ]
59
+ transform = transforms.Compose(transform_list)
60
+ img = transform(img)
61
+ img = torch.unsqueeze(img, 0)
62
+ result = model(img)
63
+ result = result[0].detach()
64
+ result = (result +1)/2.0
65
+
66
+ result = transforms.ToPILImage()(result)
67
+ return result
68
+
69
  def inferColorization(img,model_name):
70
  print(model_name)
71
+ if model_name == "Pix2Pix Resnet 9block":
72
  model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_resnet9b')
73
+ elif model_name == "Pix2Pix Unet 256":
74
  model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_unet256')
75
  elif model_name == "Deoldify":
76
  model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'DeOldifyColorization')
 
126
  print("Process interrupted")
127
  sys.exit(1)
128
 
129
+ def run(image,Restoration_mode, colorizaition_mode):
130
+ if Restoration_mode == "BOPBTL":
131
+ if os.path.isdir("Temp"):
132
+ shutil.rmtree("Temp")
133
+
134
+ os.makedirs("Temp")
135
+ os.makedirs("Temp/input")
136
+ print(type(image))
137
+ cv2.imwrite("Temp/input/input_img.png", image)
138
 
139
+ command = ("python run.py --input_folder "
140
+ + "Temp/input"
141
+ + " --output_folder "
142
+ + "Temp"
143
+ + " --GPU "
144
+ + "-1"
145
+ + " --with_scratch")
146
+ run_cmd(command)
147
 
148
+ result_restoration = Image.open("Temp/final_output/input_img.png")
149
+ shutil.rmtree("Temp")
150
+
151
+ elif Restoration_mode == "Pix2Pix":
152
+ result_restoration = inferRestoration(image, Restoration_mode)
153
 
154
+ result_colorization = inferColorization(result_restoration,colorizaition_mode)
155
 
156
  return result_colorization
157
 
158
  examples = [['example/1.jpeg'],['example/2.jpg'],['example/3.jpg'],['example/4.jpg']]
159
+ iface = gr.Interface(run,
160
+ [gr.inputs.Image(),gr.inputs.Radio(["BOPBTL", "Pix2Pix"]),gr.inputs.Radio(["Deoldify", "Pix2Pix Resnet 9block","Pix2Pix Unet 256"])],
161
+ outputs="image",
162
+ examples=examples).launch(debug=True,share=True)