Spaces:
Runtime error
Runtime error
yjernite
commited on
Commit
β’
ebbd0d6
1
Parent(s):
0593856
print on launch
Browse files
app.py
CHANGED
@@ -7,7 +7,7 @@ import operator
|
|
7 |
pd.options.plotting.backend = "plotly"
|
8 |
|
9 |
|
10 |
-
TITLE = "Diffusion
|
11 |
|
12 |
clusters_dicts = dict(
|
13 |
(num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json")))
|
@@ -30,22 +30,27 @@ df_models = {
|
|
30 |
"Dall-E 2": "DallE",
|
31 |
}
|
32 |
|
|
|
33 |
def describe_cluster(num_clusters, block="label"):
|
34 |
-
cl_dict= clusters_dicts[num_clusters]
|
35 |
labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1))
|
36 |
labels_values.reverse()
|
37 |
total = float(sum(cl_dict.values()))
|
38 |
lv_prcnt = list(
|
39 |
-
(item[0], round(item[1] * 100 / total, 0)) for item in labels_values
|
|
|
40 |
top_label = lv_prcnt[0][0]
|
41 |
-
description_string =
|
42 |
-
|
|
|
|
|
43 |
description_string += "<p>This is followed by: "
|
44 |
for lv in lv_prcnt[1:]:
|
45 |
description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1])
|
46 |
description_string += "</p>"
|
47 |
return description_string
|
48 |
|
|
|
49 |
def make_profession_plot(num_clusters, prof_name):
|
50 |
pre_pandas = dict(
|
51 |
[
|
@@ -123,19 +128,24 @@ def make_profession_table(num_clusters, prof_names, mod_name, max_cols=8):
|
|
123 |
for prof_name, prof_clusters in professions_list_clusters
|
124 |
]
|
125 |
clusters_df = pd.DataFrame.from_dict(prof_list_pre_pandas)
|
126 |
-
return [c[0] for c in totals], (
|
127 |
-
|
128 |
-
|
129 |
-
.format(precision=1)
|
130 |
-
.to_html()
|
131 |
)
|
|
|
|
|
|
|
|
|
132 |
|
133 |
def show_examplars(num_clusters, prof_name, mod_name, cl_id):
|
134 |
# TODO: show the actual images
|
135 |
-
examplars_dict = clusters_dicts[num_clusters][df_models[mod_name]][prof_name][
|
|
|
|
|
136 |
return json.dumps(examplars_dict)
|
137 |
|
138 |
-
|
|
|
139 |
gr.Markdown("# π€ Diffusion Cluster Explorer")
|
140 |
gr.Markdown("description will go here")
|
141 |
with gr.Tab("Professions Overview"):
|
@@ -167,12 +177,18 @@ with gr.Blocks() as demo:
|
|
167 |
interactive=True,
|
168 |
)
|
169 |
with gr.Column(scale=3):
|
170 |
-
with gr.Row():
|
171 |
table = gr.HTML(
|
172 |
label="Profession assignment per cluster", wrap=True
|
173 |
-
)
|
174 |
-
#clusters = gr.Dataframe(type="array", visible=False, col_count=1)
|
175 |
-
clusters = gr.Textbox(label=
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
for var in [num_clusters, model_choices, profession_choices_overview]:
|
177 |
var.change(
|
178 |
make_profession_table,
|
@@ -181,8 +197,6 @@ with gr.Blocks() as demo:
|
|
181 |
queue=False,
|
182 |
)
|
183 |
|
184 |
-
|
185 |
-
|
186 |
with gr.Tab("Profession Focus"):
|
187 |
with gr.Row():
|
188 |
with gr.Column():
|
@@ -197,7 +211,9 @@ with gr.Blocks() as demo:
|
|
197 |
value="social worker",
|
198 |
label="Select profession:",
|
199 |
)
|
200 |
-
gr.Markdown(
|
|
|
|
|
201 |
model_choices_focus = gr.Dropdown(
|
202 |
[
|
203 |
"All Models",
|
@@ -218,6 +234,12 @@ with gr.Blocks() as demo:
|
|
218 |
plot = gr.Plot(
|
219 |
label=f"Makeup of the cluster assignments for profession {profession_choice_focus}"
|
220 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
for var in [num_clusters_focus, profession_choice_focus]:
|
222 |
var.change(
|
223 |
make_profession_plot,
|
@@ -226,14 +248,33 @@ with gr.Blocks() as demo:
|
|
226 |
queue=False,
|
227 |
)
|
228 |
with gr.Row():
|
229 |
-
examplars_plot =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
for var in [model_choices_focus, cluster_id_focus]:
|
231 |
var.change(
|
232 |
show_examplars,
|
233 |
-
[
|
|
|
|
|
|
|
|
|
|
|
234 |
examplars_plot,
|
235 |
queue=False,
|
236 |
)
|
237 |
|
238 |
|
239 |
-
|
|
|
|
7 |
pd.options.plotting.backend = "plotly"
|
8 |
|
9 |
|
10 |
+
TITLE = "Diffusion Professions Cluster Explorer"
|
11 |
|
12 |
clusters_dicts = dict(
|
13 |
(num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json")))
|
|
|
30 |
"Dall-E 2": "DallE",
|
31 |
}
|
32 |
|
33 |
+
|
34 |
def describe_cluster(num_clusters, block="label"):
|
35 |
+
cl_dict = clusters_dicts[num_clusters]
|
36 |
labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1))
|
37 |
labels_values.reverse()
|
38 |
total = float(sum(cl_dict.values()))
|
39 |
lv_prcnt = list(
|
40 |
+
(item[0], round(item[1] * 100 / total, 0)) for item in labels_values
|
41 |
+
)
|
42 |
top_label = lv_prcnt[0][0]
|
43 |
+
description_string = (
|
44 |
+
"<span>The most represented %s is <b>%s</b>, making up about <b>%d%%</b> of the cluster.</span>"
|
45 |
+
% (to_string(block), to_string(top_label), lv_prcnt[0][1])
|
46 |
+
)
|
47 |
description_string += "<p>This is followed by: "
|
48 |
for lv in lv_prcnt[1:]:
|
49 |
description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1])
|
50 |
description_string += "</p>"
|
51 |
return description_string
|
52 |
|
53 |
+
|
54 |
def make_profession_plot(num_clusters, prof_name):
|
55 |
pre_pandas = dict(
|
56 |
[
|
|
|
128 |
for prof_name, prof_clusters in professions_list_clusters
|
129 |
]
|
130 |
clusters_df = pd.DataFrame.from_dict(prof_list_pre_pandas)
|
131 |
+
return [c[0] for c in totals], (
|
132 |
+
clusters_df.style.background_gradient(
|
133 |
+
axis=None, vmin=0, vmax=100, cmap="YlGnBu"
|
|
|
|
|
134 |
)
|
135 |
+
.format(precision=1)
|
136 |
+
.to_html()
|
137 |
+
)
|
138 |
+
|
139 |
|
140 |
def show_examplars(num_clusters, prof_name, mod_name, cl_id):
|
141 |
# TODO: show the actual images
|
142 |
+
examplars_dict = clusters_dicts[num_clusters][df_models[mod_name]][prof_name][
|
143 |
+
"cluster_examplars"
|
144 |
+
][str(cl_id)]
|
145 |
return json.dumps(examplars_dict)
|
146 |
|
147 |
+
|
148 |
+
with gr.Blocks(title=TITLE) as demo:
|
149 |
gr.Markdown("# π€ Diffusion Cluster Explorer")
|
150 |
gr.Markdown("description will go here")
|
151 |
with gr.Tab("Professions Overview"):
|
|
|
177 |
interactive=True,
|
178 |
)
|
179 |
with gr.Column(scale=3):
|
180 |
+
with gr.Row():
|
181 |
table = gr.HTML(
|
182 |
label="Profession assignment per cluster", wrap=True
|
183 |
+
)
|
184 |
+
# clusters = gr.Dataframe(type="array", visible=False, col_count=1)
|
185 |
+
clusters = gr.Textbox(label="clusters", visible=False)
|
186 |
+
demo.load(
|
187 |
+
make_profession_table,
|
188 |
+
[num_clusters, profession_choices_overview, model_choices],
|
189 |
+
[clusters, table],
|
190 |
+
queue=False,
|
191 |
+
)
|
192 |
for var in [num_clusters, model_choices, profession_choices_overview]:
|
193 |
var.change(
|
194 |
make_profession_table,
|
|
|
197 |
queue=False,
|
198 |
)
|
199 |
|
|
|
|
|
200 |
with gr.Tab("Profession Focus"):
|
201 |
with gr.Row():
|
202 |
with gr.Column():
|
|
|
211 |
value="social worker",
|
212 |
label="Select profession:",
|
213 |
)
|
214 |
+
gr.Markdown(
|
215 |
+
"You can show examples of profession images assigned to each cluster:"
|
216 |
+
)
|
217 |
model_choices_focus = gr.Dropdown(
|
218 |
[
|
219 |
"All Models",
|
|
|
234 |
plot = gr.Plot(
|
235 |
label=f"Makeup of the cluster assignments for profession {profession_choice_focus}"
|
236 |
)
|
237 |
+
demo.load(
|
238 |
+
make_profession_plot,
|
239 |
+
[num_clusters_focus, profession_choice_focus],
|
240 |
+
plot,
|
241 |
+
queue=False,
|
242 |
+
)
|
243 |
for var in [num_clusters_focus, profession_choice_focus]:
|
244 |
var.change(
|
245 |
make_profession_plot,
|
|
|
248 |
queue=False,
|
249 |
)
|
250 |
with gr.Row():
|
251 |
+
examplars_plot = (
|
252 |
+
gr.JSON()
|
253 |
+
) # TODO: turn this into a plot with the actual images
|
254 |
+
demo.load(
|
255 |
+
show_examplars,
|
256 |
+
[
|
257 |
+
num_clusters_focus,
|
258 |
+
profession_choice_focus,
|
259 |
+
model_choices_focus,
|
260 |
+
cluster_id_focus,
|
261 |
+
],
|
262 |
+
examplars_plot,
|
263 |
+
queue=False,
|
264 |
+
)
|
265 |
for var in [model_choices_focus, cluster_id_focus]:
|
266 |
var.change(
|
267 |
show_examplars,
|
268 |
+
[
|
269 |
+
num_clusters_focus,
|
270 |
+
profession_choice_focus,
|
271 |
+
model_choices_focus,
|
272 |
+
cluster_id_focus,
|
273 |
+
],
|
274 |
examplars_plot,
|
275 |
queue=False,
|
276 |
)
|
277 |
|
278 |
|
279 |
+
if __name__ == "__main__":
|
280 |
+
demo.queue().launch(debug=True)
|