ariG23498 HF staff commited on
Commit
03067d6
·
1 Parent(s): f00aad7

chore: format and resize

Browse files
Files changed (1) hide show
  1. app.py +33 -18
app.py CHANGED
@@ -1,13 +1,15 @@
1
  import os
2
  import warnings
3
- import torch
 
4
  import numpy as np
5
- from PIL import Image
6
- from torchvision import transforms as T
7
  import torch.nn.functional as F
8
- import gradio as gr
9
- from hydra import compose, initialize
10
  from huggingface_hub import Repository
 
 
 
 
11
  from models.builder import build_model
12
  from segmentation.datasets import PascalVOCDataset
13
  from visualization import mask2rgb
@@ -41,12 +43,13 @@ COLORS = [
41
  # Initialize Hydra
42
  initialize(config_path=CONFIG_PATH, version_base=None)
43
 
 
44
  # Configuration and Model Initialization
45
  def load_model():
46
  Repository(
47
  local_dir="clip-dinoiser",
48
  clone_from="ariG23498/clip-dinoiser",
49
- use_auth_token=os.environ.get("token")
50
  )
51
 
52
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -55,27 +58,35 @@ def load_model():
55
 
56
  model = build_model(cfg.model, class_names=PascalVOCDataset.CLASSES).to(device)
57
  model.clip_backbone.decode_head.use_templates = False
58
- model.load_state_dict(checkpoint['model_state_dict'], strict=False)
59
  return model.eval()
60
 
 
61
  def run_clip_dinoiser(input_image, text_prompts, model, device, colors):
62
- image = input_image.convert("RGB")
 
 
 
63
  text_prompts = text_prompts.split(",")
64
- palette = colors[:len(text_prompts)]
65
 
66
  model.clip_backbone.decode_head.update_vocab(text_prompts)
67
  model.to(device)
68
 
69
- img_tens = T.PILToTensor()(image).unsqueeze(0).to(device) / 255.
70
  h, w = img_tens.shape[-2:]
71
  output = model(img_tens).cpu()
72
- output = F.interpolate(output, scale_factor=model.clip_backbone.backbone.patch_size, mode="bilinear", align_corners=False)[..., :h, :w]
 
 
 
 
 
73
  output = output[0].argmax(dim=0)
74
 
75
  mask = mask2rgb(output, palette)
76
- classes = np.unique(output).tolist()
77
  alpha = 0.5
78
- blend = (alpha * np.array(image) / 255.) + ((1 - alpha) * mask / 255.)
79
 
80
  h_text = [(text, f"{idx}") for idx, text in enumerate(text_prompts)]
81
  return blend, mask, h_text
@@ -108,22 +119,26 @@ def setup_gradio_interface(model, device, colors, color_map):
108
  label="Labels",
109
  combine_adjacent=False,
110
  show_legend=False,
111
- color_map=color_map
112
  )
113
 
114
  run_button.click(
115
- fn=lambda img, prompts: run_clip_dinoiser(img, prompts, model, device, colors),
 
 
116
  inputs=[input_image, text_prompts],
117
- outputs=[overlay_mask, only_mask, h_text]
118
  )
119
 
120
  gr.Examples(
121
  examples=[["vintage_bike.jpeg", "background, vintage bike, leather bag"]],
122
  inputs=[input_image, text_prompts],
123
  outputs=[overlay_mask, only_mask, h_text],
124
- fn=lambda img, prompts: run_clip_dinoiser(img, prompts, model, device, colors),
 
 
125
  cache_examples=True,
126
- label='Try this example input!'
127
  )
128
 
129
  return block
 
1
  import os
2
  import warnings
3
+
4
+ import gradio as gr
5
  import numpy as np
6
+ import torch
 
7
  import torch.nn.functional as F
 
 
8
  from huggingface_hub import Repository
9
+ from hydra import compose, initialize
10
+ from PIL import Image
11
+ from torchvision import transforms as T
12
+
13
  from models.builder import build_model
14
  from segmentation.datasets import PascalVOCDataset
15
  from visualization import mask2rgb
 
43
  # Initialize Hydra
44
  initialize(config_path=CONFIG_PATH, version_base=None)
45
 
46
+
47
  # Configuration and Model Initialization
48
  def load_model():
49
  Repository(
50
  local_dir="clip-dinoiser",
51
  clone_from="ariG23498/clip-dinoiser",
52
+ use_auth_token=os.environ.get("token"),
53
  )
54
 
55
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
58
 
59
  model = build_model(cfg.model, class_names=PascalVOCDataset.CLASSES).to(device)
60
  model.clip_backbone.decode_head.use_templates = False
61
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
62
  return model.eval()
63
 
64
+
65
  def run_clip_dinoiser(input_image, text_prompts, model, device, colors):
66
+ # Resize the input image
67
+ image = input_image.resize((400, 700))
68
+
69
+ image = image.convert("RGB")
70
  text_prompts = text_prompts.split(",")
71
+ palette = colors[: len(text_prompts)]
72
 
73
  model.clip_backbone.decode_head.update_vocab(text_prompts)
74
  model.to(device)
75
 
76
+ img_tens = T.PILToTensor()(image).unsqueeze(0).to(device) / 255.0
77
  h, w = img_tens.shape[-2:]
78
  output = model(img_tens).cpu()
79
+ output = F.interpolate(
80
+ output,
81
+ scale_factor=model.clip_backbone.backbone.patch_size,
82
+ mode="bilinear",
83
+ align_corners=False,
84
+ )[..., :h, :w]
85
  output = output[0].argmax(dim=0)
86
 
87
  mask = mask2rgb(output, palette)
 
88
  alpha = 0.5
89
+ blend = (alpha * np.array(image) / 255.0) + ((1 - alpha) * mask / 255.0)
90
 
91
  h_text = [(text, f"{idx}") for idx, text in enumerate(text_prompts)]
92
  return blend, mask, h_text
 
119
  label="Labels",
120
  combine_adjacent=False,
121
  show_legend=False,
122
+ color_map=color_map,
123
  )
124
 
125
  run_button.click(
126
+ fn=lambda img, prompts: run_clip_dinoiser(
127
+ img, prompts, model, device, colors
128
+ ),
129
  inputs=[input_image, text_prompts],
130
+ outputs=[overlay_mask, only_mask, h_text],
131
  )
132
 
133
  gr.Examples(
134
  examples=[["vintage_bike.jpeg", "background, vintage bike, leather bag"]],
135
  inputs=[input_image, text_prompts],
136
  outputs=[overlay_mask, only_mask, h_text],
137
+ fn=lambda img, prompts: run_clip_dinoiser(
138
+ img, prompts, model, device, colors
139
+ ),
140
  cache_examples=True,
141
+ label="Try this example input!",
142
  )
143
 
144
  return block