Spaces:
Runtime error
Runtime error
import os | |
import json | |
import random | |
import re | |
# Load JSON files | |
def load_json_file(file_name): | |
file_path = os.path.join("data", file_name) | |
with open(file_path, "r") as file: | |
return json.load(file) | |
# Load gender-specific JSON files | |
FEMALE_DEFAULT_TAGS = load_json_file("female_default_tags.json") | |
MALE_DEFAULT_TAGS = load_json_file("male_default_tags.json") | |
FEMALE_BODY_TYPES = load_json_file("female_body_types.json") | |
MALE_BODY_TYPES = load_json_file("male_body_types.json") | |
FEMALE_CLOTHING = load_json_file("female_clothing.json") | |
MALE_CLOTHING = load_json_file("male_clothing.json") | |
FEMALE_ADDITIONAL_DETAILS = load_json_file("female_additional_details.json") | |
MALE_ADDITIONAL_DETAILS = load_json_file("male_additional_details.json") | |
# Load non-gender-specific JSON files | |
ARTFORM = load_json_file("artform.json") | |
PHOTO_TYPE = load_json_file("photo_type.json") | |
ROLES = load_json_file("roles.json") | |
HAIRSTYLES = load_json_file("hairstyles.json") | |
PLACE = load_json_file("place.json") | |
LIGHTING = load_json_file("lighting.json") | |
COMPOSITION = load_json_file("composition.json") | |
POSE = load_json_file("pose.json") | |
BACKGROUND = load_json_file("background.json") | |
PHOTOGRAPHY_STYLES = load_json_file("photography_styles.json") | |
DEVICE = load_json_file("device.json") | |
PHOTOGRAPHER = load_json_file("photographer.json") | |
ARTIST = load_json_file("artist.json") | |
DIGITAL_ARTFORM = load_json_file("digital_artform.json") | |
class PromptGenerator: | |
def __init__(self, seed=None): | |
self.rng = random.Random(seed) | |
self.next_data = self.load_next_data() | |
def split_and_choose(self, input_str): | |
choices = [choice.strip() for choice in input_str.split(",")] | |
return self.rng.choices(choices, k=1)[0] | |
def get_choice(self, input_str, default_choices): | |
if input_str.lower() == "disabled": | |
return "" | |
elif "," in input_str: | |
return self.split_and_choose(input_str) | |
elif input_str.lower() == "random": | |
return self.rng.choices(default_choices, k=1)[0] | |
else: | |
return input_str | |
def clean_consecutive_commas(self, input_string): | |
cleaned_string = re.sub(r',\s*,', ', ', input_string) | |
return cleaned_string | |
def process_string(self, replaced, seed): | |
replaced = re.sub(r'\s*,\s*', ', ', replaced) | |
replaced = re.sub(r',+', ', ', replaced) | |
original = replaced | |
first_break_clipl_index = replaced.find("BREAK_CLIPL") | |
second_break_clipl_index = replaced.find("BREAK_CLIPL", first_break_clipl_index + len("BREAK_CLIPL")) | |
if first_break_clipl_index != -1 and second_break_clipl_index != -1: | |
clip_content_l = replaced[first_break_clipl_index + len("BREAK_CLIPL"):second_break_clipl_index] | |
replaced = replaced[:first_break_clipl_index].strip(", ") + replaced[second_break_clipl_index + len("BREAK_CLIPL"):].strip(", ") | |
clip_l = clip_content_l | |
else: | |
clip_l = "" | |
first_break_clipg_index = replaced.find("BREAK_CLIPG") | |
second_break_clipg_index = replaced.find("BREAK_CLIPG", first_break_clipg_index + len("BREAK_CLIPG")) | |
if first_break_clipg_index != -1 and second_break_clipg_index != -1: | |
clip_content_g = replaced[first_break_clipg_index + len("BREAK_CLIPG"):second_break_clipg_index] | |
replaced = replaced[:first_break_clipg_index].strip(", ") + replaced[second_break_clipg_index + len("BREAK_CLIPG"):].strip(", ") | |
clip_g = clip_content_g | |
else: | |
clip_g = "" | |
t5xxl = replaced | |
original = original.replace("BREAK_CLIPL", "").replace("BREAK_CLIPG", "") | |
original = re.sub(r'\s*,\s*', ', ', original) | |
original = re.sub(r',+', ', ', original) | |
clip_l = re.sub(r'\s*,\s*', ', ', clip_l) | |
clip_l = re.sub(r',+', ', ', clip_l) | |
clip_g = re.sub(r'\s*,\s*', ', ', clip_g) | |
clip_g = re.sub(r',+', ', ', clip_g) | |
if clip_l.startswith(", "): | |
clip_l = clip_l[2:] | |
if clip_g.startswith(", "): | |
clip_g = clip_g[2:] | |
if original.startswith(", "): | |
original = original[2:] | |
if t5xxl.startswith(", "): | |
t5xxl = t5xxl[2:] | |
# Add spaces after commas | |
replaced = re.sub(r',(?!\s)', ', ', replaced) | |
original = re.sub(r',(?!\s)', ', ', original) | |
clip_l = re.sub(r',(?!\s)', ', ', clip_l) | |
clip_g = re.sub(r',(?!\s)', ', ', clip_g) | |
t5xxl = re.sub(r',(?!\s)', ', ', t5xxl) | |
return original, seed, t5xxl, clip_l, clip_g | |
def load_next_data(self): | |
next_data = {} | |
next_path = os.path.join("data", "next") | |
for category in os.listdir(next_path): | |
category_path = os.path.join(next_path, category) | |
if os.path.isdir(category_path): | |
next_data[category] = {} | |
for file in os.listdir(category_path): | |
if file.endswith(".json"): | |
file_path = os.path.join(category_path, file) | |
with open(file_path, "r", encoding="utf-8") as f: | |
json_data = json.load(f) | |
next_data[category][file[:-5]] = json_data | |
return next_data | |
def process_next_data(self, prompt, separator, category, field, value): | |
if category in self.next_data and field in self.next_data[category]: | |
field_data = self.next_data[category][field] | |
if isinstance(field_data, list): | |
items = field_data | |
elif isinstance(field_data, dict): | |
items = field_data.get("items", []) | |
else: | |
return prompt | |
if value == "None": | |
return prompt | |
elif value == "Random": | |
selected_items = [self.rng.choice(items)] | |
elif value == "Multiple Random": | |
count = self.rng.randint(1, 3) | |
selected_items = self.rng.sample(items, min(count, len(items))) | |
else: | |
selected_items = [value] | |
formatted_values = separator.join(selected_items) | |
prompt += f"{separator}{formatted_values}" | |
return prompt | |
def generate_prompt(self, seed, custom, subject, gender, artform, photo_type, body_types, default_tags, roles, hairstyles, | |
additional_details, photography_styles, device, photographer, artist, digital_artform, | |
place, lighting, clothing, composition, pose, background, input_image, next_params): | |
kwargs = locals() | |
del kwargs['self'] | |
del kwargs['next_params'] | |
seed = kwargs.get("seed", 0) | |
if seed is not None: | |
self.rng = random.Random(seed) | |
components = [] | |
custom = kwargs.get("custom", "") | |
if custom: | |
components.append(custom) | |
is_photographer = kwargs.get("artform", "").lower() == "photography" or ( | |
kwargs.get("artform", "").lower() == "random" | |
and self.rng.choice([True, False]) | |
) | |
subject = kwargs.get("subject", "") | |
gender = kwargs.get("gender", "female") | |
if is_photographer: | |
selected_photo_style = self.get_choice(kwargs.get("photography_styles", ""), PHOTOGRAPHY_STYLES) | |
if not selected_photo_style: | |
selected_photo_style = "photography" | |
components.append(selected_photo_style) | |
if kwargs.get("photography_style", "") != "disabled" and kwargs.get("default_tags", "") != "disabled" or subject != "": | |
components.append(" of") | |
default_tags = kwargs.get("default_tags", "random") | |
body_type = kwargs.get("body_types", "") | |
if not subject: | |
if default_tags == "random": | |
if body_type != "disabled" and body_type != "random": | |
selected_subject = self.get_choice(kwargs.get("default_tags", ""), FEMALE_DEFAULT_TAGS if gender == "female" else MALE_DEFAULT_TAGS).replace("a ", "").replace("an ", "") | |
components.append("a ") | |
components.append(body_type) | |
components.append(selected_subject) | |
elif body_type == "disabled": | |
selected_subject = self.get_choice(kwargs.get("default_tags", ""), FEMALE_DEFAULT_TAGS if gender == "female" else MALE_DEFAULT_TAGS) | |
components.append(selected_subject) | |
else: | |
body_type = self.get_choice(body_type, FEMALE_BODY_TYPES if gender == "female" else MALE_BODY_TYPES) | |
components.append("a ") | |
components.append(body_type) | |
selected_subject = self.get_choice(kwargs.get("default_tags", ""), FEMALE_DEFAULT_TAGS if gender == "female" else MALE_DEFAULT_TAGS).replace("a ", "").replace("an ", "") | |
components.append(selected_subject) | |
elif default_tags == "disabled": | |
pass | |
else: | |
components.append(default_tags) | |
else: | |
if body_type != "disabled" and body_type != "random": | |
components.append("a ") | |
components.append(body_type) | |
elif body_type == "disabled": | |
pass | |
else: | |
body_type = self.get_choice(body_type, FEMALE_BODY_TYPES if gender == "female" else MALE_BODY_TYPES) | |
components.append("a ") | |
components.append(body_type) | |
components.append(subject) | |
params = [ | |
("roles", ROLES), | |
("hairstyles", HAIRSTYLES), | |
("additional_details", FEMALE_ADDITIONAL_DETAILS if gender == "female" else MALE_ADDITIONAL_DETAILS), | |
] | |
for param in params: | |
components.append(self.get_choice(kwargs.get(param[0], ""), param[1])) | |
for i in reversed(range(len(components))): | |
if components[i] in PLACE: | |
components[i] += ", " | |
break | |
if kwargs.get("clothing", "") != "disabled" and kwargs.get("clothing", "") != "random": | |
components.append(", dressed in ") | |
clothing = kwargs.get("clothing", "") | |
components.append(clothing) | |
elif kwargs.get("clothing", "") == "random": | |
components.append(", dressed in ") | |
clothing = self.get_choice(kwargs.get("clothing", ""), FEMALE_CLOTHING if gender == "female" else MALE_CLOTHING) | |
components.append(clothing) | |
if kwargs.get("composition", "") != "disabled" and kwargs.get("composition", "") != "random": | |
components.append(", ") | |
composition = kwargs.get("composition", "") | |
components.append(composition) | |
elif kwargs.get("composition", "") == "random": | |
components.append(", ") | |
composition = self.get_choice(kwargs.get("composition", ""), COMPOSITION) | |
components.append(composition) | |
if kwargs.get("pose", "") != "disabled" and kwargs.get("pose", "") != "random": | |
components.append(", ") | |
pose = kwargs.get("pose", "") | |
components.append(pose) | |
elif kwargs.get("pose", "") == "random": | |
components.append(", ") | |
pose = self.get_choice(kwargs.get("pose", ""), POSE) | |
components.append(pose) | |
components.append("BREAK_CLIPG") | |
if kwargs.get("background", "") != "disabled" and kwargs.get("background", "") != "random": | |
components.append(", ") | |
background = kwargs.get("background", "") | |
components.append(background) | |
elif kwargs.get("background", "") == "random": | |
components.append(", ") | |
background = self.get_choice(kwargs.get("background", ""), BACKGROUND) | |
components.append(background) | |
if kwargs.get("place", "") != "disabled" and kwargs.get("place", "") != "random": | |
components.append(", ") | |
place = kwargs.get("place", "") | |
components.append(place) | |
elif kwargs.get("place", "") == "random": | |
components.append(", ") | |
place = self.get_choice(kwargs.get("place", ""), PLACE) | |
components.append(place + ", ") | |
lighting = kwargs.get("lighting", "").lower() | |
if lighting == "random": | |
selected_lighting = ", ".join(self.rng.sample(LIGHTING, self.rng.randint(2, 5))) | |
components.append(", ") | |
components.append(selected_lighting) | |
elif lighting == "disabled": | |
pass | |
else: | |
components.append(", ") | |
components.append(lighting) | |
components.append("BREAK_CLIPG") | |
components.append("BREAK_CLIPL") | |
if is_photographer: | |
if kwargs.get("photo_type", "") != "disabled": | |
photo_type_choice = self.get_choice(kwargs.get("photo_type", ""), PHOTO_TYPE) | |
if photo_type_choice and photo_type_choice != "random" and photo_type_choice != "disabled": | |
random_value = round(self.rng.uniform(1.1, 1.5), 1) | |
components.append(f", ({photo_type_choice}:{random_value}), ") | |
params = [ | |
("device", DEVICE), | |
("photographer", PHOTOGRAPHER), | |
] | |
components.extend([self.get_choice(kwargs.get(param[0], ""), param[1]) for param in params]) | |
if kwargs.get("device", "") != "disabled": | |
components[-2] = f", shot on {components[-2]}" | |
if kwargs.get("photographer", "") != "disabled": | |
components[-1] = f", photo by {components[-1]}" | |
else: | |
digital_artform_choice = self.get_choice(kwargs.get("digital_artform", ""), DIGITAL_ARTFORM) | |
if digital_artform_choice: | |
components.append(f"{digital_artform_choice}") | |
if kwargs.get("artist", "") != "disabled": | |
components.append(f"by {self.get_choice(kwargs.get('artist', ''), ARTIST)}") | |
components.append("BREAK_CLIPL") | |
prompt = " ".join(components) | |
prompt = re.sub(" +", " ", prompt) | |
replaced = prompt.replace("of as", "of") | |
replaced = self.clean_consecutive_commas(replaced) | |
# Process next_params | |
next_prompts = [] | |
for category, fields in next_params.items(): | |
for field, value in fields.items(): | |
next_prompt = self.process_next_data("", ", ", category, field, value) | |
if next_prompt: | |
next_prompts.append(next_prompt.strip()) | |
# Combine main prompt with next prompts | |
combined_prompt = replaced + " " + " ".join(next_prompts) | |
combined_prompt = self.clean_consecutive_commas(combined_prompt) | |
# Return the processed string including next prompts | |
return self.process_string(combined_prompt.strip(), seed) | |
def add_caption_to_prompt(self, prompt, caption): | |
if caption: | |
return f"{prompt}, {caption}" | |
return prompt |