yjernite
print on launch
ebbd0d6
raw
history blame
9.7 kB
import gradio as gr
import json
import numpy as np
import pandas as pd
import operator
pd.options.plotting.backend = "plotly"
TITLE = "Diffusion Professions Cluster Explorer"
clusters_dicts = dict(
(num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json")))
for num_cl in [12, 24, 48]
)
prompts = pd.read_csv("promptsadjectives.csv")
professions = list(sorted([p.lower() for p in prompts["Occupation-Noun"].tolist()]))
models = {
"All": "All Models",
"SD_14": "Stable Diffusion 1.4",
"SD_2": "Stable Diffusion 2",
"DallE": "Dall-E 2",
}
df_models = {
"All Models": "All",
"Stable Diffusion 1.4": "SD_14",
"Stable Diffusion 2": "SD_2",
"Dall-E 2": "DallE",
}
def describe_cluster(num_clusters, block="label"):
cl_dict = clusters_dicts[num_clusters]
labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1))
labels_values.reverse()
total = float(sum(cl_dict.values()))
lv_prcnt = list(
(item[0], round(item[1] * 100 / total, 0)) for item in labels_values
)
top_label = lv_prcnt[0][0]
description_string = (
"<span>The most represented %s is <b>%s</b>, making up about <b>%d%%</b> of the cluster.</span>"
% (to_string(block), to_string(top_label), lv_prcnt[0][1])
)
description_string += "<p>This is followed by: "
for lv in lv_prcnt[1:]:
description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1])
description_string += "</p>"
return description_string
def make_profession_plot(num_clusters, prof_name):
pre_pandas = dict(
[
(
models[mod_name],
dict(
(
f"Cluster {k}",
clusters_dicts[num_clusters][mod_name][prof_name][
"cluster_proportions"
][k],
)
for k, v in sorted(
clusters_dicts[num_clusters]["All"][prof_name][
"cluster_proportions"
].items(),
key=lambda x: x[1],
reverse=True,
)
if v > 0
),
)
for mod_name in models
]
)
df = pd.DataFrame.from_dict(pre_pandas)
prof_plot = df.plot(kind="bar", barmode="group")
return prof_plot
def make_profession_table(num_clusters, prof_names, mod_name, max_cols=8):
professions_list_clusters = [
(
prof_name,
clusters_dicts[num_clusters][df_models[mod_name]][prof_name][
"cluster_proportions"
],
)
for prof_name in prof_names
]
totals = sorted(
[
(
k,
sum(
prof_clusters[str(k)]
for _, prof_clusters in professions_list_clusters
),
)
for k in range(num_clusters)
],
key=lambda x: x[1],
reverse=True,
)[:max_cols]
prof_list_pre_pandas = [
dict(
[
("Profession", prof_name),
(
"Entropy",
clusters_dicts[num_clusters][df_models[mod_name]][prof_name][
"entropy"
],
),
(
"Labor Women",
clusters_dicts[num_clusters][df_models[mod_name]][prof_name][
"labor_fm"
][0],
),
("", ""),
]
+ [(f"Cluster {k}", prof_clusters[str(k)]) for k, v in totals if v > 0]
)
for prof_name, prof_clusters in professions_list_clusters
]
clusters_df = pd.DataFrame.from_dict(prof_list_pre_pandas)
return [c[0] for c in totals], (
clusters_df.style.background_gradient(
axis=None, vmin=0, vmax=100, cmap="YlGnBu"
)
.format(precision=1)
.to_html()
)
def show_examplars(num_clusters, prof_name, mod_name, cl_id):
# TODO: show the actual images
examplars_dict = clusters_dicts[num_clusters][df_models[mod_name]][prof_name][
"cluster_examplars"
][str(cl_id)]
return json.dumps(examplars_dict)
with gr.Blocks(title=TITLE) as demo:
gr.Markdown("# πŸ€— Diffusion Cluster Explorer")
gr.Markdown("description will go here")
with gr.Tab("Professions Overview"):
gr.Markdown("TODO")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("Select the parameters here:")
num_clusters = gr.Radio(
[12, 24, 48],
value=12,
label="How many clusters do you want to use to represent identities?",
)
model_choices = gr.Dropdown(
[
"All Models",
"Stable Diffusion 1.4",
"Stable Diffusion 2",
"Dall-E 2",
],
value="All Models",
label="Which models do you want to compare?",
interactive=True,
)
profession_choices_overview = gr.Dropdown(
professions,
value=["CEO", "director", "social assistant", "social worker"],
label="Which professions do you want to compare?",
multiselect=True,
interactive=True,
)
with gr.Column(scale=3):
with gr.Row():
table = gr.HTML(
label="Profession assignment per cluster", wrap=True
)
# clusters = gr.Dataframe(type="array", visible=False, col_count=1)
clusters = gr.Textbox(label="clusters", visible=False)
demo.load(
make_profession_table,
[num_clusters, profession_choices_overview, model_choices],
[clusters, table],
queue=False,
)
for var in [num_clusters, model_choices, profession_choices_overview]:
var.change(
make_profession_table,
[num_clusters, profession_choices_overview, model_choices],
[clusters, table],
queue=False,
)
with gr.Tab("Profession Focus"):
with gr.Row():
with gr.Column():
gr.Markdown("Select profession to visualize here:")
num_clusters_focus = gr.Radio(
[12, 24, 48],
value=12,
label="How many clusters do you want to use to represent identities?",
)
profession_choice_focus = gr.Dropdown(
choices=professions,
value="social worker",
label="Select profession:",
)
gr.Markdown(
"You can show examples of profession images assigned to each cluster:"
)
model_choices_focus = gr.Dropdown(
[
"All Models",
"Stable Diffusion 1.4",
"Stable Diffusion 2",
"Dall-E 2",
],
value="All Models",
label="Select generation model:",
interactive=True,
)
cluster_id_focus = gr.Dropdown(
choices=[i for i in range(num_clusters_focus.value)],
value=0,
label="Select cluster to visualize:",
)
with gr.Column():
plot = gr.Plot(
label=f"Makeup of the cluster assignments for profession {profession_choice_focus}"
)
demo.load(
make_profession_plot,
[num_clusters_focus, profession_choice_focus],
plot,
queue=False,
)
for var in [num_clusters_focus, profession_choice_focus]:
var.change(
make_profession_plot,
[num_clusters_focus, profession_choice_focus],
plot,
queue=False,
)
with gr.Row():
examplars_plot = (
gr.JSON()
) # TODO: turn this into a plot with the actual images
demo.load(
show_examplars,
[
num_clusters_focus,
profession_choice_focus,
model_choices_focus,
cluster_id_focus,
],
examplars_plot,
queue=False,
)
for var in [model_choices_focus, cluster_id_focus]:
var.change(
show_examplars,
[
num_clusters_focus,
profession_choice_focus,
model_choices_focus,
cluster_id_focus,
],
examplars_plot,
queue=False,
)
if __name__ == "__main__":
demo.queue().launch(debug=True)