import openai import json from pydantic import BaseModel, Field from PIL import Image from tqdm import tqdm from transformers import AutoProcessor, AutoModelForCausalLM import torch import requests import spaces class PromptTuple(BaseModel): class Tuple(BaseModel): type: str = Field( description="The type of the tuple. One of entity, attribute, relation", example="attribute", ) type_detail: str = Field( description="""The detail of the type. For example: - Entity: whole (entire entity, e.g., chair), part (part of entity, e.g., back of chair). - Attribute: color (e.g., red book), type (e.g., aviator goggles), material (e.g., wooden chair), count (e.g., 5 geese), texture (e.g., rough surface), text rendering (e.g., letters “Macaroni”), shape (e.g., triangle block), size (e.g., large fence). - Relation: spatial (e.g., A next to B); action (A kicks B).""", example="color", ) semantics: list = Field( description="List of strings that explain the existence of type and type_detail in the tuple", example=["motorcycle", "blue"], ) tuples: list[Tuple] = Field( description="List of tuples. Maximum 8 tuples.", example=[ { "type": "attribute", "type_detail": "color", "semantics": ["motorcycle", "blue"], } ], ) class DSGPromptProcessor: def __init__(self, model_name="gpt-4o-mini"): self.client = openai.OpenAI() self.model_name = model_name self.device = "cuda" if torch.cuda.is_available() else "cpu" self.binary_vqa = AutoModelForCausalLM.from_pretrained("toilaluan/Florence-2-base-Yes-No-VQA", trust_remote_code=True).to(self.device, torch.float16) self.binary_vqa_processor = processor = AutoProcessor.from_pretrained("toilaluan/Florence-2-base-Yes-No-VQA", trust_remote_code=True) def generate_tuples(self, input_text: str) -> PromptTuple: system_message = """ Given an image caption, extract the relevant entities, attributes, and relations present in the caption, and structure them into JSON format according to the following schema: Each tuple contains the following information: - Id: A unique identifier for the tuple. - Type: The category of the tuple. Choose from "entity," "attribute," or "relation." - Type Detail: Provide additional details based on the selected type: - Entity: Specify whether it refers to the whole entity (e.g., "chair") or a part of the entity (e.g., "back of chair"). - Attribute: Specify the attribute type, such as "color", "type", "material", "count", "style", "texture", "text rendering", "shape" or "size". - Relation: Specify the relation type, such as "spatial" (e.g., "A next to B") or "action" (e.g., "A kicks B"). - Semantics: A list of strings that represent the words or phrases from the caption that correspond to the tuple. Example Input: "A blue motorcycle parked next to a red car." Example output: { "tuples": [ { "type": "entity", "type_detail": "whole", "semantics": ["motorcycle"] }, { "type": "attribute", "type_detail": "color", "semantics": ["motorcycle", "blue"] }, { "type": "entity", "type_detail": "whole", "semantics": ["car"] }, { "type": "attribute", "type_detail": "color", "semantics": ["car", "red"] }, { "type": "relation", "type_detail": "spatial", "semantics": ["motorcycle", "next to", "car"] } ] } The final JSON should contain a list of tuples, each describing a unique entity, attribute, or relation from the image caption. Each JSON should contain a maximum of 8 tuples. """ messages = [ { "role": "system", "content": system_message, }, { "role": "user", "content": input_text, }, ] response = self.client.chat.completions.create( model=self.model_name, messages=messages, response_format={"type": "json_object"}, max_tokens=512, ) output = json.loads(response.choices[0].message.content) return PromptTuple(**output), response.usage.total_tokens def generate_dependencies(self, tuples: PromptTuple) -> dict: DEPENDENCY_PROMPT = """ Given the following tuples extracted from an image caption, determine the dependencies between the entities, attributes, and relations in the JSON format. Each tuple contains the following information: - Id: A unique identifier for the tuple. - Type: The category of the tuple. Choose from "entity," "attribute," or "relation." - Type Detail: Provide additional details based on the selected type: - Entity: Specify whether it refers to the whole entity (e.g., "chair") or a part of the entity (e.g., "back of chair"). - Attribute: Specify the attribute type, such as "color," "type," "material," "count," "texture," "text rendering," "shape," or "size." - Relation: Specify the relation type, such as "spatial" (e.g., "A next to B") or "action" (e.g., "A kicks B"). - Semantics: A list of strings that represent the words or phrases from the caption that correspond to the tuple. Output is a dictionary where the key is the id of the tuple and the value is a list of ids that the tuple depends on. Example input: [ { "id": 1, "type": "entity", "type_detail": "whole", "semantics": ["motorcycle"] }, { "id": 2, "type": "attribute", "type_detail": "color", "semantics": ["motorcycle", "blue"] }, { "id": 3, "type": "entity", "type_detail": "whole", "semantics": ["car"] }, { "id": 4, "type": "attribute", "type_detail": "color", "semantics": ["car", "red"] }, { "id": 5, "type": "relation", "type_detail": "spatial", "semantics": ["motorcycle", "next to", "car"] } ] Example output: { "1": [], "2": [1], "3": [], "4": [3], "5": [1, 3] } """ input_obj = [{"id": i, **t.dict()} for i, t in enumerate(tuples.tuples)] messages = [ { "role": "system", "content": DEPENDENCY_PROMPT, }, { "role": "user", "content": json.dumps(input_obj), }, ] response = self.client.chat.completions.create( model=self.model_name, messages=messages, response_format={"type": "json_object"}, ) return ( json.loads(response.choices[0].message.content), response.usage.total_tokens, ) def generate_questions( self, prompt: str, tuples: list[dict], dependencies: dict ) -> list[str]: """Generate validate question based on tuples and dependencies. Args: prompt (str): a prompt describe the image tuples (list[dict]): each tuple is a unit of information extracted from the prompt dependencies (dict): the dependencies between tuples """ system_message = """ Task: Given a prompt that describe the image and a list of tuples extracted from the prompt. Generate questions based on tuple in natural language as a list. Each tuple contains the following information: - Id: A unique identifier for the tuple. - Type: The category of the tuple. Choose from "entity," "attribute," or "relation." - Type Detail: Provide additional details based on the selected type: - Entity: Specify whether it refers to the whole entity (e.g., "chair") or a part of the entity (e.g., "back of chair"). - Attribute: Specify the attribute type, such as "color", "type", "material", "count", "style", "texture", "text rendering", "shape" or "size". - Relation: Specify the relation type, such as "spatial" (e.g., "A next to B") or "action" (e.g., "A kicks B"). - Semantics: A list of strings that represent the words or phrases from the caption that correspond to the tuple. Output is a list of questions, each question corresponds to a tuple. The number of questions must be the same as the number of tuples. Example input: Prompt: "A traffic light and a signpost at a crossroads intersection near a waterway" Tuples: [ { "id": 1, "type": "entity", "type_detail": "whole", "semantics": ["traffic light"] }, { "id": 2, "type": "entity", "type_detail": "whole", "semantics": ["signpost"] }, { "id": 3, "type": "relation", "type_detail": "spatial", "semantics": ["traffic light", "at", "crossroads intersection"] }, { "id": 4, "type": "relation", "type_detail": "spatial", "semantics": ["crossroads intersection", "near", "waterway"] } ] Dependencies: { "1": [], "2": [], "3": [1, 2], "4": [3] } Example output is a json object. Each question ask about the existence of the tuple in the prompt and the answer should always be yes. { "1": "Is there a light?", "2": "Is there a signpost?", "3": "Is the traffic light at a crossroads intersection?", "4": "Is the crossroads intersection near a waterway?" } """ user_str = f""" Prompt: {prompt} Tuples: {tuples} Dependencies: {dependencies} """ messages = [ { "role": "system", "content": system_message, }, { "role": "user", "content": user_str, }, ] response = self.client.chat.completions.create( model=self.model_name, messages=messages, response_format={"type": "json_object"}, ) return ( json.loads(response.choices[0].message.content), response.usage.total_tokens, ) def find_layers(self, dep_dict): layers = [] remaining_keys = set(dep_dict.keys()) while remaining_keys: current_layer = [] for key in list(remaining_keys): # If all dependencies of the key are in previous layers if all( str(dep) in [k for layer in layers for k in layer] for dep in dep_dict[key] ): current_layer.append(key) # If no new layer is formed, break to avoid infinite loop if not current_layer: break # Add the current layer to the list of layers layers.append(current_layer) # Remove the keys that are now layered remaining_keys -= set(current_layer) if len(layers) == 3: break ordered_indexes = [item for sublist in layers for item in sublist] return ordered_indexes def _create_graph_questions(self, questions: dict, dependencies: dict) -> set: # create a question graph layered_indexes = self.find_layers(dependencies) print(layered_indexes) sorted_questions = [questions[i] for i in layered_indexes] return sorted_questions def get_reward( self, questions: list[str], dependencies: dict[list], images: list, mode="hybrid", ): """Get reward for the generated questions use structured question graph. Args: questions (list[str]): a list of questions generated based on the tuples dependencies (dict[list]): the dependencies between tuples images (list[str]): a list of image urls """ self.device = "cuda" if torch.cuda.is_available() else "cpu" self.binary_vqa.to(self.device) scores = {} sorted_questions = self._create_graph_questions(questions, dependencies) print(sorted_questions) for i in range(len(images)): scores[i] = [0] * len(sorted_questions) def get_reward_for_a_question( question: str, question_dependencies: list[int], image: Image.Image, prev_scores: list[int], ) -> float: if any([not (prev_scores[i] > 0.5) for i in question_dependencies]): print( f"Skipping question: {question}. It depends on {[sorted_questions[i] for i in range(len(question_dependencies))]} that was answered as No." ) return 0 if not isinstance(image, Image.Image): raise ValueError("Invalid image type") inputs = self.binary_vqa_processor(text=question, images=image, return_tensors="pt").to(self.device, torch.float16) decoder_input_ids = torch.LongTensor([[self.binary_vqa.language_model.config.pad_token_id, self.binary_vqa.language_model.config.decoder_start_token_id]]).to(self.device) outputs = self.binary_vqa( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], decoder_input_ids=decoder_input_ids ) logits = outputs.logits[:, -1] score = logits[0].sigmoid().item() print(f"The answer Yes has {score} probs") return score pbar = tqdm( total=len(sorted_questions) * len(images), desc=f"Calculating reward over {len(images)} images and {len(sorted_questions)} questions", ) for i, question in enumerate(sorted_questions): for j, image in enumerate(images): scores[j][i] = get_reward_for_a_question( question, dependencies[str(i)], image, scores[j] ) pbar.update(1) return scores, sorted_questions if __name__ == "__main__": processor = DSGPromptProcessor(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1") url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true" image = Image.open(requests.get(url, stream=True).raw) input_text = "ghibli style image of a cat" tuples, tokens = processor.generate_tuples(input_text) print(tuples) dependencies, tokens = processor.generate_dependencies(tuples) print(dependencies) questions, tokens = processor.generate_questions( input_text, tuples.tuples, dependencies ) print(questions) reward = processor.get_reward(input_text, questions, dependencies, [image]) print(reward)