File size: 16,016 Bytes
25bd6f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16508ee
25bd6f8
 
 
 
 
 
 
 
 
62786d6
 
 
25bd6f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f52b9ad
25bd6f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
import openai
import json
from pydantic import BaseModel, Field
from PIL import Image
from tqdm import tqdm
from transformers import AutoProcessor, AutoModelForCausalLM 
import torch
import requests
import spaces

class PromptTuple(BaseModel):
    class Tuple(BaseModel):
        type: str = Field(
            description="The type of the tuple. One of entity, attribute, relation",
            example="attribute",
        )
        type_detail: str = Field(
            description="""The detail of the type. For example: 
            - Entity: whole (entire entity, e.g., chair), part (part of entity, e.g., back of chair). 
            - Attribute: color (e.g., red book), type (e.g., aviator goggles), material (e.g., wooden chair), count (e.g., 5 geese), texture (e.g., rough surface), text rendering (e.g., letters “Macaroni”), shape (e.g., triangle block), size (e.g., large fence). 
            - Relation: spatial (e.g., A next to B); action (A kicks B).""",
            example="color",
        )
        semantics: list = Field(
            description="List of strings that explain the existence of type and type_detail in the tuple",
            example=["motorcycle", "blue"],
        )

    tuples: list[Tuple] = Field(
        description="List of tuples. Maximum 8 tuples.",
        example=[
            {
                "type": "attribute",
                "type_detail": "color",
                "semantics": ["motorcycle", "blue"],
            }
        ],
    )


