ludusc commited on
Commit
4098d00
·
1 Parent(s): 78d8811
backend/disentangle_concepts.py CHANGED
@@ -31,7 +31,7 @@ def generate_composite_images(model, z, decision_boundaries, lambdas, latent_spa
31
 
32
  z = torch.from_numpy(z.copy()).to(device)
33
  repetitions = 16
34
- z_0 = z.copy()
35
 
36
  for decision_boundary, lmbd in zip(decision_boundaries, lambdas):
37
  decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
 
31
 
32
  z = torch.from_numpy(z.copy()).to(device)
33
  repetitions = 16
34
+ z_0 = z
35
 
36
  for decision_boundary, lmbd in zip(decision_boundaries, lambdas):
37
  decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
pages/1_Textiles_Disentanglement.py CHANGED
@@ -47,7 +47,14 @@ if 'color_ids' not in st.session_state:
47
  st.session_state.concept_ids = COLORS_LIST[-1]
48
  if 'space_id' not in st.session_state:
49
  st.session_state.space_id = 'W'
50
-
 
 
 
 
 
 
 
51
  # def on_change_random_input():
52
  # st.session_state.image_id = st.session_state.image_id
53
 
@@ -87,9 +94,12 @@ with input_col_2:
87
  color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=0, step=1, value=7)
88
  color_lambda_button = st.form_submit_button('Choose the defined lambda')
89
 
90
- if colors_button:
91
  st.session_state.concept_ids = type_col
92
- st.session_state.space_id = space_id
 
 
 
93
 
94
  with input_col_3:
95
  with st.form('text_form'):
@@ -104,6 +114,10 @@ with input_col_3:
104
  value_lambda = chosen_value_lambda_input.number_input('Lambda:', min_value=0, step=1, key=1)
105
  value_lambda_button = st.form_submit_button('Choose the defined lambda for Value')
106
 
 
 
 
 
107
  # with input_col_4:
108
  # with st.form('Network specifics:'):
109
  # st.write('**Choose a latent space to use**')
@@ -142,7 +156,7 @@ with header_col_2:
142
  color_separation_vector, performance_color = concept_vectors[concept_vectors['color'] == st.session_state.concept_ids].reset_index().loc[0, ['vector', 'score']]
143
  saturation_separation_vector, performance_saturation = concept_vectors[concept_vectors['color'] == 'Saturation'].reset_index().loc[0, ['vector', 'score']]
144
  value_separation_vector, performance_value = concept_vectors[concept_vectors['color'] == 'Value'].reset_index().loc[0, ['vector', 'score']]
145
- st.write(f'Change in {st.session_state.concept_ids} of {np.round(color_lambda, 2)}, in saturation of {np.round(saturation_lambda, 2)}, in value of {np.round(value_lambda, 2)}. - Performance color vector: {performance_color}, saturation vector: {performance_saturation}, value vector: {performance_value}')
146
 
147
  # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
148
 
@@ -157,5 +171,5 @@ with output_col_1:
157
  st.image(img)
158
 
159
  with output_col_2:
160
- image_updated = generate_composite_images(model, original_image_vec, [separation_vector_color, saturation_separation_vector, value_separation_vector], lambdas=[color_lambda, saturation_lambda, value_lambda])
161
  st.image(image_updated)
 
47
  st.session_state.concept_ids = COLORS_LIST[-1]
48
  if 'space_id' not in st.session_state:
49
  st.session_state.space_id = 'W'
50
+ if 'color_lambda' not in st.session_state:
51
+ st.session_state.color_lambda = 7
52
+ if 'saturation_lambda' not in st.session_state:
53
+ st.session_state.saturation_lambda = 0
54
+ if 'value_lambda' not in st.session_state:
55
+ st.session_state.value_lambda = 0
56
+
57
+
58
  # def on_change_random_input():
59
  # st.session_state.image_id = st.session_state.image_id
60
 
 
94
  color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=0, step=1, value=7)
95
  color_lambda_button = st.form_submit_button('Choose the defined lambda')
96
 
97
+ if colors_button or color_lambda_button:
98
  st.session_state.concept_ids = type_col
99
+ st.session_state.color_lambda = color_lambda
100
+
101
+
102
+
103
 
104
  with input_col_3:
105
  with st.form('text_form'):
 
114
  value_lambda = chosen_value_lambda_input.number_input('Lambda:', min_value=0, step=1, key=1)
115
  value_lambda_button = st.form_submit_button('Choose the defined lambda for Value')
116
 
117
+ if saturation_lambda_button or value_lambda_button:
118
+ st.session_state.saturation_lambda = int(saturation_lambda)
119
+ st.session_state.value_lambda = int(value_lambda)
120
+
121
  # with input_col_4:
122
  # with st.form('Network specifics:'):
123
  # st.write('**Choose a latent space to use**')
 
156
  color_separation_vector, performance_color = concept_vectors[concept_vectors['color'] == st.session_state.concept_ids].reset_index().loc[0, ['vector', 'score']]
157
  saturation_separation_vector, performance_saturation = concept_vectors[concept_vectors['color'] == 'Saturation'].reset_index().loc[0, ['vector', 'score']]
158
  value_separation_vector, performance_value = concept_vectors[concept_vectors['color'] == 'Value'].reset_index().loc[0, ['vector', 'score']]
159
+ st.write(f'Change in {st.session_state.concept_ids} of {np.round(st.session_state.color_lambda, 2)}, in saturation of {np.round(st.session_state.saturation_lambda, 2)}, in value of {np.round(st.session_state.value_lambda, 2)}. - Performance color vector: {performance_color}, saturation vector: {performance_saturation/100}, value vector: {performance_value/100}')
160
 
161
  # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
162
 
 
171
  st.image(img)
172
 
173
  with output_col_2:
174
+ image_updated = generate_composite_images(model, original_image_vec, [color_separation_vector, saturation_separation_vector, value_separation_vector], lambdas=[st.session_state.color_lambda, st.session_state.saturation_lambda, st.session_state.value_lambda])
175
  st.image(image_updated)