yjernite commited on
Commit
138a816
·
1 Parent(s): deadd68

remove extraneous model selection

Browse files
Files changed (1) hide show
  1. app.py +33 -32
app.py CHANGED
@@ -14,14 +14,22 @@ TITLE = "Diffusion Professions Cluster Explorer"
14
  professions_dset = load_from_disk("professions")
15
  professions_df = professions_dset.to_pandas()
16
 
 
17
  def get_image(model, fname):
18
- return professions_dset.select(professions_df[(professions_df["image_path"]==fname) & (professions_df["model"]==model)].index)["image"][0]
 
 
 
 
 
19
 
20
  clusters_dicts = dict(
21
  (num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json")))
22
  for num_cl in [12, 24, 48]
23
  )
24
 
 
 
25
  prompts = pd.read_csv("promptsadjectives.csv")
26
  professions = list(sorted([p.lower() for p in prompts["Occupation-Noun"].tolist()]))
27
  models = {
@@ -60,6 +68,17 @@ def describe_cluster(num_clusters, block="label"):
60
 
61
 
62
  def make_profession_plot(num_clusters, prof_name):
 
 
 
 
 
 
 
 
 
 
 
63
  pre_pandas = dict(
64
  [
65
  (
@@ -71,14 +90,7 @@ def make_profession_plot(num_clusters, prof_name):
71
  "cluster_proportions"
72
  ][k],
73
  )
74
- for k, v in sorted(
75
- clusters_dicts[num_clusters]["All"][prof_name][
76
- "cluster_proportions"
77
- ].items(),
78
- key=lambda x: x[1],
79
- reverse=True,
80
- )
81
- if v > 0
82
  ),
83
  )
84
  for mod_name in models
@@ -86,7 +98,9 @@ def make_profession_plot(num_clusters, prof_name):
86
  )
87
  df = pd.DataFrame.from_dict(pre_pandas)
88
  prof_plot = df.plot(kind="bar", barmode="group")
89
- return prof_plot
 
 
90
 
91
 
92
  def make_profession_table(num_clusters, prof_names, mod_name, max_cols=8):
@@ -145,12 +159,12 @@ def make_profession_table(num_clusters, prof_names, mod_name, max_cols=8):
145
  )
146
 
147
 
148
- def show_examplars(num_clusters, prof_name, mod_name, cl_id):
149
- examplars_dict = clusters_dicts[num_clusters][df_models[mod_name]][prof_name][
150
  "cluster_examplars"
151
  ][str(cl_id)]
152
  l = list(chain(*[examplars_dict[k] for k in examplars_dict]))
153
- return [get_image(model,fname) for _,model,fname in l]
154
 
155
 
156
  with gr.Blocks(title=TITLE) as demo:
@@ -222,17 +236,6 @@ with gr.Blocks(title=TITLE) as demo:
222
  gr.Markdown(
223
  "You can show examples of profession images assigned to each cluster:"
224
  )
