Spaces:
Sleeping
Sleeping
import openai | |
import json | |
from pydantic import BaseModel, Field | |
from PIL import Image | |
from tqdm import tqdm | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
import torch | |
import requests | |
import spaces | |
class PromptTuple(BaseModel): | |
class Tuple(BaseModel): | |
type: str = Field( | |
description="The type of the tuple. One of entity, attribute, relation", | |
example="attribute", | |
) | |
type_detail: str = Field( | |
description="""The detail of the type. For example: | |
- Entity: whole (entire entity, e.g., chair), part (part of entity, e.g., back of chair). | |
- Attribute: color (e.g., red book), type (e.g., aviator goggles), material (e.g., wooden chair), count (e.g., 5 geese), texture (e.g., rough surface), text rendering (e.g., letters “Macaroni”), shape (e.g., triangle block), size (e.g., large fence). | |
- Relation: spatial (e.g., A next to B); action (A kicks B).""", | |
example="color", | |
) | |
semantics: list = Field( | |
description="List of strings that explain the existence of type and type_detail in the tuple", | |
example=["motorcycle", "blue"], | |
) | |
tuples: list[Tuple] = Field( | |
description="List of tuples. Maximum 8 tuples.", | |
example=[ | |
{ | |
"type": "attribute", | |
"type_detail": "color", | |
"semantics": ["motorcycle", "blue"], | |
} | |
], | |
) | |
class DSGPromptProcessor: | |
def __init__(self, model_name="gpt-4o-mini"): | |
self.client = openai.OpenAI() | |
self.model_name = model_name | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.binary_vqa = AutoModelForCausalLM.from_pretrained("toilaluan/Florence-2-base-Yes-No-VQA", trust_remote_code=True).to(self.device, torch.float16) | |
self.binary_vqa_processor = processor = AutoProcessor.from_pretrained("toilaluan/Florence-2-base-Yes-No-VQA", trust_remote_code=True) | |
def generate_tuples(self, input_text: str) -> PromptTuple: | |
system_message = """ | |
Given an image caption, extract the relevant entities, attributes, and relations present in the caption, and structure them into JSON format according to the following schema: | |
Each tuple contains the following information: | |
- Id: A unique identifier for the tuple. | |
- Type: The category of the tuple. Choose from "entity," "attribute," or "relation." | |
- Type Detail: Provide additional details based on the selected type: | |
- Entity: Specify whether it refers to the whole entity (e.g., "chair") or a part of the entity (e.g., "back of chair"). | |
- Attribute: Specify the attribute type, such as "color", "type", "material", "count", "style", "texture", "text rendering", "shape" or "size". | |
- Relation: Specify the relation type, such as "spatial" (e.g., "A next to B") or "action" (e.g., "A kicks B"). | |
- Semantics: A list of strings that represent the words or phrases from the caption that correspond to the tuple. | |
Example Input: "A blue motorcycle parked next to a red car." | |
Example output: | |
{ | |
"tuples": [ | |
{ | |
"type": "entity", | |
"type_detail": "whole", | |
"semantics": ["motorcycle"] | |
}, | |
{ | |
"type": "attribute", | |
"type_detail": "color", | |
"semantics": ["motorcycle", "blue"] | |
}, | |
{ | |
"type": "entity", | |
"type_detail": "whole", | |
"semantics": ["car"] | |
}, | |
{ | |
"type": "attribute", | |
"type_detail": "color", | |
"semantics": ["car", "red"] | |
}, | |
{ | |
"type": "relation", | |
"type_detail": "spatial", | |
"semantics": ["motorcycle", "next to", "car"] | |
} | |
] | |
} | |
The final JSON should contain a list of tuples, each describing a unique entity, attribute, or relation from the image caption. Each JSON should contain a maximum of 8 tuples. | |
""" | |
messages = [ | |
{ | |
"role": "system", | |
"content": system_message, | |
}, | |
{ | |
"role": "user", | |
"content": input_text, | |
}, | |
] | |
response = self.client.chat.completions.create( | |
model=self.model_name, | |
messages=messages, | |
response_format={"type": "json_object"}, | |
max_tokens=512, | |
) | |
output = json.loads(response.choices[0].message.content) | |
return PromptTuple(**output), response.usage.total_tokens | |
def generate_dependencies(self, tuples: PromptTuple) -> dict: | |
DEPENDENCY_PROMPT = """ | |
Given the following tuples extracted from an image caption, determine the dependencies between the entities, attributes, and relations in the JSON format. | |
Each tuple contains the following information: | |
- Id: A unique identifier for the tuple. | |
- Type: The category of the tuple. Choose from "entity," "attribute," or "relation." | |
- Type Detail: Provide additional details based on the selected type: | |
- Entity: Specify whether it refers to the whole entity (e.g., "chair") or a part of the entity (e.g., "back of chair"). | |
- Attribute: Specify the attribute type, such as "color," "type," "material," "count," "texture," "text rendering," "shape," or "size." | |
- Relation: Specify the relation type, such as "spatial" (e.g., "A next to B") or "action" (e.g., "A kicks B"). | |
- Semantics: A list of strings that represent the words or phrases from the caption that correspond to the tuple. | |
Output is a dictionary where the key is the id of the tuple and the value is a list of ids that the tuple depends on. | |
Example input: | |
[ | |
{ | |
"id": 1, | |
"type": "entity", | |
"type_detail": "whole", | |
"semantics": ["motorcycle"] | |
}, | |
{ | |
"id": 2, | |
"type": "attribute", | |
"type_detail": "color", | |
"semantics": ["motorcycle", "blue"] | |
}, | |
{ | |
"id": 3, | |
"type": "entity", | |
"type_detail": "whole", | |
"semantics": ["car"] | |
}, | |
{ | |
"id": 4, | |
"type": "attribute", | |
"type_detail": "color", | |
"semantics": ["car", "red"] | |
}, | |
{ | |
"id": 5, | |
"type": "relation", | |
"type_detail": "spatial", | |
"semantics": ["motorcycle", "next to", "car"] | |
} | |
] | |
Example output: | |
{ | |
"1": [], | |
"2": [1], | |
"3": [], | |
"4": [3], | |
"5": [1, 3] | |
} | |
""" | |
input_obj = [{"id": i, **t.dict()} for i, t in enumerate(tuples.tuples)] | |
messages = [ | |
{ | |
"role": "system", | |
"content": DEPENDENCY_PROMPT, | |
}, | |
{ | |
"role": "user", | |
"content": json.dumps(input_obj), | |
}, | |
] | |
response = self.client.chat.completions.create( | |
model=self.model_name, | |
messages=messages, | |
response_format={"type": "json_object"}, | |
) | |
return ( | |
json.loads(response.choices[0].message.content), | |
response.usage.total_tokens, | |
) | |
def generate_questions( | |
self, prompt: str, tuples: list[dict], dependencies: dict | |
) -> list[str]: | |
"""Generate validate question based on tuples and dependencies. | |
Args: | |
prompt (str): a prompt describe the image | |
tuples (list[dict]): each tuple is a unit of information extracted from the prompt | |
dependencies (dict): the dependencies between tuples | |
""" | |
system_message = """ | |
Task: Given a prompt that describe the image and a list of tuples extracted from the prompt. Generate questions based on tuple in natural language as a list. | |
Each tuple contains the following information: | |
- Id: A unique identifier for the tuple. | |
- Type: The category of the tuple. Choose from "entity," "attribute," or "relation." | |
- Type Detail: Provide additional details based on the selected type: | |
- Entity: Specify whether it refers to the whole entity (e.g., "chair") or a part of the entity (e.g., "back of chair"). | |
- Attribute: Specify the attribute type, such as "color", "type", "material", "count", "style", "texture", "text rendering", "shape" or "size". | |
- Relation: Specify the relation type, such as "spatial" (e.g., "A next to B") or "action" (e.g., "A kicks B"). | |
- Semantics: A list of strings that represent the words or phrases from the caption that correspond to the tuple. | |
Output is a list of questions, each question corresponds to a tuple. The number of questions must be the same as the number of tuples. | |
Example input: | |
Prompt: "A traffic light and a signpost at a crossroads intersection near a waterway" | |
Tuples: | |
[ | |
{ | |
"id": 1, | |
"type": "entity", | |
"type_detail": "whole", | |
"semantics": ["traffic light"] | |
}, | |
{ | |
"id": 2, | |
"type": "entity", | |
"type_detail": "whole", | |
"semantics": ["signpost"] | |
}, | |
{ | |
"id": 3, | |
"type": "relation", | |
"type_detail": "spatial", | |
"semantics": ["traffic light", "at", "crossroads intersection"] | |
}, | |
{ | |
"id": 4, | |
"type": "relation", | |
"type_detail": "spatial", | |
"semantics": ["crossroads intersection", "near", "waterway"] | |
} | |
] | |
Dependencies: | |
{ | |
"1": [], | |
"2": [], | |
"3": [1, 2], | |
"4": [3] | |
} | |
Example output is a json object. Each question ask about the existence of the tuple in the prompt and the answer should always be yes. | |
{ | |
"1": "Is there a light?", | |
"2": "Is there a signpost?", | |
"3": "Is the traffic light at a crossroads intersection?", | |
"4": "Is the crossroads intersection near a waterway?" | |
} | |
""" | |
user_str = f""" | |
Prompt: {prompt} | |
Tuples: {tuples} | |
Dependencies: {dependencies} | |
""" | |
messages = [ | |
{ | |
"role": "system", | |
"content": system_message, | |
}, | |
{ | |
"role": "user", | |
"content": user_str, | |
}, | |
] | |
response = self.client.chat.completions.create( | |
model=self.model_name, | |
messages=messages, | |
response_format={"type": "json_object"}, | |
) | |
return ( | |
json.loads(response.choices[0].message.content), | |
response.usage.total_tokens, | |
) | |
def find_layers(self, dep_dict): | |
layers = [] | |
remaining_keys = set(dep_dict.keys()) | |
while remaining_keys: | |
current_layer = [] | |
for key in list(remaining_keys): | |
# If all dependencies of the key are in previous layers | |
if all( | |
str(dep) in [k for layer in layers for k in layer] | |
for dep in dep_dict[key] | |
): | |
current_layer.append(key) | |
# If no new layer is formed, break to avoid infinite loop | |
if not current_layer: | |
break | |
# Add the current layer to the list of layers | |
layers.append(current_layer) | |
# Remove the keys that are now layered | |
remaining_keys -= set(current_layer) | |
if len(layers) == 3: | |
break | |
ordered_indexes = [item for sublist in layers for item in sublist] | |
return ordered_indexes | |
def _create_graph_questions(self, questions: dict, dependencies: dict) -> set: | |
# create a question graph | |
layered_indexes = self.find_layers(dependencies) | |
print(layered_indexes) | |
sorted_questions = [questions[i] for i in layered_indexes] | |
return sorted_questions | |
def get_reward( | |
self, | |
questions: list[str], | |
dependencies: dict[list], | |
images: list, | |
mode="hybrid", | |
): | |
"""Get reward for the generated questions use structured question graph. | |
Args: | |
questions (list[str]): a list of questions generated based on the tuples | |
dependencies (dict[list]): the dependencies between tuples | |
images (list[str]): a list of image urls | |
""" | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.binary_vqa.to(self.device) | |
scores = {} | |
sorted_questions = self._create_graph_questions(questions, dependencies) | |
print(sorted_questions) | |
for i in range(len(images)): | |
scores[i] = [0] * len(sorted_questions) | |
def get_reward_for_a_question( | |
question: str, | |
question_dependencies: list[int], | |
image: Image.Image, | |
prev_scores: list[int], | |
) -> float: | |
if any([not (prev_scores[i] > 0.5) for i in question_dependencies]): | |
print( | |
f"Skipping question: {question}. It depends on {[sorted_questions[i] for i in range(len(question_dependencies))]} that was answered as No." | |
) | |
return 0 | |
if not isinstance(image, Image.Image): | |
raise ValueError("Invalid image type") | |
inputs = self.binary_vqa_processor(text=question, images=image, return_tensors="pt").to(self.device, torch.float16) | |
decoder_input_ids = torch.LongTensor([[self.binary_vqa.language_model.config.pad_token_id, self.binary_vqa.language_model.config.decoder_start_token_id]]).to(self.device) | |
outputs = self.binary_vqa( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
decoder_input_ids=decoder_input_ids | |
) | |
logits = outputs.logits[:, -1] | |
score = logits[0].sigmoid().item() | |
print(f"The answer Yes has {score} probs") | |
return score | |
pbar = tqdm( | |
total=len(sorted_questions) * len(images), | |
desc=f"Calculating reward over {len(images)} images and {len(sorted_questions)} questions", | |
) | |
for i, question in enumerate(sorted_questions): | |
for j, image in enumerate(images): | |
scores[j][i] = get_reward_for_a_question( | |
question, dependencies[str(i)], image, scores[j] | |
) | |
pbar.update(1) | |
return scores, sorted_questions | |
if __name__ == "__main__": | |
processor = DSGPromptProcessor(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1") | |
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true" | |
image = Image.open(requests.get(url, stream=True).raw) | |
input_text = "ghibli style image of a cat" | |
tuples, tokens = processor.generate_tuples(input_text) | |
print(tuples) | |
dependencies, tokens = processor.generate_dependencies(tuples) | |
print(dependencies) | |
questions, tokens = processor.generate_questions( | |
input_text, tuples.tuples, dependencies | |
) | |
print(questions) | |
reward = processor.get_reward(input_text, questions, dependencies, [image]) | |
print(reward) | |