alessandro trinca tornidor commited on
Commit
8036f02
·
1 Parent(s): 97b0909

refactor: remove unused modules

Browse files
lisa_on_cuda/utils/ade20k_classes.json DELETED
@@ -1,30 +0,0 @@
1
- [
2
- "wall", "building", "sky", "floor", "tree", "ceiling", "road",
3
- "bed", "windowpane", "grass", "cabinet", "sidewalk",
4
- "person", "earth", "door", "table", "mountain", "plant",
5
- "curtain", "chair", "car", "water", "painting", "sofa",
6
- "shelf", "house", "sea", "mirror", "rug", "field", "armchair",
7
- "seat", "fence", "desk", "rock", "wardrobe", "lamp",
8
- "bathtub", "railing", "cushion", "base", "box", "column",
9
- "signboard", "chest of drawers", "counter", "sand", "sink",
10
- "skyscraper", "fireplace", "refrigerator", "grandstand",
11
- "path", "stairs", "runway", "case", "pool table", "pillow",
12
- "screen door", "stairway", "river", "bridge", "bookcase",
13
- "blind", "coffee table", "toilet", "flower", "book", "hill",
14
- "bench", "countertop", "stove", "palm", "kitchen island",
15
- "computer", "swivel chair", "boat", "bar", "arcade machine",
16
- "hovel", "bus", "towel", "light", "truck", "tower",
17
- "chandelier", "awning", "streetlight", "booth",
18
- "television receiver", "airplane", "dirt track", "apparel",
19
- "pole", "land", "bannister", "escalator", "ottoman", "bottle",
20
- "buffet", "poster", "stage", "van", "ship", "fountain",
21
- "conveyer belt", "canopy", "washer", "plaything",
22
- "swimming pool", "stool", "barrel", "basket", "waterfall",
23
- "tent", "bag", "minibike", "cradle", "oven", "ball", "food",
24
- "step", "tank", "trade name", "microwave", "pot", "animal",
25
- "bicycle", "lake", "dishwasher", "screen", "blanket",
26
- "sculpture", "hood", "sconce", "vase", "traffic light",
27
- "tray", "ashcan", "fan", "pier", "crt screen", "plate",
28
- "monitor", "bulletin board", "shower", "radiator", "glass",
29
- "clock", "flag"
30
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lisa_on_cuda/utils/conversation.py DELETED
@@ -1,308 +0,0 @@
1
- """
2
- Conversation prompt templates.
3
- """
4
-
5
- import dataclasses
6
- from enum import Enum, auto
7
- from typing import Any, List
8
-
9
-
10
- class SeparatorStyle(Enum):
11
- """Different separator style."""
12
-
13
- ADD_COLON_SINGLE = auto()
14
- ADD_COLON_TWO = auto()
15
- NO_COLON_SINGLE = auto()
16
- BAIZE = auto()
17
- DOLLY = auto()
18
- RWKV = auto()
19
-
20
-
21
- @dataclasses.dataclass
22
- class Conversation:
23
- """A class that keeps all conversation history."""
24
-
25
- # System prompts
26
- system: str
27
- # Two roles
28
- roles: List[str]
29
- # All messages
30
- messages: List[List[str]]
31
- # Offset of few shot examples
32
- offset: int
33
- # Separator
34
- sep_style: SeparatorStyle
35
- sep: str
36
- sep2: str = None
37
- # Stop criteria (the default one is EOS token)
38
- stop_str: str = None
39
- # Stops generation if meeting any token in this list
40
- stop_token_ids: List[int] = None
41
-
42
- # Used for the state in the gradio servers.
43
- # TODO(lmzheng): refactor this
44
- conv_id: Any = None
45
- skip_next: bool = False
46
- model_name: str = None
47
-
48
- def get_prompt(self):
49
- if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
50
- ret = self.system + self.sep
51
- for role, message in self.messages:
52
- if message:
53
- ret += role + ": " + message + self.sep
54
- else:
55
- ret += role + ":"
56
- return ret
57
- elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
58
- seps = [self.sep, self.sep2]
59
- ret = self.system + seps[0]
60
- for i, (role, message) in enumerate(self.messages):
61
- if message:
62
- ret += role + ": " + message + seps[i % 2]
63
- else:
64
- ret += role + ":"
65
- return ret
66
- elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
67
- ret = self.system
68
- for role, message in self.messages:
69
- if message:
70
- ret += role + message + self.sep
71
- else:
72
- ret += role
73
- return ret
74
- elif self.sep_style == SeparatorStyle.BAIZE:
75
- ret = self.system + "\n"
76
- for role, message in self.messages:
77
- if message:
78
- ret += role + message + "\n"
79
- else:
80
- ret += role
81
- return ret
82
- elif self.sep_style == SeparatorStyle.DOLLY:
83
- seps = [self.sep, self.sep2]
84
- ret = self.system
85
- for i, (role, message) in enumerate(self.messages):
86
- if message:
87
- ret += role + ":\n" + message + seps[i % 2]
88
- if i % 2 == 1:
89
- ret += "\n\n"
90
- else:
91
- ret += role + ":\n"
92
- return ret
93
- elif self.sep_style == SeparatorStyle.RWKV:
94
- ret = self.system
95
- for i, (role, message) in enumerate(self.messages):
96
- if message:
97
- ret += (
98
- role
99
- + ": "
100
- + message.replace("\r\n", "\n").replace("\n\n", "\n")
101
- )
102
- ret += "\n\n"
103
- else:
104
- ret += role + ":"
105
- return ret
106
- else:
107
- raise ValueError(f"Invalid style: {self.sep_style}")
108
-
109
- def append_message(self, role, message):
110
- self.messages.append([role, message])
111
-
112
- def to_gradio_chatbot(self):
113
- ret = []
114
- for i, (role, msg) in enumerate(self.messages[self.offset :]):
115
- if i % 2 == 0:
116
- ret.append([msg, None])
117
- else:
118
- ret[-1][-1] = msg
119
- return ret
120
-
121
- def copy(self):
122
- return Conversation(
123
- system=self.system,
124
- roles=self.roles,
125
- messages=[[x, y] for x, y in self.messages],
126
- offset=self.offset,
127
- sep_style=self.sep_style,
128
- sep=self.sep,
129
- sep2=self.sep2,
130
- stop_str=self.stop_str,
131
- stop_token_ids=self.stop_token_ids,
132
- conv_id=self.conv_id,
133
- model_name=self.model_name,
134
- )
135
-
136
- def dict(self):
137
- return {
138
- "system": self.system,
139
- "roles": self.roles,
140
- "messages": self.messages,
141
- "offset": self.offset,
142
- "conv_id": self.conv_id,
143
- "model_name": self.model_name,
144
- }
145
-
146
-
147
- # A template with one conversation example
148
- conv_one_shot = Conversation(
149
- system="A chat between a curious human and an artificial intelligence assistant. "
150
- "The assistant gives helpful, detailed, and polite answers to the human's questions.",
151
- roles=("Human", "Assistant"),
152
- messages=(
153
- (
154
- "Human",
155
- "What are the key differences between renewable and non-renewable energy sources?",
156
- ),
157
- (
158
- "Assistant",
159
- "Renewable energy sources are those that can be replenished naturally in a relatively "
160
- "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
161
- "Non-renewable energy sources, on the other hand, are finite and will eventually be "
162
- "depleted, such as coal, oil, and natural gas. Here are some key differences between "
163
- "renewable and non-renewable energy sources:\n"
164
- "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
165
- "energy sources are finite and will eventually run out.\n"
166
- "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
167
- "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
168
- "and other negative effects.\n"
169
- "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
170
- "have lower operational costs than non-renewable sources.\n"
171
- "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
172
- "locations than non-renewable sources.\n"
173
- "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
174
- "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
175
- "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
176
- "non-renewable sources are not, and their depletion can lead to economic and social instability.",
177
- ),
178
- ),
179
- offset=2,
180
- sep_style=SeparatorStyle.ADD_COLON_SINGLE,
181
- sep="\n### ",
182
- stop_str="###",
183
- )
184
-
185
-
186
- # Vicuna v1.1 template
187
- conv_vicuna_v1_1 = Conversation(
188
- system="A chat between a curious user and an artificial intelligence assistant. "
189
- "The assistant gives helpful, detailed, and polite answers to the user's questions.",
190
- roles=("USER", "ASSISTANT"),
191
- messages=(),
192
- offset=0,
193
- sep_style=SeparatorStyle.ADD_COLON_TWO,
194
- sep=" ",
195
- sep2="</s>",
196
- )
197
-
198
- # Koala default template
199
- conv_koala_v1 = Conversation(
200
- system="BEGINNING OF CONVERSATION:",
201
- roles=("USER", "GPT"),
202
- messages=(),
203
- offset=0,
204
- sep_style=SeparatorStyle.ADD_COLON_TWO,
205
- sep=" ",
206
- sep2="</s>",
207
- )
208
-
209
- # Dolly V2 default template
210
- conv_dolly = Conversation(
211
- system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
212
- roles=("### Instruction", "### Response"),
213
- messages=(),
214
- offset=0,
215
- sep_style=SeparatorStyle.DOLLY,
216
- sep="\n\n",
217
- sep2="### End",
218
- )
219
-
220
- # OpenAssistant Pythia default template
221
- conv_oasst = Conversation(
222
- system="",
223
- roles=("<|prompter|>", "<|assistant|>"),
224
- messages=(),
225
- offset=0,
226
- sep_style=SeparatorStyle.NO_COLON_SINGLE,
227
- sep="<|endoftext|>",
228
- )
229
-
230
- # StableLM Alpha default template
231
- conv_stablelm = Conversation(
232
- system="""<|SYSTEM|># StableLM Tuned (Alpha version)
233
- - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
234
- - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
235
- - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
236
- - StableLM will refuse to participate in anything that could harm a human.
237
- """,
238
- roles=("<|USER|>", "<|ASSISTANT|>"),
239
- messages=(),
240
- offset=0,
241
- sep_style=SeparatorStyle.NO_COLON_SINGLE,
242
- sep="",
243
- stop_token_ids=[50278, 50279, 50277, 1, 0],
244
- )
245
-
246
- # Baize default template
247
- conv_baize = Conversation(
248
- system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.",
249
- roles=("[|Human|]", "[|AI|]"),
250
- messages=(
251
- ("[|Human|]", "Hello!"),
252
- ("[|AI|]", "Hi!"),
253
- ),
254
- offset=2,
255
- sep_style=SeparatorStyle.BAIZE,
256
- sep="[|Human|]",
257
- stop_str="[|Human|]",
258
- )
259
-
260
- # RWKV-4-Raven default template
261
- conv_rwkv = Conversation(
262
- system="",
263
- roles=("Bob", "Alice"),
264
- messages=(),
265
- offset=0,
266
- sep_style=SeparatorStyle.RWKV,
267
- sep="",
268
- stop_str="\n\n",
269
- )
270
-
271
- conv_templates = {
272
- "baize": conv_baize,
273
- "conv_one_shot": conv_one_shot,
274
- "dolly": conv_dolly,
275
- "koala_v1": conv_koala_v1,
276
- "oasst": conv_oasst,
277
- "stablelm": conv_stablelm,
278
- "vicuna_v1.1": conv_vicuna_v1_1,
279
- "rwkv": conv_rwkv,
280
- }
281
-
282
-
283
- def get_default_conv_template(model_name):
284
- model_name = model_name.lower()
285
- if "vicuna" in model_name or "output" in model_name:
286
- return conv_vicuna_v1_1
287
- elif "koala" in model_name:
288
- return conv_koala_v1
289
- elif "dolly-v2" in model_name:
290
- return conv_dolly
291
- elif "oasst" in model_name and "pythia" in model_name:
292
- return conv_oasst
293
- elif "baize" in model_name:
294
- return conv_baize
295
- elif "stablelm" in model_name:
296
- return conv_stablelm
297
- elif "rwkv-4" in model_name:
298
- return conv_rwkv
299
- return conv_one_shot
300
-
301
-
302
- if __name__ == "__main__":
303
- conv = conv_templates["vicuna_v1.1"].copy()
304
- conv.append_message(conv.roles[0], "Hello!")
305
- conv.append_message(conv.roles[1], "Hi!")
306
- conv.append_message(conv.roles[0], "How are you?")
307
- conv.append_message(conv.roles[1], None)
308
- print(conv.get_prompt())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lisa_on_cuda/utils/create_folders_and_variables_if_not_exists.py DELETED
@@ -1,56 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- from pathlib import Path
5
-
6
-
7
- def stats_pathname(pathname: Path | str):
8
- current_pathname = Path(pathname)
9
- return current_pathname.is_dir()
10
-
11
-
12
- def create_folder_if_not_exists(pathname: Path | str):
13
- current_pathname = Path(pathname)
14
- try:
15
- print(f"Pathname exists? {current_pathname.exists()}, That's a folder? {current_pathname.is_dir()}...")
16
- logging.info(f"Pathname exists? {current_pathname.exists()}, That's a folder? {current_pathname.is_dir()}...")
17
- current_pathname.unlink(missing_ok=True)
18
- except PermissionError as pe:
19
- print(f"permission denied on removing pathname before folder creation:{pe}.")
20
- logging.error(f"permission denied on removing pathname before folder creation:{pe}.")
21
- except IsADirectoryError as errdir:
22
- print(f"that's a directory:{errdir}.")
23
- logging.error(f"that's a directory:{errdir}.")
24
-
25
- print(f"Creating pathname: {current_pathname} ...")
26
- logging.info(f"Creating pathname: {current_pathname} ...")
27
- current_pathname.mkdir(mode=0o770, parents=True, exist_ok=True)
28
-
29
- print(f"assertion: pathname exists and is a folder: {current_pathname} ...")
30
- logging.info(f"assertion: pathname exists and is a folder: {current_pathname} ...")
31
- assert current_pathname.is_dir()
32
-
33
-
34
- def folders_creation():
35
- folders_string = os.getenv("FOLDERS_MAP")
36
- try:
37
- folders_dict = json.loads(folders_string)
38
- for folder_env_ref, folder_env_path in folders_dict.items():
39
- print(f"folder_env_ref:{folder_env_ref}, folder_env_path:{folder_env_path}.")
40
- logging.info(f"folder_env_ref:{folder_env_ref}, folder_env_path:{folder_env_path}.")
41
- create_folder_if_not_exists(folder_env_path)
42
- print("========")
43
- assert os.getenv(folder_env_ref) == folder_env_path
44
- except (json.JSONDecodeError, TypeError) as jde:
45
- print(f"jde:{jde}.")
46
- logging.error(f"jde:{jde}.")
47
- print("double check your variables, e.g. for misspelling like 'FOLDER_MAP'...")
48
- logging.info("double check your variables, e.g. for misspelling like 'FOLDER_MAP' instead than 'FOLDERS_MAP'...")
49
- for k_env, v_env in dict(os.environ).items():
50
- print(f"{k_env}, v_env:{v_env}.")
51
- logging.info(f"{k_env}, v_env:{v_env}.")
52
-
53
-
54
- if __name__ == '__main__':
55
- folders_creation()
56
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lisa_on_cuda/utils/data_processing.py DELETED
@@ -1,90 +0,0 @@
1
- import glob
2
- import json
3
- import os
4
-
5
- import cv2
6
- import numpy as np
7
-
8
-
9
- def get_mask_from_json(json_path, img):
10
- try:
11
- with open(json_path, "r") as r:
12
- anno = json.loads(r.read())
13
- except:
14
- with open(json_path, "r", encoding="cp1252") as r:
15
- anno = json.loads(r.read())
16
-
17
- inform = anno["shapes"]
18
- comments = anno["text"]
19
- is_sentence = anno["is_sentence"]
20
-
21
- height, width = img.shape[:2]
22
-
23
- ### sort polies by area
24
- area_list = []
25
- valid_poly_list = []
26
- for i in inform:
27
- label_id = i["label"]
28
- points = i["points"]
29
- if "flag" == label_id.lower(): ## meaningless deprecated annotations
30
- continue
31
-
32
- tmp_mask = np.zeros((height, width), dtype=np.uint8)
33
- cv2.polylines(tmp_mask, np.array([points], dtype=np.int32), True, 1, 1)
34
- cv2.fillPoly(tmp_mask, np.array([points], dtype=np.int32), 1)
35
- tmp_area = tmp_mask.sum()
36
-
37
- area_list.append(tmp_area)
38
- valid_poly_list.append(i)
39
-
40
- ### ground-truth mask
41
- sort_index = np.argsort(area_list)[::-1].astype(np.int32)
42
- sort_index = list(sort_index)
43
- sort_inform = []
44
- for s_idx in sort_index:
45
- sort_inform.append(valid_poly_list[s_idx])
46
-
47
- mask = np.zeros((height, width), dtype=np.uint8)
48
- for i in sort_inform:
49
- label_id = i["label"]
50
- points = i["points"]
51
-
52
- if "ignore" in label_id.lower():
53
- label_value = 255 # ignored during evaluation
54
- else:
55
- label_value = 1 # target
56
-
57
- cv2.polylines(mask, np.array([points], dtype=np.int32), True, label_value, 1)
58
- cv2.fillPoly(mask, np.array([points], dtype=np.int32), label_value)
59
-
60
- return mask, comments, is_sentence
61
-
62
-
63
- if __name__ == "__main__":
64
- data_dir = "./train"
65
- vis_dir = "./vis"
66
-
67
- if not os.path.exists(vis_dir):
68
- os.makedirs(vis_dir)
69
-
70
- json_path_list = sorted(glob.glob(data_dir + "/*.json"))
71
- for json_path in json_path_list:
72
- img_path = json_path.replace(".json", ".jpg")
73
- img = cv2.imread(img_path)[:, :, ::-1]
74
-
75
- # In generated mask, value 1 denotes valid target region, and value 255 stands for region ignored during evaluaiton.
76
- mask, comments, is_sentence = get_mask_from_json(json_path, img)
77
-
78
- ## visualization. Green for target, and red for ignore.
79
- valid_mask = (mask == 1).astype(np.float32)[:, :, None]
80
- ignore_mask = (mask == 255).astype(np.float32)[:, :, None]
81
- vis_img = img * (1 - valid_mask) * (1 - ignore_mask) + (
82
- (np.array([0, 255, 0]) * 0.6 + img * 0.4) * valid_mask
83
- + (np.array([255, 0, 0]) * 0.6 + img * 0.4) * ignore_mask
84
- )
85
- vis_img = np.concatenate([img, vis_img], 1)
86
- vis_path = os.path.join(
87
- vis_dir, json_path.split("/")[-1].replace(".json", ".jpg")
88
- )
89
- cv2.imwrite(vis_path, vis_img[:, :, ::-1])
90
- print("Visualization has been saved to: ", vis_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lisa_on_cuda/utils/dataset.py DELETED
@@ -1,466 +0,0 @@
1
- import glob
2
- import os
3
- import random
4
-
5
- import cv2
6
- import numpy as np
7
- import torch
8
- import torch.nn.functional as F
9
- from pycocotools import mask
10
- from transformers import CLIPImageProcessor
11
-
12
- from lisa_on_cuda.llava import conversation as conversation_lib
13
- from lisa_on_cuda.llava.constants import (DEFAULT_IMAGE_TOKEN, IGNORE_INDEX,
14
- IMAGE_TOKEN_INDEX)
15
- from lisa_on_cuda.llava.mm_utils import tokenizer_image_token
16
- from lisa_on_cuda.segment_anything.utils.transforms import ResizeLongestSide
17
-
18
- from .conversation import get_default_conv_template
19
- from .data_processing import get_mask_from_json
20
- from .reason_seg_dataset import ReasonSegDataset
21
- from .refer import REFER
22
- from .refer_seg_dataset import ReferSegDataset
23
- from .sem_seg_dataset import SemSegDataset
24
- from .utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
25
- DEFAULT_IMAGE_TOKEN)
26
- from .vqa_dataset import VQADataset
27
-
28
-
29
- def collate_fn(
30
- batch, tokenizer=None, conv_type="llava_v1", use_mm_start_end=True, local_rank=-1
31
- ):
32
- image_path_list = []
33
- images_list = []
34
- images_clip_list = []
35
- conversation_list = []
36
- masks_list = []
37
- label_list = []
38
- resize_list = []
39
- questions_list = []
40
- sampled_classes_list = []
41
- offset_list = [0]
42
- cnt = 0
43
- inferences = []
44
- for (
45
- image_path,
46
- images,
47
- images_clip,
48
- conversations,
49
- masks,
50
- label,
51
- resize,
52
- questions,
53
- sampled_classes,
54
- inference,
55
- ) in batch:
56
- image_path_list.append(image_path)
57
- images_list.append(images)
58
- images_clip_list.append(images_clip)
59
- conversation_list.extend(conversations)
60
- label_list.append(label)
61
- masks_list.append(masks.float())
62
- resize_list.append(resize)
63
- questions_list.append(questions)
64
- sampled_classes_list.append(sampled_classes)
65
- cnt += len(conversations)
66
- offset_list.append(cnt)
67
- inferences.append(inference)
68
-
69
- if use_mm_start_end:
70
- # replace <image> token
71
- for i in range(len(conversation_list)):
72
- replace_token = DEFAULT_IMAGE_TOKEN
73
- replace_token = (
74
- DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
75
- )
76
- conversation_list[i] = conversation_list[i].replace(
77
- DEFAULT_IMAGE_TOKEN, replace_token
78
- )
79
- input_ids = [
80
- tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
81
- for prompt in conversation_list
82
- ]
83
- input_ids = torch.nn.utils.rnn.pad_sequence(
84
- input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
85
- )
86
- attention_masks = input_ids.ne(tokenizer.pad_token_id)
87
-
88
- conv = conversation_lib.default_conversation.copy()
89
- targets = input_ids.clone()
90
-
91
- if conv_type == "llava_v1":
92
- sep = conv.sep + conv.roles[1] + ": "
93
- else:
94
- sep = "[/INST] "
95
- for conversation, target in zip(conversation_list, targets):
96
- total_len = int(target.ne(tokenizer.pad_token_id).sum())
97
-
98
- rounds = conversation.split(conv.sep2)
99
- cur_len = 1
100
- target[:cur_len] = IGNORE_INDEX
101
- for i, rou in enumerate(rounds):
102
- if rou == "":
103
- break
104
-
105
- parts = rou.split(sep)
106
- # if len(parts) != 2:
107
- # break
108
- assert len(parts) == 2, (len(parts), rou)
109
- parts[0] += sep
110
-
111
- if DEFAULT_IMAGE_TOKEN in conversation:
112
- round_len = len(tokenizer_image_token(rou, tokenizer))
113
- instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
114
- else:
115
- round_len = len(tokenizer(rou).input_ids)
116
- instruction_len = len(tokenizer(parts[0]).input_ids) - 2
117
-
118
- target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
119
-
120
- cur_len += round_len
121
- target[cur_len:] = IGNORE_INDEX
122
-
123
- if False:
124
- z = target.clone()
125
- z = torch.where(z == IGNORE_INDEX, tokenizer.unk_token_id, z)
126
- if local_rank == 0:
127
- print(
128
- "conversation: ",
129
- conversation,
130
- "tokenizer.decode(z): ",
131
- tokenizer.decode(z),
132
- )
133
-
134
- if cur_len < tokenizer.model_max_length:
135
- assert cur_len == total_len
136
-
137
- if inferences[0] == False:
138
- truncate_len = tokenizer.model_max_length - 255
139
-
140
- if input_ids.shape[1] > truncate_len:
141
- input_ids = input_ids[:, :truncate_len]
142
- targets = targets[:, :truncate_len]
143
- attention_masks = attention_masks[:, :truncate_len]
144
-
145
- return {
146
- "image_paths": image_path_list,
147
- "images": torch.stack(images_list, dim=0),
148
- "images_clip": torch.stack(images_clip_list, dim=0),
149
- "input_ids": input_ids,
150
- "labels": targets,
151
- "attention_masks": attention_masks,
152
- "masks_list": masks_list,
153
- "label_list": label_list,
154
- "resize_list": resize_list,
155
- "offset": torch.LongTensor(offset_list),
156
- "questions_list": questions_list,
157
- "sampled_classes_list": sampled_classes_list,
158
- "inference": inferences[0],
159
- "conversation_list": conversation_list,
160
- }
161
-
162
-
163
- class HybridDataset(torch.utils.data.Dataset):
164
- pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
165
- pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
166
- img_size = 1024
167
- ignore_label = 255
168
-
169
- def __init__(
170
- self,
171
- base_image_dir,
172
- tokenizer,
173
- vision_tower,
174
- samples_per_epoch=500 * 8 * 2 * 10,
175
- precision: str = "fp32",
176
- image_size: int = 224,
177
- num_classes_per_sample: int = 3,
178
- exclude_val=False,
179
- dataset="sem_seg||refer_seg||vqa||reason_seg",
180
- sample_rate=[9, 3, 3, 1],
181
- sem_seg_data="ade20k||cocostuff||partimagenet||pascal_part||paco_lvis||mapillary",
182
- refer_seg_data="refclef||refcoco||refcoco+||refcocog",
183
- vqa_data="llava_instruct_150k",
184
- reason_seg_data="ReasonSeg|train",
185
- explanatory=0.1,
186
- ):
187
- self.exclude_val = exclude_val
188
- self.dataset = dataset
189
- self.samples_per_epoch = samples_per_epoch
190
- self.explanatory = explanatory
191
- self.num_classes_per_sample = num_classes_per_sample
192
- sample_rate = np.array(sample_rate)
193
- self.sample_rate = sample_rate / sample_rate.sum()
194
-
195
- self.base_image_dir = base_image_dir
196
- self.image_size = image_size
197
- self.tokenizer = tokenizer
198
- self.precision = precision
199
-
200
- self.datasets = dataset.split("||")
201
-
202
- self.all_datasets = []
203
- for dataset in self.datasets:
204
- if dataset == "sem_seg":
205
- self.all_datasets.append(
206
- SemSegDataset(
207
- base_image_dir,
208
- tokenizer,
209
- vision_tower,
210
- samples_per_epoch,
211
- precision,
212
- image_size,
213
- num_classes_per_sample,
214
- exclude_val,
215
- sem_seg_data,
216
- )
217
- )
218
- elif dataset == "refer_seg":
219
- self.all_datasets.append(
220
- ReferSegDataset(
221
- base_image_dir,
222
- tokenizer,
223
- vision_tower,
224
- samples_per_epoch,
225
- precision,
226
- image_size,
227
- num_classes_per_sample,
228
- exclude_val,
229
- refer_seg_data,
230
- )
231
- )
232
- elif dataset == "vqa":
233
- self.all_datasets.append(
234
- VQADataset(
235
- base_image_dir,
236
- tokenizer,
237
- vision_tower,
238
- samples_per_epoch,
239
- precision,
240
- image_size,
241
- num_classes_per_sample,
242
- exclude_val,
243
- vqa_data,
244
- )
245
- )
246
- elif dataset == "reason_seg":
247
- self.all_datasets.append(
248
- ReasonSegDataset(
249
- base_image_dir,
250
- tokenizer,
251
- vision_tower,
252
- samples_per_epoch,
253
- precision,
254
- image_size,
255
- num_classes_per_sample,
256
- exclude_val,
257
- reason_seg_data,
258
- explanatory,
259
- )
260
- )
261
-
262
- def __len__(self):
263
- return self.samples_per_epoch
264
-
265
- def __getitem__(self, idx):
266
- ind = np.random.choice(list(range(len(self.datasets))), p=self.sample_rate)
267
- data = self.all_datasets[ind]
268
- inference = False
269
- return *data[0], inference
270
-
271
-
272
- class ValDataset(torch.utils.data.Dataset):
273
- pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
274
- pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
275
- img_size = 1024
276
- ignore_label = 255
277
-
278
- def __init__(
279
- self,
280
- base_image_dir,
281
- tokenizer,
282
- vision_tower,
283
- val_dataset,
284
- image_size=1024,
285
- ):
286
- self.base_image_dir = base_image_dir
287
- splits = val_dataset.split("|")
288
- if len(splits) == 2:
289
- ds, split = splits
290
- images = glob.glob(
291
- os.path.join(self.base_image_dir, "reason_seg", ds, split, "*.jpg")
292
- )
293
- self.images = images
294
- self.data_type = "reason_seg"
295
- elif len(splits) == 3:
296
- ds, splitBy, split = splits
297
- refer_api = REFER(self.base_image_dir, ds, splitBy)
298
- ref_ids_val = refer_api.getRefIds(split=split)
299
- images_ids_val = refer_api.getImgIds(ref_ids=ref_ids_val)
300
- refs_val = refer_api.loadRefs(ref_ids=ref_ids_val)
301
- refer_seg_ds = {}
302
- refer_seg_ds["images"] = []
303
- loaded_images = refer_api.loadImgs(image_ids=images_ids_val)
304
- for item in loaded_images:
305
- item = item.copy()
306
- if ds == "refclef":
307
- item["file_name"] = os.path.join(
308
- base_image_dir, "images/saiapr_tc-12", item["file_name"]
309
- )
310
- elif ds in ["refcoco", "refcoco+", "refcocog", "grefcoco"]:
311
- item["file_name"] = os.path.join(
312
- base_image_dir,
313
- "images/mscoco/images/train2014",
314
- item["file_name"],
315
- )
316
- refer_seg_ds["images"].append(item)
317
- refer_seg_ds["annotations"] = refer_api.Anns # anns_val
318
-
319
- img2refs = {}
320
- for ref in refs_val:
321
- image_id = ref["image_id"]
322
- img2refs[image_id] = img2refs.get(image_id, []) + [
323
- ref,
324
- ]
325
- refer_seg_ds["img2refs"] = img2refs
326
- self.refer_seg_ds = refer_seg_ds
327
- self.data_type = "refer_seg"
328
-
329
- self.ds = ds
330
- self.image_size = image_size
331
- self.tokenizer = tokenizer
332
- self.transform = ResizeLongestSide(image_size)
333
- self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
334
-
335
- def __len__(self):
336
- if self.data_type == "refer_seg":
337
- return len(self.refer_seg_ds["images"])
338
- else:
339
- return len(self.images)
340
-
341
- def preprocess(self, x: torch.Tensor) -> torch.Tensor:
342
- """Normalize pixel values and pad to a square input."""
343
- # Normalize colors
344
- x = (x - self.pixel_mean) / self.pixel_std
345
-
346
- # Pad
347
- h, w = x.shape[-2:]
348
- padh = self.img_size - h
349
- padw = self.img_size - w
350
- x = F.pad(x, (0, padw, 0, padh))
351
- return x
352
-
353
- def __getitem__(self, idx):
354
- if self.data_type == "refer_seg":
355
- refer_seg_ds = self.refer_seg_ds
356
- images = refer_seg_ds["images"]
357
- annotations = refer_seg_ds["annotations"]
358
- img2refs = refer_seg_ds["img2refs"]
359
-
360
- image_info = images[idx]
361
- image_path = image_info["file_name"]
362
- image_id = image_info["id"]
363
-
364
- refs = img2refs[image_id]
365
- if len(refs) == 0:
366
- raise ValueError("image {} has no refs".format(image_id))
367
-
368
- sents = []
369
- ann_ids = []
370
- for ref in refs:
371
- for sent in ref["sentences"]:
372
- sents.append(sent["sent"].strip().lower())
373
- ann_ids.append(ref["ann_id"])
374
-
375
- sampled_sents = sents
376
- sampled_ann_ids = ann_ids
377
- image = cv2.imread(image_path)
378
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
379
- is_sentence = False
380
- else:
381
- image_path = self.images[idx]
382
- image = cv2.imread(image_path)
383
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
384
- json_path = image_path.replace(".jpg", ".json")
385
- mask_json, sampled_sents, is_sentence = get_mask_from_json(json_path, image)
386
- sampled_sents = [sampled_sents[0]]
387
-
388
- conversations = []
389
- conv = conversation_lib.default_conversation.copy()
390
- i = 0
391
- while i < len(sampled_sents):
392
- conv.messages = []
393
- text = sampled_sents[i].strip()
394
- if is_sentence:
395
- conv.append_message(
396
- conv.roles[0],
397
- DEFAULT_IMAGE_TOKEN
398
- + "\n {} Please output segmentation mask.".format(text),
399
- )
400
- conv.append_message(conv.roles[1], "[SEG].")
401
- else:
402
- conv.append_message(
403
- conv.roles[0],
404
- DEFAULT_IMAGE_TOKEN
405
- + "\n What is {} in this image? Please output segmentation mask.".format(
406
- text
407
- ),
408
- )
409
- conv.append_message(conv.roles[1], "[SEG].")
410
- conversations.append(conv.get_prompt())
411
- i += 1
412
-
413
- # preprocess image for clip
414
- image_clip = self.clip_image_processor.preprocess(image, return_tensors="pt")[
415
- "pixel_values"
416
- ][0]
417
-
418
- # preprocess image for sam
419
- image = self.transform.apply_image(image)
420
- resize = image.shape[:2]
421
- image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
422
-
423
- if self.data_type == "refer_seg":
424
- masks = []
425
- for i, ann_id in enumerate(sampled_ann_ids):
426
- ann = annotations[ann_id]
427
- if len(ann["segmentation"]) == 0 and sampled_sents[i] != "":
428
- m = np.zeros((image_info["height"], image_info["width"], 1))
429
- else:
430
- if type(ann["segmentation"][0]) == list: # polygon
431
- rle = mask.frPyObjects(
432
- ann["segmentation"],
433
- image_info["height"],
434
- image_info["width"],
435
- )
436
- else:
437
- rle = ann["segmentation"]
438
- for i in range(len(rle)):
439
- if not isinstance(rle[i]["counts"], bytes):
440
- rle[i]["counts"] = rle[i]["counts"].encode()
441
- m = mask.decode(rle)
442
- m = np.sum(
443
- m, axis=2
444
- ) # sometimes there are multiple binary map (corresponding to multiple segs)
445
- m = m.astype(np.uint8) # convert to np.uint8
446
- masks.append(m)
447
- else:
448
- masks = [mask_json]
449
-
450
- masks = np.stack(masks, axis=0)
451
- masks = torch.from_numpy(masks)
452
- labels = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
453
- inference = True
454
-
455
- return (
456
- image_path,
457
- image,
458
- image_clip,
459
- conversations,
460
- masks,
461
- labels,
462
- resize,
463
- None,
464
- None,
465
- inference,
466
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lisa_on_cuda/utils/frontend_builder.py DELETED
@@ -1,89 +0,0 @@
1
- import logging
2
- import os
3
- import subprocess
4
- from pathlib import Path
5
-
6
- from dotenv import load_dotenv
7
-
8
- from lisa_on_cuda.utils import session_logger
9
-
10
-
11
- load_dotenv()
12
- LOGLEVEL = os.getenv('LOGLEVEL', 'INFO').upper()
13
- session_logger.change_logging(LOGLEVEL)
14
- root_folder = Path(globals().get("__file__", "./_")).absolute().parent.parent.parent
15
- env_project_root_folder = Path(os.getenv("PROJECT_ROOT_FOLDER", root_folder))
16
- env_input_css_path = Path(os.getenv("INPUT_CSS_PATH"))
17
-
18
-
19
- def assert_envs(envs_list):
20
- for current_env in envs_list:
21
- try:
22
- assert current_env is not None and current_env != ""
23
- except AssertionError as aex:
24
- logging.error(f"error on assertion for current_env: {current_env}.")
25
- raise aex
26
-
27
-
28
- def read_std_out_err(std_out_err, output_type: str, command: list):
29
- output = std_out_err.split("\n")
30
- logging.info(f"output type:{output_type} for command:{' '.join(command)}.")
31
- for line in iter(output):
32
- logging.info(f"output_content_home stdout:{line.strip()}.")
33
- logging.info("########")
34
-
35
-
36
- def run_command(commands_list: list, capture_output: bool = True, text: bool = True, check: bool = True) -> None:
37
- try:
38
- output_content_home = subprocess.run(
39
- commands_list,
40
- capture_output=capture_output,
41
- text=text,
42
- check=check
43
- )
44
- read_std_out_err(output_content_home.stdout, "stdout", commands_list)
45
- read_std_out_err(output_content_home.stderr, "stderr", commands_list)
46
- except Exception as ex:
47
- logging.error(f"ex:{ex}.")
48
- raise ex
49
-
50
-
51
- def build_frontend(
52
- project_root_folder: str | Path,
53
- input_css_path: str | Path,
54
- output_dist_folder: Path = root_folder / "static" / "dist",
55
- ) -> None:
56
- assert_envs([
57
- str(project_root_folder),
58
- str(input_css_path)
59
- ])
60
-
61
- # install deps
62
- os.chdir(Path(project_root_folder) / "static")
63
- current_folder = os.getcwd()
64
- logging.info(f"current_folder:{current_folder}, install pnpm...")
65
- run_command(["npm", "install", "-g", "npm", "pnpm"])
66
- logging.info(f"install pnpm dependencies...")
67
- run_command(["pnpm", "install"])
68
-
69
- # build frontend dist and assert for its correct build
70
- output_css = str(output_dist_folder / "output.css")
71
- output_index_html = str(output_dist_folder / "index.html")
72
- output_dist_folder = str(output_dist_folder)
73
- logging.info(f"pnpm: build '{output_dist_folder}'...")
74
- run_command(["pnpm", "build"])
75
- logging.info(f"pnpm: ls -l {output_index_html}:")
76
- run_command(["ls", "-l", output_index_html])
77
- cmd = ["pnpm", "tailwindcss", "-i", str(input_css_path), "-o", output_css]
78
- logging.info(f"pnpm: {' '.join(cmd)}...")
79
- run_command(["pnpm", "tailwindcss", "-i", str(input_css_path), "-o", output_css])
80
- logging.info(f"pnpm: ls -l {output_css}:")
81
- run_command(["ls", "-l", output_css])
82
- logging.info(f"end!")
83
-
84
-
85
- if __name__ == '__main__':
86
- build_frontend(
87
- project_root_folder=env_project_root_folder,
88
- input_css_path=env_input_css_path
89
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lisa_on_cuda/utils/grefcoco.py DELETED
@@ -1,198 +0,0 @@
1
- import contextlib
2
- import copy
3
- import io
4
- import logging
5
- import os
6
- import random
7
-
8
- import numpy as np
9
- import pycocotools.mask as mask_util
10
- from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes
11
- from detectron2.utils.file_io import PathManager
12
- from fvcore.common.timer import Timer
13
- from PIL import Image
14
-
15
- """
16
- This file contains functions to parse RefCOCO-format annotations into dicts in "Detectron2 format".
17
- """
18
-
19
-
20
- logger = logging.getLogger(__name__)
21
-
22
- __all__ = ["load_refcoco_json"]
23
-
24
-
25
- def load_grefcoco_json(
26
- refer_root,
27
- dataset_name,
28
- splitby,
29
- split,
30
- image_root,
31
- extra_annotation_keys=None,
32
- extra_refer_keys=None,
33
- ):
34
- if dataset_name == "refcocop":
35
- dataset_name = "refcoco+"
36
- if dataset_name == "refcoco" or dataset_name == "refcoco+":
37
- splitby == "unc"
38
- if dataset_name == "refcocog":
39
- assert splitby == "umd" or splitby == "google"
40
-
41
- dataset_id = "_".join([dataset_name, splitby, split])
42
-
43
- from .grefer import G_REFER
44
-
45
- logger.info("Loading dataset {} ({}-{}) ...".format(dataset_name, splitby, split))
46
- logger.info("Refcoco root: {}".format(refer_root))
47
- timer = Timer()
48
- refer_root = PathManager.get_local_path(refer_root)
49
- with contextlib.redirect_stdout(io.StringIO()):
50
- refer_api = G_REFER(data_root=refer_root, dataset=dataset_name, splitBy=splitby)
51
- if timer.seconds() > 1:
52
- logger.info(
53
- "Loading {} takes {:.2f} seconds.".format(dataset_id, timer.seconds())
54
- )
55
-
56
- ref_ids = refer_api.getRefIds(split=split)
57
- img_ids = refer_api.getImgIds(ref_ids)
58
- refs = refer_api.loadRefs(ref_ids)
59
- imgs = [refer_api.loadImgs(ref["image_id"])[0] for ref in refs]
60
- anns = [refer_api.loadAnns(ref["ann_id"]) for ref in refs]
61
- imgs_refs_anns = list(zip(imgs, refs, anns))
62
-
63
- logger.info(
64
- "Loaded {} images, {} referring object sets in G_RefCOCO format from {}".format(
65
- len(img_ids), len(ref_ids), dataset_id
66
- )
67
- )
68
-
69
- dataset_dicts = []
70
-
71
- ann_keys = ["iscrowd", "bbox", "category_id"] + (extra_annotation_keys or [])
72
- ref_keys = ["raw", "sent_id"] + (extra_refer_keys or [])
73
-
74
- ann_lib = {}
75
-
76
- NT_count = 0
77
- MT_count = 0
78
-
79
- for img_dict, ref_dict, anno_dicts in imgs_refs_anns:
80
- record = {}
81
- record["source"] = "grefcoco"
82
- record["file_name"] = os.path.join(image_root, img_dict["file_name"])
83
- record["height"] = img_dict["height"]
84
- record["width"] = img_dict["width"]
85
- image_id = record["image_id"] = img_dict["id"]
86
-
87
- # Check that information of image, ann and ref match each other
88
- # This fails only when the data parsing logic or the annotation file is buggy.
89
- assert ref_dict["image_id"] == image_id
90
- assert ref_dict["split"] == split
91
- if not isinstance(ref_dict["ann_id"], list):
92
- ref_dict["ann_id"] = [ref_dict["ann_id"]]
93
-
94
- # No target samples
95
- if None in anno_dicts:
96
- assert anno_dicts == [None]
97
- assert ref_dict["ann_id"] == [-1]
98
- record["empty"] = True
99
- obj = {key: None for key in ann_keys if key in ann_keys}
100
- obj["bbox_mode"] = BoxMode.XYWH_ABS
101
- obj["empty"] = True
102
- obj = [obj]
103
-
104
- # Multi target samples
105
- else:
106
- record["empty"] = False
107
- obj = []
108
- for anno_dict in anno_dicts:
109
- ann_id = anno_dict["id"]
110
- if anno_dict["iscrowd"]:
111
- continue
112
- assert anno_dict["image_id"] == image_id
113
- assert ann_id in ref_dict["ann_id"]
114
-
115
- if ann_id in ann_lib:
116
- ann = ann_lib[ann_id]
117
- else:
118
- ann = {key: anno_dict[key] for key in ann_keys if key in anno_dict}
119
- ann["bbox_mode"] = BoxMode.XYWH_ABS
120
- ann["empty"] = False
121
-
122
- segm = anno_dict.get("segmentation", None)
123
- assert segm # either list[list[float]] or dict(RLE)
124
- if isinstance(segm, dict):
125
- if isinstance(segm["counts"], list):
126
- # convert to compressed RLE
127
- segm = mask_util.frPyObjects(segm, *segm["size"])
128
- else:
129
- # filter out invalid polygons (< 3 points)
130
- segm = [
131
- poly
132
- for poly in segm
133
- if len(poly) % 2 == 0 and len(poly) >= 6
134
- ]
135
- if len(segm) == 0:
136
- num_instances_without_valid_segmentation += 1
137
- continue # ignore this instance
138
- ann["segmentation"] = segm
139
- ann_lib[ann_id] = ann
140
-
141
- obj.append(ann)
142
-
143
- record["annotations"] = obj
144
-
145
- # Process referring expressions
146
- sents = ref_dict["sentences"]
147
- for sent in sents:
148
- ref_record = record.copy()
149
- ref = {key: sent[key] for key in ref_keys if key in sent}
150
- ref["ref_id"] = ref_dict["ref_id"]
151
- ref_record["sentence"] = ref
152
- dataset_dicts.append(ref_record)
153
- # if ref_record['empty']:
154
- # NT_count += 1
155
- # else:
156
- # MT_count += 1
157
-
158
- # logger.info("NT samples: %d, MT samples: %d", NT_count, MT_count)
159
-
160
- # Debug mode
161
- # return dataset_dicts[:100]
162
-
163
- return dataset_dicts
164
-
165
-
166
- if __name__ == "__main__":
167
- """
168
- Test the COCO json dataset loader.
169
-
170
- Usage:
171
- python -m detectron2.data.datasets.coco \
172
- path/to/json path/to/image_root dataset_name
173
-
174
- "dataset_name" can be "coco_2014_minival_100", or other
175
- pre-registered ones
176
- """
177
- import sys
178
-
179
- import detectron2.data.datasets # noqa # add pre-defined metadata
180
- from detectron2.utils.logger import setup_logger
181
- from detectron2.utils.visualizer import Visualizer
182
-
183
- REFCOCO_PATH = "/mnt/lustre/hhding/code/ReLA/datasets"
184
- COCO_TRAIN_2014_IMAGE_ROOT = "/mnt/lustre/hhding/code/ReLA/datasets/images"
185
- REFCOCO_DATASET = "grefcoco"
186
- REFCOCO_SPLITBY = "unc"
187
- REFCOCO_SPLIT = "train"
188
-
189
- logger = setup_logger(name=__name__)
190
-
191
- dicts = load_grefcoco_json(
192
- REFCOCO_PATH,
193
- REFCOCO_DATASET,
194
- REFCOCO_SPLITBY,
195
- REFCOCO_SPLIT,
196
- COCO_TRAIN_2014_IMAGE_ROOT,
197
- )
198
- logger.info("Done loading {} samples.".format(len(dicts)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lisa_on_cuda/utils/grefer.py DELETED
@@ -1,352 +0,0 @@
1
- """
2
- grefer v0.1
3
- This interface provides access to gRefCOCO.
4
-
5
- The following API functions are defined:
6
- G_REFER - REFER api class
7
- getRefIds - get ref ids that satisfy given filter conditions.
8
- getAnnIds - get ann ids that satisfy given filter conditions.
9
- getImgIds - get image ids that satisfy given filter conditions.
10
- getCatIds - get category ids that satisfy given filter conditions.
11
- loadRefs - load refs with the specified ref ids.
12
- loadAnns - load anns with the specified ann ids.
13
- loadImgs - load images with the specified image ids.
14
- loadCats - load category names with the specified category ids.
15
- getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
16
- showRef - show image, segmentation or box of the referred object with the ref
17
- getMaskByRef - get mask and area of the referred object given ref or ref ids
18
- getMask - get mask and area of the referred object given ref
19
- showMask - show mask of the referred object given ref
20
- """
21
-
22
- import itertools
23
- import json
24
- import os.path as osp
25
- import pickle
26
- import time
27
-
28
- import matplotlib.pyplot as plt
29
- import numpy as np
30
- import skimage.io as io
31
- from matplotlib.collections import PatchCollection
32
- from matplotlib.patches import Polygon, Rectangle
33
- from pycocotools import mask
34
-
35
-
36
- class G_REFER:
37
- def __init__(self, data_root, dataset="grefcoco", splitBy="unc"):
38
- # provide data_root folder which contains grefcoco
39
- print("loading dataset %s into memory..." % dataset)
40
- self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
41
- self.DATA_DIR = osp.join(data_root, dataset)
42
- if dataset in ["grefcoco"]:
43
- self.IMAGE_DIR = osp.join(data_root, "images/train2014")
44
- else:
45
- raise KeyError("No refer dataset is called [%s]" % dataset)
46
-
47
- tic = time.time()
48
-
49
- # load refs from data/dataset/refs(dataset).json
50
- self.data = {}
51
- self.data["dataset"] = dataset
52
-
53
- ref_file = osp.join(self.DATA_DIR, f"grefs({splitBy}).p")
54
- if osp.exists(ref_file):
55
- self.data["refs"] = pickle.load(open(ref_file, "rb"), fix_imports=True)
56
- else:
57
- ref_file = osp.join(self.DATA_DIR, f"grefs({splitBy}).json")
58
- if osp.exists(ref_file):
59
- self.data["refs"] = json.load(open(ref_file, "rb"))
60
- else:
61
- raise FileNotFoundError("JSON file not found")
62
-
63
- # load annotations from data/dataset/instances.json
64
- instances_file = osp.join(self.DATA_DIR, "instances.json")
65
- instances = json.load(open(instances_file, "r"))
66
- self.data["images"] = instances["images"]
67
- self.data["annotations"] = instances["annotations"]
68
- self.data["categories"] = instances["categories"]
69
-
70
- # create index
71
- self.createIndex()
72
- print("DONE (t=%.2fs)" % (time.time() - tic))
73
-
74
- @staticmethod
75
- def _toList(x):
76
- return x if isinstance(x, list) else [x]
77
-
78
- @staticmethod
79
- def match_any(a, b):
80
- a = a if isinstance(a, list) else [a]
81
- b = b if isinstance(b, list) else [b]
82
- return set(a) & set(b)
83
-
84
- def createIndex(self):
85
- # create sets of mapping
86
- # 1) Refs: {ref_id: ref}
87
- # 2) Anns: {ann_id: ann}
88
- # 3) Imgs: {image_id: image}
89
- # 4) Cats: {category_id: category_name}
90
- # 5) Sents: {sent_id: sent}
91
- # 6) imgToRefs: {image_id: refs}
92
- # 7) imgToAnns: {image_id: anns}
93
- # 8) refToAnn: {ref_id: ann}
94
- # 9) annToRef: {ann_id: ref}
95
- # 10) catToRefs: {category_id: refs}
96
- # 11) sentToRef: {sent_id: ref}
97
- # 12) sentToTokens: {sent_id: tokens}
98
- print("creating index...")
99
- # fetch info from instances
100
- Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
101
- Anns[-1] = None
102
- for ann in self.data["annotations"]:
103
- Anns[ann["id"]] = ann
104
- imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann]
105
- for img in self.data["images"]:
106
- Imgs[img["id"]] = img
107
- for cat in self.data["categories"]:
108
- Cats[cat["id"]] = cat["name"]
109
-
110
- # fetch info from refs
111
- Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
112
- Sents, sentToRef, sentToTokens = {}, {}, {}
113
- availableSplits = []
114
- for ref in self.data["refs"]:
115
- # ids
116
- ref_id = ref["ref_id"]
117
- ann_id = ref["ann_id"]
118
- category_id = ref["category_id"]
119
- image_id = ref["image_id"]
120
-
121
- if ref["split"] not in availableSplits:
122
- availableSplits.append(ref["split"])
123
-
124
- # add mapping related to ref
125
- if ref_id in Refs:
126
- print("Duplicate ref id")
127
- Refs[ref_id] = ref
128
- imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
129
-
130
- category_id = self._toList(category_id)
131
- added_cats = []
132
- for cat in category_id:
133
- if cat not in added_cats:
134
- added_cats.append(cat)
135
- catToRefs[cat] = catToRefs.get(cat, []) + [ref]
136
-
137
- ann_id = self._toList(ann_id)
138
- refToAnn[ref_id] = [Anns[ann] for ann in ann_id]
139
- for ann_id_n in ann_id:
140
- annToRef[ann_id_n] = annToRef.get(ann_id_n, []) + [ref]
141
-
142
- # add mapping of sent
143
- for sent in ref["sentences"]:
144
- Sents[sent["sent_id"]] = sent
145
- sentToRef[sent["sent_id"]] = ref
146
- sentToTokens[sent["sent_id"]] = sent["tokens"]
147
-
148
- # create class members
149
- self.Refs = Refs
150
- self.Anns = Anns
151
- self.Imgs = Imgs
152
- self.Cats = Cats
153
- self.Sents = Sents
154
- self.imgToRefs = imgToRefs
155
- self.imgToAnns = imgToAnns
156
- self.refToAnn = refToAnn
157
- self.annToRef = annToRef
158
- self.catToRefs = catToRefs
159
- self.sentToRef = sentToRef
160
- self.sentToTokens = sentToTokens
161
- self.availableSplits = availableSplits
162
- print("index created.")
163
-
164
- def getRefIds(self, image_ids=[], cat_ids=[], split=[]):
165
- image_ids = self._toList(image_ids)
166
- cat_ids = self._toList(cat_ids)
167
- split = self._toList(split)
168
-
169
- for s in split:
170
- if s not in self.availableSplits:
171
- raise ValueError(f"Invalid split name: {s}")
172
-
173
- refs = self.data["refs"]
174
-
175
- if len(image_ids) > 0:
176
- lists = [self.imgToRefs[image_id] for image_id in image_ids]
177
- refs = list(itertools.chain.from_iterable(lists))
178
- if len(cat_ids) > 0:
179
- refs = [ref for ref in refs if self.match_any(ref["category_id"], cat_ids)]
180
- if len(split) > 0:
181
- refs = [ref for ref in refs if ref["split"] in split]
182
-
183
- ref_ids = [ref["ref_id"] for ref in refs]
184
- return ref_ids
185
-
186
- def getAnnIds(self, image_ids=[], ref_ids=[]):
187
- image_ids = self._toList(image_ids)
188
- ref_ids = self._toList(ref_ids)
189
-
190
- if any([len(image_ids), len(ref_ids)]):
191
- if len(image_ids) > 0:
192
- lists = [
193
- self.imgToAnns[image_id]
194
- for image_id in image_ids
195
- if image_id in self.imgToAnns
196
- ]
197
- anns = list(itertools.chain.from_iterable(lists))
198
- else:
199
- anns = self.data["annotations"]
200
- ann_ids = [ann["id"] for ann in anns]
201
- if len(ref_ids) > 0:
202
- lists = [self.Refs[ref_id]["ann_id"] for ref_id in ref_ids]
203
- anns_by_ref_id = list(itertools.chain.from_iterable(lists))
204
- ann_ids = list(set(ann_ids).intersection(set(anns_by_ref_id)))
205
- else:
206
- ann_ids = [ann["id"] for ann in self.data["annotations"]]
207
-
208
- return ann_ids
209
-
210
- def getImgIds(self, ref_ids=[]):
211
- ref_ids = self._toList(ref_ids)
212
-
213
- if len(ref_ids) > 0:
214
- image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids]))
215
- else:
216
- image_ids = self.Imgs.keys()
217
- return image_ids
218
-
219
- def getCatIds(self):
220
- return self.Cats.keys()
221
-
222
- def loadRefs(self, ref_ids=[]):
223
- return [self.Refs[ref_id] for ref_id in self._toList(ref_ids)]
224
-
225
- def loadAnns(self, ann_ids=[]):
226
- if isinstance(ann_ids, str):
227
- ann_ids = int(ann_ids)
228
- return [self.Anns[ann_id] for ann_id in self._toList(ann_ids)]
229
-
230
- def loadImgs(self, image_ids=[]):
231
- return [self.Imgs[image_id] for image_id in self._toList(image_ids)]
232
-
233
- def loadCats(self, cat_ids=[]):
234
- return [self.Cats[cat_id] for cat_id in self._toList(cat_ids)]
235
-
236
- def getRefBox(self, ref_id):
237
- anns = self.refToAnn[ref_id]
238
- return [ann["bbox"] for ann in anns] # [x, y, w, h]
239
-
240
- def showRef(self, ref, seg_box="seg"):
241
- ax = plt.gca()
242
- # show image
243
- image = self.Imgs[ref["image_id"]]
244
- I = io.imread(osp.join(self.IMAGE_DIR, image["file_name"]))
245
- ax.imshow(I)
246
- # show refer expression
247
- for sid, sent in enumerate(ref["sentences"]):
248
- print("%s. %s" % (sid + 1, sent["sent"]))
249
- # show segmentations
250
- if seg_box == "seg":
251
- ann_id = ref["ann_id"]
252
- ann = self.Anns[ann_id]
253
- polygons = []
254
- color = []
255
- c = "none"
256
- if type(ann["segmentation"][0]) == list:
257
- # polygon used for refcoco*
258
- for seg in ann["segmentation"]:
259
- poly = np.array(seg).reshape((len(seg) / 2, 2))
260
- polygons.append(Polygon(poly, True, alpha=0.4))
261
- color.append(c)
262
- p = PatchCollection(
263
- polygons,
264
- facecolors=color,
265
- edgecolors=(1, 1, 0, 0),
266
- linewidths=3,
267
- alpha=1,
268
- )
269
- ax.add_collection(p) # thick yellow polygon
270
- p = PatchCollection(
271
- polygons,
272
- facecolors=color,
273
- edgecolors=(1, 0, 0, 0),
274
- linewidths=1,
275
- alpha=1,
276
- )
277
- ax.add_collection(p) # thin red polygon
278
- else:
279
- # mask used for refclef
280
- rle = ann["segmentation"]
281
- m = mask.decode(rle)
282
- img = np.ones((m.shape[0], m.shape[1], 3))
283
- color_mask = np.array([2.0, 166.0, 101.0]) / 255
284
- for i in range(3):
285
- img[:, :, i] = color_mask[i]
286
- ax.imshow(np.dstack((img, m * 0.5)))
287
- # show bounding-box
288
- elif seg_box == "box":
289
- ann_id = ref["ann_id"]
290
- ann = self.Anns[ann_id]
291
- bbox = self.getRefBox(ref["ref_id"])
292
- box_plot = Rectangle(
293
- (bbox[0], bbox[1]),
294
- bbox[2],
295
- bbox[3],
296
- fill=False,
297
- edgecolor="green",
298
- linewidth=3,
299
- )
300
- ax.add_patch(box_plot)
301
-
302
- def getMask(self, ann):
303
- if not ann:
304
- return None
305
- if ann["iscrowd"]:
306
- raise ValueError("Crowd object")
307
- image = self.Imgs[ann["image_id"]]
308
- if type(ann["segmentation"][0]) == list: # polygon
309
- rle = mask.frPyObjects(ann["segmentation"], image["height"], image["width"])
310
- else:
311
- rle = ann["segmentation"]
312
-
313
- m = mask.decode(rle)
314
- m = np.sum(
315
- m, axis=2
316
- ) # sometimes there are multiple binary map (corresponding to multiple segs)
317
- m = m.astype(np.uint8) # convert to np.uint8
318
- # compute area
319
- area = sum(mask.area(rle)) # should be close to ann['area']
320
- return {"mask": m, "area": area}
321
-
322
- def getMaskByRef(self, ref=None, ref_id=None, merge=False):
323
- if not ref and not ref_id:
324
- raise ValueError
325
- if ref:
326
- ann_ids = ref["ann_id"]
327
- ref_id = ref["ref_id"]
328
- else:
329
- ann_ids = self.getAnnIds(ref_ids=ref_id)
330
-
331
- if ann_ids == [-1]:
332
- img = self.Imgs[self.Refs[ref_id]["image_id"]]
333
- return {
334
- "mask": np.zeros([img["height"], img["width"]], dtype=np.uint8),
335
- "empty": True,
336
- }
337
-
338
- anns = self.loadAnns(ann_ids)
339
- mask_list = [self.getMask(ann) for ann in anns if not ann["iscrowd"]]
340
-
341
- if merge:
342
- merged_masks = sum([mask["mask"] for mask in mask_list])
343
- merged_masks[np.where(merged_masks > 1)] = 1
344
- return {"mask": merged_masks, "empty": False}
345
- else:
346
- return mask_list
347
-
348
- def showMask(self, ref):
349
- M = self.getMask(ref)
350
- msk = M["mask"]
351
- ax = plt.gca()
352
- ax.imshow(msk)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lisa_on_cuda/utils/reason_seg_dataset.py DELETED
@@ -1,218 +0,0 @@
1
- import glob
2
- import json
3
- import os
4
- import random
5
-
6
- import cv2
7
- import numpy as np
8
- import torch
9
- import torch.nn.functional as F
10
- from transformers import CLIPImageProcessor
11
-
12
- from lisa_on_cuda.llava import conversation as conversation_lib
13
- from lisa_on_cuda.segment_anything.utils.transforms import ResizeLongestSide
14
-
15
- from .data_processing import get_mask_from_json
16
- from .utils import (ANSWER_LIST, DEFAULT_IMAGE_TOKEN,
17
- EXPLANATORY_QUESTION_LIST, LONG_QUESTION_LIST,
18
- SHORT_QUESTION_LIST)
19
-
20
-
21
- class ReasonSegDataset(torch.utils.data.Dataset):
22
- pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
23
- pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
24
- img_size = 1024
25
- ignore_label = 255
26
-
27
- def __init__(
28
- self,
29
- base_image_dir,
30
- tokenizer,
31
- vision_tower,
32
- samples_per_epoch=500 * 8 * 2 * 10,
33
- precision: str = "fp32",
34
- image_size: int = 224,
35
- num_classes_per_sample: int = 3,
36
- exclude_val=False,
37
- reason_seg_data="ReasonSeg|train",
38
- explanatory=0.1,
39
- ):
40
- self.exclude_val = exclude_val
41
- self.reason_seg_data = reason_seg_data
42
- self.samples_per_epoch = samples_per_epoch
43
- self.explanatory = explanatory
44
- self.num_classes_per_sample = num_classes_per_sample
45
-
46
- self.base_image_dir = base_image_dir
47
- self.image_size = image_size
48
- self.tokenizer = tokenizer
49
- self.precision = precision
50
- self.transform = ResizeLongestSide(image_size)
51
- self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
52
-
53
- self.short_question_list = SHORT_QUESTION_LIST
54
- self.long_question_list = LONG_QUESTION_LIST
55
- self.answer_list = ANSWER_LIST
56
-
57
- reason_seg_data, splits = reason_seg_data.split("|")
58
- splits = splits.split("_")
59
- images = []
60
- for split in splits:
61
- images_split = glob.glob(
62
- os.path.join(
63
- base_image_dir, "reason_seg", reason_seg_data, split, "*.jpg"
64
- )
65
- )
66
- images.extend(images_split)
67
- jsons = [path.replace(".jpg", ".json") for path in images]
68
- self.reason_seg_data = (images, jsons)
69
-
70
- print("number of reason_seg samples: ", len(images))
71
-
72
- if explanatory != -1:
73
- self.explanatory_question_list = EXPLANATORY_QUESTION_LIST
74
- self.img_to_explanation = {}
75
- with open(
76
- os.path.join(
77
- base_image_dir,
78
- "reason_seg",
79
- reason_seg_data,
80
- "explanatory",
81
- "train.json",
82
- )
83
- ) as f:
84
- items = json.load(f)
85
- for item in items:
86
- img_name = item["image"]
87
- self.img_to_explanation[img_name] = {
88
- "query": item["query"],
89
- "outputs": item["outputs"],
90
- }
91
-
92
- print("len(self.img_to_explanation): ", len(self.img_to_explanation))
93
-
94
- def __len__(self):
95
- return self.samples_per_epoch
96
-
97
- def preprocess(self, x: torch.Tensor) -> torch.Tensor:
98
- """Normalize pixel values and pad to a square input."""
99
- # Normalize colors
100
- x = (x - self.pixel_mean) / self.pixel_std
101
-
102
- # Pad
103
- h, w = x.shape[-2:]
104
- padh = self.img_size - h
105
- padw = self.img_size - w
106
- x = F.pad(x, (0, padw, 0, padh))
107
- return x
108
-
109
- def __getitem__(self, idx):
110
- images, jsons = self.reason_seg_data
111
- idx = random.randint(0, len(images) - 1)
112
- image_path = images[idx]
113
- json_path = jsons[idx]
114
-
115
- image = cv2.imread(image_path)
116
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
117
- ori_size = image.shape[:2]
118
- # preprocess image for clip
119
- image_clip = self.clip_image_processor.preprocess(image, return_tensors="pt")[
120
- "pixel_values"
121
- ][0]
122
-
123
- mask, sents, is_sentence = get_mask_from_json(json_path, image)
124
- if len(sents) >= self.num_classes_per_sample:
125
- sampled_inds = np.random.choice(
126
- list(range(len(sents))), size=self.num_classes_per_sample, replace=False
127
- )
128
- else:
129
- sampled_inds = list(range(len(sents)))
130
- sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist()
131
- sampled_masks = [
132
- (mask == 1).astype(np.float32) for _ in range(len(sampled_inds))
133
- ]
134
-
135
- image = self.transform.apply_image(image) # preprocess image for sam
136
- resize = image.shape[:2]
137
-
138
- image_name = image_path.split("/")[-1]
139
- if self.explanatory != -1 and image_name in self.img_to_explanation:
140
- if random.random() < self.explanatory:
141
- choice = 2
142
- else:
143
- choice = random.randint(0, 1)
144
-
145
- questions = []
146
- answers = []
147
- for text in sampled_sents:
148
- if is_sentence:
149
- question_template = random.choice(self.long_question_list)
150
- questions.append(question_template.format(sent=text))
151
- else:
152
- question_template = random.choice(self.short_question_list)
153
- questions.append(question_template.format(class_name=text.lower()))
154
-
155
- # add explanation if applicable
156
- img_name = image_path.split("/")[-1]
157
- if self.explanatory != -1 and img_name in self.img_to_explanation:
158
- if choice == 0: # [SEG] token
159
- answers.append(random.choice(self.answer_list))
160
- elif choice == 1: # [SEG] token + text answer
161
- image_name = image_path.split("/")[-1]
162
- answer = self.img_to_explanation[image_name]["outputs"]
163
- answer = random.choice(self.answer_list) + " {}".format(answer)
164
- questions[-1] = (
165
- DEFAULT_IMAGE_TOKEN
166
- + "\n"
167
- + text
168
- + " {}".format(random.choice(self.explanatory_question_list))
169
- )
170
- answers.append(answer)
171
- elif choice == 2: # vanilla text answer
172
- image_name = image_path.split("/")[-1]
173
- answer = self.img_to_explanation[image_name]["outputs"]
174
- questions[-1] = DEFAULT_IMAGE_TOKEN + "\n" + text
175
- answers.append(answer)
176
- else:
177
- raise ValueError("Not implemented yet.")
178
- else:
179
- answers.append(random.choice(self.answer_list))
180
-
181
- conversations = []
182
- conv = conversation_lib.default_conversation.copy()
183
- roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
184
-
185
- i = 0
186
- while i < len(questions):
187
- conv.messages = []
188
- conv.append_message(conv.roles[0], questions[i])
189
- conv.append_message(conv.roles[1], answers[i])
190
- conversations.append(conv.get_prompt())
191
- i += 1
192
-
193
- image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
194
-
195
- image_name = image_path.split("/")[-1]
196
- if (
197
- self.explanatory != -1
198
- and image_name in self.img_to_explanation
199
- and choice == 2
200
- ):
201
- masks = torch.rand(0, *ori_size)
202
- label = torch.ones(ori_size) * self.ignore_label
203
- else:
204
- masks = np.stack(sampled_masks, axis=0)
205
- masks = torch.from_numpy(masks)
206
- label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
207
-
208
- return (
209
- image_path,
210
- image,
211
- image_clip,
212
- conversations,
213
- masks,
214
- label,
215
- resize,
216
- questions,
217
- sampled_sents,
218
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lisa_on_cuda/utils/refer.py DELETED
@@ -1,391 +0,0 @@
1
- __author__ = "licheng"
2
-
3
- """
4
- This interface provides access to four datasets:
5
- 1) refclef
6
- 2) refcoco
7
- 3) refcoco+
8
- 4) refcocog
9
- split by unc and google
10
-
11
- The following API functions are defined:
12
- REFER - REFER api class
13
- getRefIds - get ref ids that satisfy given filter conditions.
14
- getAnnIds - get ann ids that satisfy given filter conditions.
15
- getImgIds - get image ids that satisfy given filter conditions.
16
- getCatIds - get category ids that satisfy given filter conditions.
17
- loadRefs - load refs with the specified ref ids.
18
- loadAnns - load anns with the specified ann ids.
19
- loadImgs - load images with the specified image ids.
20
- loadCats - load category names with the specified category ids.
21
- getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
22
- showRef - show image, segmentation or box of the referred object with the ref
23
- getMask - get mask and area of the referred object given ref
24
- showMask - show mask of the referred object given ref
25
- """
26
-
27
- import itertools
28
- import json
29
- import os.path as osp
30
- import pickle
31
- import sys
32
- import time
33
- from pprint import pprint
34
-
35
- import matplotlib.pyplot as plt
36
- import numpy as np
37
- import skimage.io as io
38
- from matplotlib.collections import PatchCollection
39
- from matplotlib.patches import Polygon, Rectangle
40
- from pycocotools import mask
41
-
42
-
43
- class REFER:
44
- def __init__(self, data_root, dataset="refcoco", splitBy="unc"):
45
- # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
46
- # also provide dataset name and splitBy information
47
- # e.g., dataset = 'refcoco', splitBy = 'unc'
48
- print("loading dataset %s into memory..." % dataset)
49
- self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
50
- self.DATA_DIR = osp.join(data_root, dataset)
51
- if dataset in ["refcoco", "refcoco+", "refcocog"]:
52
- self.IMAGE_DIR = osp.join(data_root, "images/mscoco/images/train2014")
53
- elif dataset == "refclef":
54
- self.IMAGE_DIR = osp.join(data_root, "images/saiapr_tc-12")
55
- else:
56
- print("No refer dataset is called [%s]" % dataset)
57
- sys.exit()
58
-
59
- self.dataset = dataset
60
-
61
- # load refs from data/dataset/refs(dataset).json
62
- tic = time.time()
63
-
64
- ref_file = osp.join(self.DATA_DIR, "refs(" + splitBy + ").p")
65
- print("ref_file: ", ref_file)
66
- self.data = {}
67
- self.data["dataset"] = dataset
68
- self.data["refs"] = pickle.load(open(ref_file, "rb"))
69
-
70
- # load annotations from data/dataset/instances.json
71
- instances_file = osp.join(self.DATA_DIR, "instances.json")
72
- instances = json.load(open(instances_file, "rb"))
73
- self.data["images"] = instances["images"]
74
- self.data["annotations"] = instances["annotations"]
75
- self.data["categories"] = instances["categories"]
76
-
77
- # create index
78
- self.createIndex()
79
- print("DONE (t=%.2fs)" % (time.time() - tic))
80
-
81
- def createIndex(self):
82
- # create sets of mapping
83
- # 1) Refs: {ref_id: ref}
84
- # 2) Anns: {ann_id: ann}
85
- # 3) Imgs: {image_id: image}
86
- # 4) Cats: {category_id: category_name}
87
- # 5) Sents: {sent_id: sent}
88
- # 6) imgToRefs: {image_id: refs}
89
- # 7) imgToAnns: {image_id: anns}
90
- # 8) refToAnn: {ref_id: ann}
91
- # 9) annToRef: {ann_id: ref}
92
- # 10) catToRefs: {category_id: refs}
93
- # 11) sentToRef: {sent_id: ref}
94
- # 12) sentToTokens: {sent_id: tokens}
95
- print("creating index...")
96
- # fetch info from instances
97
- Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
98
- for ann in self.data["annotations"]:
99
- Anns[ann["id"]] = ann
100
- imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann]
101
- for img in self.data["images"]:
102
- Imgs[img["id"]] = img
103
- for cat in self.data["categories"]:
104
- Cats[cat["id"]] = cat["name"]
105
-
106
- # fetch info from refs
107
- Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
108
- Sents, sentToRef, sentToTokens = {}, {}, {}
109
- for ref in self.data["refs"]:
110
- # ids
111
- ref_id = ref["ref_id"]
112
- ann_id = ref["ann_id"]
113
- category_id = ref["category_id"]
114
- image_id = ref["image_id"]
115
-
116
- # add mapping related to ref
117
- Refs[ref_id] = ref
118
- imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
119
- catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
120
- refToAnn[ref_id] = Anns[ann_id]
121
- annToRef[ann_id] = ref
122
-
123
- # add mapping of sent
124
- for sent in ref["sentences"]:
125
- Sents[sent["sent_id"]] = sent
126
- sentToRef[sent["sent_id"]] = ref
127
- sentToTokens[sent["sent_id"]] = sent["tokens"]
128
-
129
- # create class members
130
- self.Refs = Refs
131
- self.Anns = Anns
132
- self.Imgs = Imgs
133
- self.Cats = Cats
134
- self.Sents = Sents
135
- self.imgToRefs = imgToRefs
136
- self.imgToAnns = imgToAnns
137
- self.refToAnn = refToAnn
138
- self.annToRef = annToRef
139
- self.catToRefs = catToRefs
140
- self.sentToRef = sentToRef
141
- self.sentToTokens = sentToTokens
142
- print("index created.")
143
-
144
- def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=""):
145
- image_ids = image_ids if type(image_ids) == list else [image_ids]
146
- cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
147
- ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
148
-
149
- if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
150
- refs = self.data["refs"]
151
- else:
152
- if not len(image_ids) == 0:
153
- refs = [self.imgToRefs[image_id] for image_id in image_ids]
154
- else:
155
- refs = self.data["refs"]
156
- if not len(cat_ids) == 0:
157
- refs = [ref for ref in refs if ref["category_id"] in cat_ids]
158
- if not len(ref_ids) == 0:
159
- refs = [ref for ref in refs if ref["ref_id"] in ref_ids]
160
- if not len(split) == 0:
161
- if split in ["testA", "testB", "testC"]:
162
- refs = [
163
- ref for ref in refs if split[-1] in ref["split"]
164
- ] # we also consider testAB, testBC, ...
165
- elif split in ["testAB", "testBC", "testAC"]:
166
- refs = [
167
- ref for ref in refs if ref["split"] == split
168
- ] # rarely used I guess...
169
- elif split == "test":
170
- refs = [ref for ref in refs if "test" in ref["split"]]
171
- elif split == "train" or split == "val":
172
- refs = [ref for ref in refs if ref["split"] == split]
173
- else:
174
- print("No such split [%s]" % split)
175
- sys.exit()
176
- ref_ids = [ref["ref_id"] for ref in refs]
177
- return ref_ids
178
-
179
- def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
180
- image_ids = image_ids if type(image_ids) == list else [image_ids]
181
- cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
182
- ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
183
-
184
- if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
185
- ann_ids = [ann["id"] for ann in self.data["annotations"]]
186
- else:
187
- if not len(image_ids) == 0:
188
- lists = [
189
- self.imgToAnns[image_id]
190
- for image_id in image_ids
191
- if image_id in self.imgToAnns
192
- ] # list of [anns]
193
- anns = list(itertools.chain.from_iterable(lists))
194
- else:
195
- anns = self.data["annotations"]
196
- if not len(cat_ids) == 0:
197
- anns = [ann for ann in anns if ann["category_id"] in cat_ids]
198
- ann_ids = [ann["id"] for ann in anns]
199
- if not len(ref_ids) == 0:
200
- ids = set(ann_ids).intersection(
201
- set([self.Refs[ref_id]["ann_id"] for ref_id in ref_ids])
202
- )
203
- return ann_ids
204
-
205
- def getImgIds(self, ref_ids=[]):
206
- ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
207
-
208
- if not len(ref_ids) == 0:
209
- image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids]))
210
- else:
211
- image_ids = self.Imgs.keys()
212
- return image_ids
213
-
214
- def getCatIds(self):
215
- return self.Cats.keys()
216
-
217
- def loadRefs(self, ref_ids=[]):
218
- if type(ref_ids) == list:
219
- return [self.Refs[ref_id] for ref_id in ref_ids]
220
- elif type(ref_ids) == int:
221
- return [self.Refs[ref_ids]]
222
-
223
- def loadAnns(self, ann_ids=[]):
224
- if type(ann_ids) == list:
225
- return [self.Anns[ann_id] for ann_id in ann_ids]
226
- elif type(ann_ids) == int or type(ann_ids) == unicode:
227
- return [self.Anns[ann_ids]]
228
-
229
- def loadImgs(self, image_ids=[]):
230
- if type(image_ids) == list:
231
- return [self.Imgs[image_id] for image_id in image_ids]
232
- elif type(image_ids) == int:
233
- return [self.Imgs[image_ids]]
234
-
235
- def loadCats(self, cat_ids=[]):
236
- if type(cat_ids) == list:
237
- return [self.Cats[cat_id] for cat_id in cat_ids]
238
- elif type(cat_ids) == int:
239
- return [self.Cats[cat_ids]]
240
-
241
- def getRefBox(self, ref_id):
242
- ref = self.Refs[ref_id]
243
- ann = self.refToAnn[ref_id]
244
- return ann["bbox"] # [x, y, w, h]
245
-
246
- def showRef(self, ref, seg_box="seg"):
247
- ax = plt.gca()
248
- # show image
249
- image = self.Imgs[ref["image_id"]]
250
- I = io.imread(osp.join(self.IMAGE_DIR, image["file_name"]))
251
- ax.imshow(I)
252
- # show refer expression
253
- for sid, sent in enumerate(ref["sentences"]):
254
- print("%s. %s" % (sid + 1, sent["sent"]))
255
- # show segmentations
256
- if seg_box == "seg":
257
- ann_id = ref["ann_id"]
258
- ann = self.Anns[ann_id]
259
- polygons = []
260
- color = []
261
- c = "none"
262
- if type(ann["segmentation"][0]) == list:
263
- # polygon used for refcoco*
264
- for seg in ann["segmentation"]:
265
- poly = np.array(seg).reshape((len(seg) / 2, 2))
266
- polygons.append(Polygon(poly, True, alpha=0.4))
267
- color.append(c)
268
- p = PatchCollection(
269
- polygons,
270
- facecolors=color,
271
- edgecolors=(1, 1, 0, 0),
272
- linewidths=3,
273
- alpha=1,
274
- )
275
- ax.add_collection(p) # thick yellow polygon
276
- p = PatchCollection(
277
- polygons,
278
- facecolors=color,
279
- edgecolors=(1, 0, 0, 0),
280
- linewidths=1,
281
- alpha=1,
282
- )
283
- ax.add_collection(p) # thin red polygon
284
- else:
285
- # mask used for refclef
286
- rle = ann["segmentation"]
287
- m = mask.decode(rle)
288
- img = np.ones((m.shape[0], m.shape[1], 3))
289
- color_mask = np.array([2.0, 166.0, 101.0]) / 255
290
- for i in range(3):
291
- img[:, :, i] = color_mask[i]
292
- ax.imshow(np.dstack((img, m * 0.5)))
293
- # show bounding-box
294
- elif seg_box == "box":
295
- ann_id = ref["ann_id"]
296
- ann = self.Anns[ann_id]
297
- bbox = self.getRefBox(ref["ref_id"])
298
- box_plot = Rectangle(
299
- (bbox[0], bbox[1]),
300
- bbox[2],
301
- bbox[3],
302
- fill=False,
303
- edgecolor="green",
304
- linewidth=3,
305
- )
306
- ax.add_patch(box_plot)
307
-
308
- def getMask(self, ref):
309
- # return mask, area and mask-center
310
- ann = self.refToAnn[ref["ref_id"]]
311
- image = self.Imgs[ref["image_id"]]
312
- if type(ann["segmentation"][0]) == list: # polygon
313
- rle = mask.frPyObjects(ann["segmentation"], image["height"], image["width"])
314
- else:
315
- rle = ann["segmentation"]
316
- m = mask.decode(rle)
317
- m = np.sum(
318
- m, axis=2
319
- ) # sometimes there are multiple binary map (corresponding to multiple segs)
320
- m = m.astype(np.uint8) # convert to np.uint8
321
- # compute area
322
- area = sum(mask.area(rle)) # should be close to ann['area']
323
- return {"mask": m, "area": area}
324
- # # position
325
- # position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style)
326
- # position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style)
327
- # # mass position (if there were multiple regions, we use the largest one.)
328
- # label_m = label(m, connectivity=m.ndim)
329
- # regions = regionprops(label_m)
330
- # if len(regions) > 0:
331
- # largest_id = np.argmax(np.array([props.filled_area for props in regions]))
332
- # largest_props = regions[largest_id]
333
- # mass_y, mass_x = largest_props.centroid
334
- # else:
335
- # mass_x, mass_y = position_x, position_y
336
- # # if centroid is not in mask, we find the closest point to it from mask
337
- # if m[mass_y, mass_x] != 1:
338
- # print('Finding closes mask point ...')
339
- # kernel = np.ones((10, 10),np.uint8)
340
- # me = cv2.erode(m, kernel, iterations = 1)
341
- # points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style
342
- # points = np.array(points)
343
- # dist = np.sum((points - (mass_y, mass_x))**2, axis=1)
344
- # id = np.argsort(dist)[0]
345
- # mass_y, mass_x = points[id]
346
- # # return
347
- # return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y}
348
- # # show image and mask
349
- # I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
350
- # plt.figure()
351
- # plt.imshow(I)
352
- # ax = plt.gca()
353
- # img = np.ones( (m.shape[0], m.shape[1], 3) )
354
- # color_mask = np.array([2.0,166.0,101.0])/255
355
- # for i in range(3):
356
- # img[:,:,i] = color_mask[i]
357
- # ax.imshow(np.dstack( (img, m*0.5) ))
358
- # plt.show()
359
-
360
- def showMask(self, ref):
361
- M = self.getMask(ref)
362
- msk = M["mask"]
363
- ax = plt.gca()
364
- ax.imshow(msk)
365
-
366
-
367
- if __name__ == "__main__":
368
- refer = REFER(dataset="refcocog", splitBy="google")
369
- ref_ids = refer.getRefIds()
370
- print(len(ref_ids))
371
-
372
- print(len(refer.Imgs))
373
- print(len(refer.imgToRefs))
374
-
375
- ref_ids = refer.getRefIds(split="train")
376
- print("There are %s training referred objects." % len(ref_ids))
377
-
378
- for ref_id in ref_ids:
379
- ref = refer.loadRefs(ref_id)[0]
380
- if len(ref["sentences"]) < 2:
381
- continue
382
-
383
- pprint(ref)
384
- print("The label is %s." % refer.Cats[ref["category_id"]])
385
- plt.figure()
386
- refer.showRef(ref, seg_box="box")
387
- plt.show()
388
-
389
- # plt.figure()
390
- # refer.showMask(ref)
391
- # plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lisa_on_cuda/utils/refer_seg_dataset.py DELETED
@@ -1,277 +0,0 @@
1
- import os
2
- import random
3
-
4
- import cv2
5
- import numpy as np
6
- import torch
7
- import torch.nn.functional as F
8
- from pycocotools import mask
9
- from transformers import CLIPImageProcessor
10
-
11
- from lisa_on_cuda.llava import conversation as conversation_lib
12
- from lisa_on_cuda.segment_anything.utils.transforms import ResizeLongestSide
13
-
14
- from .grefer import G_REFER
15
- from .refer import REFER
16
- from .utils import ANSWER_LIST, SHORT_QUESTION_LIST
17
-
18
-
19
- class ReferSegDataset(torch.utils.data.Dataset):
20
- pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
21
- pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
22
- img_size = 1024
23
- ignore_label = 255
24
-
25
- def __init__(
26
- self,
27
- base_image_dir,
28
- tokenizer,
29
- vision_tower,
30
- samples_per_epoch=500 * 8 * 2 * 10,
31
- precision: str = "fp32",
32
- image_size: int = 224,
33
- num_classes_per_sample: int = 3,
34
- exclude_val=False,
35
- refer_seg_data="refclef||refcoco||refcoco+||refcocog",
36
- ):
37
- self.exclude_val = exclude_val
38
- self.samples_per_epoch = samples_per_epoch
39
- self.num_classes_per_sample = num_classes_per_sample
40
-
41
- self.base_image_dir = base_image_dir
42
- self.image_size = image_size
43
- self.tokenizer = tokenizer
44
- self.precision = precision
45
- self.transform = ResizeLongestSide(image_size)
46
- self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
47
-
48
- self.short_question_list = SHORT_QUESTION_LIST
49
- self.answer_list = ANSWER_LIST
50
-
51
- DATA_DIR = os.path.join(base_image_dir, "refer_seg")
52
- self.refer_seg_ds_list = refer_seg_data.split(
53
- "||"
54
- ) # ['refclef', 'refcoco', 'refcoco+', 'refcocog']
55
- self.refer_seg_data = {}
56
- for ds in self.refer_seg_ds_list:
57
- if ds == "refcocog":
58
- splitBy = "umd"
59
- else:
60
- splitBy = "unc"
61
-
62
- if ds == "grefcoco":
63
- refer_api = G_REFER(DATA_DIR, ds, splitBy)
64
- else:
65
- refer_api = REFER(DATA_DIR, ds, splitBy)
66
- ref_ids_train = refer_api.getRefIds(split="train")
67
- images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train)
68
- refs_train = refer_api.loadRefs(ref_ids=ref_ids_train)
69
-
70
- refer_seg_ds = {}
71
- refer_seg_ds["images"] = []
72
- loaded_images = refer_api.loadImgs(image_ids=images_ids_train)
73
-
74
- for item in loaded_images:
75
- item = item.copy()
76
- if ds == "refclef":
77
- item["file_name"] = os.path.join(
78
- DATA_DIR, "images/saiapr_tc-12", item["file_name"]
79
- )
80
- else:
81
- item["file_name"] = os.path.join(
82
- DATA_DIR, "images/mscoco/images/train2014", item["file_name"]
83
- )
84
- refer_seg_ds["images"].append(item)
85
- refer_seg_ds["annotations"] = refer_api.Anns # anns_train
86
-
87
- print(
88
- "dataset {} (refs {}) (train split) has {} images and {} annotations.".format(
89
- ds,
90
- splitBy,
91
- len(refer_seg_ds["images"]),
92
- len(refer_seg_ds["annotations"]),
93
- )
94
- )
95
-
96
- img2refs = {}
97
- for ref in refs_train:
98
- image_id = ref["image_id"]
99
- img2refs[image_id] = img2refs.get(image_id, []) + [
100
- ref,
101
- ]
102
- refer_seg_ds["img2refs"] = img2refs
103
- self.refer_seg_data[ds] = refer_seg_ds
104
-
105
- def __len__(self):
106
- return self.samples_per_epoch
107
-
108
- def preprocess(self, x: torch.Tensor) -> torch.Tensor:
109
- """Normalize pixel values and pad to a square input."""
110
- # Normalize colors
111
- x = (x - self.pixel_mean) / self.pixel_std
112
-
113
- # Pad
114
- h, w = x.shape[-2:]
115
- padh = self.img_size - h
116
- padw = self.img_size - w
117
- x = F.pad(x, (0, padw, 0, padh))
118
- return x
119
-
120
- def __getitem__(self, idx):
121
- ds = random.randint(0, len(self.refer_seg_ds_list) - 1)
122
- ds = self.refer_seg_ds_list[ds]
123
- refer_seg_ds = self.refer_seg_data[ds]
124
- images = refer_seg_ds["images"]
125
- annotations = refer_seg_ds["annotations"]
126
- img2refs = refer_seg_ds["img2refs"]
127
- idx = random.randint(0, len(images) - 1)
128
- image_info = images[idx]
129
- image_path = image_info["file_name"]
130
- image_id = image_info["id"]
131
- refs = img2refs[image_id]
132
- if len(refs) == 0:
133
- return self.__getitem__(0)
134
-
135
- sents = []
136
- ann_ids = []
137
- for ref in refs:
138
- for sent in ref["sentences"]:
139
- text = sent["sent"]
140
- sents.append(text)
141
- ann_ids.append(ref["ann_id"])
142
- if len(sents) >= self.num_classes_per_sample:
143
- sampled_inds = np.random.choice(
144
- list(range(len(sents))), size=self.num_classes_per_sample, replace=False
145
- )
146
- else:
147
- sampled_inds = list(range(len(sents)))
148
- sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist()
149
- # sampled_ann_ids = np.vectorize(ann_ids.__getitem__)(sampled_inds).tolist()
150
- sampled_ann_ids = [ann_ids[ind] for ind in sampled_inds]
151
- sampled_classes = sampled_sents
152
- image = cv2.imread(image_path)
153
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
154
-
155
- # preprocess image for clip
156
- image_clip = self.clip_image_processor.preprocess(image, return_tensors="pt")[
157
- "pixel_values"
158
- ][0]
159
-
160
- image = self.transform.apply_image(image) # preprocess image for sam
161
- resize = image.shape[:2]
162
-
163
- questions = []
164
- answers = []
165
- for text in sampled_classes:
166
- text = text.strip()
167
- assert len(text.split("||")) == 1
168
- question_template = random.choice(self.short_question_list)
169
- questions.append(question_template.format(class_name=text.lower()))
170
- answers.append(random.choice(self.answer_list))
171
-
172
- conversations = []
173
- conv = conversation_lib.default_conversation.copy()
174
-
175
- i = 0
176
- while i < len(questions):
177
- conv.messages = []
178
- conv.append_message(conv.roles[0], questions[i])
179
- conv.append_message(conv.roles[1], answers[i])
180
- conversations.append(conv.get_prompt())
181
- i += 1
182
-
183
- image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
184
-
185
- flag = False
186
- masks = []
187
- for ann_id in sampled_ann_ids:
188
- if isinstance(ann_id, list):
189
- flag = True
190
- if -1 in ann_id:
191
- assert len(ann_id) == 1
192
- m = np.zeros((image_info["height"], image_info["width"])).astype(
193
- np.uint8
194
- )
195
- else:
196
- m_final = np.zeros(
197
- (image_info["height"], image_info["width"])
198
- ).astype(np.uint8)
199
- for ann_id_i in ann_id:
200
- ann = annotations[ann_id_i]
201
-
202
- if len(ann["segmentation"]) == 0:
203
- m = np.zeros(
204
- (image_info["height"], image_info["width"])
205
- ).astype(np.uint8)
206
- else:
207
- if type(ann["segmentation"][0]) == list: # polygon
208
- rle = mask.frPyObjects(
209
- ann["segmentation"],
210
- image_info["height"],
211
- image_info["width"],
212
- )
213
- else:
214
- rle = ann["segmentation"]
215
- for i in range(len(rle)):
216
- if not isinstance(rle[i]["counts"], bytes):
217
- rle[i]["counts"] = rle[i]["counts"].encode()
218
- m = mask.decode(rle)
219
- m = np.sum(
220
- m, axis=2
221
- ) # sometimes there are multiple binary map (corresponding to multiple segs)
222
- m = m.astype(np.uint8) # convert to np.uint8
223
- m_final = m_final | m
224
- m = m_final
225
- masks.append(m)
226
- continue
227
-
228
- ann = annotations[ann_id]
229
-
230
- if len(ann["segmentation"]) == 0:
231
- m = np.zeros((image_info["height"], image_info["width"])).astype(
232
- np.uint8
233
- )
234
- masks.append(m)
235
- continue
236
-
237
- if type(ann["segmentation"][0]) == list: # polygon
238
- rle = mask.frPyObjects(
239
- ann["segmentation"], image_info["height"], image_info["width"]
240
- )
241
- else:
242
- rle = ann["segmentation"]
243
- for i in range(len(rle)):
244
- if not isinstance(rle[i]["counts"], bytes):
245
- rle[i]["counts"] = rle[i]["counts"].encode()
246
- m = mask.decode(rle)
247
- m = np.sum(
248
- m, axis=2
249
- ) # sometimes there are multiple binary map (corresponding to multiple segs)
250
- m = m.astype(np.uint8) # convert to np.uint8
251
- masks.append(m)
252
-
253
- masks = np.stack(masks, axis=0)
254
-
255
- # if ds == 'grefcoco' and flag:
256
- # import shutil
257
- # image_name = image_path.split("/")[-1]
258
- # save_dir = os.path.join("/group/30042/xlai/LISA_refactor_final/debug", image_name.split(".")[0])
259
- # os.makedirs(save_dir, exist_ok=True)
260
- # shutil.copy(image_path, save_dir)
261
- # for i in range(masks.shape[0]):
262
- # cv2.imwrite(os.path.join(save_dir, "{}_{}_{}.jpg".format(image_name, i, sampled_classes[i])), masks[i].astype(np.int32) * 100)
263
-
264
- masks = torch.from_numpy(masks)
265
- label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
266
-
267
- return (
268
- image_path,
269
- image,
270
- image_clip,
271
- conversations,
272
- masks,
273
- label,
274
- resize,
275
- questions,
276
- sampled_classes,
277
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lisa_on_cuda/utils/sem_seg_dataset.py DELETED
@@ -1,335 +0,0 @@
1
- import glob
2
- import json
3
- import os
4
- import random
5
-
6
- import cv2
7
- import numpy as np
8
- import torch
9
- import torch.nn.functional as F
10
- from PIL import Image
11
- from pycocotools.coco import COCO
12
- from transformers import CLIPImageProcessor
13
-
14
- from lisa_on_cuda.llava import conversation as conversation_lib
15
- from lisa_on_cuda.segment_anything.utils.transforms import ResizeLongestSide
16
-
17
- from .utils import ANSWER_LIST, SHORT_QUESTION_LIST
18
-
19
-
20
- def init_mapillary(base_image_dir):
21
- mapillary_data_root = os.path.join(base_image_dir, "mapillary")
22
- with open(os.path.join(mapillary_data_root, "config_v2.0.json")) as f:
23
- mapillary_classes = json.load(f)["labels"]
24
- mapillary_classes = [x["readable"].lower() for x in mapillary_classes]
25
- mapillary_classes = np.array(mapillary_classes)
26
- mapillary_labels = sorted(
27
- glob.glob(
28
- os.path.join(mapillary_data_root, "training", "v2.0", "labels", "*.png")
29
- )
30
- )
31
- mapillary_images = [
32
- x.replace(".png", ".jpg").replace("v2.0/labels", "images")
33
- for x in mapillary_labels
34
- ]
35
- print("mapillary: ", len(mapillary_images))
36
- return mapillary_classes, mapillary_images, mapillary_labels
37
-
38
-
39
- def init_ade20k(base_image_dir):
40
- with open("utils/ade20k_classes.json", "r") as f:
41
- ade20k_classes = json.load(f)
42
- ade20k_classes = np.array(ade20k_classes)
43
- image_ids = sorted(
44
- os.listdir(os.path.join(base_image_dir, "ade20k/images", "training"))
45
- )
46
- ade20k_image_ids = []
47
- for x in image_ids:
48
- if x.endswith(".jpg"):
49
- ade20k_image_ids.append(x[:-4])
50
- ade20k_images = []
51
- for image_id in ade20k_image_ids: # self.descriptions:
52
- ade20k_images.append(
53
- os.path.join(
54
- base_image_dir,
55
- "ade20k",
56
- "images",
57
- "training",
58
- "{}.jpg".format(image_id),
59
- )
60
- )
61
- ade20k_labels = [
62
- x.replace(".jpg", ".png").replace("images", "annotations")
63
- for x in ade20k_images
64
- ]
65
- print("ade20k: ", len(ade20k_images))
66
- return ade20k_classes, ade20k_images, ade20k_labels
67
-
68
-
69
- def init_cocostuff(base_image_dir):
70
- cocostuff_classes = []
71
- with open("utils/cocostuff_classes.txt") as f:
72
- for line in f.readlines()[1:]:
73
- cocostuff_classes.append(line.strip().split(": ")[-1])
74
- cocostuff_classes = np.array(cocostuff_classes)
75
- cocostuff_images = []
76
-
77
- cocostuff_labels = glob.glob(
78
- os.path.join(base_image_dir, "cocostuff", "train2017", "*.png")
79
- )
80
- cocostuff_images = [
81
- x.replace(".png", ".jpg").replace("cocostuff", "coco") for x in cocostuff_labels
82
- ]
83
-
84
- print("cocostuff: ", len(cocostuff_images))
85
- return cocostuff_classes, cocostuff_images, cocostuff_labels
86
-
87
-
88
- def init_paco_lvis(base_image_dir):
89
- coco_api_paco_lvis = COCO(
90
- os.path.join(
91
- base_image_dir, "vlpart", "paco", "annotations", "paco_lvis_v1_train.json"
92
- )
93
- )
94
- all_classes = coco_api_paco_lvis.loadCats(coco_api_paco_lvis.getCatIds())
95
- class_map_paco_lvis = {}
96
- for cat in all_classes:
97
- cat_split = cat["name"].strip().split(":")
98
- if len(cat_split) == 1:
99
- name = cat_split[0].split("_(")[0]
100
- else:
101
- assert len(cat_split) == 2
102
- obj, part = cat_split
103
- obj = obj.split("_(")[0]
104
- part = part.split("_(")[0]
105
- name = (obj, part)
106
- class_map_paco_lvis[cat["id"]] = name
107
- img_ids = coco_api_paco_lvis.getImgIds()
108
- print("paco_lvis: ", len(img_ids))
109
- return class_map_paco_lvis, img_ids, coco_api_paco_lvis
110
-
111
-
112
- def init_pascal_part(base_image_dir):
113
- coco_api_pascal_part = COCO(
114
- os.path.join(base_image_dir, "vlpart", "pascal_part", "train.json")
115
- )
116
- all_classes = coco_api_pascal_part.loadCats(coco_api_pascal_part.getCatIds())
117
- class_map_pascal_part = {}
118
- for cat in all_classes:
119
- cat_main, cat_part = cat["name"].strip().split(":")
120
- name = (cat_main, cat_part)
121
- class_map_pascal_part[cat["id"]] = name
122
- img_ids = coco_api_pascal_part.getImgIds()
123
- print("pascal_part: ", len(img_ids))
124
- return class_map_pascal_part, img_ids, coco_api_pascal_part
125
-
126
-
127
- class SemSegDataset(torch.utils.data.Dataset):
128
- pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
129
- pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
130
- img_size = 1024
131
- ignore_label = 255
132
-
133
- def __init__(
134
- self,
135
- base_image_dir,
136
- tokenizer,
137
- vision_tower,
138
- samples_per_epoch=500 * 8 * 2 * 10,
139
- precision: str = "fp32",
140
- image_size: int = 224,
141
- num_classes_per_sample: int = 3,
142
- exclude_val=False,
143
- sem_seg_data="ade20k||cocostuff||partimagenet||pascal_part||paco_lvis||mapillary",
144
- ):
145
- self.exclude_val = exclude_val
146
- self.samples_per_epoch = samples_per_epoch
147
- self.num_classes_per_sample = num_classes_per_sample
148
-
149
- self.base_image_dir = base_image_dir
150
- self.image_size = image_size
151
- self.tokenizer = tokenizer
152
- self.precision = precision
153
- self.transform = ResizeLongestSide(image_size)
154
- self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
155
-
156
- self.short_question_list = SHORT_QUESTION_LIST
157
- self.answer_list = ANSWER_LIST
158
-
159
- self.data2list = {}
160
- self.data2classes = {}
161
-
162
- self.sem_seg_datas = sem_seg_data.split("||")
163
- for ds in self.sem_seg_datas:
164
- classes, images, labels = eval("init_{}".format(ds))(base_image_dir)
165
- self.data2list[ds] = (images, labels)
166
- self.data2classes[ds] = classes
167
-
168
- if "cocostuff" in self.sem_seg_datas:
169
- self.cocostuff_class2index = {
170
- c: i for i, c in enumerate(self.data2classes["cocostuff"])
171
- }
172
-
173
- def __len__(self):
174
- return self.samples_per_epoch
175
-
176
- def preprocess(self, x: torch.Tensor) -> torch.Tensor:
177
- """Normalize pixel values and pad to a square input."""
178
- # Normalize colors
179
- x = (x - self.pixel_mean) / self.pixel_std
180
-
181
- # Pad
182
- h, w = x.shape[-2:]
183
- padh = self.img_size - h
184
- padw = self.img_size - w
185
- x = F.pad(x, (0, padw, 0, padh))
186
- return x
187
-
188
- def __getitem__(self, idx):
189
- ds = random.randint(0, len(self.sem_seg_datas) - 1)
190
- ds = self.sem_seg_datas[ds]
191
-
192
- if ds in ["paco_lvis", "pascal_part"]:
193
- class_map = self.data2classes[ds]
194
- img_ids, coco_api = self.data2list[ds]
195
- idx = random.randint(0, len(img_ids) - 1)
196
- img_id = img_ids[idx]
197
- image_info = coco_api.loadImgs([img_id])[0]
198
- file_name = image_info["file_name"]
199
- if ds == "pascal_part":
200
- file_name = os.path.join(
201
- "VOCdevkit", "VOC2010", "JPEGImages", file_name
202
- )
203
- image_path = os.path.join(self.base_image_dir, "vlpart", ds, file_name)
204
- elif ds == "paco_lvis":
205
- image_path = os.path.join(self.base_image_dir, "coco", file_name)
206
- image = cv2.imread(image_path)
207
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
208
-
209
- # preprocess image for clip
210
- image_clip = self.clip_image_processor.preprocess(
211
- image, return_tensors="pt"
212
- )["pixel_values"][0]
213
- image = self.transform.apply_image(image) # preprocess image for sam
214
- resize = image.shape[:2]
215
- annIds = coco_api.getAnnIds(imgIds=image_info["id"])
216
- anns = coco_api.loadAnns(annIds)
217
- if len(anns) == 0:
218
- return self.__getitem__(0)
219
- if len(anns) >= self.num_classes_per_sample:
220
- sampled_anns = np.random.choice(
221
- anns, size=self.num_classes_per_sample, replace=False
222
- ).tolist()
223
- else:
224
- sampled_anns = anns
225
- sampled_classes = []
226
- for ann in sampled_anns:
227
- sampled_cls = class_map[ann["category_id"]]
228
- if isinstance(sampled_cls, tuple):
229
- obj, part = sampled_cls
230
- if random.random() < 0.5:
231
- name = obj + " " + part
232
- else:
233
- name = "the {} of the {}".format(part, obj)
234
- else:
235
- name = sampled_cls
236
- sampled_classes.append(name)
237
-
238
- elif ds in ["ade20k", "cocostuff", "mapillary"]:
239
- image, labels = self.data2list[ds]
240
- idx = random.randint(0, len(image) - 1)
241
- image_path = image[idx]
242
- label_path = labels[idx]
243
- label = Image.open(label_path)
244
- label = np.array(label)
245
- if ds == "ade20k":
246
- label[label == 0] = 255
247
- label -= 1
248
- label[label == 254] = 255
249
- elif ds == "cocostuff":
250
- for c, i in self.cocostuff_class2index.items():
251
- if "-" in c:
252
- label[label == i] = 255
253
- img = cv2.imread(image_path)
254
- image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
255
- # preprocess image for clip
256
- image_clip = self.clip_image_processor.preprocess(
257
- image, return_tensors="pt"
258
- )["pixel_values"][0]
259
- image = self.transform.apply_image(image) # preprocess image for sam
260
- resize = image.shape[:2]
261
- unique_label = np.unique(label).tolist()
262
- if 255 in unique_label:
263
- unique_label.remove(255)
264
- if len(unique_label) == 0:
265
- return self.__getitem__(0)
266
-
267
- classes = [self.data2classes[ds][class_id] for class_id in unique_label]
268
- if len(classes) >= self.num_classes_per_sample:
269
- sampled_classes = np.random.choice(
270
- classes, size=self.num_classes_per_sample, replace=False
271
- ).tolist()
272
- else:
273
- sampled_classes = classes
274
-
275
- questions = []
276
- answers = []
277
- class_ids = []
278
- for sampled_cls in sampled_classes:
279
- text = sampled_cls
280
-
281
- assert len(text.split("||")) == 1
282
- question_template = random.choice(self.short_question_list)
283
- questions.append(question_template.format(class_name=text.lower()))
284
-
285
- answers.append(random.choice(self.answer_list))
286
-
287
- if ds in ["paco_lvis", "pascal_part"]:
288
- continue
289
-
290
- class_id = self.data2classes[ds].tolist().index(sampled_cls)
291
- class_ids.append(class_id)
292
-
293
- conversations = []
294
- conv = conversation_lib.default_conversation.copy()
295
-
296
- i = 0
297
- while i < len(questions):
298
- conv.messages = []
299
- conv.append_message(conv.roles[0], questions[i])
300
- conv.append_message(conv.roles[1], answers[i])
301
- conversations.append(conv.get_prompt())
302
- i += 1
303
-
304
- image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
305
-
306
- if ds in ["paco_lvis", "pascal_part"]:
307
- masks = []
308
- for ann in sampled_anns:
309
- try:
310
- masks.append(coco_api.annToMask(ann))
311
- except Exception as e:
312
- print(e)
313
- return self.__getitem__(0)
314
-
315
- masks = np.stack(masks, axis=0)
316
- masks = torch.from_numpy(masks)
317
- label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
318
-
319
- else:
320
- label = torch.from_numpy(label).long()
321
- masks = []
322
- for class_id in class_ids:
323
- masks.append(label == class_id)
324
- masks = torch.stack(masks, dim=0)
325
- return (
326
- image_path,
327
- image,
328
- image_clip,
329
- conversations,
330
- masks,
331
- label,
332
- resize,
333
- questions,
334
- sampled_classes,
335
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lisa_on_cuda/utils/session_logger.py DELETED
@@ -1,36 +0,0 @@
1
- import contextvars
2
- import logging
3
- from functools import wraps
4
- from typing import Callable
5
-
6
- logging_uuid = contextvars.ContextVar("uuid")
7
- formatter = '%(asctime)s | %(uuid)s [%(pathname)s:%(module)s %(lineno)d] %(levelname)s | %(message)s'
8
-
9
-
10
- loggingType = logging.CRITICAL | logging.ERROR | logging.WARNING | logging.INFO | logging.DEBUG
11
-
12
-
13
- def change_logging(level_log: loggingType = logging.INFO) -> None:
14
- old_factory = logging.getLogRecordFactory()
15
-
16
- def record_factory(*args, **kwargs):
17
- record = old_factory(*args, **kwargs)
18
- record.uuid = logging_uuid.get("uuid")
19
- if isinstance(record.msg, str):
20
- record.msg = record.msg.replace("\\", "\\\\").replace("\n", "\\n")
21
- return record
22
-
23
- logging.setLogRecordFactory(record_factory)
24
- logging.basicConfig(level=level_log, format=formatter, force=True)
25
-
26
-
27
- def set_uuid_logging(func: Callable) -> Callable:
28
- @wraps(func)
29
- def wrapper(*args, **kwargs):
30
- import uuid
31
-
32
- current_uuid = f"{uuid.uuid4()}"
33
- logging_uuid.set(current_uuid)
34
- return func(*args, **kwargs)
35
-
36
- return wrapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lisa_on_cuda/utils/vqa_dataset.py DELETED
@@ -1,135 +0,0 @@
1
- import json
2
- import os
3
- import random
4
-
5
- import cv2
6
- import torch
7
- import torch.nn.functional as F
8
- from transformers import CLIPImageProcessor
9
-
10
- from lisa_on_cuda.llava import conversation as conversation_lib
11
- from lisa_on_cuda.segment_anything.utils.transforms import ResizeLongestSide
12
-
13
- from .utils import DEFAULT_IMAGE_TOKEN
14
-
15
-
16
- def preprocess_multimodal(source, mm_use_im_start_end):
17
- for sentence in source:
18
- if DEFAULT_IMAGE_TOKEN in sentence["value"]:
19
- sentence["value"] = (
20
- sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
21
- )
22
- sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
23
- sentence["value"] = sentence["value"].strip()
24
- if "mmtag" in conversation_lib.default_conversation.version:
25
- sentence["value"] = sentence["value"].replace(
26
- DEFAULT_IMAGE_TOKEN, "<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>"
27
- )
28
- return source
29
-
30
-
31
- class VQADataset(torch.utils.data.Dataset):
32
- pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
33
- pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
34
- img_size = 1024
35
- ignore_label = 255
36
-
37
- def __init__(
38
- self,
39
- base_image_dir,
40
- tokenizer,
41
- vision_tower,
42
- samples_per_epoch=500 * 8 * 2 * 10,
43
- precision: str = "fp32",
44
- image_size: int = 224,
45
- num_classes_per_sample: int = 3,
46
- exclude_val=False,
47
- vqa_data="llava_instruct_150k",
48
- ):
49
- self.exclude_val = exclude_val
50
- self.samples_per_epoch = samples_per_epoch
51
- self.num_classes_per_sample = num_classes_per_sample
52
-
53
- self.base_image_dir = base_image_dir
54
- self.image_size = image_size
55
- self.tokenizer = tokenizer
56
- self.precision = precision
57
- self.transform = ResizeLongestSide(image_size)
58
- self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
59
-
60
- DATA_DIR = os.path.join(base_image_dir, "llava_dataset")
61
- self.vqa_image_root = os.path.join(base_image_dir, "coco/train2017")
62
- with open(os.path.join(DATA_DIR, "{}.json".format(vqa_data))) as f:
63
- vqa_data = json.load(f)
64
- self.vqa_data = vqa_data
65
-
66
- print("vqa_data: ", len(self.vqa_data))
67
-
68
- def __len__(self):
69
- return self.samples_per_epoch
70
-
71
- def preprocess(self, x: torch.Tensor) -> torch.Tensor:
72
- """Normalize pixel values and pad to a square input."""
73
- # Normalize colors
74
- x = (x - self.pixel_mean) / self.pixel_std
75
-
76
- # Pad
77
- h, w = x.shape[-2:]
78
- padh = self.img_size - h
79
- padw = self.img_size - w
80
- x = F.pad(x, (0, padw, 0, padh))
81
- return x
82
-
83
- def __getitem__(self, idx):
84
- idx = random.randint(0, len(self.vqa_data) - 1)
85
- item = self.vqa_data[idx]
86
- image_path = os.path.join(self.vqa_image_root, item["image"])
87
- image = cv2.imread(image_path)
88
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
89
- ori_size = image.shape[:2]
90
- image_clip = self.clip_image_processor.preprocess(image, return_tensors="pt")[
91
- "pixel_values"
92
- ][
93
- 0
94
- ] # preprocess image for clip
95
-
96
- image = self.transform.apply_image(image) # preprocess image for sam
97
- resize = image.shape[:2]
98
-
99
- conv = conversation_lib.default_conversation.copy()
100
- source = item["conversations"]
101
- source = preprocess_multimodal(
102
- source,
103
- mm_use_im_start_end=conv.sep_style == conversation_lib.SeparatorStyle.TWO,
104
- )
105
- roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
106
- conversations = []
107
- if roles[source[0]["from"]] != conv.roles[0]:
108
- # Skip the first one if it is not from human
109
- source = source[1:]
110
- conv.messages = []
111
- for j, sentence in enumerate(source):
112
- role = roles[sentence["from"]]
113
- assert role == conv.roles[j % 2], f"{i}"
114
- conv.append_message(role, sentence["value"])
115
- conversations.append(conv.get_prompt())
116
-
117
- questions = conversations
118
- sampled_classes = conversations
119
-
120
- image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
121
-
122
- masks = torch.rand(0, *ori_size)
123
- label = torch.ones(ori_size) * self.ignore_label
124
-
125
- return (
126
- image_path,
127
- image,
128
- image_clip,
129
- conversations,
130
- masks,
131
- label,
132
- resize,
133
- questions,
134
- sampled_classes,
135
- )