Spaces:
Runtime error
Runtime error
import os | |
import io_utils as io_uts | |
import vis_utils as v_uts | |
from vis_common import * | |
import pandas as pd | |
from GPT_prompts import ( | |
TEMPLATE_0, | |
TEMPLATE_1, | |
TEMPLATE_2 | |
) | |
from call_assistant_api import ( | |
EditActionClassifier | |
) | |
import json | |
from datasets import Dataset | |
unknown_action = "Unknown" | |
def dfs(actions, res, res_set): | |
""" | |
Enumerate all options in an edit action. | |
""" | |
if len(actions) == 0: | |
res_set.append(res) | |
return | |
for word in actions[0]: | |
cur_res = res + [word] | |
dfs(actions[1:], cur_res, res_set) | |
return res_set | |
def split_actions(actions): | |
if '/' in actions: | |
words = actions.split(" ") | |
common = "" | |
cur_actions = [] # Changed from {} to [] | |
counter = 0 | |
for word in words: | |
if "/" in word: | |
action = unknown_action + f"{counter} " | |
cur_actions.append(word.split('/')) | |
counter += 1 | |
else: | |
action = word + " " | |
common += action | |
actions_sets = dfs(cur_actions, [], []) | |
instructions = [] | |
for action_set in actions_sets: | |
temp_common = common | |
for i, action in enumerate(action_set): | |
temp_common = temp_common.replace(unknown_action+f"{i}", action.replace('_', '')) | |
instructions.append(temp_common.strip()) | |
return instructions | |
else: | |
return [actions] | |
def sample_prompt(sub, class_name, edit_action): | |
if not ("the subject" in edit_action): | |
if (" wall " in edit_action) or (" ground " in edit_action) or ("furnished" in edit_action): | |
prompt = "an indoor living room." if random.uniform(0, 1) < 0.5 else "a beautiful lobby" | |
return prompt | |
if (" sky " in edit_action): | |
prompt = "a natural image of sea, mountains and sky" | |
return prompt | |
if (" weather" in edit_action) or (" snow" in edit_action): | |
prompt = "a naturalistic scene with trees" | |
return prompt | |
p = random.uniform(0, 1) | |
if p < 0.5: | |
prompt = random.choice(sub["scenes"]) | |
return prompt | |
p = random.uniform(0, 1) | |
person = ["view", "pose", "adj", "color", "human_age","people"] | |
subject = ["view", "pose", "adj", "color", "animal_age", "subjects"] | |
appends = [" of ", " ", " ", " ", " ", "."] | |
attri_set = person if p < 0.7 else subject | |
prompt = "" | |
for i, key in enumerate(attri_set): | |
attr = random.choice(sub[key]) | |
prompt = prompt + attr + appends[i] | |
return prompt | |
def prepare_our_prompt_v0(): | |
""" | |
Prepare the prompt with our coverage, simple prompt, found good for person. | |
""" | |
random.seed(0) | |
data_root="/mlx/users/peng.wang/playground/data/chat_edit/assets/test200" | |
edit_file = f"{data_root}/edit_class.txt" | |
edit_lines = io_uts.load_lines(edit_file) | |
sub_file = f"{data_root}/subject.yaml" | |
sub = io_uts.load_yaml(sub_file) | |
from_human = f"{data_root}/edit_instructions_v0.jsonl" | |
# sample an item or empty each feature | |
items = [] | |
for edit_line in tqdm(edit_lines): | |
class_name, edit_actions = edit_line.split(":") | |
edit_actions = split_actions(edit_actions) | |
for edit_action in edit_actions: | |
prompt1 = sample_prompt(sub, class_name, edit_action) | |
prompt = TEMPLATE_0.format(prompt1=prompt1, edit_action=edit_action) | |
item = {} | |
item["prompt_0"] = prompt | |
item["class"] = class_name | |
item["input"] = prompt1 | |
item["edit"] = edit_action | |
item["output"] = f"{prompt1} with {edit_action}" | |
items.append(item) | |
print("number of examples:", len(items)) | |
io_uts.dump_jsonl(from_human, items) | |
def config_our_prompt_v1(): | |
# if region wise, let first find and locate the region. | |
pass | |
def config_our_prompt_v2(): | |
# if region wise, let first find and locate the region. | |
pass | |
def prepare_p2p_prompt_v0(): | |
test_root="/mlx/users/peng.wang/playground/repo/instruct-pix2pix/data/chat_edit/assets/test200/" | |
cache_root="/mlx/users/peng.wang/playground/repo/instruct-pix2pix/data/chat_edit/assets/p2p700" | |
jsonl_file = f"{test_root}instruct_p2p_700.jsonl" | |
jsonl_file_out = f"{test_root}instruct_p2p_700_reformat.jsonl" | |
def classify_p2p_edit_action(): | |
classifier = EditActionClassifier() | |
examples = io_uts.load_jsonl(jsonl_file) | |
examples_out = [] | |
for count, example in tqdm(enumerate(examples)): | |
res_file = f"{cache_root}/{count}.json" | |
if os.path.exists(res_file): | |
example = io_uts.load_json(res_file) | |
examples_out.append(example) | |
continue | |
edit_class = classifier.infer(example["edit"]) | |
example["class"] = edit_class | |
example["prompt_0"] = TEMPLATE_0.format(prompt1=example["input"], edit_action=example["edit"]) | |
io_uts.dump_json(res_file, example) | |
examples_out.append(example) | |
io_uts.dump_jsonl(jsonl_file_out, examples_out) | |
def subsample_p2p(): | |
jsonl_file_sample_out = f"{test_root}/instruct_p2p_val.jsonl" | |
examples = io_uts.load_jsonl(jsonl_file_out) | |
classes = {} | |
results = [] | |
max_each_class = 1 | |
for example in examples: | |
if example["class"] not in classes.keys(): | |
classes[example["class"]] = 1 | |
results.append(example) | |
else: | |
if classes[example["class"]] < max_each_class: | |
classes[example["class"]] += 1 | |
results.append(example) | |
print("sample num: ", len(results)) | |
io_uts.dump_jsonl(jsonl_file_sample_out, results) | |
# classify_p2p_edit_action() | |
subsample_p2p() | |
def prepare_emu_set(): | |
test_root="/mlx/users/peng.wang/playground/repo/instruct-pix2pix/data/chat_edit/assets/emu_test/" | |
output_root="/mlx/users/peng.wang/playground/repo/instruct-pix2pix/data/chat_edit/assets/test200/" | |
items = [] | |
files = v_uts.list_all_files(test_root, exts=["txt"]) | |
class_map = { | |
"add": "Local,Add", | |
"background": "Global,Background", | |
"color": "Global,Color", | |
"global": "Global", | |
"local": "Local", | |
"remove": "Local,Remove", | |
"style": "Global,Stylization", | |
"text": "Local,Add,Text" | |
} | |
for edit_file in tqdm(files): | |
edit_action = io_uts.load_lines(edit_file) | |
item = {"input": edit_action[1], "edit": edit_action[0], "output": edit_action[2]} | |
item["prompt_0"] = TEMPLATE_0.format(prompt1=item["input"], edit_action=item["edit"]) | |
class_name = edit_file.split('/')[-2] | |
item["class"] = class_map[class_name] | |
items.append(item) | |
io_uts.dump_jsonl(f"{output_root}/emu_val_90.jsonl", items) | |
def merge_prompts(): | |
output_root="/mlx/users/peng.wang/playground/repo/instruct-pix2pix/data/chat_edit/assets/ChatEdit/" | |
our_set = "edit_instructions_val" | |
p2p_set = "instruct_p2p_val" | |
emu_set = "emu_val_90" | |
full_items = [] | |
for val_set in [our_set, p2p_set, emu_set]: | |
items = io_uts.load_jsonl(f"{output_root}/{val_set}.jsonl") | |
print(val_set, len(items)) | |
keynames = ["input", "edit", "output", "prompt_0", "class"] | |
items_out = [] | |
for item in items: | |
# reorder the item keys based on keynames | |
item_out = {} | |
for key in keynames: | |
item_out[key] = item[key] | |
item_out["prompt_1"] = TEMPLATE_1.format( | |
prompt1=item["input"], | |
prompt2=item['output'], | |
edit_action=item["edit"]) | |
item_out["prompt_2"] = TEMPLATE_2.format( | |
prompt1=item["input"], | |
prompt2=item['output'], | |
edit_action=item["edit"]) | |
items_out.append(item_out) | |
full_items = full_items + items_out | |
print("num: ", len(full_items)) | |
io_uts.dump_jsonl(f"{output_root}/full_val.jsonl", full_items) | |
def classify_and_sample_p2p_prompts(): | |
pass | |
def write_dataset_toparquet(): | |
dataroot = "/mnt/bn/datacompv6/data/chat_edit/assets/ChatEdit/" | |
jsonl_path = f"{dataroot}/full_val.jsonl" | |
folder_name = "prompt_0" | |
image_folder = f"{dataroot}/{folder_name}" | |
output_path = f"{dataroot}/data/" | |
v_uts.mkdir(output_path) | |
items = io_uts.load_jsonl(jsonl_path) | |
items_out = [] | |
for i, item in enumerate(tqdm(items)): | |
image_path = f"{image_folder}/{i:03}.png" | |
item['image_id'] = f"{i:03}" | |
item['image'] = v_uts.encode_b64(image_path) | |
items_out.append(item) | |
# Convert the data to a pandas DataFrame | |
df = pd.DataFrame(items_out) | |
# Create a Hugging Face dataset from the DataFrame | |
dataset = Dataset.from_pandas(df) | |
# Save the dataset to a Parquet file | |
dataset.to_parquet(f"{output_path}/{folder_name}.parquet") | |
if __name__ == '__main__': | |
# res = "make firework/rainbow in sky/ground region in the image" | |
# print(split_actions(res)) | |
# prepare_our_prompt_v0() | |
# prepare_p2p_prompt_v0() | |
# prepare_emu_set() | |
# merge_prompts() | |
write_dataset_toparquet() | |