import gradio as gr import json import numpy as np import pandas as pd import operator pd.options.plotting.backend = "plotly" TITLE = "Diffusion Professions Cluster Explorer" clusters_dicts = dict( (num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json"))) for num_cl in [12, 24, 48] ) prompts = pd.read_csv("promptsadjectives.csv") professions = list(sorted([p.lower() for p in prompts["Occupation-Noun"].tolist()])) models = { "All": "All Models", "SD_14": "Stable Diffusion 1.4", "SD_2": "Stable Diffusion 2", "DallE": "Dall-E 2", } df_models = { "All Models": "All", "Stable Diffusion 1.4": "SD_14", "Stable Diffusion 2": "SD_2", "Dall-E 2": "DallE", } def describe_cluster(num_clusters, block="label"): cl_dict = clusters_dicts[num_clusters] labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1)) labels_values.reverse() total = float(sum(cl_dict.values())) lv_prcnt = list( (item[0], round(item[1] * 100 / total, 0)) for item in labels_values ) top_label = lv_prcnt[0][0] description_string = ( "The most represented %s is %s, making up about %d%% of the cluster." % (to_string(block), to_string(top_label), lv_prcnt[0][1]) ) description_string += "

This is followed by: " for lv in lv_prcnt[1:]: description_string += "
%s: %d%%" % (to_string(lv[0]), lv[1]) description_string += "

" return description_string def make_profession_plot(num_clusters, prof_name): pre_pandas = dict( [ ( models[mod_name], dict( ( f"Cluster {k}", clusters_dicts[num_clusters][mod_name][prof_name][ "cluster_proportions" ][k], ) for k, v in sorted( clusters_dicts[num_clusters]["All"][prof_name][ "cluster_proportions" ].items(), key=lambda x: x[1], reverse=True, ) if v > 0 ), ) for mod_name in models ] ) df = pd.DataFrame.from_dict(pre_pandas) prof_plot = df.plot(kind="bar", barmode="group") return prof_plot def make_profession_table(num_clusters, prof_names, mod_name, max_cols=8): professions_list_clusters = [ ( prof_name, clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ "cluster_proportions" ], ) for prof_name in prof_names ] totals = sorted( [ ( k, sum( prof_clusters[str(k)] for _, prof_clusters in professions_list_clusters ), ) for k in range(num_clusters) ], key=lambda x: x[1], reverse=True, )[:max_cols] prof_list_pre_pandas = [ dict( [ ("Profession", prof_name), ( "Entropy", clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ "entropy" ], ), ( "Labor Women", clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ "labor_fm" ][0], ), ("", ""), ] + [(f"Cluster {k}", prof_clusters[str(k)]) for k, v in totals if v > 0] ) for prof_name, prof_clusters in professions_list_clusters ] clusters_df = pd.DataFrame.from_dict(prof_list_pre_pandas) return [c[0] for c in totals], ( clusters_df.style.background_gradient( axis=None, vmin=0, vmax=100, cmap="YlGnBu" ) .format(precision=1) .to_html() ) def show_examplars(num_clusters, prof_name, mod_name, cl_id): # TODO: show the actual images examplars_dict = clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ "cluster_examplars" ][str(cl_id)] return json.dumps(examplars_dict) with gr.Blocks(title=TITLE) as demo: gr.Markdown("# 🤗 Diffusion Cluster Explorer") gr.Markdown("description will go here") with gr.Tab("Professions Overview"): gr.Markdown("TODO") with gr.Row(): with gr.Column(scale=1): gr.Markdown("Select the parameters here:") num_clusters = gr.Radio( [12, 24, 48], value=12, label="How many clusters do you want to use to represent identities?", ) model_choices = gr.Dropdown( [ "All Models", "Stable Diffusion 1.4", "Stable Diffusion 2", "Dall-E 2", ], value="All Models", label="Which models do you want to compare?", interactive=True, ) profession_choices_overview = gr.Dropdown( professions, value=["CEO", "director", "social assistant", "social worker"], label="Which professions do you want to compare?", multiselect=True, interactive=True, ) with gr.Column(scale=3): with gr.Row(): table = gr.HTML( label="Profession assignment per cluster", wrap=True ) # clusters = gr.Dataframe(type="array", visible=False, col_count=1) clusters = gr.Textbox(label="clusters", visible=False) demo.load( make_profession_table, [num_clusters, profession_choices_overview, model_choices], [clusters, table], queue=False, ) for var in [num_clusters, model_choices, profession_choices_overview]: var.change( make_profession_table, [num_clusters, profession_choices_overview, model_choices], [clusters, table], queue=False, ) with gr.Tab("Profession Focus"): with gr.Row(): with gr.Column(): gr.Markdown("Select profession to visualize here:") num_clusters_focus = gr.Radio( [12, 24, 48], value=12, label="How many clusters do you want to use to represent identities?", ) profession_choice_focus = gr.Dropdown( choices=professions, value="social worker", label="Select profession:", ) gr.Markdown( "You can show examples of profession images assigned to each cluster:" ) model_choices_focus = gr.Dropdown( [ "All Models", "Stable Diffusion 1.4", "Stable Diffusion 2", "Dall-E 2", ], value="All Models", label="Select generation model:", interactive=True, ) cluster_id_focus = gr.Dropdown( choices=[i for i in range(num_clusters_focus.value)], value=0, label="Select cluster to visualize:", ) with gr.Column(): plot = gr.Plot( label=f"Makeup of the cluster assignments for profession {profession_choice_focus}" ) demo.load( make_profession_plot, [num_clusters_focus, profession_choice_focus], plot, queue=False, ) for var in [num_clusters_focus, profession_choice_focus]: var.change( make_profession_plot, [num_clusters_focus, profession_choice_focus], plot, queue=False, ) with gr.Row(): examplars_plot = ( gr.JSON() ) # TODO: turn this into a plot with the actual images demo.load( show_examplars, [ num_clusters_focus, profession_choice_focus, model_choices_focus, cluster_id_focus, ], examplars_plot, queue=False, ) for var in [model_choices_focus, cluster_id_focus]: var.change( show_examplars, [ num_clusters_focus, profession_choice_focus, model_choices_focus, cluster_id_focus, ], examplars_plot, queue=False, ) if __name__ == "__main__": demo.queue().launch(debug=True)