Spaces:
Paused
Paused
alessandro trinca tornidor
commited on
Commit
·
8036f02
1
Parent(s):
97b0909
refactor: remove unused modules
Browse files- lisa_on_cuda/utils/ade20k_classes.json +0 -30
- lisa_on_cuda/utils/conversation.py +0 -308
- lisa_on_cuda/utils/create_folders_and_variables_if_not_exists.py +0 -56
- lisa_on_cuda/utils/data_processing.py +0 -90
- lisa_on_cuda/utils/dataset.py +0 -466
- lisa_on_cuda/utils/frontend_builder.py +0 -89
- lisa_on_cuda/utils/grefcoco.py +0 -198
- lisa_on_cuda/utils/grefer.py +0 -352
- lisa_on_cuda/utils/reason_seg_dataset.py +0 -218
- lisa_on_cuda/utils/refer.py +0 -391
- lisa_on_cuda/utils/refer_seg_dataset.py +0 -277
- lisa_on_cuda/utils/sem_seg_dataset.py +0 -335
- lisa_on_cuda/utils/session_logger.py +0 -36
- lisa_on_cuda/utils/vqa_dataset.py +0 -135
lisa_on_cuda/utils/ade20k_classes.json
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
[
|
2 |
-
"wall", "building", "sky", "floor", "tree", "ceiling", "road",
|
3 |
-
"bed", "windowpane", "grass", "cabinet", "sidewalk",
|
4 |
-
"person", "earth", "door", "table", "mountain", "plant",
|
5 |
-
"curtain", "chair", "car", "water", "painting", "sofa",
|
6 |
-
"shelf", "house", "sea", "mirror", "rug", "field", "armchair",
|
7 |
-
"seat", "fence", "desk", "rock", "wardrobe", "lamp",
|
8 |
-
"bathtub", "railing", "cushion", "base", "box", "column",
|
9 |
-
"signboard", "chest of drawers", "counter", "sand", "sink",
|
10 |
-
"skyscraper", "fireplace", "refrigerator", "grandstand",
|
11 |
-
"path", "stairs", "runway", "case", "pool table", "pillow",
|
12 |
-
"screen door", "stairway", "river", "bridge", "bookcase",
|
13 |
-
"blind", "coffee table", "toilet", "flower", "book", "hill",
|
14 |
-
"bench", "countertop", "stove", "palm", "kitchen island",
|
15 |
-
"computer", "swivel chair", "boat", "bar", "arcade machine",
|
16 |
-
"hovel", "bus", "towel", "light", "truck", "tower",
|
17 |
-
"chandelier", "awning", "streetlight", "booth",
|
18 |
-
"television receiver", "airplane", "dirt track", "apparel",
|
19 |
-
"pole", "land", "bannister", "escalator", "ottoman", "bottle",
|
20 |
-
"buffet", "poster", "stage", "van", "ship", "fountain",
|
21 |
-
"conveyer belt", "canopy", "washer", "plaything",
|
22 |
-
"swimming pool", "stool", "barrel", "basket", "waterfall",
|
23 |
-
"tent", "bag", "minibike", "cradle", "oven", "ball", "food",
|
24 |
-
"step", "tank", "trade name", "microwave", "pot", "animal",
|
25 |
-
"bicycle", "lake", "dishwasher", "screen", "blanket",
|
26 |
-
"sculpture", "hood", "sconce", "vase", "traffic light",
|
27 |
-
"tray", "ashcan", "fan", "pier", "crt screen", "plate",
|
28 |
-
"monitor", "bulletin board", "shower", "radiator", "glass",
|
29 |
-
"clock", "flag"
|
30 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lisa_on_cuda/utils/conversation.py
DELETED
@@ -1,308 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Conversation prompt templates.
|
3 |
-
"""
|
4 |
-
|
5 |
-
import dataclasses
|
6 |
-
from enum import Enum, auto
|
7 |
-
from typing import Any, List
|
8 |
-
|
9 |
-
|
10 |
-
class SeparatorStyle(Enum):
|
11 |
-
"""Different separator style."""
|
12 |
-
|
13 |
-
ADD_COLON_SINGLE = auto()
|
14 |
-
ADD_COLON_TWO = auto()
|
15 |
-
NO_COLON_SINGLE = auto()
|
16 |
-
BAIZE = auto()
|
17 |
-
DOLLY = auto()
|
18 |
-
RWKV = auto()
|
19 |
-
|
20 |
-
|
21 |
-
@dataclasses.dataclass
|
22 |
-
class Conversation:
|
23 |
-
"""A class that keeps all conversation history."""
|
24 |
-
|
25 |
-
# System prompts
|
26 |
-
system: str
|
27 |
-
# Two roles
|
28 |
-
roles: List[str]
|
29 |
-
# All messages
|
30 |
-
messages: List[List[str]]
|
31 |
-
# Offset of few shot examples
|
32 |
-
offset: int
|
33 |
-
# Separator
|
34 |
-
sep_style: SeparatorStyle
|
35 |
-
sep: str
|
36 |
-
sep2: str = None
|
37 |
-
# Stop criteria (the default one is EOS token)
|
38 |
-
stop_str: str = None
|
39 |
-
# Stops generation if meeting any token in this list
|
40 |
-
stop_token_ids: List[int] = None
|
41 |
-
|
42 |
-
# Used for the state in the gradio servers.
|
43 |
-
# TODO(lmzheng): refactor this
|
44 |
-
conv_id: Any = None
|
45 |
-
skip_next: bool = False
|
46 |
-
model_name: str = None
|
47 |
-
|
48 |
-
def get_prompt(self):
|
49 |
-
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
|
50 |
-
ret = self.system + self.sep
|
51 |
-
for role, message in self.messages:
|
52 |
-
if message:
|
53 |
-
ret += role + ": " + message + self.sep
|
54 |
-
else:
|
55 |
-
ret += role + ":"
|
56 |
-
return ret
|
57 |
-
elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
|
58 |
-
seps = [self.sep, self.sep2]
|
59 |
-
ret = self.system + seps[0]
|
60 |
-
for i, (role, message) in enumerate(self.messages):
|
61 |
-
if message:
|
62 |
-
ret += role + ": " + message + seps[i % 2]
|
63 |
-
else:
|
64 |
-
ret += role + ":"
|
65 |
-
return ret
|
66 |
-
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
67 |
-
ret = self.system
|
68 |
-
for role, message in self.messages:
|
69 |
-
if message:
|
70 |
-
ret += role + message + self.sep
|
71 |
-
else:
|
72 |
-
ret += role
|
73 |
-
return ret
|
74 |
-
elif self.sep_style == SeparatorStyle.BAIZE:
|
75 |
-
ret = self.system + "\n"
|
76 |
-
for role, message in self.messages:
|
77 |
-
if message:
|
78 |
-
ret += role + message + "\n"
|
79 |
-
else:
|
80 |
-
ret += role
|
81 |
-
return ret
|
82 |
-
elif self.sep_style == SeparatorStyle.DOLLY:
|
83 |
-
seps = [self.sep, self.sep2]
|
84 |
-
ret = self.system
|
85 |
-
for i, (role, message) in enumerate(self.messages):
|
86 |
-
if message:
|
87 |
-
ret += role + ":\n" + message + seps[i % 2]
|
88 |
-
if i % 2 == 1:
|
89 |
-
ret += "\n\n"
|
90 |
-
else:
|
91 |
-
ret += role + ":\n"
|
92 |
-
return ret
|
93 |
-
elif self.sep_style == SeparatorStyle.RWKV:
|
94 |
-
ret = self.system
|
95 |
-
for i, (role, message) in enumerate(self.messages):
|
96 |
-
if message:
|
97 |
-
ret += (
|
98 |
-
role
|
99 |
-
+ ": "
|
100 |
-
+ message.replace("\r\n", "\n").replace("\n\n", "\n")
|
101 |
-
)
|
102 |
-
ret += "\n\n"
|
103 |
-
else:
|
104 |
-
ret += role + ":"
|
105 |
-
return ret
|
106 |
-
else:
|
107 |
-
raise ValueError(f"Invalid style: {self.sep_style}")
|
108 |
-
|
109 |
-
def append_message(self, role, message):
|
110 |
-
self.messages.append([role, message])
|
111 |
-
|
112 |
-
def to_gradio_chatbot(self):
|
113 |
-
ret = []
|
114 |
-
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
115 |
-
if i % 2 == 0:
|
116 |
-
ret.append([msg, None])
|
117 |
-
else:
|
118 |
-
ret[-1][-1] = msg
|
119 |
-
return ret
|
120 |
-
|
121 |
-
def copy(self):
|
122 |
-
return Conversation(
|
123 |
-
system=self.system,
|
124 |
-
roles=self.roles,
|
125 |
-
messages=[[x, y] for x, y in self.messages],
|
126 |
-
offset=self.offset,
|
127 |
-
sep_style=self.sep_style,
|
128 |
-
sep=self.sep,
|
129 |
-
sep2=self.sep2,
|
130 |
-
stop_str=self.stop_str,
|
131 |
-
stop_token_ids=self.stop_token_ids,
|
132 |
-
conv_id=self.conv_id,
|
133 |
-
model_name=self.model_name,
|
134 |
-
)
|
135 |
-
|
136 |
-
def dict(self):
|
137 |
-
return {
|
138 |
-
"system": self.system,
|
139 |
-
"roles": self.roles,
|
140 |
-
"messages": self.messages,
|
141 |
-
"offset": self.offset,
|
142 |
-
"conv_id": self.conv_id,
|
143 |
-
"model_name": self.model_name,
|
144 |
-
}
|
145 |
-
|
146 |
-
|
147 |
-
# A template with one conversation example
|
148 |
-
conv_one_shot = Conversation(
|
149 |
-
system="A chat between a curious human and an artificial intelligence assistant. "
|
150 |
-
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
151 |
-
roles=("Human", "Assistant"),
|
152 |
-
messages=(
|
153 |
-
(
|
154 |
-
"Human",
|
155 |
-
"What are the key differences between renewable and non-renewable energy sources?",
|
156 |
-
),
|
157 |
-
(
|
158 |
-
"Assistant",
|
159 |
-
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
160 |
-
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
161 |
-
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
162 |
-
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
163 |
-
"renewable and non-renewable energy sources:\n"
|
164 |
-
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
165 |
-
"energy sources are finite and will eventually run out.\n"
|
166 |
-
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
167 |
-
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
168 |
-
"and other negative effects.\n"
|
169 |
-
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
170 |
-
"have lower operational costs than non-renewable sources.\n"
|
171 |
-
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
172 |
-
"locations than non-renewable sources.\n"
|
173 |
-
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
174 |
-
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
175 |
-
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
176 |
-
"non-renewable sources are not, and their depletion can lead to economic and social instability.",
|
177 |
-
),
|
178 |
-
),
|
179 |
-
offset=2,
|
180 |
-
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
181 |
-
sep="\n### ",
|
182 |
-
stop_str="###",
|
183 |
-
)
|
184 |
-
|
185 |
-
|
186 |
-
# Vicuna v1.1 template
|
187 |
-
conv_vicuna_v1_1 = Conversation(
|
188 |
-
system="A chat between a curious user and an artificial intelligence assistant. "
|
189 |
-
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
190 |
-
roles=("USER", "ASSISTANT"),
|
191 |
-
messages=(),
|
192 |
-
offset=0,
|
193 |
-
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
194 |
-
sep=" ",
|
195 |
-
sep2="</s>",
|
196 |
-
)
|
197 |
-
|
198 |
-
# Koala default template
|
199 |
-
conv_koala_v1 = Conversation(
|
200 |
-
system="BEGINNING OF CONVERSATION:",
|
201 |
-
roles=("USER", "GPT"),
|
202 |
-
messages=(),
|
203 |
-
offset=0,
|
204 |
-
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
205 |
-
sep=" ",
|
206 |
-
sep2="</s>",
|
207 |
-
)
|
208 |
-
|
209 |
-
# Dolly V2 default template
|
210 |
-
conv_dolly = Conversation(
|
211 |
-
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
|
212 |
-
roles=("### Instruction", "### Response"),
|
213 |
-
messages=(),
|
214 |
-
offset=0,
|
215 |
-
sep_style=SeparatorStyle.DOLLY,
|
216 |
-
sep="\n\n",
|
217 |
-
sep2="### End",
|
218 |
-
)
|
219 |
-
|
220 |
-
# OpenAssistant Pythia default template
|
221 |
-
conv_oasst = Conversation(
|
222 |
-
system="",
|
223 |
-
roles=("<|prompter|>", "<|assistant|>"),
|
224 |
-
messages=(),
|
225 |
-
offset=0,
|
226 |
-
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
227 |
-
sep="<|endoftext|>",
|
228 |
-
)
|
229 |
-
|
230 |
-
# StableLM Alpha default template
|
231 |
-
conv_stablelm = Conversation(
|
232 |
-
system="""<|SYSTEM|># StableLM Tuned (Alpha version)
|
233 |
-
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
234 |
-
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
235 |
-
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
236 |
-
- StableLM will refuse to participate in anything that could harm a human.
|
237 |
-
""",
|
238 |
-
roles=("<|USER|>", "<|ASSISTANT|>"),
|
239 |
-
messages=(),
|
240 |
-
offset=0,
|
241 |
-
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
242 |
-
sep="",
|
243 |
-
stop_token_ids=[50278, 50279, 50277, 1, 0],
|
244 |
-
)
|
245 |
-
|
246 |
-
# Baize default template
|
247 |
-
conv_baize = Conversation(
|
248 |
-
system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.",
|
249 |
-
roles=("[|Human|]", "[|AI|]"),
|
250 |
-
messages=(
|
251 |
-
("[|Human|]", "Hello!"),
|
252 |
-
("[|AI|]", "Hi!"),
|
253 |
-
),
|
254 |
-
offset=2,
|
255 |
-
sep_style=SeparatorStyle.BAIZE,
|
256 |
-
sep="[|Human|]",
|
257 |
-
stop_str="[|Human|]",
|
258 |
-
)
|
259 |
-
|
260 |
-
# RWKV-4-Raven default template
|
261 |
-
conv_rwkv = Conversation(
|
262 |
-
system="",
|
263 |
-
roles=("Bob", "Alice"),
|
264 |
-
messages=(),
|
265 |
-
offset=0,
|
266 |
-
sep_style=SeparatorStyle.RWKV,
|
267 |
-
sep="",
|
268 |
-
stop_str="\n\n",
|
269 |
-
)
|
270 |
-
|
271 |
-
conv_templates = {
|
272 |
-
"baize": conv_baize,
|
273 |
-
"conv_one_shot": conv_one_shot,
|
274 |
-
"dolly": conv_dolly,
|
275 |
-
"koala_v1": conv_koala_v1,
|
276 |
-
"oasst": conv_oasst,
|
277 |
-
"stablelm": conv_stablelm,
|
278 |
-
"vicuna_v1.1": conv_vicuna_v1_1,
|
279 |
-
"rwkv": conv_rwkv,
|
280 |
-
}
|
281 |
-
|
282 |
-
|
283 |
-
def get_default_conv_template(model_name):
|
284 |
-
model_name = model_name.lower()
|
285 |
-
if "vicuna" in model_name or "output" in model_name:
|
286 |
-
return conv_vicuna_v1_1
|
287 |
-
elif "koala" in model_name:
|
288 |
-
return conv_koala_v1
|
289 |
-
elif "dolly-v2" in model_name:
|
290 |
-
return conv_dolly
|
291 |
-
elif "oasst" in model_name and "pythia" in model_name:
|
292 |
-
return conv_oasst
|
293 |
-
elif "baize" in model_name:
|
294 |
-
return conv_baize
|
295 |
-
elif "stablelm" in model_name:
|
296 |
-
return conv_stablelm
|
297 |
-
elif "rwkv-4" in model_name:
|
298 |
-
return conv_rwkv
|
299 |
-
return conv_one_shot
|
300 |
-
|
301 |
-
|
302 |
-
if __name__ == "__main__":
|
303 |
-
conv = conv_templates["vicuna_v1.1"].copy()
|
304 |
-
conv.append_message(conv.roles[0], "Hello!")
|
305 |
-
conv.append_message(conv.roles[1], "Hi!")
|
306 |
-
conv.append_message(conv.roles[0], "How are you?")
|
307 |
-
conv.append_message(conv.roles[1], None)
|
308 |
-
print(conv.get_prompt())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lisa_on_cuda/utils/create_folders_and_variables_if_not_exists.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
from pathlib import Path
|
5 |
-
|
6 |
-
|
7 |
-
def stats_pathname(pathname: Path | str):
|
8 |
-
current_pathname = Path(pathname)
|
9 |
-
return current_pathname.is_dir()
|
10 |
-
|
11 |
-
|
12 |
-
def create_folder_if_not_exists(pathname: Path | str):
|
13 |
-
current_pathname = Path(pathname)
|
14 |
-
try:
|
15 |
-
print(f"Pathname exists? {current_pathname.exists()}, That's a folder? {current_pathname.is_dir()}...")
|
16 |
-
logging.info(f"Pathname exists? {current_pathname.exists()}, That's a folder? {current_pathname.is_dir()}...")
|
17 |
-
current_pathname.unlink(missing_ok=True)
|
18 |
-
except PermissionError as pe:
|
19 |
-
print(f"permission denied on removing pathname before folder creation:{pe}.")
|
20 |
-
logging.error(f"permission denied on removing pathname before folder creation:{pe}.")
|
21 |
-
except IsADirectoryError as errdir:
|
22 |
-
print(f"that's a directory:{errdir}.")
|
23 |
-
logging.error(f"that's a directory:{errdir}.")
|
24 |
-
|
25 |
-
print(f"Creating pathname: {current_pathname} ...")
|
26 |
-
logging.info(f"Creating pathname: {current_pathname} ...")
|
27 |
-
current_pathname.mkdir(mode=0o770, parents=True, exist_ok=True)
|
28 |
-
|
29 |
-
print(f"assertion: pathname exists and is a folder: {current_pathname} ...")
|
30 |
-
logging.info(f"assertion: pathname exists and is a folder: {current_pathname} ...")
|
31 |
-
assert current_pathname.is_dir()
|
32 |
-
|
33 |
-
|
34 |
-
def folders_creation():
|
35 |
-
folders_string = os.getenv("FOLDERS_MAP")
|
36 |
-
try:
|
37 |
-
folders_dict = json.loads(folders_string)
|
38 |
-
for folder_env_ref, folder_env_path in folders_dict.items():
|
39 |
-
print(f"folder_env_ref:{folder_env_ref}, folder_env_path:{folder_env_path}.")
|
40 |
-
logging.info(f"folder_env_ref:{folder_env_ref}, folder_env_path:{folder_env_path}.")
|
41 |
-
create_folder_if_not_exists(folder_env_path)
|
42 |
-
print("========")
|
43 |
-
assert os.getenv(folder_env_ref) == folder_env_path
|
44 |
-
except (json.JSONDecodeError, TypeError) as jde:
|
45 |
-
print(f"jde:{jde}.")
|
46 |
-
logging.error(f"jde:{jde}.")
|
47 |
-
print("double check your variables, e.g. for misspelling like 'FOLDER_MAP'...")
|
48 |
-
logging.info("double check your variables, e.g. for misspelling like 'FOLDER_MAP' instead than 'FOLDERS_MAP'...")
|
49 |
-
for k_env, v_env in dict(os.environ).items():
|
50 |
-
print(f"{k_env}, v_env:{v_env}.")
|
51 |
-
logging.info(f"{k_env}, v_env:{v_env}.")
|
52 |
-
|
53 |
-
|
54 |
-
if __name__ == '__main__':
|
55 |
-
folders_creation()
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lisa_on_cuda/utils/data_processing.py
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
import glob
|
2 |
-
import json
|
3 |
-
import os
|
4 |
-
|
5 |
-
import cv2
|
6 |
-
import numpy as np
|
7 |
-
|
8 |
-
|
9 |
-
def get_mask_from_json(json_path, img):
|
10 |
-
try:
|
11 |
-
with open(json_path, "r") as r:
|
12 |
-
anno = json.loads(r.read())
|
13 |
-
except:
|
14 |
-
with open(json_path, "r", encoding="cp1252") as r:
|
15 |
-
anno = json.loads(r.read())
|
16 |
-
|
17 |
-
inform = anno["shapes"]
|
18 |
-
comments = anno["text"]
|
19 |
-
is_sentence = anno["is_sentence"]
|
20 |
-
|
21 |
-
height, width = img.shape[:2]
|
22 |
-
|
23 |
-
### sort polies by area
|
24 |
-
area_list = []
|
25 |
-
valid_poly_list = []
|
26 |
-
for i in inform:
|
27 |
-
label_id = i["label"]
|
28 |
-
points = i["points"]
|
29 |
-
if "flag" == label_id.lower(): ## meaningless deprecated annotations
|
30 |
-
continue
|
31 |
-
|
32 |
-
tmp_mask = np.zeros((height, width), dtype=np.uint8)
|
33 |
-
cv2.polylines(tmp_mask, np.array([points], dtype=np.int32), True, 1, 1)
|
34 |
-
cv2.fillPoly(tmp_mask, np.array([points], dtype=np.int32), 1)
|
35 |
-
tmp_area = tmp_mask.sum()
|
36 |
-
|
37 |
-
area_list.append(tmp_area)
|
38 |
-
valid_poly_list.append(i)
|
39 |
-
|
40 |
-
### ground-truth mask
|
41 |
-
sort_index = np.argsort(area_list)[::-1].astype(np.int32)
|
42 |
-
sort_index = list(sort_index)
|
43 |
-
sort_inform = []
|
44 |
-
for s_idx in sort_index:
|
45 |
-
sort_inform.append(valid_poly_list[s_idx])
|
46 |
-
|
47 |
-
mask = np.zeros((height, width), dtype=np.uint8)
|
48 |
-
for i in sort_inform:
|
49 |
-
label_id = i["label"]
|
50 |
-
points = i["points"]
|
51 |
-
|
52 |
-
if "ignore" in label_id.lower():
|
53 |
-
label_value = 255 # ignored during evaluation
|
54 |
-
else:
|
55 |
-
label_value = 1 # target
|
56 |
-
|
57 |
-
cv2.polylines(mask, np.array([points], dtype=np.int32), True, label_value, 1)
|
58 |
-
cv2.fillPoly(mask, np.array([points], dtype=np.int32), label_value)
|
59 |
-
|
60 |
-
return mask, comments, is_sentence
|
61 |
-
|
62 |
-
|
63 |
-
if __name__ == "__main__":
|
64 |
-
data_dir = "./train"
|
65 |
-
vis_dir = "./vis"
|
66 |
-
|
67 |
-
if not os.path.exists(vis_dir):
|
68 |
-
os.makedirs(vis_dir)
|
69 |
-
|
70 |
-
json_path_list = sorted(glob.glob(data_dir + "/*.json"))
|
71 |
-
for json_path in json_path_list:
|
72 |
-
img_path = json_path.replace(".json", ".jpg")
|
73 |
-
img = cv2.imread(img_path)[:, :, ::-1]
|
74 |
-
|
75 |
-
# In generated mask, value 1 denotes valid target region, and value 255 stands for region ignored during evaluaiton.
|
76 |
-
mask, comments, is_sentence = get_mask_from_json(json_path, img)
|
77 |
-
|
78 |
-
## visualization. Green for target, and red for ignore.
|
79 |
-
valid_mask = (mask == 1).astype(np.float32)[:, :, None]
|
80 |
-
ignore_mask = (mask == 255).astype(np.float32)[:, :, None]
|
81 |
-
vis_img = img * (1 - valid_mask) * (1 - ignore_mask) + (
|
82 |
-
(np.array([0, 255, 0]) * 0.6 + img * 0.4) * valid_mask
|
83 |
-
+ (np.array([255, 0, 0]) * 0.6 + img * 0.4) * ignore_mask
|
84 |
-
)
|
85 |
-
vis_img = np.concatenate([img, vis_img], 1)
|
86 |
-
vis_path = os.path.join(
|
87 |
-
vis_dir, json_path.split("/")[-1].replace(".json", ".jpg")
|
88 |
-
)
|
89 |
-
cv2.imwrite(vis_path, vis_img[:, :, ::-1])
|
90 |
-
print("Visualization has been saved to: ", vis_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lisa_on_cuda/utils/dataset.py
DELETED
@@ -1,466 +0,0 @@
|
|
1 |
-
import glob
|
2 |
-
import os
|
3 |
-
import random
|
4 |
-
|
5 |
-
import cv2
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
import torch.nn.functional as F
|
9 |
-
from pycocotools import mask
|
10 |
-
from transformers import CLIPImageProcessor
|
11 |
-
|
12 |
-
from lisa_on_cuda.llava import conversation as conversation_lib
|
13 |
-
from lisa_on_cuda.llava.constants import (DEFAULT_IMAGE_TOKEN, IGNORE_INDEX,
|
14 |
-
IMAGE_TOKEN_INDEX)
|
15 |
-
from lisa_on_cuda.llava.mm_utils import tokenizer_image_token
|
16 |
-
from lisa_on_cuda.segment_anything.utils.transforms import ResizeLongestSide
|
17 |
-
|
18 |
-
from .conversation import get_default_conv_template
|
19 |
-
from .data_processing import get_mask_from_json
|
20 |
-
from .reason_seg_dataset import ReasonSegDataset
|
21 |
-
from .refer import REFER
|
22 |
-
from .refer_seg_dataset import ReferSegDataset
|
23 |
-
from .sem_seg_dataset import SemSegDataset
|
24 |
-
from .utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
|
25 |
-
DEFAULT_IMAGE_TOKEN)
|
26 |
-
from .vqa_dataset import VQADataset
|
27 |
-
|
28 |
-
|
29 |
-
def collate_fn(
|
30 |
-
batch, tokenizer=None, conv_type="llava_v1", use_mm_start_end=True, local_rank=-1
|
31 |
-
):
|
32 |
-
image_path_list = []
|
33 |
-
images_list = []
|
34 |
-
images_clip_list = []
|
35 |
-
conversation_list = []
|
36 |
-
masks_list = []
|
37 |
-
label_list = []
|
38 |
-
resize_list = []
|
39 |
-
questions_list = []
|
40 |
-
sampled_classes_list = []
|
41 |
-
offset_list = [0]
|
42 |
-
cnt = 0
|
43 |
-
inferences = []
|
44 |
-
for (
|
45 |
-
image_path,
|
46 |
-
images,
|
47 |
-
images_clip,
|
48 |
-
conversations,
|
49 |
-
masks,
|
50 |
-
label,
|
51 |
-
resize,
|
52 |
-
questions,
|
53 |
-
sampled_classes,
|
54 |
-
inference,
|
55 |
-
) in batch:
|
56 |
-
image_path_list.append(image_path)
|
57 |
-
images_list.append(images)
|
58 |
-
images_clip_list.append(images_clip)
|
59 |
-
conversation_list.extend(conversations)
|
60 |
-
label_list.append(label)
|
61 |
-
masks_list.append(masks.float())
|
62 |
-
resize_list.append(resize)
|
63 |
-
questions_list.append(questions)
|
64 |
-
sampled_classes_list.append(sampled_classes)
|
65 |
-
cnt += len(conversations)
|
66 |
-
offset_list.append(cnt)
|
67 |
-
inferences.append(inference)
|
68 |
-
|
69 |
-
if use_mm_start_end:
|
70 |
-
# replace <image> token
|
71 |
-
for i in range(len(conversation_list)):
|
72 |
-
replace_token = DEFAULT_IMAGE_TOKEN
|
73 |
-
replace_token = (
|
74 |
-
DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
75 |
-
)
|
76 |
-
conversation_list[i] = conversation_list[i].replace(
|
77 |
-
DEFAULT_IMAGE_TOKEN, replace_token
|
78 |
-
)
|
79 |
-
input_ids = [
|
80 |
-
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
81 |
-
for prompt in conversation_list
|
82 |
-
]
|
83 |
-
input_ids = torch.nn.utils.rnn.pad_sequence(
|
84 |
-
input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
|
85 |
-
)
|
86 |
-
attention_masks = input_ids.ne(tokenizer.pad_token_id)
|
87 |
-
|
88 |
-
conv = conversation_lib.default_conversation.copy()
|
89 |
-
targets = input_ids.clone()
|
90 |
-
|
91 |
-
if conv_type == "llava_v1":
|
92 |
-
sep = conv.sep + conv.roles[1] + ": "
|
93 |
-
else:
|
94 |
-
sep = "[/INST] "
|
95 |
-
for conversation, target in zip(conversation_list, targets):
|
96 |
-
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
97 |
-
|
98 |
-
rounds = conversation.split(conv.sep2)
|
99 |
-
cur_len = 1
|
100 |
-
target[:cur_len] = IGNORE_INDEX
|
101 |
-
for i, rou in enumerate(rounds):
|
102 |
-
if rou == "":
|
103 |
-
break
|
104 |
-
|
105 |
-
parts = rou.split(sep)
|
106 |
-
# if len(parts) != 2:
|
107 |
-
# break
|
108 |
-
assert len(parts) == 2, (len(parts), rou)
|
109 |
-
parts[0] += sep
|
110 |
-
|
111 |
-
if DEFAULT_IMAGE_TOKEN in conversation:
|
112 |
-
round_len = len(tokenizer_image_token(rou, tokenizer))
|
113 |
-
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
|
114 |
-
else:
|
115 |
-
round_len = len(tokenizer(rou).input_ids)
|
116 |
-
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
117 |
-
|
118 |
-
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
119 |
-
|
120 |
-
cur_len += round_len
|
121 |
-
target[cur_len:] = IGNORE_INDEX
|
122 |
-
|
123 |
-
if False:
|
124 |
-
z = target.clone()
|
125 |
-
z = torch.where(z == IGNORE_INDEX, tokenizer.unk_token_id, z)
|
126 |
-
if local_rank == 0:
|
127 |
-
print(
|
128 |
-
"conversation: ",
|
129 |
-
conversation,
|
130 |
-
"tokenizer.decode(z): ",
|
131 |
-
tokenizer.decode(z),
|
132 |
-
)
|
133 |
-
|
134 |
-
if cur_len < tokenizer.model_max_length:
|
135 |
-
assert cur_len == total_len
|
136 |
-
|
137 |
-
if inferences[0] == False:
|
138 |
-
truncate_len = tokenizer.model_max_length - 255
|
139 |
-
|
140 |
-
if input_ids.shape[1] > truncate_len:
|
141 |
-
input_ids = input_ids[:, :truncate_len]
|
142 |
-
targets = targets[:, :truncate_len]
|
143 |
-
attention_masks = attention_masks[:, :truncate_len]
|
144 |
-
|
145 |
-
return {
|
146 |
-
"image_paths": image_path_list,
|
147 |
-
"images": torch.stack(images_list, dim=0),
|
148 |
-
"images_clip": torch.stack(images_clip_list, dim=0),
|
149 |
-
"input_ids": input_ids,
|
150 |
-
"labels": targets,
|
151 |
-
"attention_masks": attention_masks,
|
152 |
-
"masks_list": masks_list,
|
153 |
-
"label_list": label_list,
|
154 |
-
"resize_list": resize_list,
|
155 |
-
"offset": torch.LongTensor(offset_list),
|
156 |
-
"questions_list": questions_list,
|
157 |
-
"sampled_classes_list": sampled_classes_list,
|
158 |
-
"inference": inferences[0],
|
159 |
-
"conversation_list": conversation_list,
|
160 |
-
}
|
161 |
-
|
162 |
-
|
163 |
-
class HybridDataset(torch.utils.data.Dataset):
|
164 |
-
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
|
165 |
-
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
|
166 |
-
img_size = 1024
|
167 |
-
ignore_label = 255
|
168 |
-
|
169 |
-
def __init__(
|
170 |
-
self,
|
171 |
-
base_image_dir,
|
172 |
-
tokenizer,
|
173 |
-
vision_tower,
|
174 |
-
samples_per_epoch=500 * 8 * 2 * 10,
|
175 |
-
precision: str = "fp32",
|
176 |
-
image_size: int = 224,
|
177 |
-
num_classes_per_sample: int = 3,
|
178 |
-
exclude_val=False,
|
179 |
-
dataset="sem_seg||refer_seg||vqa||reason_seg",
|
180 |
-
sample_rate=[9, 3, 3, 1],
|
181 |
-
sem_seg_data="ade20k||cocostuff||partimagenet||pascal_part||paco_lvis||mapillary",
|
182 |
-
refer_seg_data="refclef||refcoco||refcoco+||refcocog",
|
183 |
-
vqa_data="llava_instruct_150k",
|
184 |
-
reason_seg_data="ReasonSeg|train",
|
185 |
-
explanatory=0.1,
|
186 |
-
):
|
187 |
-
self.exclude_val = exclude_val
|
188 |
-
self.dataset = dataset
|
189 |
-
self.samples_per_epoch = samples_per_epoch
|
190 |
-
self.explanatory = explanatory
|
191 |
-
self.num_classes_per_sample = num_classes_per_sample
|
192 |
-
sample_rate = np.array(sample_rate)
|
193 |
-
self.sample_rate = sample_rate / sample_rate.sum()
|
194 |
-
|
195 |
-
self.base_image_dir = base_image_dir
|
196 |
-
self.image_size = image_size
|
197 |
-
self.tokenizer = tokenizer
|
198 |
-
self.precision = precision
|
199 |
-
|
200 |
-
self.datasets = dataset.split("||")
|
201 |
-
|
202 |
-
self.all_datasets = []
|
203 |
-
for dataset in self.datasets:
|
204 |
-
if dataset == "sem_seg":
|
205 |
-
self.all_datasets.append(
|
206 |
-
SemSegDataset(
|
207 |
-
base_image_dir,
|
208 |
-
tokenizer,
|
209 |
-
vision_tower,
|
210 |
-
samples_per_epoch,
|
211 |
-
precision,
|
212 |
-
image_size,
|
213 |
-
num_classes_per_sample,
|
214 |
-
exclude_val,
|
215 |
-
sem_seg_data,
|
216 |
-
)
|
217 |
-
)
|
218 |
-
elif dataset == "refer_seg":
|
219 |
-
self.all_datasets.append(
|
220 |
-
ReferSegDataset(
|
221 |
-
base_image_dir,
|
222 |
-
tokenizer,
|
223 |
-
vision_tower,
|
224 |
-
samples_per_epoch,
|
225 |
-
precision,
|
226 |
-
image_size,
|
227 |
-
num_classes_per_sample,
|
228 |
-
exclude_val,
|
229 |
-
refer_seg_data,
|
230 |
-
)
|
231 |
-
)
|
232 |
-
elif dataset == "vqa":
|
233 |
-
self.all_datasets.append(
|
234 |
-
VQADataset(
|
235 |
-
base_image_dir,
|
236 |
-
tokenizer,
|
237 |
-
vision_tower,
|
238 |
-
samples_per_epoch,
|
239 |
-
precision,
|
240 |
-
image_size,
|
241 |
-
num_classes_per_sample,
|
242 |
-
exclude_val,
|
243 |
-
vqa_data,
|
244 |
-
)
|
245 |
-
)
|
246 |
-
elif dataset == "reason_seg":
|
247 |
-
self.all_datasets.append(
|
248 |
-
ReasonSegDataset(
|
249 |
-
base_image_dir,
|
250 |
-
tokenizer,
|
251 |
-
vision_tower,
|
252 |
-
samples_per_epoch,
|
253 |
-
precision,
|
254 |
-
image_size,
|
255 |
-
num_classes_per_sample,
|
256 |
-
exclude_val,
|
257 |
-
reason_seg_data,
|
258 |
-
explanatory,
|
259 |
-
)
|
260 |
-
)
|
261 |
-
|
262 |
-
def __len__(self):
|
263 |
-
return self.samples_per_epoch
|
264 |
-
|
265 |
-
def __getitem__(self, idx):
|
266 |
-
ind = np.random.choice(list(range(len(self.datasets))), p=self.sample_rate)
|
267 |
-
data = self.all_datasets[ind]
|
268 |
-
inference = False
|
269 |
-
return *data[0], inference
|
270 |
-
|
271 |
-
|
272 |
-
class ValDataset(torch.utils.data.Dataset):
|
273 |
-
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
|
274 |
-
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
|
275 |
-
img_size = 1024
|
276 |
-
ignore_label = 255
|
277 |
-
|
278 |
-
def __init__(
|
279 |
-
self,
|
280 |
-
base_image_dir,
|
281 |
-
tokenizer,
|
282 |
-
vision_tower,
|
283 |
-
val_dataset,
|
284 |
-
image_size=1024,
|
285 |
-
):
|
286 |
-
self.base_image_dir = base_image_dir
|
287 |
-
splits = val_dataset.split("|")
|
288 |
-
if len(splits) == 2:
|
289 |
-
ds, split = splits
|
290 |
-
images = glob.glob(
|
291 |
-
os.path.join(self.base_image_dir, "reason_seg", ds, split, "*.jpg")
|
292 |
-
)
|
293 |
-
self.images = images
|
294 |
-
self.data_type = "reason_seg"
|
295 |
-
elif len(splits) == 3:
|
296 |
-
ds, splitBy, split = splits
|
297 |
-
refer_api = REFER(self.base_image_dir, ds, splitBy)
|
298 |
-
ref_ids_val = refer_api.getRefIds(split=split)
|
299 |
-
images_ids_val = refer_api.getImgIds(ref_ids=ref_ids_val)
|
300 |
-
refs_val = refer_api.loadRefs(ref_ids=ref_ids_val)
|
301 |
-
refer_seg_ds = {}
|
302 |
-
refer_seg_ds["images"] = []
|
303 |
-
loaded_images = refer_api.loadImgs(image_ids=images_ids_val)
|
304 |
-
for item in loaded_images:
|
305 |
-
item = item.copy()
|
306 |
-
if ds == "refclef":
|
307 |
-
item["file_name"] = os.path.join(
|
308 |
-
base_image_dir, "images/saiapr_tc-12", item["file_name"]
|
309 |
-
)
|
310 |
-
elif ds in ["refcoco", "refcoco+", "refcocog", "grefcoco"]:
|
311 |
-
item["file_name"] = os.path.join(
|
312 |
-
base_image_dir,
|
313 |
-
"images/mscoco/images/train2014",
|
314 |
-
item["file_name"],
|
315 |
-
)
|
316 |
-
refer_seg_ds["images"].append(item)
|
317 |
-
refer_seg_ds["annotations"] = refer_api.Anns # anns_val
|
318 |
-
|
319 |
-
img2refs = {}
|
320 |
-
for ref in refs_val:
|
321 |
-
image_id = ref["image_id"]
|
322 |
-
img2refs[image_id] = img2refs.get(image_id, []) + [
|
323 |
-
ref,
|
324 |
-
]
|
325 |
-
refer_seg_ds["img2refs"] = img2refs
|
326 |
-
self.refer_seg_ds = refer_seg_ds
|
327 |
-
self.data_type = "refer_seg"
|
328 |
-
|
329 |
-
self.ds = ds
|
330 |
-
self.image_size = image_size
|
331 |
-
self.tokenizer = tokenizer
|
332 |
-
self.transform = ResizeLongestSide(image_size)
|
333 |
-
self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
334 |
-
|
335 |
-
def __len__(self):
|
336 |
-
if self.data_type == "refer_seg":
|
337 |
-
return len(self.refer_seg_ds["images"])
|
338 |
-
else:
|
339 |
-
return len(self.images)
|
340 |
-
|
341 |
-
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
342 |
-
"""Normalize pixel values and pad to a square input."""
|
343 |
-
# Normalize colors
|
344 |
-
x = (x - self.pixel_mean) / self.pixel_std
|
345 |
-
|
346 |
-
# Pad
|
347 |
-
h, w = x.shape[-2:]
|
348 |
-
padh = self.img_size - h
|
349 |
-
padw = self.img_size - w
|
350 |
-
x = F.pad(x, (0, padw, 0, padh))
|
351 |
-
return x
|
352 |
-
|
353 |
-
def __getitem__(self, idx):
|
354 |
-
if self.data_type == "refer_seg":
|
355 |
-
refer_seg_ds = self.refer_seg_ds
|
356 |
-
images = refer_seg_ds["images"]
|
357 |
-
annotations = refer_seg_ds["annotations"]
|
358 |
-
img2refs = refer_seg_ds["img2refs"]
|
359 |
-
|
360 |
-
image_info = images[idx]
|
361 |
-
image_path = image_info["file_name"]
|
362 |
-
image_id = image_info["id"]
|
363 |
-
|
364 |
-
refs = img2refs[image_id]
|
365 |
-
if len(refs) == 0:
|
366 |
-
raise ValueError("image {} has no refs".format(image_id))
|
367 |
-
|
368 |
-
sents = []
|
369 |
-
ann_ids = []
|
370 |
-
for ref in refs:
|
371 |
-
for sent in ref["sentences"]:
|
372 |
-
sents.append(sent["sent"].strip().lower())
|
373 |
-
ann_ids.append(ref["ann_id"])
|
374 |
-
|
375 |
-
sampled_sents = sents
|
376 |
-
sampled_ann_ids = ann_ids
|
377 |
-
image = cv2.imread(image_path)
|
378 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
379 |
-
is_sentence = False
|
380 |
-
else:
|
381 |
-
image_path = self.images[idx]
|
382 |
-
image = cv2.imread(image_path)
|
383 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
384 |
-
json_path = image_path.replace(".jpg", ".json")
|
385 |
-
mask_json, sampled_sents, is_sentence = get_mask_from_json(json_path, image)
|
386 |
-
sampled_sents = [sampled_sents[0]]
|
387 |
-
|
388 |
-
conversations = []
|
389 |
-
conv = conversation_lib.default_conversation.copy()
|
390 |
-
i = 0
|
391 |
-
while i < len(sampled_sents):
|
392 |
-
conv.messages = []
|
393 |
-
text = sampled_sents[i].strip()
|
394 |
-
if is_sentence:
|
395 |
-
conv.append_message(
|
396 |
-
conv.roles[0],
|
397 |
-
DEFAULT_IMAGE_TOKEN
|
398 |
-
+ "\n {} Please output segmentation mask.".format(text),
|
399 |
-
)
|
400 |
-
conv.append_message(conv.roles[1], "[SEG].")
|
401 |
-
else:
|
402 |
-
conv.append_message(
|
403 |
-
conv.roles[0],
|
404 |
-
DEFAULT_IMAGE_TOKEN
|
405 |
-
+ "\n What is {} in this image? Please output segmentation mask.".format(
|
406 |
-
text
|
407 |
-
),
|
408 |
-
)
|
409 |
-
conv.append_message(conv.roles[1], "[SEG].")
|
410 |
-
conversations.append(conv.get_prompt())
|
411 |
-
i += 1
|
412 |
-
|
413 |
-
# preprocess image for clip
|
414 |
-
image_clip = self.clip_image_processor.preprocess(image, return_tensors="pt")[
|
415 |
-
"pixel_values"
|
416 |
-
][0]
|
417 |
-
|
418 |
-
# preprocess image for sam
|
419 |
-
image = self.transform.apply_image(image)
|
420 |
-
resize = image.shape[:2]
|
421 |
-
image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
|
422 |
-
|
423 |
-
if self.data_type == "refer_seg":
|
424 |
-
masks = []
|
425 |
-
for i, ann_id in enumerate(sampled_ann_ids):
|
426 |
-
ann = annotations[ann_id]
|
427 |
-
if len(ann["segmentation"]) == 0 and sampled_sents[i] != "":
|
428 |
-
m = np.zeros((image_info["height"], image_info["width"], 1))
|
429 |
-
else:
|
430 |
-
if type(ann["segmentation"][0]) == list: # polygon
|
431 |
-
rle = mask.frPyObjects(
|
432 |
-
ann["segmentation"],
|
433 |
-
image_info["height"],
|
434 |
-
image_info["width"],
|
435 |
-
)
|
436 |
-
else:
|
437 |
-
rle = ann["segmentation"]
|
438 |
-
for i in range(len(rle)):
|
439 |
-
if not isinstance(rle[i]["counts"], bytes):
|
440 |
-
rle[i]["counts"] = rle[i]["counts"].encode()
|
441 |
-
m = mask.decode(rle)
|
442 |
-
m = np.sum(
|
443 |
-
m, axis=2
|
444 |
-
) # sometimes there are multiple binary map (corresponding to multiple segs)
|
445 |
-
m = m.astype(np.uint8) # convert to np.uint8
|
446 |
-
masks.append(m)
|
447 |
-
else:
|
448 |
-
masks = [mask_json]
|
449 |
-
|
450 |
-
masks = np.stack(masks, axis=0)
|
451 |
-
masks = torch.from_numpy(masks)
|
452 |
-
labels = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
|
453 |
-
inference = True
|
454 |
-
|
455 |
-
return (
|
456 |
-
image_path,
|
457 |
-
image,
|
458 |
-
image_clip,
|
459 |
-
conversations,
|
460 |
-
masks,
|
461 |
-
labels,
|
462 |
-
resize,
|
463 |
-
None,
|
464 |
-
None,
|
465 |
-
inference,
|
466 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lisa_on_cuda/utils/frontend_builder.py
DELETED
@@ -1,89 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import os
|
3 |
-
import subprocess
|
4 |
-
from pathlib import Path
|
5 |
-
|
6 |
-
from dotenv import load_dotenv
|
7 |
-
|
8 |
-
from lisa_on_cuda.utils import session_logger
|
9 |
-
|
10 |
-
|
11 |
-
load_dotenv()
|
12 |
-
LOGLEVEL = os.getenv('LOGLEVEL', 'INFO').upper()
|
13 |
-
session_logger.change_logging(LOGLEVEL)
|
14 |
-
root_folder = Path(globals().get("__file__", "./_")).absolute().parent.parent.parent
|
15 |
-
env_project_root_folder = Path(os.getenv("PROJECT_ROOT_FOLDER", root_folder))
|
16 |
-
env_input_css_path = Path(os.getenv("INPUT_CSS_PATH"))
|
17 |
-
|
18 |
-
|
19 |
-
def assert_envs(envs_list):
|
20 |
-
for current_env in envs_list:
|
21 |
-
try:
|
22 |
-
assert current_env is not None and current_env != ""
|
23 |
-
except AssertionError as aex:
|
24 |
-
logging.error(f"error on assertion for current_env: {current_env}.")
|
25 |
-
raise aex
|
26 |
-
|
27 |
-
|
28 |
-
def read_std_out_err(std_out_err, output_type: str, command: list):
|
29 |
-
output = std_out_err.split("\n")
|
30 |
-
logging.info(f"output type:{output_type} for command:{' '.join(command)}.")
|
31 |
-
for line in iter(output):
|
32 |
-
logging.info(f"output_content_home stdout:{line.strip()}.")
|
33 |
-
logging.info("########")
|
34 |
-
|
35 |
-
|
36 |
-
def run_command(commands_list: list, capture_output: bool = True, text: bool = True, check: bool = True) -> None:
|
37 |
-
try:
|
38 |
-
output_content_home = subprocess.run(
|
39 |
-
commands_list,
|
40 |
-
capture_output=capture_output,
|
41 |
-
text=text,
|
42 |
-
check=check
|
43 |
-
)
|
44 |
-
read_std_out_err(output_content_home.stdout, "stdout", commands_list)
|
45 |
-
read_std_out_err(output_content_home.stderr, "stderr", commands_list)
|
46 |
-
except Exception as ex:
|
47 |
-
logging.error(f"ex:{ex}.")
|
48 |
-
raise ex
|
49 |
-
|
50 |
-
|
51 |
-
def build_frontend(
|
52 |
-
project_root_folder: str | Path,
|
53 |
-
input_css_path: str | Path,
|
54 |
-
output_dist_folder: Path = root_folder / "static" / "dist",
|
55 |
-
) -> None:
|
56 |
-
assert_envs([
|
57 |
-
str(project_root_folder),
|
58 |
-
str(input_css_path)
|
59 |
-
])
|
60 |
-
|
61 |
-
# install deps
|
62 |
-
os.chdir(Path(project_root_folder) / "static")
|
63 |
-
current_folder = os.getcwd()
|
64 |
-
logging.info(f"current_folder:{current_folder}, install pnpm...")
|
65 |
-
run_command(["npm", "install", "-g", "npm", "pnpm"])
|
66 |
-
logging.info(f"install pnpm dependencies...")
|
67 |
-
run_command(["pnpm", "install"])
|
68 |
-
|
69 |
-
# build frontend dist and assert for its correct build
|
70 |
-
output_css = str(output_dist_folder / "output.css")
|
71 |
-
output_index_html = str(output_dist_folder / "index.html")
|
72 |
-
output_dist_folder = str(output_dist_folder)
|
73 |
-
logging.info(f"pnpm: build '{output_dist_folder}'...")
|
74 |
-
run_command(["pnpm", "build"])
|
75 |
-
logging.info(f"pnpm: ls -l {output_index_html}:")
|
76 |
-
run_command(["ls", "-l", output_index_html])
|
77 |
-
cmd = ["pnpm", "tailwindcss", "-i", str(input_css_path), "-o", output_css]
|
78 |
-
logging.info(f"pnpm: {' '.join(cmd)}...")
|
79 |
-
run_command(["pnpm", "tailwindcss", "-i", str(input_css_path), "-o", output_css])
|
80 |
-
logging.info(f"pnpm: ls -l {output_css}:")
|
81 |
-
run_command(["ls", "-l", output_css])
|
82 |
-
logging.info(f"end!")
|
83 |
-
|
84 |
-
|
85 |
-
if __name__ == '__main__':
|
86 |
-
build_frontend(
|
87 |
-
project_root_folder=env_project_root_folder,
|
88 |
-
input_css_path=env_input_css_path
|
89 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lisa_on_cuda/utils/grefcoco.py
DELETED
@@ -1,198 +0,0 @@
|
|
1 |
-
import contextlib
|
2 |
-
import copy
|
3 |
-
import io
|
4 |
-
import logging
|
5 |
-
import os
|
6 |
-
import random
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
import pycocotools.mask as mask_util
|
10 |
-
from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes
|
11 |
-
from detectron2.utils.file_io import PathManager
|
12 |
-
from fvcore.common.timer import Timer
|
13 |
-
from PIL import Image
|
14 |
-
|
15 |
-
"""
|
16 |
-
This file contains functions to parse RefCOCO-format annotations into dicts in "Detectron2 format".
|
17 |
-
"""
|
18 |
-
|
19 |
-
|
20 |
-
logger = logging.getLogger(__name__)
|
21 |
-
|
22 |
-
__all__ = ["load_refcoco_json"]
|
23 |
-
|
24 |
-
|
25 |
-
def load_grefcoco_json(
|
26 |
-
refer_root,
|
27 |
-
dataset_name,
|
28 |
-
splitby,
|
29 |
-
split,
|
30 |
-
image_root,
|
31 |
-
extra_annotation_keys=None,
|
32 |
-
extra_refer_keys=None,
|
33 |
-
):
|
34 |
-
if dataset_name == "refcocop":
|
35 |
-
dataset_name = "refcoco+"
|
36 |
-
if dataset_name == "refcoco" or dataset_name == "refcoco+":
|
37 |
-
splitby == "unc"
|
38 |
-
if dataset_name == "refcocog":
|
39 |
-
assert splitby == "umd" or splitby == "google"
|
40 |
-
|
41 |
-
dataset_id = "_".join([dataset_name, splitby, split])
|
42 |
-
|
43 |
-
from .grefer import G_REFER
|
44 |
-
|
45 |
-
logger.info("Loading dataset {} ({}-{}) ...".format(dataset_name, splitby, split))
|
46 |
-
logger.info("Refcoco root: {}".format(refer_root))
|
47 |
-
timer = Timer()
|
48 |
-
refer_root = PathManager.get_local_path(refer_root)
|
49 |
-
with contextlib.redirect_stdout(io.StringIO()):
|
50 |
-
refer_api = G_REFER(data_root=refer_root, dataset=dataset_name, splitBy=splitby)
|
51 |
-
if timer.seconds() > 1:
|
52 |
-
logger.info(
|
53 |
-
"Loading {} takes {:.2f} seconds.".format(dataset_id, timer.seconds())
|
54 |
-
)
|
55 |
-
|
56 |
-
ref_ids = refer_api.getRefIds(split=split)
|
57 |
-
img_ids = refer_api.getImgIds(ref_ids)
|
58 |
-
refs = refer_api.loadRefs(ref_ids)
|
59 |
-
imgs = [refer_api.loadImgs(ref["image_id"])[0] for ref in refs]
|
60 |
-
anns = [refer_api.loadAnns(ref["ann_id"]) for ref in refs]
|
61 |
-
imgs_refs_anns = list(zip(imgs, refs, anns))
|
62 |
-
|
63 |
-
logger.info(
|
64 |
-
"Loaded {} images, {} referring object sets in G_RefCOCO format from {}".format(
|
65 |
-
len(img_ids), len(ref_ids), dataset_id
|
66 |
-
)
|
67 |
-
)
|
68 |
-
|
69 |
-
dataset_dicts = []
|
70 |
-
|
71 |
-
ann_keys = ["iscrowd", "bbox", "category_id"] + (extra_annotation_keys or [])
|
72 |
-
ref_keys = ["raw", "sent_id"] + (extra_refer_keys or [])
|
73 |
-
|
74 |
-
ann_lib = {}
|
75 |
-
|
76 |
-
NT_count = 0
|
77 |
-
MT_count = 0
|
78 |
-
|
79 |
-
for img_dict, ref_dict, anno_dicts in imgs_refs_anns:
|
80 |
-
record = {}
|
81 |
-
record["source"] = "grefcoco"
|
82 |
-
record["file_name"] = os.path.join(image_root, img_dict["file_name"])
|
83 |
-
record["height"] = img_dict["height"]
|
84 |
-
record["width"] = img_dict["width"]
|
85 |
-
image_id = record["image_id"] = img_dict["id"]
|
86 |
-
|
87 |
-
# Check that information of image, ann and ref match each other
|
88 |
-
# This fails only when the data parsing logic or the annotation file is buggy.
|
89 |
-
assert ref_dict["image_id"] == image_id
|
90 |
-
assert ref_dict["split"] == split
|
91 |
-
if not isinstance(ref_dict["ann_id"], list):
|
92 |
-
ref_dict["ann_id"] = [ref_dict["ann_id"]]
|
93 |
-
|
94 |
-
# No target samples
|
95 |
-
if None in anno_dicts:
|
96 |
-
assert anno_dicts == [None]
|
97 |
-
assert ref_dict["ann_id"] == [-1]
|
98 |
-
record["empty"] = True
|
99 |
-
obj = {key: None for key in ann_keys if key in ann_keys}
|
100 |
-
obj["bbox_mode"] = BoxMode.XYWH_ABS
|
101 |
-
obj["empty"] = True
|
102 |
-
obj = [obj]
|
103 |
-
|
104 |
-
# Multi target samples
|
105 |
-
else:
|
106 |
-
record["empty"] = False
|
107 |
-
obj = []
|
108 |
-
for anno_dict in anno_dicts:
|
109 |
-
ann_id = anno_dict["id"]
|
110 |
-
if anno_dict["iscrowd"]:
|
111 |
-
continue
|
112 |
-
assert anno_dict["image_id"] == image_id
|
113 |
-
assert ann_id in ref_dict["ann_id"]
|
114 |
-
|
115 |
-
if ann_id in ann_lib:
|
116 |
-
ann = ann_lib[ann_id]
|
117 |
-
else:
|
118 |
-
ann = {key: anno_dict[key] for key in ann_keys if key in anno_dict}
|
119 |
-
ann["bbox_mode"] = BoxMode.XYWH_ABS
|
120 |
-
ann["empty"] = False
|
121 |
-
|
122 |
-
segm = anno_dict.get("segmentation", None)
|
123 |
-
assert segm # either list[list[float]] or dict(RLE)
|
124 |
-
if isinstance(segm, dict):
|
125 |
-
if isinstance(segm["counts"], list):
|
126 |
-
# convert to compressed RLE
|
127 |
-
segm = mask_util.frPyObjects(segm, *segm["size"])
|
128 |
-
else:
|
129 |
-
# filter out invalid polygons (< 3 points)
|
130 |
-
segm = [
|
131 |
-
poly
|
132 |
-
for poly in segm
|
133 |
-
if len(poly) % 2 == 0 and len(poly) >= 6
|
134 |
-
]
|
135 |
-
if len(segm) == 0:
|
136 |
-
num_instances_without_valid_segmentation += 1
|
137 |
-
continue # ignore this instance
|
138 |
-
ann["segmentation"] = segm
|
139 |
-
ann_lib[ann_id] = ann
|
140 |
-
|
141 |
-
obj.append(ann)
|
142 |
-
|
143 |
-
record["annotations"] = obj
|
144 |
-
|
145 |
-
# Process referring expressions
|
146 |
-
sents = ref_dict["sentences"]
|
147 |
-
for sent in sents:
|
148 |
-
ref_record = record.copy()
|
149 |
-
ref = {key: sent[key] for key in ref_keys if key in sent}
|
150 |
-
ref["ref_id"] = ref_dict["ref_id"]
|
151 |
-
ref_record["sentence"] = ref
|
152 |
-
dataset_dicts.append(ref_record)
|
153 |
-
# if ref_record['empty']:
|
154 |
-
# NT_count += 1
|
155 |
-
# else:
|
156 |
-
# MT_count += 1
|
157 |
-
|
158 |
-
# logger.info("NT samples: %d, MT samples: %d", NT_count, MT_count)
|
159 |
-
|
160 |
-
# Debug mode
|
161 |
-
# return dataset_dicts[:100]
|
162 |
-
|
163 |
-
return dataset_dicts
|
164 |
-
|
165 |
-
|
166 |
-
if __name__ == "__main__":
|
167 |
-
"""
|
168 |
-
Test the COCO json dataset loader.
|
169 |
-
|
170 |
-
Usage:
|
171 |
-
python -m detectron2.data.datasets.coco \
|
172 |
-
path/to/json path/to/image_root dataset_name
|
173 |
-
|
174 |
-
"dataset_name" can be "coco_2014_minival_100", or other
|
175 |
-
pre-registered ones
|
176 |
-
"""
|
177 |
-
import sys
|
178 |
-
|
179 |
-
import detectron2.data.datasets # noqa # add pre-defined metadata
|
180 |
-
from detectron2.utils.logger import setup_logger
|
181 |
-
from detectron2.utils.visualizer import Visualizer
|
182 |
-
|
183 |
-
REFCOCO_PATH = "/mnt/lustre/hhding/code/ReLA/datasets"
|
184 |
-
COCO_TRAIN_2014_IMAGE_ROOT = "/mnt/lustre/hhding/code/ReLA/datasets/images"
|
185 |
-
REFCOCO_DATASET = "grefcoco"
|
186 |
-
REFCOCO_SPLITBY = "unc"
|
187 |
-
REFCOCO_SPLIT = "train"
|
188 |
-
|
189 |
-
logger = setup_logger(name=__name__)
|
190 |
-
|
191 |
-
dicts = load_grefcoco_json(
|
192 |
-
REFCOCO_PATH,
|
193 |
-
REFCOCO_DATASET,
|
194 |
-
REFCOCO_SPLITBY,
|
195 |
-
REFCOCO_SPLIT,
|
196 |
-
COCO_TRAIN_2014_IMAGE_ROOT,
|
197 |
-
)
|
198 |
-
logger.info("Done loading {} samples.".format(len(dicts)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lisa_on_cuda/utils/grefer.py
DELETED
@@ -1,352 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
grefer v0.1
|
3 |
-
This interface provides access to gRefCOCO.
|
4 |
-
|
5 |
-
The following API functions are defined:
|
6 |
-
G_REFER - REFER api class
|
7 |
-
getRefIds - get ref ids that satisfy given filter conditions.
|
8 |
-
getAnnIds - get ann ids that satisfy given filter conditions.
|
9 |
-
getImgIds - get image ids that satisfy given filter conditions.
|
10 |
-
getCatIds - get category ids that satisfy given filter conditions.
|
11 |
-
loadRefs - load refs with the specified ref ids.
|
12 |
-
loadAnns - load anns with the specified ann ids.
|
13 |
-
loadImgs - load images with the specified image ids.
|
14 |
-
loadCats - load category names with the specified category ids.
|
15 |
-
getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
|
16 |
-
showRef - show image, segmentation or box of the referred object with the ref
|
17 |
-
getMaskByRef - get mask and area of the referred object given ref or ref ids
|
18 |
-
getMask - get mask and area of the referred object given ref
|
19 |
-
showMask - show mask of the referred object given ref
|
20 |
-
"""
|
21 |
-
|
22 |
-
import itertools
|
23 |
-
import json
|
24 |
-
import os.path as osp
|
25 |
-
import pickle
|
26 |
-
import time
|
27 |
-
|
28 |
-
import matplotlib.pyplot as plt
|
29 |
-
import numpy as np
|
30 |
-
import skimage.io as io
|
31 |
-
from matplotlib.collections import PatchCollection
|
32 |
-
from matplotlib.patches import Polygon, Rectangle
|
33 |
-
from pycocotools import mask
|
34 |
-
|
35 |
-
|
36 |
-
class G_REFER:
|
37 |
-
def __init__(self, data_root, dataset="grefcoco", splitBy="unc"):
|
38 |
-
# provide data_root folder which contains grefcoco
|
39 |
-
print("loading dataset %s into memory..." % dataset)
|
40 |
-
self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
|
41 |
-
self.DATA_DIR = osp.join(data_root, dataset)
|
42 |
-
if dataset in ["grefcoco"]:
|
43 |
-
self.IMAGE_DIR = osp.join(data_root, "images/train2014")
|
44 |
-
else:
|
45 |
-
raise KeyError("No refer dataset is called [%s]" % dataset)
|
46 |
-
|
47 |
-
tic = time.time()
|
48 |
-
|
49 |
-
# load refs from data/dataset/refs(dataset).json
|
50 |
-
self.data = {}
|
51 |
-
self.data["dataset"] = dataset
|
52 |
-
|
53 |
-
ref_file = osp.join(self.DATA_DIR, f"grefs({splitBy}).p")
|
54 |
-
if osp.exists(ref_file):
|
55 |
-
self.data["refs"] = pickle.load(open(ref_file, "rb"), fix_imports=True)
|
56 |
-
else:
|
57 |
-
ref_file = osp.join(self.DATA_DIR, f"grefs({splitBy}).json")
|
58 |
-
if osp.exists(ref_file):
|
59 |
-
self.data["refs"] = json.load(open(ref_file, "rb"))
|
60 |
-
else:
|
61 |
-
raise FileNotFoundError("JSON file not found")
|
62 |
-
|
63 |
-
# load annotations from data/dataset/instances.json
|
64 |
-
instances_file = osp.join(self.DATA_DIR, "instances.json")
|
65 |
-
instances = json.load(open(instances_file, "r"))
|
66 |
-
self.data["images"] = instances["images"]
|
67 |
-
self.data["annotations"] = instances["annotations"]
|
68 |
-
self.data["categories"] = instances["categories"]
|
69 |
-
|
70 |
-
# create index
|
71 |
-
self.createIndex()
|
72 |
-
print("DONE (t=%.2fs)" % (time.time() - tic))
|
73 |
-
|
74 |
-
@staticmethod
|
75 |
-
def _toList(x):
|
76 |
-
return x if isinstance(x, list) else [x]
|
77 |
-
|
78 |
-
@staticmethod
|
79 |
-
def match_any(a, b):
|
80 |
-
a = a if isinstance(a, list) else [a]
|
81 |
-
b = b if isinstance(b, list) else [b]
|
82 |
-
return set(a) & set(b)
|
83 |
-
|
84 |
-
def createIndex(self):
|
85 |
-
# create sets of mapping
|
86 |
-
# 1) Refs: {ref_id: ref}
|
87 |
-
# 2) Anns: {ann_id: ann}
|
88 |
-
# 3) Imgs: {image_id: image}
|
89 |
-
# 4) Cats: {category_id: category_name}
|
90 |
-
# 5) Sents: {sent_id: sent}
|
91 |
-
# 6) imgToRefs: {image_id: refs}
|
92 |
-
# 7) imgToAnns: {image_id: anns}
|
93 |
-
# 8) refToAnn: {ref_id: ann}
|
94 |
-
# 9) annToRef: {ann_id: ref}
|
95 |
-
# 10) catToRefs: {category_id: refs}
|
96 |
-
# 11) sentToRef: {sent_id: ref}
|
97 |
-
# 12) sentToTokens: {sent_id: tokens}
|
98 |
-
print("creating index...")
|
99 |
-
# fetch info from instances
|
100 |
-
Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
|
101 |
-
Anns[-1] = None
|
102 |
-
for ann in self.data["annotations"]:
|
103 |
-
Anns[ann["id"]] = ann
|
104 |
-
imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann]
|
105 |
-
for img in self.data["images"]:
|
106 |
-
Imgs[img["id"]] = img
|
107 |
-
for cat in self.data["categories"]:
|
108 |
-
Cats[cat["id"]] = cat["name"]
|
109 |
-
|
110 |
-
# fetch info from refs
|
111 |
-
Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
|
112 |
-
Sents, sentToRef, sentToTokens = {}, {}, {}
|
113 |
-
availableSplits = []
|
114 |
-
for ref in self.data["refs"]:
|
115 |
-
# ids
|
116 |
-
ref_id = ref["ref_id"]
|
117 |
-
ann_id = ref["ann_id"]
|
118 |
-
category_id = ref["category_id"]
|
119 |
-
image_id = ref["image_id"]
|
120 |
-
|
121 |
-
if ref["split"] not in availableSplits:
|
122 |
-
availableSplits.append(ref["split"])
|
123 |
-
|
124 |
-
# add mapping related to ref
|
125 |
-
if ref_id in Refs:
|
126 |
-
print("Duplicate ref id")
|
127 |
-
Refs[ref_id] = ref
|
128 |
-
imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
|
129 |
-
|
130 |
-
category_id = self._toList(category_id)
|
131 |
-
added_cats = []
|
132 |
-
for cat in category_id:
|
133 |
-
if cat not in added_cats:
|
134 |
-
added_cats.append(cat)
|
135 |
-
catToRefs[cat] = catToRefs.get(cat, []) + [ref]
|
136 |
-
|
137 |
-
ann_id = self._toList(ann_id)
|
138 |
-
refToAnn[ref_id] = [Anns[ann] for ann in ann_id]
|
139 |
-
for ann_id_n in ann_id:
|
140 |
-
annToRef[ann_id_n] = annToRef.get(ann_id_n, []) + [ref]
|
141 |
-
|
142 |
-
# add mapping of sent
|
143 |
-
for sent in ref["sentences"]:
|
144 |
-
Sents[sent["sent_id"]] = sent
|
145 |
-
sentToRef[sent["sent_id"]] = ref
|
146 |
-
sentToTokens[sent["sent_id"]] = sent["tokens"]
|
147 |
-
|
148 |
-
# create class members
|
149 |
-
self.Refs = Refs
|
150 |
-
self.Anns = Anns
|
151 |
-
self.Imgs = Imgs
|
152 |
-
self.Cats = Cats
|
153 |
-
self.Sents = Sents
|
154 |
-
self.imgToRefs = imgToRefs
|
155 |
-
self.imgToAnns = imgToAnns
|
156 |
-
self.refToAnn = refToAnn
|
157 |
-
self.annToRef = annToRef
|
158 |
-
self.catToRefs = catToRefs
|
159 |
-
self.sentToRef = sentToRef
|
160 |
-
self.sentToTokens = sentToTokens
|
161 |
-
self.availableSplits = availableSplits
|
162 |
-
print("index created.")
|
163 |
-
|
164 |
-
def getRefIds(self, image_ids=[], cat_ids=[], split=[]):
|
165 |
-
image_ids = self._toList(image_ids)
|
166 |
-
cat_ids = self._toList(cat_ids)
|
167 |
-
split = self._toList(split)
|
168 |
-
|
169 |
-
for s in split:
|
170 |
-
if s not in self.availableSplits:
|
171 |
-
raise ValueError(f"Invalid split name: {s}")
|
172 |
-
|
173 |
-
refs = self.data["refs"]
|
174 |
-
|
175 |
-
if len(image_ids) > 0:
|
176 |
-
lists = [self.imgToRefs[image_id] for image_id in image_ids]
|
177 |
-
refs = list(itertools.chain.from_iterable(lists))
|
178 |
-
if len(cat_ids) > 0:
|
179 |
-
refs = [ref for ref in refs if self.match_any(ref["category_id"], cat_ids)]
|
180 |
-
if len(split) > 0:
|
181 |
-
refs = [ref for ref in refs if ref["split"] in split]
|
182 |
-
|
183 |
-
ref_ids = [ref["ref_id"] for ref in refs]
|
184 |
-
return ref_ids
|
185 |
-
|
186 |
-
def getAnnIds(self, image_ids=[], ref_ids=[]):
|
187 |
-
image_ids = self._toList(image_ids)
|
188 |
-
ref_ids = self._toList(ref_ids)
|
189 |
-
|
190 |
-
if any([len(image_ids), len(ref_ids)]):
|
191 |
-
if len(image_ids) > 0:
|
192 |
-
lists = [
|
193 |
-
self.imgToAnns[image_id]
|
194 |
-
for image_id in image_ids
|
195 |
-
if image_id in self.imgToAnns
|
196 |
-
]
|
197 |
-
anns = list(itertools.chain.from_iterable(lists))
|
198 |
-
else:
|
199 |
-
anns = self.data["annotations"]
|
200 |
-
ann_ids = [ann["id"] for ann in anns]
|
201 |
-
if len(ref_ids) > 0:
|
202 |
-
lists = [self.Refs[ref_id]["ann_id"] for ref_id in ref_ids]
|
203 |
-
anns_by_ref_id = list(itertools.chain.from_iterable(lists))
|
204 |
-
ann_ids = list(set(ann_ids).intersection(set(anns_by_ref_id)))
|
205 |
-
else:
|
206 |
-
ann_ids = [ann["id"] for ann in self.data["annotations"]]
|
207 |
-
|
208 |
-
return ann_ids
|
209 |
-
|
210 |
-
def getImgIds(self, ref_ids=[]):
|
211 |
-
ref_ids = self._toList(ref_ids)
|
212 |
-
|
213 |
-
if len(ref_ids) > 0:
|
214 |
-
image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids]))
|
215 |
-
else:
|
216 |
-
image_ids = self.Imgs.keys()
|
217 |
-
return image_ids
|
218 |
-
|
219 |
-
def getCatIds(self):
|
220 |
-
return self.Cats.keys()
|
221 |
-
|
222 |
-
def loadRefs(self, ref_ids=[]):
|
223 |
-
return [self.Refs[ref_id] for ref_id in self._toList(ref_ids)]
|
224 |
-
|
225 |
-
def loadAnns(self, ann_ids=[]):
|
226 |
-
if isinstance(ann_ids, str):
|
227 |
-
ann_ids = int(ann_ids)
|
228 |
-
return [self.Anns[ann_id] for ann_id in self._toList(ann_ids)]
|
229 |
-
|
230 |
-
def loadImgs(self, image_ids=[]):
|
231 |
-
return [self.Imgs[image_id] for image_id in self._toList(image_ids)]
|
232 |
-
|
233 |
-
def loadCats(self, cat_ids=[]):
|
234 |
-
return [self.Cats[cat_id] for cat_id in self._toList(cat_ids)]
|
235 |
-
|
236 |
-
def getRefBox(self, ref_id):
|
237 |
-
anns = self.refToAnn[ref_id]
|
238 |
-
return [ann["bbox"] for ann in anns] # [x, y, w, h]
|
239 |
-
|
240 |
-
def showRef(self, ref, seg_box="seg"):
|
241 |
-
ax = plt.gca()
|
242 |
-
# show image
|
243 |
-
image = self.Imgs[ref["image_id"]]
|
244 |
-
I = io.imread(osp.join(self.IMAGE_DIR, image["file_name"]))
|
245 |
-
ax.imshow(I)
|
246 |
-
# show refer expression
|
247 |
-
for sid, sent in enumerate(ref["sentences"]):
|
248 |
-
print("%s. %s" % (sid + 1, sent["sent"]))
|
249 |
-
# show segmentations
|
250 |
-
if seg_box == "seg":
|
251 |
-
ann_id = ref["ann_id"]
|
252 |
-
ann = self.Anns[ann_id]
|
253 |
-
polygons = []
|
254 |
-
color = []
|
255 |
-
c = "none"
|
256 |
-
if type(ann["segmentation"][0]) == list:
|
257 |
-
# polygon used for refcoco*
|
258 |
-
for seg in ann["segmentation"]:
|
259 |
-
poly = np.array(seg).reshape((len(seg) / 2, 2))
|
260 |
-
polygons.append(Polygon(poly, True, alpha=0.4))
|
261 |
-
color.append(c)
|
262 |
-
p = PatchCollection(
|
263 |
-
polygons,
|
264 |
-
facecolors=color,
|
265 |
-
edgecolors=(1, 1, 0, 0),
|
266 |
-
linewidths=3,
|
267 |
-
alpha=1,
|
268 |
-
)
|
269 |
-
ax.add_collection(p) # thick yellow polygon
|
270 |
-
p = PatchCollection(
|
271 |
-
polygons,
|
272 |
-
facecolors=color,
|
273 |
-
edgecolors=(1, 0, 0, 0),
|
274 |
-
linewidths=1,
|
275 |
-
alpha=1,
|
276 |
-
)
|
277 |
-
ax.add_collection(p) # thin red polygon
|
278 |
-
else:
|
279 |
-
# mask used for refclef
|
280 |
-
rle = ann["segmentation"]
|
281 |
-
m = mask.decode(rle)
|
282 |
-
img = np.ones((m.shape[0], m.shape[1], 3))
|
283 |
-
color_mask = np.array([2.0, 166.0, 101.0]) / 255
|
284 |
-
for i in range(3):
|
285 |
-
img[:, :, i] = color_mask[i]
|
286 |
-
ax.imshow(np.dstack((img, m * 0.5)))
|
287 |
-
# show bounding-box
|
288 |
-
elif seg_box == "box":
|
289 |
-
ann_id = ref["ann_id"]
|
290 |
-
ann = self.Anns[ann_id]
|
291 |
-
bbox = self.getRefBox(ref["ref_id"])
|
292 |
-
box_plot = Rectangle(
|
293 |
-
(bbox[0], bbox[1]),
|
294 |
-
bbox[2],
|
295 |
-
bbox[3],
|
296 |
-
fill=False,
|
297 |
-
edgecolor="green",
|
298 |
-
linewidth=3,
|
299 |
-
)
|
300 |
-
ax.add_patch(box_plot)
|
301 |
-
|
302 |
-
def getMask(self, ann):
|
303 |
-
if not ann:
|
304 |
-
return None
|
305 |
-
if ann["iscrowd"]:
|
306 |
-
raise ValueError("Crowd object")
|
307 |
-
image = self.Imgs[ann["image_id"]]
|
308 |
-
if type(ann["segmentation"][0]) == list: # polygon
|
309 |
-
rle = mask.frPyObjects(ann["segmentation"], image["height"], image["width"])
|
310 |
-
else:
|
311 |
-
rle = ann["segmentation"]
|
312 |
-
|
313 |
-
m = mask.decode(rle)
|
314 |
-
m = np.sum(
|
315 |
-
m, axis=2
|
316 |
-
) # sometimes there are multiple binary map (corresponding to multiple segs)
|
317 |
-
m = m.astype(np.uint8) # convert to np.uint8
|
318 |
-
# compute area
|
319 |
-
area = sum(mask.area(rle)) # should be close to ann['area']
|
320 |
-
return {"mask": m, "area": area}
|
321 |
-
|
322 |
-
def getMaskByRef(self, ref=None, ref_id=None, merge=False):
|
323 |
-
if not ref and not ref_id:
|
324 |
-
raise ValueError
|
325 |
-
if ref:
|
326 |
-
ann_ids = ref["ann_id"]
|
327 |
-
ref_id = ref["ref_id"]
|
328 |
-
else:
|
329 |
-
ann_ids = self.getAnnIds(ref_ids=ref_id)
|
330 |
-
|
331 |
-
if ann_ids == [-1]:
|
332 |
-
img = self.Imgs[self.Refs[ref_id]["image_id"]]
|
333 |
-
return {
|
334 |
-
"mask": np.zeros([img["height"], img["width"]], dtype=np.uint8),
|
335 |
-
"empty": True,
|
336 |
-
}
|
337 |
-
|
338 |
-
anns = self.loadAnns(ann_ids)
|
339 |
-
mask_list = [self.getMask(ann) for ann in anns if not ann["iscrowd"]]
|
340 |
-
|
341 |
-
if merge:
|
342 |
-
merged_masks = sum([mask["mask"] for mask in mask_list])
|
343 |
-
merged_masks[np.where(merged_masks > 1)] = 1
|
344 |
-
return {"mask": merged_masks, "empty": False}
|
345 |
-
else:
|
346 |
-
return mask_list
|
347 |
-
|
348 |
-
def showMask(self, ref):
|
349 |
-
M = self.getMask(ref)
|
350 |
-
msk = M["mask"]
|
351 |
-
ax = plt.gca()
|
352 |
-
ax.imshow(msk)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lisa_on_cuda/utils/reason_seg_dataset.py
DELETED
@@ -1,218 +0,0 @@
|
|
1 |
-
import glob
|
2 |
-
import json
|
3 |
-
import os
|
4 |
-
import random
|
5 |
-
|
6 |
-
import cv2
|
7 |
-
import numpy as np
|
8 |
-
import torch
|
9 |
-
import torch.nn.functional as F
|
10 |
-
from transformers import CLIPImageProcessor
|
11 |
-
|
12 |
-
from lisa_on_cuda.llava import conversation as conversation_lib
|
13 |
-
from lisa_on_cuda.segment_anything.utils.transforms import ResizeLongestSide
|
14 |
-
|
15 |
-
from .data_processing import get_mask_from_json
|
16 |
-
from .utils import (ANSWER_LIST, DEFAULT_IMAGE_TOKEN,
|
17 |
-
EXPLANATORY_QUESTION_LIST, LONG_QUESTION_LIST,
|
18 |
-
SHORT_QUESTION_LIST)
|
19 |
-
|
20 |
-
|
21 |
-
class ReasonSegDataset(torch.utils.data.Dataset):
|
22 |
-
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
|
23 |
-
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
|
24 |
-
img_size = 1024
|
25 |
-
ignore_label = 255
|
26 |
-
|
27 |
-
def __init__(
|
28 |
-
self,
|
29 |
-
base_image_dir,
|
30 |
-
tokenizer,
|
31 |
-
vision_tower,
|
32 |
-
samples_per_epoch=500 * 8 * 2 * 10,
|
33 |
-
precision: str = "fp32",
|
34 |
-
image_size: int = 224,
|
35 |
-
num_classes_per_sample: int = 3,
|
36 |
-
exclude_val=False,
|
37 |
-
reason_seg_data="ReasonSeg|train",
|
38 |
-
explanatory=0.1,
|
39 |
-
):
|
40 |
-
self.exclude_val = exclude_val
|
41 |
-
self.reason_seg_data = reason_seg_data
|
42 |
-
self.samples_per_epoch = samples_per_epoch
|
43 |
-
self.explanatory = explanatory
|
44 |
-
self.num_classes_per_sample = num_classes_per_sample
|
45 |
-
|
46 |
-
self.base_image_dir = base_image_dir
|
47 |
-
self.image_size = image_size
|
48 |
-
self.tokenizer = tokenizer
|
49 |
-
self.precision = precision
|
50 |
-
self.transform = ResizeLongestSide(image_size)
|
51 |
-
self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
52 |
-
|
53 |
-
self.short_question_list = SHORT_QUESTION_LIST
|
54 |
-
self.long_question_list = LONG_QUESTION_LIST
|
55 |
-
self.answer_list = ANSWER_LIST
|
56 |
-
|
57 |
-
reason_seg_data, splits = reason_seg_data.split("|")
|
58 |
-
splits = splits.split("_")
|
59 |
-
images = []
|
60 |
-
for split in splits:
|
61 |
-
images_split = glob.glob(
|
62 |
-
os.path.join(
|
63 |
-
base_image_dir, "reason_seg", reason_seg_data, split, "*.jpg"
|
64 |
-
)
|
65 |
-
)
|
66 |
-
images.extend(images_split)
|
67 |
-
jsons = [path.replace(".jpg", ".json") for path in images]
|
68 |
-
self.reason_seg_data = (images, jsons)
|
69 |
-
|
70 |
-
print("number of reason_seg samples: ", len(images))
|
71 |
-
|
72 |
-
if explanatory != -1:
|
73 |
-
self.explanatory_question_list = EXPLANATORY_QUESTION_LIST
|
74 |
-
self.img_to_explanation = {}
|
75 |
-
with open(
|
76 |
-
os.path.join(
|
77 |
-
base_image_dir,
|
78 |
-
"reason_seg",
|
79 |
-
reason_seg_data,
|
80 |
-
"explanatory",
|
81 |
-
"train.json",
|
82 |
-
)
|
83 |
-
) as f:
|
84 |
-
items = json.load(f)
|
85 |
-
for item in items:
|
86 |
-
img_name = item["image"]
|
87 |
-
self.img_to_explanation[img_name] = {
|
88 |
-
"query": item["query"],
|
89 |
-
"outputs": item["outputs"],
|
90 |
-
}
|
91 |
-
|
92 |
-
print("len(self.img_to_explanation): ", len(self.img_to_explanation))
|
93 |
-
|
94 |
-
def __len__(self):
|
95 |
-
return self.samples_per_epoch
|
96 |
-
|
97 |
-
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
98 |
-
"""Normalize pixel values and pad to a square input."""
|
99 |
-
# Normalize colors
|
100 |
-
x = (x - self.pixel_mean) / self.pixel_std
|
101 |
-
|
102 |
-
# Pad
|
103 |
-
h, w = x.shape[-2:]
|
104 |
-
padh = self.img_size - h
|
105 |
-
padw = self.img_size - w
|
106 |
-
x = F.pad(x, (0, padw, 0, padh))
|
107 |
-
return x
|
108 |
-
|
109 |
-
def __getitem__(self, idx):
|
110 |
-
images, jsons = self.reason_seg_data
|
111 |
-
idx = random.randint(0, len(images) - 1)
|
112 |
-
image_path = images[idx]
|
113 |
-
json_path = jsons[idx]
|
114 |
-
|
115 |
-
image = cv2.imread(image_path)
|
116 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
117 |
-
ori_size = image.shape[:2]
|
118 |
-
# preprocess image for clip
|
119 |
-
image_clip = self.clip_image_processor.preprocess(image, return_tensors="pt")[
|
120 |
-
"pixel_values"
|
121 |
-
][0]
|
122 |
-
|
123 |
-
mask, sents, is_sentence = get_mask_from_json(json_path, image)
|
124 |
-
if len(sents) >= self.num_classes_per_sample:
|
125 |
-
sampled_inds = np.random.choice(
|
126 |
-
list(range(len(sents))), size=self.num_classes_per_sample, replace=False
|
127 |
-
)
|
128 |
-
else:
|
129 |
-
sampled_inds = list(range(len(sents)))
|
130 |
-
sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist()
|
131 |
-
sampled_masks = [
|
132 |
-
(mask == 1).astype(np.float32) for _ in range(len(sampled_inds))
|
133 |
-
]
|
134 |
-
|
135 |
-
image = self.transform.apply_image(image) # preprocess image for sam
|
136 |
-
resize = image.shape[:2]
|
137 |
-
|
138 |
-
image_name = image_path.split("/")[-1]
|
139 |
-
if self.explanatory != -1 and image_name in self.img_to_explanation:
|
140 |
-
if random.random() < self.explanatory:
|
141 |
-
choice = 2
|
142 |
-
else:
|
143 |
-
choice = random.randint(0, 1)
|
144 |
-
|
145 |
-
questions = []
|
146 |
-
answers = []
|
147 |
-
for text in sampled_sents:
|
148 |
-
if is_sentence:
|
149 |
-
question_template = random.choice(self.long_question_list)
|
150 |
-
questions.append(question_template.format(sent=text))
|
151 |
-
else:
|
152 |
-
question_template = random.choice(self.short_question_list)
|
153 |
-
questions.append(question_template.format(class_name=text.lower()))
|
154 |
-
|
155 |
-
# add explanation if applicable
|
156 |
-
img_name = image_path.split("/")[-1]
|
157 |
-
if self.explanatory != -1 and img_name in self.img_to_explanation:
|
158 |
-
if choice == 0: # [SEG] token
|
159 |
-
answers.append(random.choice(self.answer_list))
|
160 |
-
elif choice == 1: # [SEG] token + text answer
|
161 |
-
image_name = image_path.split("/")[-1]
|
162 |
-
answer = self.img_to_explanation[image_name]["outputs"]
|
163 |
-
answer = random.choice(self.answer_list) + " {}".format(answer)
|
164 |
-
questions[-1] = (
|
165 |
-
DEFAULT_IMAGE_TOKEN
|
166 |
-
+ "\n"
|
167 |
-
+ text
|
168 |
-
+ " {}".format(random.choice(self.explanatory_question_list))
|
169 |
-
)
|
170 |
-
answers.append(answer)
|
171 |
-
elif choice == 2: # vanilla text answer
|
172 |
-
image_name = image_path.split("/")[-1]
|
173 |
-
answer = self.img_to_explanation[image_name]["outputs"]
|
174 |
-
questions[-1] = DEFAULT_IMAGE_TOKEN + "\n" + text
|
175 |
-
answers.append(answer)
|
176 |
-
else:
|
177 |
-
raise ValueError("Not implemented yet.")
|
178 |
-
else:
|
179 |
-
answers.append(random.choice(self.answer_list))
|
180 |
-
|
181 |
-
conversations = []
|
182 |
-
conv = conversation_lib.default_conversation.copy()
|
183 |
-
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
184 |
-
|
185 |
-
i = 0
|
186 |
-
while i < len(questions):
|
187 |
-
conv.messages = []
|
188 |
-
conv.append_message(conv.roles[0], questions[i])
|
189 |
-
conv.append_message(conv.roles[1], answers[i])
|
190 |
-
conversations.append(conv.get_prompt())
|
191 |
-
i += 1
|
192 |
-
|
193 |
-
image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
|
194 |
-
|
195 |
-
image_name = image_path.split("/")[-1]
|
196 |
-
if (
|
197 |
-
self.explanatory != -1
|
198 |
-
and image_name in self.img_to_explanation
|
199 |
-
and choice == 2
|
200 |
-
):
|
201 |
-
masks = torch.rand(0, *ori_size)
|
202 |
-
label = torch.ones(ori_size) * self.ignore_label
|
203 |
-
else:
|
204 |
-
masks = np.stack(sampled_masks, axis=0)
|
205 |
-
masks = torch.from_numpy(masks)
|
206 |
-
label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
|
207 |
-
|
208 |
-
return (
|
209 |
-
image_path,
|
210 |
-
image,
|
211 |
-
image_clip,
|
212 |
-
conversations,
|
213 |
-
masks,
|
214 |
-
label,
|
215 |
-
resize,
|
216 |
-
questions,
|
217 |
-
sampled_sents,
|
218 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lisa_on_cuda/utils/refer.py
DELETED
@@ -1,391 +0,0 @@
|
|
1 |
-
__author__ = "licheng"
|
2 |
-
|
3 |
-
"""
|
4 |
-
This interface provides access to four datasets:
|
5 |
-
1) refclef
|
6 |
-
2) refcoco
|
7 |
-
3) refcoco+
|
8 |
-
4) refcocog
|
9 |
-
split by unc and google
|
10 |
-
|
11 |
-
The following API functions are defined:
|
12 |
-
REFER - REFER api class
|
13 |
-
getRefIds - get ref ids that satisfy given filter conditions.
|
14 |
-
getAnnIds - get ann ids that satisfy given filter conditions.
|
15 |
-
getImgIds - get image ids that satisfy given filter conditions.
|
16 |
-
getCatIds - get category ids that satisfy given filter conditions.
|
17 |
-
loadRefs - load refs with the specified ref ids.
|
18 |
-
loadAnns - load anns with the specified ann ids.
|
19 |
-
loadImgs - load images with the specified image ids.
|
20 |
-
loadCats - load category names with the specified category ids.
|
21 |
-
getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
|
22 |
-
showRef - show image, segmentation or box of the referred object with the ref
|
23 |
-
getMask - get mask and area of the referred object given ref
|
24 |
-
showMask - show mask of the referred object given ref
|
25 |
-
"""
|
26 |
-
|
27 |
-
import itertools
|
28 |
-
import json
|
29 |
-
import os.path as osp
|
30 |
-
import pickle
|
31 |
-
import sys
|
32 |
-
import time
|
33 |
-
from pprint import pprint
|
34 |
-
|
35 |
-
import matplotlib.pyplot as plt
|
36 |
-
import numpy as np
|
37 |
-
import skimage.io as io
|
38 |
-
from matplotlib.collections import PatchCollection
|
39 |
-
from matplotlib.patches import Polygon, Rectangle
|
40 |
-
from pycocotools import mask
|
41 |
-
|
42 |
-
|
43 |
-
class REFER:
|
44 |
-
def __init__(self, data_root, dataset="refcoco", splitBy="unc"):
|
45 |
-
# provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
|
46 |
-
# also provide dataset name and splitBy information
|
47 |
-
# e.g., dataset = 'refcoco', splitBy = 'unc'
|
48 |
-
print("loading dataset %s into memory..." % dataset)
|
49 |
-
self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
|
50 |
-
self.DATA_DIR = osp.join(data_root, dataset)
|
51 |
-
if dataset in ["refcoco", "refcoco+", "refcocog"]:
|
52 |
-
self.IMAGE_DIR = osp.join(data_root, "images/mscoco/images/train2014")
|
53 |
-
elif dataset == "refclef":
|
54 |
-
self.IMAGE_DIR = osp.join(data_root, "images/saiapr_tc-12")
|
55 |
-
else:
|
56 |
-
print("No refer dataset is called [%s]" % dataset)
|
57 |
-
sys.exit()
|
58 |
-
|
59 |
-
self.dataset = dataset
|
60 |
-
|
61 |
-
# load refs from data/dataset/refs(dataset).json
|
62 |
-
tic = time.time()
|
63 |
-
|
64 |
-
ref_file = osp.join(self.DATA_DIR, "refs(" + splitBy + ").p")
|
65 |
-
print("ref_file: ", ref_file)
|
66 |
-
self.data = {}
|
67 |
-
self.data["dataset"] = dataset
|
68 |
-
self.data["refs"] = pickle.load(open(ref_file, "rb"))
|
69 |
-
|
70 |
-
# load annotations from data/dataset/instances.json
|
71 |
-
instances_file = osp.join(self.DATA_DIR, "instances.json")
|
72 |
-
instances = json.load(open(instances_file, "rb"))
|
73 |
-
self.data["images"] = instances["images"]
|
74 |
-
self.data["annotations"] = instances["annotations"]
|
75 |
-
self.data["categories"] = instances["categories"]
|
76 |
-
|
77 |
-
# create index
|
78 |
-
self.createIndex()
|
79 |
-
print("DONE (t=%.2fs)" % (time.time() - tic))
|
80 |
-
|
81 |
-
def createIndex(self):
|
82 |
-
# create sets of mapping
|
83 |
-
# 1) Refs: {ref_id: ref}
|
84 |
-
# 2) Anns: {ann_id: ann}
|
85 |
-
# 3) Imgs: {image_id: image}
|
86 |
-
# 4) Cats: {category_id: category_name}
|
87 |
-
# 5) Sents: {sent_id: sent}
|
88 |
-
# 6) imgToRefs: {image_id: refs}
|
89 |
-
# 7) imgToAnns: {image_id: anns}
|
90 |
-
# 8) refToAnn: {ref_id: ann}
|
91 |
-
# 9) annToRef: {ann_id: ref}
|
92 |
-
# 10) catToRefs: {category_id: refs}
|
93 |
-
# 11) sentToRef: {sent_id: ref}
|
94 |
-
# 12) sentToTokens: {sent_id: tokens}
|
95 |
-
print("creating index...")
|
96 |
-
# fetch info from instances
|
97 |
-
Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
|
98 |
-
for ann in self.data["annotations"]:
|
99 |
-
Anns[ann["id"]] = ann
|
100 |
-
imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann]
|
101 |
-
for img in self.data["images"]:
|
102 |
-
Imgs[img["id"]] = img
|
103 |
-
for cat in self.data["categories"]:
|
104 |
-
Cats[cat["id"]] = cat["name"]
|
105 |
-
|
106 |
-
# fetch info from refs
|
107 |
-
Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
|
108 |
-
Sents, sentToRef, sentToTokens = {}, {}, {}
|
109 |
-
for ref in self.data["refs"]:
|
110 |
-
# ids
|
111 |
-
ref_id = ref["ref_id"]
|
112 |
-
ann_id = ref["ann_id"]
|
113 |
-
category_id = ref["category_id"]
|
114 |
-
image_id = ref["image_id"]
|
115 |
-
|
116 |
-
# add mapping related to ref
|
117 |
-
Refs[ref_id] = ref
|
118 |
-
imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
|
119 |
-
catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
|
120 |
-
refToAnn[ref_id] = Anns[ann_id]
|
121 |
-
annToRef[ann_id] = ref
|
122 |
-
|
123 |
-
# add mapping of sent
|
124 |
-
for sent in ref["sentences"]:
|
125 |
-
Sents[sent["sent_id"]] = sent
|
126 |
-
sentToRef[sent["sent_id"]] = ref
|
127 |
-
sentToTokens[sent["sent_id"]] = sent["tokens"]
|
128 |
-
|
129 |
-
# create class members
|
130 |
-
self.Refs = Refs
|
131 |
-
self.Anns = Anns
|
132 |
-
self.Imgs = Imgs
|
133 |
-
self.Cats = Cats
|
134 |
-
self.Sents = Sents
|
135 |
-
self.imgToRefs = imgToRefs
|
136 |
-
self.imgToAnns = imgToAnns
|
137 |
-
self.refToAnn = refToAnn
|
138 |
-
self.annToRef = annToRef
|
139 |
-
self.catToRefs = catToRefs
|
140 |
-
self.sentToRef = sentToRef
|
141 |
-
self.sentToTokens = sentToTokens
|
142 |
-
print("index created.")
|
143 |
-
|
144 |
-
def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=""):
|
145 |
-
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
146 |
-
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
147 |
-
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
148 |
-
|
149 |
-
if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
|
150 |
-
refs = self.data["refs"]
|
151 |
-
else:
|
152 |
-
if not len(image_ids) == 0:
|
153 |
-
refs = [self.imgToRefs[image_id] for image_id in image_ids]
|
154 |
-
else:
|
155 |
-
refs = self.data["refs"]
|
156 |
-
if not len(cat_ids) == 0:
|
157 |
-
refs = [ref for ref in refs if ref["category_id"] in cat_ids]
|
158 |
-
if not len(ref_ids) == 0:
|
159 |
-
refs = [ref for ref in refs if ref["ref_id"] in ref_ids]
|
160 |
-
if not len(split) == 0:
|
161 |
-
if split in ["testA", "testB", "testC"]:
|
162 |
-
refs = [
|
163 |
-
ref for ref in refs if split[-1] in ref["split"]
|
164 |
-
] # we also consider testAB, testBC, ...
|
165 |
-
elif split in ["testAB", "testBC", "testAC"]:
|
166 |
-
refs = [
|
167 |
-
ref for ref in refs if ref["split"] == split
|
168 |
-
] # rarely used I guess...
|
169 |
-
elif split == "test":
|
170 |
-
refs = [ref for ref in refs if "test" in ref["split"]]
|
171 |
-
elif split == "train" or split == "val":
|
172 |
-
refs = [ref for ref in refs if ref["split"] == split]
|
173 |
-
else:
|
174 |
-
print("No such split [%s]" % split)
|
175 |
-
sys.exit()
|
176 |
-
ref_ids = [ref["ref_id"] for ref in refs]
|
177 |
-
return ref_ids
|
178 |
-
|
179 |
-
def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
|
180 |
-
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
181 |
-
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
182 |
-
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
183 |
-
|
184 |
-
if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
|
185 |
-
ann_ids = [ann["id"] for ann in self.data["annotations"]]
|
186 |
-
else:
|
187 |
-
if not len(image_ids) == 0:
|
188 |
-
lists = [
|
189 |
-
self.imgToAnns[image_id]
|
190 |
-
for image_id in image_ids
|
191 |
-
if image_id in self.imgToAnns
|
192 |
-
] # list of [anns]
|
193 |
-
anns = list(itertools.chain.from_iterable(lists))
|
194 |
-
else:
|
195 |
-
anns = self.data["annotations"]
|
196 |
-
if not len(cat_ids) == 0:
|
197 |
-
anns = [ann for ann in anns if ann["category_id"] in cat_ids]
|
198 |
-
ann_ids = [ann["id"] for ann in anns]
|
199 |
-
if not len(ref_ids) == 0:
|
200 |
-
ids = set(ann_ids).intersection(
|
201 |
-
set([self.Refs[ref_id]["ann_id"] for ref_id in ref_ids])
|
202 |
-
)
|
203 |
-
return ann_ids
|
204 |
-
|
205 |
-
def getImgIds(self, ref_ids=[]):
|
206 |
-
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
207 |
-
|
208 |
-
if not len(ref_ids) == 0:
|
209 |
-
image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids]))
|
210 |
-
else:
|
211 |
-
image_ids = self.Imgs.keys()
|
212 |
-
return image_ids
|
213 |
-
|
214 |
-
def getCatIds(self):
|
215 |
-
return self.Cats.keys()
|
216 |
-
|
217 |
-
def loadRefs(self, ref_ids=[]):
|
218 |
-
if type(ref_ids) == list:
|
219 |
-
return [self.Refs[ref_id] for ref_id in ref_ids]
|
220 |
-
elif type(ref_ids) == int:
|
221 |
-
return [self.Refs[ref_ids]]
|
222 |
-
|
223 |
-
def loadAnns(self, ann_ids=[]):
|
224 |
-
if type(ann_ids) == list:
|
225 |
-
return [self.Anns[ann_id] for ann_id in ann_ids]
|
226 |
-
elif type(ann_ids) == int or type(ann_ids) == unicode:
|
227 |
-
return [self.Anns[ann_ids]]
|
228 |
-
|
229 |
-
def loadImgs(self, image_ids=[]):
|
230 |
-
if type(image_ids) == list:
|
231 |
-
return [self.Imgs[image_id] for image_id in image_ids]
|
232 |
-
elif type(image_ids) == int:
|
233 |
-
return [self.Imgs[image_ids]]
|
234 |
-
|
235 |
-
def loadCats(self, cat_ids=[]):
|
236 |
-
if type(cat_ids) == list:
|
237 |
-
return [self.Cats[cat_id] for cat_id in cat_ids]
|
238 |
-
elif type(cat_ids) == int:
|
239 |
-
return [self.Cats[cat_ids]]
|
240 |
-
|
241 |
-
def getRefBox(self, ref_id):
|
242 |
-
ref = self.Refs[ref_id]
|
243 |
-
ann = self.refToAnn[ref_id]
|
244 |
-
return ann["bbox"] # [x, y, w, h]
|
245 |
-
|
246 |
-
def showRef(self, ref, seg_box="seg"):
|
247 |
-
ax = plt.gca()
|
248 |
-
# show image
|
249 |
-
image = self.Imgs[ref["image_id"]]
|
250 |
-
I = io.imread(osp.join(self.IMAGE_DIR, image["file_name"]))
|
251 |
-
ax.imshow(I)
|
252 |
-
# show refer expression
|
253 |
-
for sid, sent in enumerate(ref["sentences"]):
|
254 |
-
print("%s. %s" % (sid + 1, sent["sent"]))
|
255 |
-
# show segmentations
|
256 |
-
if seg_box == "seg":
|
257 |
-
ann_id = ref["ann_id"]
|
258 |
-
ann = self.Anns[ann_id]
|
259 |
-
polygons = []
|
260 |
-
color = []
|
261 |
-
c = "none"
|
262 |
-
if type(ann["segmentation"][0]) == list:
|
263 |
-
# polygon used for refcoco*
|
264 |
-
for seg in ann["segmentation"]:
|
265 |
-
poly = np.array(seg).reshape((len(seg) / 2, 2))
|
266 |
-
polygons.append(Polygon(poly, True, alpha=0.4))
|
267 |
-
color.append(c)
|
268 |
-
p = PatchCollection(
|
269 |
-
polygons,
|
270 |
-
facecolors=color,
|
271 |
-
edgecolors=(1, 1, 0, 0),
|
272 |
-
linewidths=3,
|
273 |
-
alpha=1,
|
274 |
-
)
|
275 |
-
ax.add_collection(p) # thick yellow polygon
|
276 |
-
p = PatchCollection(
|
277 |
-
polygons,
|
278 |
-
facecolors=color,
|
279 |
-
edgecolors=(1, 0, 0, 0),
|
280 |
-
linewidths=1,
|
281 |
-
alpha=1,
|
282 |
-
)
|
283 |
-
ax.add_collection(p) # thin red polygon
|
284 |
-
else:
|
285 |
-
# mask used for refclef
|
286 |
-
rle = ann["segmentation"]
|
287 |
-
m = mask.decode(rle)
|
288 |
-
img = np.ones((m.shape[0], m.shape[1], 3))
|
289 |
-
color_mask = np.array([2.0, 166.0, 101.0]) / 255
|
290 |
-
for i in range(3):
|
291 |
-
img[:, :, i] = color_mask[i]
|
292 |
-
ax.imshow(np.dstack((img, m * 0.5)))
|
293 |
-
# show bounding-box
|
294 |
-
elif seg_box == "box":
|
295 |
-
ann_id = ref["ann_id"]
|
296 |
-
ann = self.Anns[ann_id]
|
297 |
-
bbox = self.getRefBox(ref["ref_id"])
|
298 |
-
box_plot = Rectangle(
|
299 |
-
(bbox[0], bbox[1]),
|
300 |
-
bbox[2],
|
301 |
-
bbox[3],
|
302 |
-
fill=False,
|
303 |
-
edgecolor="green",
|
304 |
-
linewidth=3,
|
305 |
-
)
|
306 |
-
ax.add_patch(box_plot)
|
307 |
-
|
308 |
-
def getMask(self, ref):
|
309 |
-
# return mask, area and mask-center
|
310 |
-
ann = self.refToAnn[ref["ref_id"]]
|
311 |
-
image = self.Imgs[ref["image_id"]]
|
312 |
-
if type(ann["segmentation"][0]) == list: # polygon
|
313 |
-
rle = mask.frPyObjects(ann["segmentation"], image["height"], image["width"])
|
314 |
-
else:
|
315 |
-
rle = ann["segmentation"]
|
316 |
-
m = mask.decode(rle)
|
317 |
-
m = np.sum(
|
318 |
-
m, axis=2
|
319 |
-
) # sometimes there are multiple binary map (corresponding to multiple segs)
|
320 |
-
m = m.astype(np.uint8) # convert to np.uint8
|
321 |
-
# compute area
|
322 |
-
area = sum(mask.area(rle)) # should be close to ann['area']
|
323 |
-
return {"mask": m, "area": area}
|
324 |
-
# # position
|
325 |
-
# position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style)
|
326 |
-
# position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style)
|
327 |
-
# # mass position (if there were multiple regions, we use the largest one.)
|
328 |
-
# label_m = label(m, connectivity=m.ndim)
|
329 |
-
# regions = regionprops(label_m)
|
330 |
-
# if len(regions) > 0:
|
331 |
-
# largest_id = np.argmax(np.array([props.filled_area for props in regions]))
|
332 |
-
# largest_props = regions[largest_id]
|
333 |
-
# mass_y, mass_x = largest_props.centroid
|
334 |
-
# else:
|
335 |
-
# mass_x, mass_y = position_x, position_y
|
336 |
-
# # if centroid is not in mask, we find the closest point to it from mask
|
337 |
-
# if m[mass_y, mass_x] != 1:
|
338 |
-
# print('Finding closes mask point ...')
|
339 |
-
# kernel = np.ones((10, 10),np.uint8)
|
340 |
-
# me = cv2.erode(m, kernel, iterations = 1)
|
341 |
-
# points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style
|
342 |
-
# points = np.array(points)
|
343 |
-
# dist = np.sum((points - (mass_y, mass_x))**2, axis=1)
|
344 |
-
# id = np.argsort(dist)[0]
|
345 |
-
# mass_y, mass_x = points[id]
|
346 |
-
# # return
|
347 |
-
# return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y}
|
348 |
-
# # show image and mask
|
349 |
-
# I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
|
350 |
-
# plt.figure()
|
351 |
-
# plt.imshow(I)
|
352 |
-
# ax = plt.gca()
|
353 |
-
# img = np.ones( (m.shape[0], m.shape[1], 3) )
|
354 |
-
# color_mask = np.array([2.0,166.0,101.0])/255
|
355 |
-
# for i in range(3):
|
356 |
-
# img[:,:,i] = color_mask[i]
|
357 |
-
# ax.imshow(np.dstack( (img, m*0.5) ))
|
358 |
-
# plt.show()
|
359 |
-
|
360 |
-
def showMask(self, ref):
|
361 |
-
M = self.getMask(ref)
|
362 |
-
msk = M["mask"]
|
363 |
-
ax = plt.gca()
|
364 |
-
ax.imshow(msk)
|
365 |
-
|
366 |
-
|
367 |
-
if __name__ == "__main__":
|
368 |
-
refer = REFER(dataset="refcocog", splitBy="google")
|
369 |
-
ref_ids = refer.getRefIds()
|
370 |
-
print(len(ref_ids))
|
371 |
-
|
372 |
-
print(len(refer.Imgs))
|
373 |
-
print(len(refer.imgToRefs))
|
374 |
-
|
375 |
-
ref_ids = refer.getRefIds(split="train")
|
376 |
-
print("There are %s training referred objects." % len(ref_ids))
|
377 |
-
|
378 |
-
for ref_id in ref_ids:
|
379 |
-
ref = refer.loadRefs(ref_id)[0]
|
380 |
-
if len(ref["sentences"]) < 2:
|
381 |
-
continue
|
382 |
-
|
383 |
-
pprint(ref)
|
384 |
-
print("The label is %s." % refer.Cats[ref["category_id"]])
|
385 |
-
plt.figure()
|
386 |
-
refer.showRef(ref, seg_box="box")
|
387 |
-
plt.show()
|
388 |
-
|
389 |
-
# plt.figure()
|
390 |
-
# refer.showMask(ref)
|
391 |
-
# plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lisa_on_cuda/utils/refer_seg_dataset.py
DELETED
@@ -1,277 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import random
|
3 |
-
|
4 |
-
import cv2
|
5 |
-
import numpy as np
|
6 |
-
import torch
|
7 |
-
import torch.nn.functional as F
|
8 |
-
from pycocotools import mask
|
9 |
-
from transformers import CLIPImageProcessor
|
10 |
-
|
11 |
-
from lisa_on_cuda.llava import conversation as conversation_lib
|
12 |
-
from lisa_on_cuda.segment_anything.utils.transforms import ResizeLongestSide
|
13 |
-
|
14 |
-
from .grefer import G_REFER
|
15 |
-
from .refer import REFER
|
16 |
-
from .utils import ANSWER_LIST, SHORT_QUESTION_LIST
|
17 |
-
|
18 |
-
|
19 |
-
class ReferSegDataset(torch.utils.data.Dataset):
|
20 |
-
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
|
21 |
-
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
|
22 |
-
img_size = 1024
|
23 |
-
ignore_label = 255
|
24 |
-
|
25 |
-
def __init__(
|
26 |
-
self,
|
27 |
-
base_image_dir,
|
28 |
-
tokenizer,
|
29 |
-
vision_tower,
|
30 |
-
samples_per_epoch=500 * 8 * 2 * 10,
|
31 |
-
precision: str = "fp32",
|
32 |
-
image_size: int = 224,
|
33 |
-
num_classes_per_sample: int = 3,
|
34 |
-
exclude_val=False,
|
35 |
-
refer_seg_data="refclef||refcoco||refcoco+||refcocog",
|
36 |
-
):
|
37 |
-
self.exclude_val = exclude_val
|
38 |
-
self.samples_per_epoch = samples_per_epoch
|
39 |
-
self.num_classes_per_sample = num_classes_per_sample
|
40 |
-
|
41 |
-
self.base_image_dir = base_image_dir
|
42 |
-
self.image_size = image_size
|
43 |
-
self.tokenizer = tokenizer
|
44 |
-
self.precision = precision
|
45 |
-
self.transform = ResizeLongestSide(image_size)
|
46 |
-
self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
47 |
-
|
48 |
-
self.short_question_list = SHORT_QUESTION_LIST
|
49 |
-
self.answer_list = ANSWER_LIST
|
50 |
-
|
51 |
-
DATA_DIR = os.path.join(base_image_dir, "refer_seg")
|
52 |
-
self.refer_seg_ds_list = refer_seg_data.split(
|
53 |
-
"||"
|
54 |
-
) # ['refclef', 'refcoco', 'refcoco+', 'refcocog']
|
55 |
-
self.refer_seg_data = {}
|
56 |
-
for ds in self.refer_seg_ds_list:
|
57 |
-
if ds == "refcocog":
|
58 |
-
splitBy = "umd"
|
59 |
-
else:
|
60 |
-
splitBy = "unc"
|
61 |
-
|
62 |
-
if ds == "grefcoco":
|
63 |
-
refer_api = G_REFER(DATA_DIR, ds, splitBy)
|
64 |
-
else:
|
65 |
-
refer_api = REFER(DATA_DIR, ds, splitBy)
|
66 |
-
ref_ids_train = refer_api.getRefIds(split="train")
|
67 |
-
images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train)
|
68 |
-
refs_train = refer_api.loadRefs(ref_ids=ref_ids_train)
|
69 |
-
|
70 |
-
refer_seg_ds = {}
|
71 |
-
refer_seg_ds["images"] = []
|
72 |
-
loaded_images = refer_api.loadImgs(image_ids=images_ids_train)
|
73 |
-
|
74 |
-
for item in loaded_images:
|
75 |
-
item = item.copy()
|
76 |
-
if ds == "refclef":
|
77 |
-
item["file_name"] = os.path.join(
|
78 |
-
DATA_DIR, "images/saiapr_tc-12", item["file_name"]
|
79 |
-
)
|
80 |
-
else:
|
81 |
-
item["file_name"] = os.path.join(
|
82 |
-
DATA_DIR, "images/mscoco/images/train2014", item["file_name"]
|
83 |
-
)
|
84 |
-
refer_seg_ds["images"].append(item)
|
85 |
-
refer_seg_ds["annotations"] = refer_api.Anns # anns_train
|
86 |
-
|
87 |
-
print(
|
88 |
-
"dataset {} (refs {}) (train split) has {} images and {} annotations.".format(
|
89 |
-
ds,
|
90 |
-
splitBy,
|
91 |
-
len(refer_seg_ds["images"]),
|
92 |
-
len(refer_seg_ds["annotations"]),
|
93 |
-
)
|
94 |
-
)
|
95 |
-
|
96 |
-
img2refs = {}
|
97 |
-
for ref in refs_train:
|
98 |
-
image_id = ref["image_id"]
|
99 |
-
img2refs[image_id] = img2refs.get(image_id, []) + [
|
100 |
-
ref,
|
101 |
-
]
|
102 |
-
refer_seg_ds["img2refs"] = img2refs
|
103 |
-
self.refer_seg_data[ds] = refer_seg_ds
|
104 |
-
|
105 |
-
def __len__(self):
|
106 |
-
return self.samples_per_epoch
|
107 |
-
|
108 |
-
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
109 |
-
"""Normalize pixel values and pad to a square input."""
|
110 |
-
# Normalize colors
|
111 |
-
x = (x - self.pixel_mean) / self.pixel_std
|
112 |
-
|
113 |
-
# Pad
|
114 |
-
h, w = x.shape[-2:]
|
115 |
-
padh = self.img_size - h
|
116 |
-
padw = self.img_size - w
|
117 |
-
x = F.pad(x, (0, padw, 0, padh))
|
118 |
-
return x
|
119 |
-
|
120 |
-
def __getitem__(self, idx):
|
121 |
-
ds = random.randint(0, len(self.refer_seg_ds_list) - 1)
|
122 |
-
ds = self.refer_seg_ds_list[ds]
|
123 |
-
refer_seg_ds = self.refer_seg_data[ds]
|
124 |
-
images = refer_seg_ds["images"]
|
125 |
-
annotations = refer_seg_ds["annotations"]
|
126 |
-
img2refs = refer_seg_ds["img2refs"]
|
127 |
-
idx = random.randint(0, len(images) - 1)
|
128 |
-
image_info = images[idx]
|
129 |
-
image_path = image_info["file_name"]
|
130 |
-
image_id = image_info["id"]
|
131 |
-
refs = img2refs[image_id]
|
132 |
-
if len(refs) == 0:
|
133 |
-
return self.__getitem__(0)
|
134 |
-
|
135 |
-
sents = []
|
136 |
-
ann_ids = []
|
137 |
-
for ref in refs:
|
138 |
-
for sent in ref["sentences"]:
|
139 |
-
text = sent["sent"]
|
140 |
-
sents.append(text)
|
141 |
-
ann_ids.append(ref["ann_id"])
|
142 |
-
if len(sents) >= self.num_classes_per_sample:
|
143 |
-
sampled_inds = np.random.choice(
|
144 |
-
list(range(len(sents))), size=self.num_classes_per_sample, replace=False
|
145 |
-
)
|
146 |
-
else:
|
147 |
-
sampled_inds = list(range(len(sents)))
|
148 |
-
sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist()
|
149 |
-
# sampled_ann_ids = np.vectorize(ann_ids.__getitem__)(sampled_inds).tolist()
|
150 |
-
sampled_ann_ids = [ann_ids[ind] for ind in sampled_inds]
|
151 |
-
sampled_classes = sampled_sents
|
152 |
-
image = cv2.imread(image_path)
|
153 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
154 |
-
|
155 |
-
# preprocess image for clip
|
156 |
-
image_clip = self.clip_image_processor.preprocess(image, return_tensors="pt")[
|
157 |
-
"pixel_values"
|
158 |
-
][0]
|
159 |
-
|
160 |
-
image = self.transform.apply_image(image) # preprocess image for sam
|
161 |
-
resize = image.shape[:2]
|
162 |
-
|
163 |
-
questions = []
|
164 |
-
answers = []
|
165 |
-
for text in sampled_classes:
|
166 |
-
text = text.strip()
|
167 |
-
assert len(text.split("||")) == 1
|
168 |
-
question_template = random.choice(self.short_question_list)
|
169 |
-
questions.append(question_template.format(class_name=text.lower()))
|
170 |
-
answers.append(random.choice(self.answer_list))
|
171 |
-
|
172 |
-
conversations = []
|
173 |
-
conv = conversation_lib.default_conversation.copy()
|
174 |
-
|
175 |
-
i = 0
|
176 |
-
while i < len(questions):
|
177 |
-
conv.messages = []
|
178 |
-
conv.append_message(conv.roles[0], questions[i])
|
179 |
-
conv.append_message(conv.roles[1], answers[i])
|
180 |
-
conversations.append(conv.get_prompt())
|
181 |
-
i += 1
|
182 |
-
|
183 |
-
image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
|
184 |
-
|
185 |
-
flag = False
|
186 |
-
masks = []
|
187 |
-
for ann_id in sampled_ann_ids:
|
188 |
-
if isinstance(ann_id, list):
|
189 |
-
flag = True
|
190 |
-
if -1 in ann_id:
|
191 |
-
assert len(ann_id) == 1
|
192 |
-
m = np.zeros((image_info["height"], image_info["width"])).astype(
|
193 |
-
np.uint8
|
194 |
-
)
|
195 |
-
else:
|
196 |
-
m_final = np.zeros(
|
197 |
-
(image_info["height"], image_info["width"])
|
198 |
-
).astype(np.uint8)
|
199 |
-
for ann_id_i in ann_id:
|
200 |
-
ann = annotations[ann_id_i]
|
201 |
-
|
202 |
-
if len(ann["segmentation"]) == 0:
|
203 |
-
m = np.zeros(
|
204 |
-
(image_info["height"], image_info["width"])
|
205 |
-
).astype(np.uint8)
|
206 |
-
else:
|
207 |
-
if type(ann["segmentation"][0]) == list: # polygon
|
208 |
-
rle = mask.frPyObjects(
|
209 |
-
ann["segmentation"],
|
210 |
-
image_info["height"],
|
211 |
-
image_info["width"],
|
212 |
-
)
|
213 |
-
else:
|
214 |
-
rle = ann["segmentation"]
|
215 |
-
for i in range(len(rle)):
|
216 |
-
if not isinstance(rle[i]["counts"], bytes):
|
217 |
-
rle[i]["counts"] = rle[i]["counts"].encode()
|
218 |
-
m = mask.decode(rle)
|
219 |
-
m = np.sum(
|
220 |
-
m, axis=2
|
221 |
-
) # sometimes there are multiple binary map (corresponding to multiple segs)
|
222 |
-
m = m.astype(np.uint8) # convert to np.uint8
|
223 |
-
m_final = m_final | m
|
224 |
-
m = m_final
|
225 |
-
masks.append(m)
|
226 |
-
continue
|
227 |
-
|
228 |
-
ann = annotations[ann_id]
|
229 |
-
|
230 |
-
if len(ann["segmentation"]) == 0:
|
231 |
-
m = np.zeros((image_info["height"], image_info["width"])).astype(
|
232 |
-
np.uint8
|
233 |
-
)
|
234 |
-
masks.append(m)
|
235 |
-
continue
|
236 |
-
|
237 |
-
if type(ann["segmentation"][0]) == list: # polygon
|
238 |
-
rle = mask.frPyObjects(
|
239 |
-
ann["segmentation"], image_info["height"], image_info["width"]
|
240 |
-
)
|
241 |
-
else:
|
242 |
-
rle = ann["segmentation"]
|
243 |
-
for i in range(len(rle)):
|
244 |
-
if not isinstance(rle[i]["counts"], bytes):
|
245 |
-
rle[i]["counts"] = rle[i]["counts"].encode()
|
246 |
-
m = mask.decode(rle)
|
247 |
-
m = np.sum(
|
248 |
-
m, axis=2
|
249 |
-
) # sometimes there are multiple binary map (corresponding to multiple segs)
|
250 |
-
m = m.astype(np.uint8) # convert to np.uint8
|
251 |
-
masks.append(m)
|
252 |
-
|
253 |
-
masks = np.stack(masks, axis=0)
|
254 |
-
|
255 |
-
# if ds == 'grefcoco' and flag:
|
256 |
-
# import shutil
|
257 |
-
# image_name = image_path.split("/")[-1]
|
258 |
-
# save_dir = os.path.join("/group/30042/xlai/LISA_refactor_final/debug", image_name.split(".")[0])
|
259 |
-
# os.makedirs(save_dir, exist_ok=True)
|
260 |
-
# shutil.copy(image_path, save_dir)
|
261 |
-
# for i in range(masks.shape[0]):
|
262 |
-
# cv2.imwrite(os.path.join(save_dir, "{}_{}_{}.jpg".format(image_name, i, sampled_classes[i])), masks[i].astype(np.int32) * 100)
|
263 |
-
|
264 |
-
masks = torch.from_numpy(masks)
|
265 |
-
label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
|
266 |
-
|
267 |
-
return (
|
268 |
-
image_path,
|
269 |
-
image,
|
270 |
-
image_clip,
|
271 |
-
conversations,
|
272 |
-
masks,
|
273 |
-
label,
|
274 |
-
resize,
|
275 |
-
questions,
|
276 |
-
sampled_classes,
|
277 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lisa_on_cuda/utils/sem_seg_dataset.py
DELETED
@@ -1,335 +0,0 @@
|
|
1 |
-
import glob
|
2 |
-
import json
|
3 |
-
import os
|
4 |
-
import random
|
5 |
-
|
6 |
-
import cv2
|
7 |
-
import numpy as np
|
8 |
-
import torch
|
9 |
-
import torch.nn.functional as F
|
10 |
-
from PIL import Image
|
11 |
-
from pycocotools.coco import COCO
|
12 |
-
from transformers import CLIPImageProcessor
|
13 |
-
|
14 |
-
from lisa_on_cuda.llava import conversation as conversation_lib
|
15 |
-
from lisa_on_cuda.segment_anything.utils.transforms import ResizeLongestSide
|
16 |
-
|
17 |
-
from .utils import ANSWER_LIST, SHORT_QUESTION_LIST
|
18 |
-
|
19 |
-
|
20 |
-
def init_mapillary(base_image_dir):
|
21 |
-
mapillary_data_root = os.path.join(base_image_dir, "mapillary")
|
22 |
-
with open(os.path.join(mapillary_data_root, "config_v2.0.json")) as f:
|
23 |
-
mapillary_classes = json.load(f)["labels"]
|
24 |
-
mapillary_classes = [x["readable"].lower() for x in mapillary_classes]
|
25 |
-
mapillary_classes = np.array(mapillary_classes)
|
26 |
-
mapillary_labels = sorted(
|
27 |
-
glob.glob(
|
28 |
-
os.path.join(mapillary_data_root, "training", "v2.0", "labels", "*.png")
|
29 |
-
)
|
30 |
-
)
|
31 |
-
mapillary_images = [
|
32 |
-
x.replace(".png", ".jpg").replace("v2.0/labels", "images")
|
33 |
-
for x in mapillary_labels
|
34 |
-
]
|
35 |
-
print("mapillary: ", len(mapillary_images))
|
36 |
-
return mapillary_classes, mapillary_images, mapillary_labels
|
37 |
-
|
38 |
-
|
39 |
-
def init_ade20k(base_image_dir):
|
40 |
-
with open("utils/ade20k_classes.json", "r") as f:
|
41 |
-
ade20k_classes = json.load(f)
|
42 |
-
ade20k_classes = np.array(ade20k_classes)
|
43 |
-
image_ids = sorted(
|
44 |
-
os.listdir(os.path.join(base_image_dir, "ade20k/images", "training"))
|
45 |
-
)
|
46 |
-
ade20k_image_ids = []
|
47 |
-
for x in image_ids:
|
48 |
-
if x.endswith(".jpg"):
|
49 |
-
ade20k_image_ids.append(x[:-4])
|
50 |
-
ade20k_images = []
|
51 |
-
for image_id in ade20k_image_ids: # self.descriptions:
|
52 |
-
ade20k_images.append(
|
53 |
-
os.path.join(
|
54 |
-
base_image_dir,
|
55 |
-
"ade20k",
|
56 |
-
"images",
|
57 |
-
"training",
|
58 |
-
"{}.jpg".format(image_id),
|
59 |
-
)
|
60 |
-
)
|
61 |
-
ade20k_labels = [
|
62 |
-
x.replace(".jpg", ".png").replace("images", "annotations")
|
63 |
-
for x in ade20k_images
|
64 |
-
]
|
65 |
-
print("ade20k: ", len(ade20k_images))
|
66 |
-
return ade20k_classes, ade20k_images, ade20k_labels
|
67 |
-
|
68 |
-
|
69 |
-
def init_cocostuff(base_image_dir):
|
70 |
-
cocostuff_classes = []
|
71 |
-
with open("utils/cocostuff_classes.txt") as f:
|
72 |
-
for line in f.readlines()[1:]:
|
73 |
-
cocostuff_classes.append(line.strip().split(": ")[-1])
|
74 |
-
cocostuff_classes = np.array(cocostuff_classes)
|
75 |
-
cocostuff_images = []
|
76 |
-
|
77 |
-
cocostuff_labels = glob.glob(
|
78 |
-
os.path.join(base_image_dir, "cocostuff", "train2017", "*.png")
|
79 |
-
)
|
80 |
-
cocostuff_images = [
|
81 |
-
x.replace(".png", ".jpg").replace("cocostuff", "coco") for x in cocostuff_labels
|
82 |
-
]
|
83 |
-
|
84 |
-
print("cocostuff: ", len(cocostuff_images))
|
85 |
-
return cocostuff_classes, cocostuff_images, cocostuff_labels
|
86 |
-
|
87 |
-
|
88 |
-
def init_paco_lvis(base_image_dir):
|
89 |
-
coco_api_paco_lvis = COCO(
|
90 |
-
os.path.join(
|
91 |
-
base_image_dir, "vlpart", "paco", "annotations", "paco_lvis_v1_train.json"
|
92 |
-
)
|
93 |
-
)
|
94 |
-
all_classes = coco_api_paco_lvis.loadCats(coco_api_paco_lvis.getCatIds())
|
95 |
-
class_map_paco_lvis = {}
|
96 |
-
for cat in all_classes:
|
97 |
-
cat_split = cat["name"].strip().split(":")
|
98 |
-
if len(cat_split) == 1:
|
99 |
-
name = cat_split[0].split("_(")[0]
|
100 |
-
else:
|
101 |
-
assert len(cat_split) == 2
|
102 |
-
obj, part = cat_split
|
103 |
-
obj = obj.split("_(")[0]
|
104 |
-
part = part.split("_(")[0]
|
105 |
-
name = (obj, part)
|
106 |
-
class_map_paco_lvis[cat["id"]] = name
|
107 |
-
img_ids = coco_api_paco_lvis.getImgIds()
|
108 |
-
print("paco_lvis: ", len(img_ids))
|
109 |
-
return class_map_paco_lvis, img_ids, coco_api_paco_lvis
|
110 |
-
|
111 |
-
|
112 |
-
def init_pascal_part(base_image_dir):
|
113 |
-
coco_api_pascal_part = COCO(
|
114 |
-
os.path.join(base_image_dir, "vlpart", "pascal_part", "train.json")
|
115 |
-
)
|
116 |
-
all_classes = coco_api_pascal_part.loadCats(coco_api_pascal_part.getCatIds())
|
117 |
-
class_map_pascal_part = {}
|
118 |
-
for cat in all_classes:
|
119 |
-
cat_main, cat_part = cat["name"].strip().split(":")
|
120 |
-
name = (cat_main, cat_part)
|
121 |
-
class_map_pascal_part[cat["id"]] = name
|
122 |
-
img_ids = coco_api_pascal_part.getImgIds()
|
123 |
-
print("pascal_part: ", len(img_ids))
|
124 |
-
return class_map_pascal_part, img_ids, coco_api_pascal_part
|
125 |
-
|
126 |
-
|
127 |
-
class SemSegDataset(torch.utils.data.Dataset):
|
128 |
-
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
|
129 |
-
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
|
130 |
-
img_size = 1024
|
131 |
-
ignore_label = 255
|
132 |
-
|
133 |
-
def __init__(
|
134 |
-
self,
|
135 |
-
base_image_dir,
|
136 |
-
tokenizer,
|
137 |
-
vision_tower,
|
138 |
-
samples_per_epoch=500 * 8 * 2 * 10,
|
139 |
-
precision: str = "fp32",
|
140 |
-
image_size: int = 224,
|
141 |
-
num_classes_per_sample: int = 3,
|
142 |
-
exclude_val=False,
|
143 |
-
sem_seg_data="ade20k||cocostuff||partimagenet||pascal_part||paco_lvis||mapillary",
|
144 |
-
):
|
145 |
-
self.exclude_val = exclude_val
|
146 |
-
self.samples_per_epoch = samples_per_epoch
|
147 |
-
self.num_classes_per_sample = num_classes_per_sample
|
148 |
-
|
149 |
-
self.base_image_dir = base_image_dir
|
150 |
-
self.image_size = image_size
|
151 |
-
self.tokenizer = tokenizer
|
152 |
-
self.precision = precision
|
153 |
-
self.transform = ResizeLongestSide(image_size)
|
154 |
-
self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
155 |
-
|
156 |
-
self.short_question_list = SHORT_QUESTION_LIST
|
157 |
-
self.answer_list = ANSWER_LIST
|
158 |
-
|
159 |
-
self.data2list = {}
|
160 |
-
self.data2classes = {}
|
161 |
-
|
162 |
-
self.sem_seg_datas = sem_seg_data.split("||")
|
163 |
-
for ds in self.sem_seg_datas:
|
164 |
-
classes, images, labels = eval("init_{}".format(ds))(base_image_dir)
|
165 |
-
self.data2list[ds] = (images, labels)
|
166 |
-
self.data2classes[ds] = classes
|
167 |
-
|
168 |
-
if "cocostuff" in self.sem_seg_datas:
|
169 |
-
self.cocostuff_class2index = {
|
170 |
-
c: i for i, c in enumerate(self.data2classes["cocostuff"])
|
171 |
-
}
|
172 |
-
|
173 |
-
def __len__(self):
|
174 |
-
return self.samples_per_epoch
|
175 |
-
|
176 |
-
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
177 |
-
"""Normalize pixel values and pad to a square input."""
|
178 |
-
# Normalize colors
|
179 |
-
x = (x - self.pixel_mean) / self.pixel_std
|
180 |
-
|
181 |
-
# Pad
|
182 |
-
h, w = x.shape[-2:]
|
183 |
-
padh = self.img_size - h
|
184 |
-
padw = self.img_size - w
|
185 |
-
x = F.pad(x, (0, padw, 0, padh))
|
186 |
-
return x
|
187 |
-
|
188 |
-
def __getitem__(self, idx):
|
189 |
-
ds = random.randint(0, len(self.sem_seg_datas) - 1)
|
190 |
-
ds = self.sem_seg_datas[ds]
|
191 |
-
|
192 |
-
if ds in ["paco_lvis", "pascal_part"]:
|
193 |
-
class_map = self.data2classes[ds]
|
194 |
-
img_ids, coco_api = self.data2list[ds]
|
195 |
-
idx = random.randint(0, len(img_ids) - 1)
|
196 |
-
img_id = img_ids[idx]
|
197 |
-
image_info = coco_api.loadImgs([img_id])[0]
|
198 |
-
file_name = image_info["file_name"]
|
199 |
-
if ds == "pascal_part":
|
200 |
-
file_name = os.path.join(
|
201 |
-
"VOCdevkit", "VOC2010", "JPEGImages", file_name
|
202 |
-
)
|
203 |
-
image_path = os.path.join(self.base_image_dir, "vlpart", ds, file_name)
|
204 |
-
elif ds == "paco_lvis":
|
205 |
-
image_path = os.path.join(self.base_image_dir, "coco", file_name)
|
206 |
-
image = cv2.imread(image_path)
|
207 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
208 |
-
|
209 |
-
# preprocess image for clip
|
210 |
-
image_clip = self.clip_image_processor.preprocess(
|
211 |
-
image, return_tensors="pt"
|
212 |
-
)["pixel_values"][0]
|
213 |
-
image = self.transform.apply_image(image) # preprocess image for sam
|
214 |
-
resize = image.shape[:2]
|
215 |
-
annIds = coco_api.getAnnIds(imgIds=image_info["id"])
|
216 |
-
anns = coco_api.loadAnns(annIds)
|
217 |
-
if len(anns) == 0:
|
218 |
-
return self.__getitem__(0)
|
219 |
-
if len(anns) >= self.num_classes_per_sample:
|
220 |
-
sampled_anns = np.random.choice(
|
221 |
-
anns, size=self.num_classes_per_sample, replace=False
|
222 |
-
).tolist()
|
223 |
-
else:
|
224 |
-
sampled_anns = anns
|
225 |
-
sampled_classes = []
|
226 |
-
for ann in sampled_anns:
|
227 |
-
sampled_cls = class_map[ann["category_id"]]
|
228 |
-
if isinstance(sampled_cls, tuple):
|
229 |
-
obj, part = sampled_cls
|
230 |
-
if random.random() < 0.5:
|
231 |
-
name = obj + " " + part
|
232 |
-
else:
|
233 |
-
name = "the {} of the {}".format(part, obj)
|
234 |
-
else:
|
235 |
-
name = sampled_cls
|
236 |
-
sampled_classes.append(name)
|
237 |
-
|
238 |
-
elif ds in ["ade20k", "cocostuff", "mapillary"]:
|
239 |
-
image, labels = self.data2list[ds]
|
240 |
-
idx = random.randint(0, len(image) - 1)
|
241 |
-
image_path = image[idx]
|
242 |
-
label_path = labels[idx]
|
243 |
-
label = Image.open(label_path)
|
244 |
-
label = np.array(label)
|
245 |
-
if ds == "ade20k":
|
246 |
-
label[label == 0] = 255
|
247 |
-
label -= 1
|
248 |
-
label[label == 254] = 255
|
249 |
-
elif ds == "cocostuff":
|
250 |
-
for c, i in self.cocostuff_class2index.items():
|
251 |
-
if "-" in c:
|
252 |
-
label[label == i] = 255
|
253 |
-
img = cv2.imread(image_path)
|
254 |
-
image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
255 |
-
# preprocess image for clip
|
256 |
-
image_clip = self.clip_image_processor.preprocess(
|
257 |
-
image, return_tensors="pt"
|
258 |
-
)["pixel_values"][0]
|
259 |
-
image = self.transform.apply_image(image) # preprocess image for sam
|
260 |
-
resize = image.shape[:2]
|
261 |
-
unique_label = np.unique(label).tolist()
|
262 |
-
if 255 in unique_label:
|
263 |
-
unique_label.remove(255)
|
264 |
-
if len(unique_label) == 0:
|
265 |
-
return self.__getitem__(0)
|
266 |
-
|
267 |
-
classes = [self.data2classes[ds][class_id] for class_id in unique_label]
|
268 |
-
if len(classes) >= self.num_classes_per_sample:
|
269 |
-
sampled_classes = np.random.choice(
|
270 |
-
classes, size=self.num_classes_per_sample, replace=False
|
271 |
-
).tolist()
|
272 |
-
else:
|
273 |
-
sampled_classes = classes
|
274 |
-
|
275 |
-
questions = []
|
276 |
-
answers = []
|
277 |
-
class_ids = []
|
278 |
-
for sampled_cls in sampled_classes:
|
279 |
-
text = sampled_cls
|
280 |
-
|
281 |
-
assert len(text.split("||")) == 1
|
282 |
-
question_template = random.choice(self.short_question_list)
|
283 |
-
questions.append(question_template.format(class_name=text.lower()))
|
284 |
-
|
285 |
-
answers.append(random.choice(self.answer_list))
|
286 |
-
|
287 |
-
if ds in ["paco_lvis", "pascal_part"]:
|
288 |
-
continue
|
289 |
-
|
290 |
-
class_id = self.data2classes[ds].tolist().index(sampled_cls)
|
291 |
-
class_ids.append(class_id)
|
292 |
-
|
293 |
-
conversations = []
|
294 |
-
conv = conversation_lib.default_conversation.copy()
|
295 |
-
|
296 |
-
i = 0
|
297 |
-
while i < len(questions):
|
298 |
-
conv.messages = []
|
299 |
-
conv.append_message(conv.roles[0], questions[i])
|
300 |
-
conv.append_message(conv.roles[1], answers[i])
|
301 |
-
conversations.append(conv.get_prompt())
|
302 |
-
i += 1
|
303 |
-
|
304 |
-
image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
|
305 |
-
|
306 |
-
if ds in ["paco_lvis", "pascal_part"]:
|
307 |
-
masks = []
|
308 |
-
for ann in sampled_anns:
|
309 |
-
try:
|
310 |
-
masks.append(coco_api.annToMask(ann))
|
311 |
-
except Exception as e:
|
312 |
-
print(e)
|
313 |
-
return self.__getitem__(0)
|
314 |
-
|
315 |
-
masks = np.stack(masks, axis=0)
|
316 |
-
masks = torch.from_numpy(masks)
|
317 |
-
label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
|
318 |
-
|
319 |
-
else:
|
320 |
-
label = torch.from_numpy(label).long()
|
321 |
-
masks = []
|
322 |
-
for class_id in class_ids:
|
323 |
-
masks.append(label == class_id)
|
324 |
-
masks = torch.stack(masks, dim=0)
|
325 |
-
return (
|
326 |
-
image_path,
|
327 |
-
image,
|
328 |
-
image_clip,
|
329 |
-
conversations,
|
330 |
-
masks,
|
331 |
-
label,
|
332 |
-
resize,
|
333 |
-
questions,
|
334 |
-
sampled_classes,
|
335 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lisa_on_cuda/utils/session_logger.py
DELETED
@@ -1,36 +0,0 @@
|
|
1 |
-
import contextvars
|
2 |
-
import logging
|
3 |
-
from functools import wraps
|
4 |
-
from typing import Callable
|
5 |
-
|
6 |
-
logging_uuid = contextvars.ContextVar("uuid")
|
7 |
-
formatter = '%(asctime)s | %(uuid)s [%(pathname)s:%(module)s %(lineno)d] %(levelname)s | %(message)s'
|
8 |
-
|
9 |
-
|
10 |
-
loggingType = logging.CRITICAL | logging.ERROR | logging.WARNING | logging.INFO | logging.DEBUG
|
11 |
-
|
12 |
-
|
13 |
-
def change_logging(level_log: loggingType = logging.INFO) -> None:
|
14 |
-
old_factory = logging.getLogRecordFactory()
|
15 |
-
|
16 |
-
def record_factory(*args, **kwargs):
|
17 |
-
record = old_factory(*args, **kwargs)
|
18 |
-
record.uuid = logging_uuid.get("uuid")
|
19 |
-
if isinstance(record.msg, str):
|
20 |
-
record.msg = record.msg.replace("\\", "\\\\").replace("\n", "\\n")
|
21 |
-
return record
|
22 |
-
|
23 |
-
logging.setLogRecordFactory(record_factory)
|
24 |
-
logging.basicConfig(level=level_log, format=formatter, force=True)
|
25 |
-
|
26 |
-
|
27 |
-
def set_uuid_logging(func: Callable) -> Callable:
|
28 |
-
@wraps(func)
|
29 |
-
def wrapper(*args, **kwargs):
|
30 |
-
import uuid
|
31 |
-
|
32 |
-
current_uuid = f"{uuid.uuid4()}"
|
33 |
-
logging_uuid.set(current_uuid)
|
34 |
-
return func(*args, **kwargs)
|
35 |
-
|
36 |
-
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lisa_on_cuda/utils/vqa_dataset.py
DELETED
@@ -1,135 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import os
|
3 |
-
import random
|
4 |
-
|
5 |
-
import cv2
|
6 |
-
import torch
|
7 |
-
import torch.nn.functional as F
|
8 |
-
from transformers import CLIPImageProcessor
|
9 |
-
|
10 |
-
from lisa_on_cuda.llava import conversation as conversation_lib
|
11 |
-
from lisa_on_cuda.segment_anything.utils.transforms import ResizeLongestSide
|
12 |
-
|
13 |
-
from .utils import DEFAULT_IMAGE_TOKEN
|
14 |
-
|
15 |
-
|
16 |
-
def preprocess_multimodal(source, mm_use_im_start_end):
|
17 |
-
for sentence in source:
|
18 |
-
if DEFAULT_IMAGE_TOKEN in sentence["value"]:
|
19 |
-
sentence["value"] = (
|
20 |
-
sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
|
21 |
-
)
|
22 |
-
sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
|
23 |
-
sentence["value"] = sentence["value"].strip()
|
24 |
-
if "mmtag" in conversation_lib.default_conversation.version:
|
25 |
-
sentence["value"] = sentence["value"].replace(
|
26 |
-
DEFAULT_IMAGE_TOKEN, "<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>"
|
27 |
-
)
|
28 |
-
return source
|
29 |
-
|
30 |
-
|
31 |
-
class VQADataset(torch.utils.data.Dataset):
|
32 |
-
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
|
33 |
-
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
|
34 |
-
img_size = 1024
|
35 |
-
ignore_label = 255
|
36 |
-
|
37 |
-
def __init__(
|
38 |
-
self,
|
39 |
-
base_image_dir,
|
40 |
-
tokenizer,
|
41 |
-
vision_tower,
|
42 |
-
samples_per_epoch=500 * 8 * 2 * 10,
|
43 |
-
precision: str = "fp32",
|
44 |
-
image_size: int = 224,
|
45 |
-
num_classes_per_sample: int = 3,
|
46 |
-
exclude_val=False,
|
47 |
-
vqa_data="llava_instruct_150k",
|
48 |
-
):
|
49 |
-
self.exclude_val = exclude_val
|
50 |
-
self.samples_per_epoch = samples_per_epoch
|
51 |
-
self.num_classes_per_sample = num_classes_per_sample
|
52 |
-
|
53 |
-
self.base_image_dir = base_image_dir
|
54 |
-
self.image_size = image_size
|
55 |
-
self.tokenizer = tokenizer
|
56 |
-
self.precision = precision
|
57 |
-
self.transform = ResizeLongestSide(image_size)
|
58 |
-
self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
59 |
-
|
60 |
-
DATA_DIR = os.path.join(base_image_dir, "llava_dataset")
|
61 |
-
self.vqa_image_root = os.path.join(base_image_dir, "coco/train2017")
|
62 |
-
with open(os.path.join(DATA_DIR, "{}.json".format(vqa_data))) as f:
|
63 |
-
vqa_data = json.load(f)
|
64 |
-
self.vqa_data = vqa_data
|
65 |
-
|
66 |
-
print("vqa_data: ", len(self.vqa_data))
|
67 |
-
|
68 |
-
def __len__(self):
|
69 |
-
return self.samples_per_epoch
|
70 |
-
|
71 |
-
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
72 |
-
"""Normalize pixel values and pad to a square input."""
|
73 |
-
# Normalize colors
|
74 |
-
x = (x - self.pixel_mean) / self.pixel_std
|
75 |
-
|
76 |
-
# Pad
|
77 |
-
h, w = x.shape[-2:]
|
78 |
-
padh = self.img_size - h
|
79 |
-
padw = self.img_size - w
|
80 |
-
x = F.pad(x, (0, padw, 0, padh))
|
81 |
-
return x
|
82 |
-
|
83 |
-
def __getitem__(self, idx):
|
84 |
-
idx = random.randint(0, len(self.vqa_data) - 1)
|
85 |
-
item = self.vqa_data[idx]
|
86 |
-
image_path = os.path.join(self.vqa_image_root, item["image"])
|
87 |
-
image = cv2.imread(image_path)
|
88 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
89 |
-
ori_size = image.shape[:2]
|
90 |
-
image_clip = self.clip_image_processor.preprocess(image, return_tensors="pt")[
|
91 |
-
"pixel_values"
|
92 |
-
][
|
93 |
-
0
|
94 |
-
] # preprocess image for clip
|
95 |
-
|
96 |
-
image = self.transform.apply_image(image) # preprocess image for sam
|
97 |
-
resize = image.shape[:2]
|
98 |
-
|
99 |
-
conv = conversation_lib.default_conversation.copy()
|
100 |
-
source = item["conversations"]
|
101 |
-
source = preprocess_multimodal(
|
102 |
-
source,
|
103 |
-
mm_use_im_start_end=conv.sep_style == conversation_lib.SeparatorStyle.TWO,
|
104 |
-
)
|
105 |
-
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
106 |
-
conversations = []
|
107 |
-
if roles[source[0]["from"]] != conv.roles[0]:
|
108 |
-
# Skip the first one if it is not from human
|
109 |
-
source = source[1:]
|
110 |
-
conv.messages = []
|
111 |
-
for j, sentence in enumerate(source):
|
112 |
-
role = roles[sentence["from"]]
|
113 |
-
assert role == conv.roles[j % 2], f"{i}"
|
114 |
-
conv.append_message(role, sentence["value"])
|
115 |
-
conversations.append(conv.get_prompt())
|
116 |
-
|
117 |
-
questions = conversations
|
118 |
-
sampled_classes = conversations
|
119 |
-
|
120 |
-
image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
|
121 |
-
|
122 |
-
masks = torch.rand(0, *ori_size)
|
123 |
-
label = torch.ones(ori_size) * self.ignore_label
|
124 |
-
|
125 |
-
return (
|
126 |
-
image_path,
|
127 |
-
image,
|
128 |
-
image_clip,
|
129 |
-
conversations,
|
130 |
-
masks,
|
131 |
-
label,
|
132 |
-
resize,
|
133 |
-
questions,
|
134 |
-
sampled_classes,
|
135 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|