alkzar90 commited on
Commit
0976e91
1 Parent(s): d087832

add mask to rgb label2color image

Browse files
Files changed (2) hide show
  1. app.py +21 -7
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import gradio as gr
 
 
2
  import torch
3
  from torch import nn
4
  from transformers import (SegformerFeatureExtractor,
@@ -23,23 +25,35 @@ def upscale_logits(logit_outputs, size):
23
  align_corners=False
24
  )
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def query_image(img):
27
  """Función para generar predicciones a la escala origina"""
28
  inputs = preprocessor(images=img, return_tensors="pt")
29
  with torch.no_grad():
30
- #preds = model(inputs.unsqueeze(0).to(device))["logits"]
31
  preds = model(**inputs)["logits"]
32
- preds_upscale = upscale_logits(preds, img.shape[2])
33
  predict_label = torch.argmax(preds_upscale, dim=1).to(device)
34
- return predict_label[0,:,:].detach().cpu().numpy()
35
-
36
 
37
- def visualize_instance_seg_mask(mask):
38
- return mask
39
 
40
  demo = gr.Interface(
41
  query_image,
42
- inputs=[gr.Image()],
43
  outputs="image",
44
  title="SegFormer Model for rock glacier image segmentation"
45
  )
 
1
  import gradio as gr
2
+ import random
3
+ import numpy as np
4
  import torch
5
  from torch import nn
6
  from transformers import (SegformerFeatureExtractor,
 
25
  align_corners=False
26
  )
27
 
28
+
29
+ def visualize_instance_seg_mask(mask):
30
+ """Agrega colores RGB a cada una de las clases en la mask"""
31
+ image = np.zeros((mask.shape[0], mask.shape[1], 3))
32
+ labels = np.unique(mask)
33
+ label2color = {label: (random.randint(0, 1),
34
+ random.randint(0, 255),
35
+ random.randint(0, 255)) for label in labels}
36
+ for i in range(image.shape[0]):
37
+ for j in range(image.shape[1]):
38
+ image[i, j, :] = label2color[mask[i, j]]
39
+ image = image / 255
40
+ return image
41
+
42
+
43
  def query_image(img):
44
  """Función para generar predicciones a la escala origina"""
45
  inputs = preprocessor(images=img, return_tensors="pt")
46
  with torch.no_grad():
 
47
  preds = model(**inputs)["logits"]
48
+ preds_upscale = upscale_logits(preds, preds.shape[2])
49
  predict_label = torch.argmax(preds_upscale, dim=1).to(device)
50
+ result = predict_label[0,:,:].detach().cpu().numpy()
51
+ return visualize_instance_seg_mask(result)
52
 
 
 
53
 
54
  demo = gr.Interface(
55
  query_image,
56
+ inputs=[gr.Image(type="pil")],
57
  outputs="image",
58
  title="SegFormer Model for rock glacier image segmentation"
59
  )
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  torch
2
  transformers
 
 
1
  torch
2
  transformers
3
+ numpy