File size: 12,227 Bytes
947767a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
from collections import OrderedDict
import torch
import os
import copy
from dataclasses import dataclass
import json
import re
from typing import Dict, Optional, Sequence


import transformers

from llava.constants import (
    IGNORE_INDEX,
)
from torch.utils.data import Dataset

from llava.util.tokenization import (
    preprocess_llama_2,
    preprocess_llama_2_obj_identifier,
    preprocess_multimodal,
    preprocess,
)

from llava import conversation as conversation_lib


class ObjIdentifierDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(
        self,
        tokenizer: transformers.PreTrainedTokenizer,
        data_path: str | list,
        scene_to_obj_mapping: str,
        obj_context_feature_type: str = "text",
        mode: str = "train",
        **kwargs,
    ):
        super(ObjIdentifierDataset, self).__init__()

        self.tokenizer = tokenizer
        self.scene_to_obj_mapping = json.load(open(scene_to_obj_mapping, "r"))
        self.update_data(data_path)
        self.obj_context_feature_type = obj_context_feature_type
        self.mode = mode

    def __len__(self):
        return len(self.list_data_dict)

    def update_data(self, data_path: str):
        assert self.scene_to_obj_mapping is not None, "scene_to_obj_mapping needs to be set first."
        if isinstance(data_path, str):
            self.list_data_dict = json.load(open(data_path, "r"))
        elif isinstance(data_path, list):
            self.list_data_dict = data_path

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        sources = copy.deepcopy(self.list_data_dict[i])
        if isinstance(i, int):
            sources = [sources]
        assert len(sources) == 1, "Don't know why it is wrapped to a list"  # FIXME

        ############
        scene_id = sources[0]["scene_id"]
        input_obj_dict = copy.deepcopy(self.scene_to_obj_mapping[scene_id])

        # prepare the object-centric features
        # we want the LLM to see:
        # "%%%% Object-centric context: <obj_0>: <obj_0_feat>, <obj_1>: <obj_1_feat>, ..."
        # where <obj_i_feat> will later be replaced by the actual feature in vector form,
        # everything else is pure text string.
        # 1. We need to first change the object_id to a new object_id, e.g., 'obj_0', 'obj_1', ...,
        #   and replace the old object_id with new object_id in the text conversation
        # 2. Tokenize the conversation, and add object context to the tokenized conversation
        # 3. Gather and return the necessary information for each object,
        #   so that it can be later embeded into vector

        # 1. change the object_id to a new object_id
        # original object_id: 'wardrobe-0', 'three-seat/multi-seat sofa-1', ...
        # convert to obj_id: 'obj_0', 'obj_1', ...
        # and remember the mapping
        old_id_to_new_id_mapping = {}
        result_obj_dict = OrderedDict()
        # first pass, map the old object_id to new object_id
        for old_id, obj_info in input_obj_dict.items():
            # make sure old_id doesn't contain < or >
            assert (
                "<" not in old_id and ">" not in old_id
            ), "object_id in scene graph should not contain < or >"
            new_id = f"obj_{len(old_id_to_new_id_mapping)}"
            old_id_to_new_id_mapping[old_id] = new_id
        # second pass, create the new object-centric context, modify the object_id in the text content
        for old_id, obj_info in input_obj_dict.items():
            new_id = old_id_to_new_id_mapping[old_id]
            # TODO: Determine what information to include in the object-centric context
            result_obj_info_dict = {}
            result_obj_info_dict["category"] = obj_info["category"]
            # result_obj_info_dict["category_id"] = obj_info["category_id"]
            # for relations, we need to replace the old object_id with new object_id
            # result_obj_info_dict["relations"] = []
            # for relation in obj_info["relations"]:
            #     for local_old_id, local_new_id in old_id_to_new_id_mapping.items():
            #         if local_old_id in relation:
            #             result_obj_info_dict["relations"].append(
            #                 re.sub(rf"<{local_old_id}>", f"<{local_new_id}>", relation)
            #             )
            if "description" in obj_info:
                result_obj_info_dict["description"] = obj_info["description"]
            else:
                # print(f"WARNING: Object {old_id} does not have a description.")
                pass

            # use two decimal places for the centroid and extent
            result_obj_info_dict["centroid"] = (
                f"[{obj_info['centroid'][0]:.2f}, {obj_info['centroid'][1]:.2f}, {obj_info['centroid'][2]:.2f}]"
            )
            result_obj_info_dict["extent"] = (
                f"[{obj_info['extent'][0]:.2f}, {obj_info['extent'][1]:.2f}, {obj_info['extent'][2]:.2f}]"
            )
            result_obj_dict[new_id] = result_obj_info_dict

        # replace the old object_id with new object_id in the text content
        # text conversation example:
        #     {
        #     "id": "55f2b905-d367-443d-8f88-ef71b958c81f@LivingRoom-3973@1",
        #     "scene_id": "55f2b905-d367-443d-8f88-ef71b958c81f@LivingRoom-3973",
        #     "conversations": [
        #         {
        #             "from": "human",
        #             "value": "Can you describe the ambiance of this room?"
        #         },
        #         {
        #             "from": "gpt",
        #             "value": "In this Living Room, the arrangement of furniture caters to both style and function. The <p>warm wooden hue wardrobe</p>[<wardrobe-0>] stands with a retro flair, while the <p>neutral grey, sleek rectangular form multi-seat sofa</p>[<three-seat/multi-seat sofa-1>] and <p>neutral grey, sleek rectangular form three-seat</p>[<three-seat/multi-seat sofa-8>] provide modern and comfortable seating options. The <p>sleek black, dark grey, brown, rectangular coffee table</p>[<coffee table-2>] and <p>sleek black, dark grey, brown coffee table</p>[<coffee table-3>] in minimalist style serve as focal points and functional pieces for gatherings. The <p>light grey and dark grey (two-tone), shell-like backrest, smooth armchair</p>[<armchair-4>] and <p>light grey and dark grey (two-tone), smooth armchair</p>[<armchair-5>] add additional seating, complemented by the <p>rich walnut brown top and contrasting light grey base side table</p>[<corner/side table-6>] for convenience. Suspended above, the <p>gradient of grey to bronze hues, cylindrical with abstract cityscape cutouts pendant lamp</p>[<pendant lamp-7>] offers a decorative element with its unique Chinoiserie design. The room's setup is ideal for hosting guests or enjoying quiet evenings, with thoughtful placement of each piece to enhance the living experience."
        #         }
        #     ]
        # },
        for conv in sources[0]["conversations"]:
            for old_id, new_id in old_id_to_new_id_mapping.items():
                conv["value"] = re.sub(rf"<{old_id}>", f"<{new_id}>", conv["value"])

        # if in generate mode, shave off the last conversation if it is from the assistant
        if self.mode == "generate" and sources[0]["conversations"][-1]["from"] == "gpt":
            sources[0]["conversations"] = sources[0]["conversations"][:-1]

        # 2. Tokenize the conversation, and add object context to the tokenized conversation
        sources = preprocess_multimodal(
            copy.deepcopy([e["conversations"] for e in sources]),
            is_multimodal=True,
            mm_use_im_start_end=False,
        )
        data_dict = preprocess_llama_2_obj_identifier(
            sources=sources,
            tokenizer=self.tokenizer,
            obj_dict=result_obj_dict,
            obj_context_feature_type=self.obj_context_feature_type,
            mode=self.mode,
        )

        # if in generate mode, add the obj context and bbox label to the data_dict
        # so that we can use them later to compute the metrics
        if self.mode == "generate":
            data_dict["obj_context"] = result_obj_dict
            if "bbox" in self.list_data_dict[i]:
                data_dict["bbox_label"] = self.list_data_dict[i]["bbox"]
            # full_info_dict is the full information of this data sample
            # {'id': 'scene0643_00$desk-0@0', 'scene_id': 'scene0643_00', 'conversations': [{...}],
            # 'referred_object_id': '0', 'referred_object_text': 'desk',
            # 'grounded_object_reference': 'a brown wooden office desk on the left to the gray shelf.',
            # 'bbox': [0.3769365990161897, -0.06906220592784873, -0.020513275327205656, 1.1370925301275254, 1.5355764355778696, 0.8130822017173767]
            # }
            data_dict["full_info_dict"] = self.list_data_dict[i]

        return data_dict


