n1kkqt commited on
Commit
f7c8faa
1 Parent(s): 7023fbb
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Interior Semantic Segmentation
3
- emoji: 🔥
4
- colorFrom: green
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.17.0
8
  app_file: app.py
 
1
  ---
2
+ title: Semantic Interior Segmentation
3
+ emoji: 👀
4
+ colorFrom: white
5
+ colorTo: black
6
  sdk: gradio
7
  sdk_version: 3.17.0
8
  app_file: app.py
ade20k_classes.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:486ec4a95f1118e3e3f3903e18d4d34fb766d3a578f5f79d25ef8aff760466d9
3
+ size 3147
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import glob
3
+ import torch
4
+ import pickle
5
+ from PIL import Image, ImageDraw
6
+ import numpy as np
7
+ from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
8
+
9
+ import numpy as np
10
+ from scipy.ndimage import center_of_mass
11
+
12
+
13
+ def combine_ims(im1, im2, val=128):
14
+ p = Image.new("L", im1.size, val)
15
+ im = Image.composite(im1, im2, p)
16
+ return im
17
+
18
+ def get_class_centers(segmentation_mask, class_dict):
19
+ segmentation_mask = segmentation_mask.numpy() + 1
20
+ class_centers = {}
21
+ for class_index, _ in class_dict.items():
22
+ class_mask = (segmentation_mask == class_index).astype(int)
23
+ center_of_mass_list = center_of_mass(class_mask)
24
+
25
+ class_centers[class_index] = center_of_mass_list
26
+
27
+ class_centers = {k:list(map(int, v)) for k,v in class_centers.items() if not np.isnan(sum(v))}
28
+ return class_centers
29
+
30
+ def visualize_mask(predicted_semantic_map, class_ids, class_colors):
31
+ h, w = predicted_semantic_map.shape
32
+ color_indexes = np.zeros((h, w), dtype=np.uint8)
33
+ color_indexes[:] = predicted_semantic_map.numpy()
34
+ color_indexes = color_indexes.flatten()
35
+
36
+ colors = class_colors[class_ids[color_indexes]]
37
+ output = colors.reshape(h, w, 3).astype(np.uint8)
38
+ image_mask = Image.fromarray(output)
39
+ return image_mask
40
+
41
+
42
+ def get_out_image(image, predicted_semantic_map):
43
+ class_centers = get_class_centers(predicted_semantic_map, class_dict)
44
+ mask = visualize_mask(predicted_semantic_map, class_ids, class_colors)
45
+ image_mask = combine_ims(image, mask, val=128)
46
+ draw = ImageDraw.Draw(image_mask)
47
+ for id, (y, x) in class_centers.items():
48
+ draw.text((x, y), str(class_names[id-1]), fill='black')
49
+
50
+ return image_mask
51
+
52
+ def gradio_process(image):
53
+ inputs = processor(images=image, return_tensors="pt")
54
+
55
+ with torch.no_grad():
56
+ outputs = model(**inputs)
57
+
58
+ predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
59
+
60
+ out_image = get_out_image(image, predicted_semantic_map)
61
+ return out_image
62
+
63
+ with open('ade20k_classes.pickle', 'rb') as f:
64
+ class_names, class_ids, class_colors = pickle.load(f)
65
+ class_names, class_ids, class_colors = np.array(class_names), np.array(class_ids), np.array(class_colors)
66
+ class_dict = dict(zip(class_ids, class_names))
67
+
68
+ device = torch.device("cpu")
69
+ processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic")
70
+ model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-ade-semantic").to(device)
71
+ model.eval()
72
+
73
+ demo = gr.Interface(
74
+ gradio_process,
75
+ inputs=gr.inputs.Image(type="pil"),
76
+ outputs=gr.outputs.Image(type="pil"),
77
+ title="Semantic Interior Segmentation (Demo for Craftwork)",
78
+ examples=[glob.glob('examples')],
79
+ allow_flagging="never",
80
+
81
+ )
82
+
83
+ demo.launch()
examples/image (1).jpg ADDED
examples/image (3).jpg ADDED
examples/image (4).jpg ADDED
examples/image (5).jpg ADDED
examples/image (6).jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ scipy