Spaces:
Runtime error
Runtime error
updates
Browse files
backend/disentangle_concepts.py
CHANGED
@@ -31,7 +31,7 @@ def generate_composite_images(model, z, decision_boundaries, lambdas, latent_spa
|
|
31 |
|
32 |
z = torch.from_numpy(z.copy()).to(device)
|
33 |
repetitions = 16
|
34 |
-
z_0 = z
|
35 |
|
36 |
for decision_boundary, lmbd in zip(decision_boundaries, lambdas):
|
37 |
decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
|
|
|
31 |
|
32 |
z = torch.from_numpy(z.copy()).to(device)
|
33 |
repetitions = 16
|
34 |
+
z_0 = z
|
35 |
|
36 |
for decision_boundary, lmbd in zip(decision_boundaries, lambdas):
|
37 |
decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
|
pages/1_Textiles_Disentanglement.py
CHANGED
@@ -47,7 +47,14 @@ if 'color_ids' not in st.session_state:
|
|
47 |
st.session_state.concept_ids = COLORS_LIST[-1]
|
48 |
if 'space_id' not in st.session_state:
|
49 |
st.session_state.space_id = 'W'
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
# def on_change_random_input():
|
52 |
# st.session_state.image_id = st.session_state.image_id
|
53 |
|
@@ -87,9 +94,12 @@ with input_col_2:
|
|
87 |
color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=0, step=1, value=7)
|
88 |
color_lambda_button = st.form_submit_button('Choose the defined lambda')
|
89 |
|
90 |
-
if colors_button:
|
91 |
st.session_state.concept_ids = type_col
|
92 |
-
st.session_state.
|
|
|
|
|
|
|
93 |
|
94 |
with input_col_3:
|
95 |
with st.form('text_form'):
|
@@ -104,6 +114,10 @@ with input_col_3:
|
|
104 |
value_lambda = chosen_value_lambda_input.number_input('Lambda:', min_value=0, step=1, key=1)
|
105 |
value_lambda_button = st.form_submit_button('Choose the defined lambda for Value')
|
106 |
|
|
|
|
|
|
|
|
|
107 |
# with input_col_4:
|
108 |
# with st.form('Network specifics:'):
|
109 |
# st.write('**Choose a latent space to use**')
|
@@ -142,7 +156,7 @@ with header_col_2:
|
|
142 |
color_separation_vector, performance_color = concept_vectors[concept_vectors['color'] == st.session_state.concept_ids].reset_index().loc[0, ['vector', 'score']]
|
143 |
saturation_separation_vector, performance_saturation = concept_vectors[concept_vectors['color'] == 'Saturation'].reset_index().loc[0, ['vector', 'score']]
|
144 |
value_separation_vector, performance_value = concept_vectors[concept_vectors['color'] == 'Value'].reset_index().loc[0, ['vector', 'score']]
|
145 |
-
st.write(f'Change in {st.session_state.concept_ids} of {np.round(color_lambda, 2)}, in saturation of {np.round(saturation_lambda, 2)}, in value of {np.round(value_lambda, 2)}. - Performance color vector: {performance_color}, saturation vector: {performance_saturation}, value vector: {performance_value}')
|
146 |
|
147 |
# ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
|
148 |
|
@@ -157,5 +171,5 @@ with output_col_1:
|
|
157 |
st.image(img)
|
158 |
|
159 |
with output_col_2:
|
160 |
-
image_updated = generate_composite_images(model, original_image_vec, [
|
161 |
st.image(image_updated)
|
|
|
47 |
st.session_state.concept_ids = COLORS_LIST[-1]
|
48 |
if 'space_id' not in st.session_state:
|
49 |
st.session_state.space_id = 'W'
|
50 |
+
if 'color_lambda' not in st.session_state:
|
51 |
+
st.session_state.color_lambda = 7
|
52 |
+
if 'saturation_lambda' not in st.session_state:
|
53 |
+
st.session_state.saturation_lambda = 0
|
54 |
+
if 'value_lambda' not in st.session_state:
|
55 |
+
st.session_state.value_lambda = 0
|
56 |
+
|
57 |
+
|
58 |
# def on_change_random_input():
|
59 |
# st.session_state.image_id = st.session_state.image_id
|
60 |
|
|
|
94 |
color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=0, step=1, value=7)
|
95 |
color_lambda_button = st.form_submit_button('Choose the defined lambda')
|
96 |
|
97 |
+
if colors_button or color_lambda_button:
|
98 |
st.session_state.concept_ids = type_col
|
99 |
+
st.session_state.color_lambda = color_lambda
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
|
104 |
with input_col_3:
|
105 |
with st.form('text_form'):
|
|
|
114 |
value_lambda = chosen_value_lambda_input.number_input('Lambda:', min_value=0, step=1, key=1)
|
115 |
value_lambda_button = st.form_submit_button('Choose the defined lambda for Value')
|
116 |
|
117 |
+
if saturation_lambda_button or value_lambda_button:
|
118 |
+
st.session_state.saturation_lambda = int(saturation_lambda)
|
119 |
+
st.session_state.value_lambda = int(value_lambda)
|
120 |
+
|
121 |
# with input_col_4:
|
122 |
# with st.form('Network specifics:'):
|
123 |
# st.write('**Choose a latent space to use**')
|
|
|
156 |
color_separation_vector, performance_color = concept_vectors[concept_vectors['color'] == st.session_state.concept_ids].reset_index().loc[0, ['vector', 'score']]
|
157 |
saturation_separation_vector, performance_saturation = concept_vectors[concept_vectors['color'] == 'Saturation'].reset_index().loc[0, ['vector', 'score']]
|
158 |
value_separation_vector, performance_value = concept_vectors[concept_vectors['color'] == 'Value'].reset_index().loc[0, ['vector', 'score']]
|
159 |
+
st.write(f'Change in {st.session_state.concept_ids} of {np.round(st.session_state.color_lambda, 2)}, in saturation of {np.round(st.session_state.saturation_lambda, 2)}, in value of {np.round(st.session_state.value_lambda, 2)}. - Performance color vector: {performance_color}, saturation vector: {performance_saturation/100}, value vector: {performance_value/100}')
|
160 |
|
161 |
# ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
|
162 |
|
|
|
171 |
st.image(img)
|
172 |
|
173 |
with output_col_2:
|
174 |
+
image_updated = generate_composite_images(model, original_image_vec, [color_separation_vector, saturation_separation_vector, value_separation_vector], lambdas=[st.session_state.color_lambda, st.session_state.saturation_lambda, st.session_state.value_lambda])
|
175 |
st.image(image_updated)
|