chris1nexus commited on
Commit
b1c8049
1 Parent(s): c371f4e

Increased model output size to 512x512

Browse files
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: Sim2real
3
- emoji: 🌍
4
- colorFrom: red
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 2.9.4
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
  title: Sim2real
3
+ emoji: 🔥
4
+ colorFrom: purple
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 2.9.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+
4
+ from PIL import Image
5
+ from torchvision import transforms as T
6
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomCrop, RandomHorizontalFlip
7
+ from torchvision.utils import make_grid
8
+ from torch.utils.data import DataLoader
9
+ from huggan.pytorch.cyclegan.modeling_cyclegan import GeneratorResNet
10
+ import torch.nn as nn
11
+ import torch
12
+ import gradio as gr
13
+
14
+ from collections import OrderedDict
15
+ import glob
16
+
17
+
18
+
19
+
20
+ def pred_pipeline(img, transforms):
21
+ orig_shape = img.shape
22
+ input = transforms(img)
23
+ input = input.unsqueeze(0)
24
+ output = model(input)
25
+
26
+ out_img = make_grid(output,#.detach().cpu(),
27
+ nrow=1, normalize=True)
28
+ out_transform = Compose([
29
+ T.Resize(orig_shape[:2]),
30
+ T.ToPILImage()
31
+ ])
32
+ return out_transform(out_img)
33
+
34
+
35
+
36
+
37
+ n_channels = 3
38
+ image_size = 512
39
+ input_shape = (image_size, image_size)
40
+
41
+ transform = Compose([
42
+ T.ToPILImage(),
43
+ T.Resize(input_shape),
44
+ ToTensor(),
45
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
46
+ ])
47
+
48
+
49
+ model = GeneratorResNet.from_pretrained('Chris1/sim2real', input_shape=(n_channels, image_size, image_size),
50
+ num_residual_blocks=9)
51
+
52
+ gr.Interface(lambda image: pred_pipeline(image, transform),
53
+ inputs=gr.inputs.Image( label='input synthetic image'),
54
+ outputs=gr.outputs.Image( type="pil",label='style transfer to the real world'),#plot,
55
+ title = "GTA5(simulated) to Cityscapes (real) translation",
56
+ examples = [
57
+ [example] for example in glob.glob('./samples/*.png')
58
+ ])\
59
+ .launch()
60
+
61
+
62
+
63
+ #iface = gr.Interface(fn=greet, inputs="text", outputs="text")
64
+ #iface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ huggingface-hub
2
+ numpy
3
+ torch
4
+ transformers
5
+ git+https://github.com/huggingface/community-events@main
samples/00012.png ADDED
samples/00237.png ADDED
samples/08164.png ADDED
samples/11603.png ADDED
samples/11607.png ADDED
samples/12073.png ADDED
samples/12227.png ADDED
samples/12605.png ADDED
samples/18621.png ADDED
samples/19627.png ADDED