chris1nexus commited on
Commit
c35f19e
1 Parent(s): 6356cd8

Add real2sim translation of the sim2real prediction

Browse files
Files changed (1) hide show
  1. app.py +16 -9
app.py CHANGED
@@ -21,15 +21,20 @@ def pred_pipeline(img, transforms):
21
  orig_shape = img.shape
22
  input = transforms(img)
23
  input = input.unsqueeze(0)
24
- output = model(input)
25
-
26
- out_img = make_grid(output,#.detach().cpu(),
27
  nrow=1, normalize=True)
 
 
 
 
 
28
  out_transform = Compose([
29
  T.Resize(orig_shape[:2]),
30
  T.ToPILImage()
31
  ])
32
- return out_transform(out_img)
33
 
34
 
35
 
@@ -46,12 +51,17 @@ transform = Compose([
46
  ])
47
 
48
 
49
- model = GeneratorResNet.from_pretrained('huggan/sim2real_cyclegan', input_shape=(n_channels, image_size, image_size),
 
 
50
  num_residual_blocks=9)
51
 
52
  gr.Interface(lambda image: pred_pipeline(image, transform),
53
  inputs=gr.inputs.Image( label='input synthetic image'),
54
- outputs=gr.outputs.Image( type="pil",label='style transfer to the real world'),#plot,
 
 
 
55
  title = "GTA5(simulated) to Cityscapes (real) translation",
56
  examples = [
57
  [example] for example in glob.glob('./samples/*.png')
@@ -59,6 +69,3 @@ gr.Interface(lambda image: pred_pipeline(image, transform),
59
  .launch()
60
 
61
 
62
-
63
- #iface = gr.Interface(fn=greet, inputs="text", outputs="text")
64
- #iface.launch()
 
21
  orig_shape = img.shape
22
  input = transforms(img)
23
  input = input.unsqueeze(0)
24
+ output_real = sim2real(input)
25
+ output_syn = real2sim(output_real)
26
+ out_img_real = make_grid(output_real,
27
  nrow=1, normalize=True)
28
+ out_syn_real = make_grid(output_syn,
29
+ nrow=1, normalize=True)
30
+
31
+
32
+
33
  out_transform = Compose([
34
  T.Resize(orig_shape[:2]),
35
  T.ToPILImage()
36
  ])
37
+ return out_transform(out_img_real), out_transform(out_syn_real)
38
 
39
 
40
 
 
51
  ])
52
 
53
 
54
+ sim2real = GeneratorResNet.from_pretrained('Chris1/sim2real-512', input_shape=(n_channels, image_size, image_size),
55
+ num_residual_blocks=9)
56
+ real2sim = GeneratorResNet.from_pretrained('Chris1/real2sim-512', input_shape=(n_channels, image_size, image_size),
57
  num_residual_blocks=9)
58
 
59
  gr.Interface(lambda image: pred_pipeline(image, transform),
60
  inputs=gr.inputs.Image( label='input synthetic image'),
61
+ outputs=[
62
+ gr.outputs.Image( type="pil",label='GAN sim2real prediction: style transfer of the input to the real world '),
63
+ gr.outputs.Image( type="pil",label='GAN real2sim prediction: translation to synthetic of the above prediction')
64
+ ],#plot,
65
  title = "GTA5(simulated) to Cityscapes (real) translation",
66
  examples = [
67
  [example] for example in glob.glob('./samples/*.png')
 
69
  .launch()
70
 
71