225
- model_choices_focus = gr.Dropdown(
226
- [
227
- "All Models",
228
- "Stable Diffusion 1.4",
229
- "Stable Diffusion 2",
230
- "Dall-E 2",
231
- ],
232
- value="All Models",
233
- label="Select generation model:",
234
- interactive=True,
235
- )
236
  cluster_id_focus = gr.Dropdown(
237
  choices=[i for i in range(num_clusters_focus.value)],
238
  value=0,
@@ -245,38 +248,36 @@ with gr.Blocks(title=TITLE) as demo:
245
  demo.load(
246
  make_profession_plot,
247
  [num_clusters_focus, profession_choice_focus],
248
- plot,
249
  queue=False,
250
  )
251
  for var in [num_clusters_focus, profession_choice_focus]:
252
  var.change(
253
  make_profession_plot,
254
  [num_clusters_focus, profession_choice_focus],
255
- plot,
256
  queue=False,
257
  )
258
  with gr.Row():
259
- examplars_plot = (
260
- gr.Gallery().style(grid=9, height="auto")
261
- )
262
  demo.load(
263
  show_examplars,
264
  [
265
  num_clusters_focus,
266
  profession_choice_focus,
267
- model_choices_focus,
268
  cluster_id_focus,
269
  ],
270
  examplars_plot,
271
  queue=False,
272
  )
273
- for var in [model_choices_focus, cluster_id_focus]:
274
  var.change(
275
  show_examplars,
276
  [
277
  num_clusters_focus,
278
  profession_choice_focus,
279
- model_choices_focus,
280
  cluster_id_focus,
281
  ],
282
  examplars_plot,
 
14
  professions_dset = load_from_disk("professions")
15
  professions_df = professions_dset.to_pandas()
16
 
17
+
18
  def get_image(model, fname):
19
+ return professions_dset.select(
20
+ professions_df[
21
+ (professions_df["image_path"] == fname) & (professions_df["model"] == model)
22
+ ].index
23
+ )["image"][0]
24
+
25
 
26
  clusters_dicts = dict(
27
  (num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json")))
28
  for num_cl in [12, 24, 48]
29
  )
30
 
31
+ cluster_summaries_by_size = json.load(open("clusters/cluster_summaries_by_size.json"))
32
+
33
  prompts = pd.read_csv("promptsadjectives.csv")
34
  professions = list(sorted([p.lower() for p in prompts["Occupation-Noun"].tolist()]))
35
  models = {
 
68
 
69
 
70
  def make_profession_plot(num_clusters, prof_name):
71
+ sorted_cl_scores = [
72
+ (k, v)
73
+ for k, v in sorted(
74
+ clusters_dicts[num_clusters]["All"][prof_name][
75
+ "cluster_proportions"
76
+ ].items(),
77
+ key=lambda x: x[1],
78
+ reverse=True,
79
+ )
80
+ if v > 0
81
+ ]
82
  pre_pandas = dict(
83
  [
84
  (
 
90
  "cluster_proportions"
91
  ][k],
92
  )
93
+ for k, _ in sorted_cl_scores
 
 
 
 
 
 
 
94
  ),
95
  )
96
  for mod_name in models
 
98
  )
99
  df = pd.DataFrame.from_dict(pre_pandas)
100
  prof_plot = df.plot(kind="bar", barmode="group")
101
+ return prof_plot, gr.update(
102
+ choices=[k for k, _ in sorted_cl_scores], value=sorted_cl_scores[0][0]
103
+ )
104
 
105
 
106
  def make_profession_table(num_clusters, prof_names, mod_name, max_cols=8):
 
159
  )
160
 
161
 
162
+ def show_examplars(num_clusters, prof_name, cl_id):
163
+ examplars_dict = clusters_dicts[num_clusters]["All"][prof_name][
164
  "cluster_examplars"
165
  ][str(cl_id)]
166
  l = list(chain(*[examplars_dict[k] for k in examplars_dict]))
167
+ return [get_image(model, fname) for _, model, fname in l]
168
 
169
 
170
  with gr.Blocks(title=TITLE) as demo:
 
236
  gr.Markdown(
237
  "You can show examples of profession images assigned to each cluster:"
238
  )
 
 
 
 
 
 
 
 
 
 
 
239
  cluster_id_focus = gr.Dropdown(
240
  choices=[i for i in range(num_clusters_focus.value)],
241
  value=0,
 
248
  demo.load(
249
  make_profession_plot,
250
  [num_clusters_focus, profession_choice_focus],
251
+ [plot, cluster_id_focus],
252
  queue=False,
253
  )
254
  for var in [num_clusters_focus, profession_choice_focus]:
255
  var.change(
256
  make_profession_plot,
257
  [num_clusters_focus, profession_choice_focus],
258
+ [plot, cluster_id_focus],
259
  queue=False,
260
  )
261
  with gr.Row():
262
+ examplars_plot = gr.Gallery(
263
+ label="Profession images assigned to the selected cluster."
264
+ ).style(grid=5, height="auto")
265
  demo.load(
266
  show_examplars,
267
  [
268
  num_clusters_focus,
269
  profession_choice_focus,
 
270
  cluster_id_focus,
271
  ],
272
  examplars_plot,
273
  queue=False,
274
  )
275
+ for var in [cluster_id_focus]:
276
  var.change(
277
  show_examplars,
278
  [
279
  num_clusters_focus,
280
  profession_choice_focus,
 
281
  cluster_id_focus,
282
  ],
283
  examplars_plot,