class DSGPromptProcessor:
    def __init__(self, model_name="gpt-4o-mini"):
        self.client = openai.OpenAI()
        self.model_name = model_name
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.binary_vqa = AutoModelForCausalLM.from_pretrained("toilaluan/Florence-2-base-Yes-No-VQA", trust_remote_code=True).to(self.device, torch.float16)
        self.binary_vqa_processor = processor = AutoProcessor.from_pretrained("toilaluan/Florence-2-base-Yes-No-VQA", trust_remote_code=True)
        

    def generate_tuples(self, input_text: str) -> PromptTuple:
        system_message = """
        Given an image caption, extract the relevant entities, attributes, and relations present in the caption, and structure them into JSON format according to the following schema:
Each tuple contains the following information:
    - Id: A unique identifier for the tuple.
    - Type: The category of the tuple. Choose from "entity," "attribute," or "relation."
    - Type Detail: Provide additional details based on the selected type:
        - Entity: Specify whether it refers to the whole entity (e.g., "chair") or a part of the entity (e.g., "back of chair").
        - Attribute: Specify the attribute type, such as "color", "type", "material", "count", "style", "texture", "text rendering", "shape" or "size".
        - Relation: Specify the relation type, such as "spatial" (e.g., "A next to B") or "action" (e.g., "A kicks B").
    - Semantics: A list of strings that represent the words or phrases from the caption that correspond to the tuple.
    Example Input: "A blue motorcycle parked next to a red car."
    Example output:
    {
        "tuples": [
            {
                "type": "entity",
                "type_detail": "whole",
                "semantics": ["motorcycle"]
            },
            {
                "type": "attribute",
                "type_detail": "color",
                "semantics": ["motorcycle", "blue"]
            },
            {
                "type": "entity",
                "type_detail": "whole",
                "semantics": ["car"]
            },
            {
                "type": "attribute",
                "type_detail": "color",
                "semantics": ["car", "red"]
            },
            {
                "type": "relation",
                "type_detail": "spatial",
                "semantics": ["motorcycle", "next to", "car"]
            }
        ]
    }
    The final JSON should contain a list of tuples, each describing a unique entity, attribute, or relation from the image caption. Each JSON should contain a maximum of 8 tuples.
        """
        messages = [
            {
                "role": "system",
                "content": system_message,
            },
            {
                "role": "user",
                "content": input_text,
            },
        ]

        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=messages,
            response_format={"type": "json_object"},
            max_tokens=512,
        )
        output = json.loads(response.choices[0].message.content)
        return PromptTuple(**output), response.usage.total_tokens

    def generate_dependencies(self, tuples: PromptTuple) -> dict:
        DEPENDENCY_PROMPT = """
        Given the following tuples extracted from an image caption, determine the dependencies between the entities, attributes, and relations in the JSON format.
        Each tuple contains the following information:
        - Id: A unique identifier for the tuple.
        - Type: The category of the tuple. Choose from "entity," "attribute," or "relation."
        - Type Detail: Provide additional details based on the selected type:
          - Entity: Specify whether it refers to the whole entity (e.g., "chair") or a part of the entity (e.g., "back of chair").
          - Attribute: Specify the attribute type, such as "color," "type," "material," "count," "texture," "text rendering," "shape," or "size."
          - Relation: Specify the relation type, such as "spatial" (e.g., "A next to B") or "action" (e.g., "A kicks B").
        - Semantics: A list of strings that represent the words or phrases from the caption that correspond to the tuple.
        Output is a dictionary where the key is the id of the tuple and the value is a list of ids that the tuple depends on.
        Example input:
        [
            {
                "id": 1,
                "type": "entity",
                "type_detail": "whole",
                "semantics": ["motorcycle"]
            },
            {
                "id": 2,
                "type": "attribute",
                "type_detail": "color",
                "semantics": ["motorcycle", "blue"]
            },
            {
                "id": 3,
                "type": "entity",
                "type_detail": "whole",
                "semantics": ["car"]
            },
            {
                "id": 4,
                "type": "attribute",
                "type_detail": "color",
                "semantics": ["car", "red"]
            },
            {
                "id": 5,
                "type": "relation",
                "type_detail": "spatial",
                "semantics": ["motorcycle", "next to", "car"]
            }
        ]

        Example output:
        {
            "1": [],
            "2": [1],
            "3": [],
            "4": [3],
            "5": [1, 3]
        }

        """
        input_obj = [{"id": i, **t.dict()} for i, t in enumerate(tuples.tuples)]

        messages = [
            {
                "role": "system",
                "content": DEPENDENCY_PROMPT,
            },
            {
                "role": "user",
                "content": json.dumps(input_obj),
            },
        ]

        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=messages,
            response_format={"type": "json_object"},
        )
        return (
            json.loads(response.choices[0].message.content),
            response.usage.total_tokens,
        )

    def generate_questions(
        self, prompt: str, tuples: list[dict], dependencies: dict
    ) -> list[str]:
        """Generate validate question based on tuples and dependencies.

        Args:
            prompt (str): a prompt describe the image
            tuples (list[dict]): each tuple is a unit of information extracted from the prompt
            dependencies (dict): the dependencies between tuples
        """
        system_message = """
        Task: Given a prompt that describe the image and a list of tuples extracted from the prompt. Generate questions based on tuple in natural language as a list.
        Each tuple contains the following information:
        - Id: A unique identifier for the tuple.
        - Type: The category of the tuple. Choose from "entity," "attribute," or "relation."
        - Type Detail: Provide additional details based on the selected type:
            - Entity: Specify whether it refers to the whole entity (e.g., "chair") or a part of the entity (e.g., "back of chair").
            - Attribute: Specify the attribute type, such as "color", "type", "material", "count", "style", "texture", "text rendering", "shape" or "size".
            - Relation: Specify the relation type, such as "spatial" (e.g., "A next to B") or "action" (e.g., "A kicks B").
        - Semantics: A list of strings that represent the words or phrases from the caption that correspond to the tuple.
        Output is a list of questions, each question corresponds to a tuple. The number of questions must be the same as the number of tuples.
        Example input:
        Prompt: "A traffic light and a signpost at a crossroads intersection near a waterway"
        Tuples:
        [
            {
                "id": 1,
                "type": "entity",
                "type_detail": "whole",
                "semantics": ["traffic light"]
            },
            {
                "id": 2,
                "type": "entity",
                "type_detail": "whole",
                "semantics": ["signpost"]
            },
            {
                "id": 3,
                "type": "relation",
                "type_detail": "spatial",
                "semantics": ["traffic light", "at", "crossroads intersection"]
            },
            {
                "id": 4,
                "type": "relation",
                "type_detail": "spatial",
                "semantics": ["crossroads intersection", "near", "waterway"]
            }
        ]
        Dependencies:
        {
            "1": [],
            "2": [],
            "3": [1, 2],
            "4": [3]
        }
        Example output is a json object. Each question ask about the existence of the tuple in the prompt and the answer should always be yes.
        {
            "1": "Is there a light?",
            "2": "Is there a signpost?",
            "3": "Is the traffic light at a crossroads intersection?",
            "4": "Is the crossroads intersection near a waterway?"
        }
        """

        user_str = f"""
        Prompt: {prompt}
        Tuples: {tuples}
        Dependencies: {dependencies}
        """
        messages = [
            {
                "role": "system",
                "content": system_message,
            },
            {
                "role": "user",
                "content": user_str,
            },
        ]

        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=messages,
            response_format={"type": "json_object"},
        )
        return (
            json.loads(response.choices[0].message.content),
            response.usage.total_tokens,
        )

    def find_layers(self, dep_dict):
        layers = []
        remaining_keys = set(dep_dict.keys())

        while remaining_keys:
            current_layer = []
            for key in list(remaining_keys):
                # If all dependencies of the key are in previous layers
                if all(
                    str(dep) in [k for layer in layers for k in layer]
                    for dep in dep_dict[key]
                ):
                    current_layer.append(key)

            # If no new layer is formed, break to avoid infinite loop
            if not current_layer:
                break

            # Add the current layer to the list of layers
            layers.append(current_layer)
            # Remove the keys that are now layered
            remaining_keys -= set(current_layer)

            if len(layers) == 3:
                break

        ordered_indexes = [item for sublist in layers for item in sublist]
        return ordered_indexes

    def _create_graph_questions(self, questions: dict, dependencies: dict) -> set:
        # create a question graph
        layered_indexes = self.find_layers(dependencies)
        print(layered_indexes)
        sorted_questions = [questions[i] for i in layered_indexes]

        return sorted_questions

    def get_reward(
        self,
        questions: list[str],
        dependencies: dict[list],
        images: list,
        mode="hybrid",
    ):
        """Get reward for the generated questions use structured question graph.

        Args:
            questions (list[str]): a list of questions generated based on the tuples
            dependencies (dict[list]): the dependencies between tuples
            images (list[str]): a list of image urls
        """

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.binary_vqa.to(self.device)
        scores = {}

        sorted_questions = self._create_graph_questions(questions, dependencies)
        print(sorted_questions)

        for i in range(len(images)):
            scores[i] = [0] * len(sorted_questions)

        def get_reward_for_a_question(
            question: str,
            question_dependencies: list[int],
            image: Image.Image,
            prev_scores: list[int],
        ) -> float:
            if any([not (prev_scores[i] > 0.5) for i in question_dependencies]):
                print(
                    f"Skipping question: {question}. It depends on {[sorted_questions[i] for i in range(len(question_dependencies))]} that was answered as No."
                )
                return 0
            if not isinstance(image, Image.Image):
                raise ValueError("Invalid image type")
            
            inputs = self.binary_vqa_processor(text=question, images=image, return_tensors="pt").to(self.device, torch.float16)
            decoder_input_ids = torch.LongTensor([[self.binary_vqa.language_model.config.pad_token_id, self.binary_vqa.language_model.config.decoder_start_token_id]]).to(self.device)
            outputs = self.binary_vqa(
                input_ids=inputs["input_ids"],
                pixel_values=inputs["pixel_values"],
                decoder_input_ids=decoder_input_ids
            )
            logits = outputs.logits[:, -1]
            score = logits[0].sigmoid().item()
            print(f"The answer Yes has {score} probs")
            return score

        pbar = tqdm(
            total=len(sorted_questions) * len(images),
            desc=f"Calculating reward over {len(images)} images and {len(sorted_questions)} questions",
        )
        for i, question in enumerate(sorted_questions):
            for j, image in enumerate(images):
                scores[j][i] = get_reward_for_a_question(
                    question, dependencies[str(i)], image, scores[j]
                )
                pbar.update(1)

        return scores, sorted_questions


if __name__ == "__main__":
    processor = DSGPromptProcessor(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1")
    url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
    image = Image.open(requests.get(url, stream=True).raw)
    input_text = "ghibli style image of a cat"
    tuples, tokens = processor.generate_tuples(input_text)
    print(tuples)
    dependencies, tokens = processor.generate_dependencies(tuples)
    print(dependencies)
    questions, tokens = processor.generate_questions(
        input_text, tuples.tuples, dependencies
    )
    print(questions)

    reward = processor.get_reward(input_text, questions, dependencies, [image])
    print(reward)