@dataclass
class DataCollatorForObjIdentifierDataset(object):
    """Collate examples for supervised fine-tuning."""

    def __init__(self, tokenizer: transformers.PreTrainedTokenizer, **kwargs):
        self.tokenizer = tokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple(
            [instance[key] for instance in instances] for key in ("input_ids", "labels")
        )
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=IGNORE_INDEX
        )
        input_ids = input_ids[:, : self.tokenizer.model_max_length]
        labels = labels[:, : self.tokenizer.model_max_length]
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )

        return batch


@dataclass
class DataCollatorForBatchDecodingObjIdentifierDataset(object):
    """Collate examples for batch decoding."""

    def __init__(self, tokenizer: transformers.PreTrainedTokenizer, **kwargs):
        self.tokenizer = tokenizer

    def pad_sequence(self, input_ids, batch_first, padding_value):
        if self.tokenizer.padding_side == "right":
            input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=batch_first, padding_value=padding_value
        )
        if self.tokenizer.padding_side == "right":
            input_ids = torch.flip(input_ids, [1])
        return input_ids

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids = [instance["input_ids"] for instance in instances]
        input_ids = self.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

        batch = dict(input_ids=input_ids)
        if "bbox_label" in instances[0].keys():
            batch["bbox_label"] = [instance["bbox_label"] for instance in instances]

        if "obj_context" in instances[0].keys():
            batch["obj_context"] = [instance["obj_context"] for instance in instances]

        if "full_info_dict" in instances[0].keys():
            batch["full_info_dict"] = [instance["full_info_dict"] for instance in instances]

        return batch


# test the dataset
if __name__ == "__main__":
    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(
        "/data/jianingy/3d-llama/checkpoints/llava-llama-2-7b-chat-lightning-preview",
        use_fast=False,
    )
    tokenizer.pad_token = tokenizer.unk_token
    conversation_lib.default_conversation = conversation_lib.conv_templates["llava_llama_2"]

    dataset = ObjIdentifierDataset(
        tokenizer,
        data_path="/home/jianingy/research/LLaVA-original/llava/dataset/3dfront/grounded_scene_description_gpt_format.json",
        scene_to_obj_mapping="/home/jianingy/research/LLaVA-original/llava/dataset/3dfront/compressed_organized_data.json",
    )
    print(len(dataset))
    print(dataset[0])