Spaces:
Runtime error
Runtime error
import gradio as gr | |
import json | |
import numpy as np | |
import pandas as pd | |
import operator | |
pd.options.plotting.backend = "plotly" | |
TITLE = "Diffusion Professions Cluster Explorer" | |
clusters_dicts = dict( | |
(num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json"))) | |
for num_cl in [12, 24, 48] | |
) | |
prompts = pd.read_csv("promptsadjectives.csv") | |
professions = list(sorted([p.lower() for p in prompts["Occupation-Noun"].tolist()])) | |
models = { | |
"All": "All Models", | |
"SD_14": "Stable Diffusion 1.4", | |
"SD_2": "Stable Diffusion 2", | |
"DallE": "Dall-E 2", | |
} | |
df_models = { | |
"All Models": "All", | |
"Stable Diffusion 1.4": "SD_14", | |
"Stable Diffusion 2": "SD_2", | |
"Dall-E 2": "DallE", | |
} | |
def describe_cluster(num_clusters, block="label"): | |
cl_dict = clusters_dicts[num_clusters] | |
labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1)) | |
labels_values.reverse() | |
total = float(sum(cl_dict.values())) | |
lv_prcnt = list( | |
(item[0], round(item[1] * 100 / total, 0)) for item in labels_values | |
) | |
top_label = lv_prcnt[0][0] | |
description_string = ( | |
"<span>The most represented %s is <b>%s</b>, making up about <b>%d%%</b> of the cluster.</span>" | |
% (to_string(block), to_string(top_label), lv_prcnt[0][1]) | |
) | |
description_string += "<p>This is followed by: " | |
for lv in lv_prcnt[1:]: | |
description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1]) | |
description_string += "</p>" | |
return description_string | |
def make_profession_plot(num_clusters, prof_name): | |
pre_pandas = dict( | |
[ | |
( | |
models[mod_name], | |
dict( | |
( | |
f"Cluster {k}", | |
clusters_dicts[num_clusters][mod_name][prof_name][ | |
"cluster_proportions" | |
][k], | |
) | |
for k, v in sorted( | |
clusters_dicts[num_clusters]["All"][prof_name][ | |
"cluster_proportions" | |
].items(), | |
key=lambda x: x[1], | |
reverse=True, | |
) | |
if v > 0 | |
), | |
) | |
for mod_name in models | |
] | |
) | |
df = pd.DataFrame.from_dict(pre_pandas) | |
prof_plot = df.plot(kind="bar", barmode="group") | |
return prof_plot | |
def make_profession_table(num_clusters, prof_names, mod_name, max_cols=8): | |
professions_list_clusters = [ | |
( | |
prof_name, | |
clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ | |
"cluster_proportions" | |
], | |
) | |
for prof_name in prof_names | |
] | |
totals = sorted( | |
[ | |
( | |
k, | |
sum( | |
prof_clusters[str(k)] | |
for _, prof_clusters in professions_list_clusters | |
), | |
) | |
for k in range(num_clusters) | |
], | |
key=lambda x: x[1], | |
reverse=True, | |
)[:max_cols] | |
prof_list_pre_pandas = [ | |
dict( | |
[ | |
("Profession", prof_name), | |
( | |
"Entropy", | |
clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ | |
"entropy" | |
], | |
), | |
( | |
"Labor Women", | |
clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ | |
"labor_fm" | |
][0], | |
), | |
("", ""), | |
] | |
+ [(f"Cluster {k}", prof_clusters[str(k)]) for k, v in totals if v > 0] | |
) | |
for prof_name, prof_clusters in professions_list_clusters | |
] | |
clusters_df = pd.DataFrame.from_dict(prof_list_pre_pandas) | |
return [c[0] for c in totals], ( | |
clusters_df.style.background_gradient( | |
axis=None, vmin=0, vmax=100, cmap="YlGnBu" | |
) | |
.format(precision=1) | |
.to_html() | |
) | |
def show_examplars(num_clusters, prof_name, mod_name, cl_id): | |
# TODO: show the actual images | |
examplars_dict = clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ | |
"cluster_examplars" | |
][str(cl_id)] | |
return json.dumps(examplars_dict) | |
with gr.Blocks(title=TITLE) as demo: | |
gr.Markdown("# π€ Diffusion Cluster Explorer") | |
gr.Markdown("description will go here") | |
with gr.Tab("Professions Overview"): | |
gr.Markdown("TODO") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("Select the parameters here:") | |
num_clusters = gr.Radio( | |
[12, 24, 48], | |
value=12, | |
label="How many clusters do you want to use to represent identities?", | |
) | |
model_choices = gr.Dropdown( | |
[ | |
"All Models", | |
"Stable Diffusion 1.4", | |
"Stable Diffusion 2", | |
"Dall-E 2", | |
], | |
value="All Models", | |
label="Which models do you want to compare?", | |
interactive=True, | |
) | |
profession_choices_overview = gr.Dropdown( | |
professions, | |
value=["CEO", "director", "social assistant", "social worker"], | |
label="Which professions do you want to compare?", | |
multiselect=True, | |
interactive=True, | |
) | |
with gr.Column(scale=3): | |
with gr.Row(): | |
table = gr.HTML( | |
label="Profession assignment per cluster", wrap=True | |
) | |
# clusters = gr.Dataframe(type="array", visible=False, col_count=1) | |
clusters = gr.Textbox(label="clusters", visible=False) | |
demo.load( | |
make_profession_table, | |
[num_clusters, profession_choices_overview, model_choices], | |
[clusters, table], | |
queue=False, | |
) | |
for var in [num_clusters, model_choices, profession_choices_overview]: | |
var.change( | |
make_profession_table, | |
[num_clusters, profession_choices_overview, model_choices], | |
[clusters, table], | |
queue=False, | |
) | |
with gr.Tab("Profession Focus"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("Select profession to visualize here:") | |
num_clusters_focus = gr.Radio( | |
[12, 24, 48], | |
value=12, | |
label="How many clusters do you want to use to represent identities?", | |
) | |
profession_choice_focus = gr.Dropdown( | |
choices=professions, | |
value="social worker", | |
label="Select profession:", | |
) | |
gr.Markdown( | |
"You can show examples of profession images assigned to each cluster:" | |
) | |
model_choices_focus = gr.Dropdown( | |
[ | |
"All Models", | |
"Stable Diffusion 1.4", | |
"Stable Diffusion 2", | |
"Dall-E 2", | |
], | |
value="All Models", | |
label="Select generation model:", | |
interactive=True, | |
) | |
cluster_id_focus = gr.Dropdown( | |
choices=[i for i in range(num_clusters_focus.value)], | |
value=0, | |
label="Select cluster to visualize:", | |
) | |
with gr.Column(): | |
plot = gr.Plot( | |
label=f"Makeup of the cluster assignments for profession {profession_choice_focus}" | |
) | |
demo.load( | |
make_profession_plot, | |
[num_clusters_focus, profession_choice_focus], | |
plot, | |
queue=False, | |
) | |
for var in [num_clusters_focus, profession_choice_focus]: | |
var.change( | |
make_profession_plot, | |
[num_clusters_focus, profession_choice_focus], | |
plot, | |
queue=False, | |
) | |
with gr.Row(): | |
examplars_plot = ( | |
gr.JSON() | |
) # TODO: turn this into a plot with the actual images | |
demo.load( | |
show_examplars, | |
[ | |
num_clusters_focus, | |
profession_choice_focus, | |
model_choices_focus, | |
cluster_id_focus, | |
], | |
examplars_plot, | |
queue=False, | |
) | |
for var in [model_choices_focus, cluster_id_focus]: | |
var.change( | |
show_examplars, | |
[ | |
num_clusters_focus, | |
profession_choice_focus, | |
model_choices_focus, | |
cluster_id_focus, | |
], | |
examplars_plot, | |
queue=False, | |
) | |
if __name__ == "__main__": | |
demo.queue().launch(debug=True) | |