Spaces:
Runtime error
Runtime error
added disentanglement of W vector
Browse files- backend/disentangle_concepts.py +85 -10
- data/annotated_files/{annotations_seeds0000-1000.pkl → seeds0000-50000.pkl} +2 -2
- data/annotated_files/sim_seeds0000-10000.csv +0 -3
- data/annotated_files/{annotations_parallel_seeds0000-10000.pkl → sim_seeds0000-50000.csv} +2 -2
- data/model_files/pytorch_model.bin +0 -3
- pages/1_Disentanglement.py +21 -8
- pages/2_Concepts_comparison.py +13 -4
- view_predictions.ipynb +0 -0
backend/disentangle_concepts.py
CHANGED
@@ -6,7 +6,7 @@ import torch
|
|
6 |
from umap import UMAP
|
7 |
import PIL
|
8 |
|
9 |
-
def get_separation_space(type_bin, annotations, df, samples=100, method='LR', C=0.1):
|
10 |
"""
|
11 |
The get_separation_space function takes in a type_bin, annotations, and df.
|
12 |
It then samples 100 of the most representative abstracts for that type_bin and 100 of the least representative abstracts for that type_bin.
|
@@ -22,10 +22,16 @@ def get_separation_space(type_bin, annotations, df, samples=100, method='LR', C=
|
|
22 |
:return: The weights of the linear classifier
|
23 |
:doc-author: Trelent
|
24 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
abstracts = np.array([float(ann) for ann in df[type_bin]])
|
26 |
abstract_idxs = list(np.argsort(abstracts))[:samples]
|
27 |
repr_idxs = list(np.argsort(abstracts))[-samples:]
|
28 |
-
X = np.array([annotations[
|
29 |
X = X.reshape((2*samples, 512))
|
30 |
y = np.array([1]*samples + [0]*samples)
|
31 |
x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
|
@@ -45,7 +51,7 @@ def get_separation_space(type_bin, annotations, df, samples=100, method='LR', C=
|
|
45 |
return clf.coef_ / np.linalg.norm(clf.coef_), imp_features, imp_nodes, np.round(clf.score(x_val, y_val),2)
|
46 |
|
47 |
|
48 |
-
def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3, count=5):
|
49 |
"""
|
50 |
The regenerate_images function takes a model, z, and decision_boundary as input. It then
|
51 |
constructs an inverse rotation/translation matrix and passes it to the generator. The generator
|
@@ -69,6 +75,7 @@ def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3
|
|
69 |
# Labels.
|
70 |
label = torch.zeros([1, G.c_dim], device=device)
|
71 |
|
|
|
72 |
z = torch.from_numpy(z.copy()).to(device)
|
73 |
decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
|
74 |
|
@@ -84,14 +91,19 @@ def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3
|
|
84 |
#m = make_transform(translate, rotate)
|
85 |
#m = np.linalg.inv(m)
|
86 |
#G.synthesis.input.transform.copy_(torch.from_numpy(m))
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
89 |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
90 |
images.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))
|
91 |
|
92 |
return images, lambdas
|
93 |
|
94 |
-
def generate_original_image(z, model):
|
95 |
"""
|
96 |
The generate_original_image function takes in a latent vector and the model,
|
97 |
and returns an image generated from that latent vector.
|
@@ -106,13 +118,19 @@ def generate_original_image(z, model):
|
|
106 |
G = model.to(device) # type: ignore
|
107 |
# Labels.
|
108 |
label = torch.zeros([1, G.c_dim], device=device)
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
112 |
return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
|
113 |
|
114 |
|
115 |
-
def get_concepts_vectors(concepts, annotations, df, samples=100, method='LR', C=0.1):
|
116 |
"""
|
117 |
The get_concepts_vectors function takes in a list of concepts, a dictionary of annotations, and the dataframe containing all the images.
|
118 |
It returns two things:
|
@@ -132,7 +150,7 @@ def get_concepts_vectors(concepts, annotations, df, samples=100, method='LR', C=
|
|
132 |
performances = []
|
133 |
vectors = np.zeros((len(concepts), 512))
|
134 |
for i, conc in enumerate(concepts):
|
135 |
-
vec, _, imp_nodes, performance = get_separation_space(conc, annotations, df, samples=samples, method=method, C=C)
|
136 |
vectors[i,:] = vec
|
137 |
performances.append(performance)
|
138 |
important_nodes.append(set(imp_nodes))
|
@@ -148,3 +166,60 @@ def get_concepts_vectors(concepts, annotations, df, samples=100, method='LR', C=
|
|
148 |
nodes_in_common = set.intersection(*important_nodes)
|
149 |
return vectors, nodes_in_common, performances
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from umap import UMAP
|
7 |
import PIL
|
8 |
|
9 |
+
def get_separation_space(type_bin, annotations, df, samples=100, method='LR', C=0.1, latent_space='Z'):
|
10 |
"""
|
11 |
The get_separation_space function takes in a type_bin, annotations, and df.
|
12 |
It then samples 100 of the most representative abstracts for that type_bin and 100 of the least representative abstracts for that type_bin.
|
|
|
22 |
:return: The weights of the linear classifier
|
23 |
:doc-author: Trelent
|
24 |
"""
|
25 |
+
|
26 |
+
if latent_space == 'Z':
|
27 |
+
col = 'z_vectors'
|
28 |
+
else:
|
29 |
+
col = 'w_vectors'
|
30 |
+
|
31 |
abstracts = np.array([float(ann) for ann in df[type_bin]])
|
32 |
abstract_idxs = list(np.argsort(abstracts))[:samples]
|
33 |
repr_idxs = list(np.argsort(abstracts))[-samples:]
|
34 |
+
X = np.array([annotations[col][i] for i in abstract_idxs+repr_idxs])
|
35 |
X = X.reshape((2*samples, 512))
|
36 |
y = np.array([1]*samples + [0]*samples)
|
37 |
x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
|
|
|
51 |
return clf.coef_ / np.linalg.norm(clf.coef_), imp_features, imp_nodes, np.round(clf.score(x_val, y_val),2)
|
52 |
|
53 |
|
54 |
+
def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3, count=5, latent_space='Z'):
|
55 |
"""
|
56 |
The regenerate_images function takes a model, z, and decision_boundary as input. It then
|
57 |
constructs an inverse rotation/translation matrix and passes it to the generator. The generator
|
|
|
75 |
# Labels.
|
76 |
label = torch.zeros([1, G.c_dim], device=device)
|
77 |
|
78 |
+
|
79 |
z = torch.from_numpy(z.copy()).to(device)
|
80 |
decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
|
81 |
|
|
|
91 |
#m = make_transform(translate, rotate)
|
92 |
#m = np.linalg.inv(m)
|
93 |
#G.synthesis.input.transform.copy_(torch.from_numpy(m))
|
94 |
+
if latent_space == 'Z':
|
95 |
+
img = G(z_0, label, truncation_psi=0.7, noise_mode='const')
|
96 |
+
|
97 |
+
else:
|
98 |
+
W = z_0.expand((14, -1)).unsqueeze(0)
|
99 |
+
img = G.synthesis(W, noise_mode='const')
|
100 |
+
|
101 |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
102 |
images.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))
|
103 |
|
104 |
return images, lambdas
|
105 |
|
106 |
+
def generate_original_image(z, model, latent_space='Z'):
|
107 |
"""
|
108 |
The generate_original_image function takes in a latent vector and the model,
|
109 |
and returns an image generated from that latent vector.
|
|
|
118 |
G = model.to(device) # type: ignore
|
119 |
# Labels.
|
120 |
label = torch.zeros([1, G.c_dim], device=device)
|
121 |
+
if latent_space == 'Z':
|
122 |
+
z = torch.from_numpy(z.copy()).to(device)
|
123 |
+
img = G(z, label, truncation_psi=0.7, noise_mode='const')
|
124 |
+
else:
|
125 |
+
W = torch.from_numpy(np.repeat(z, 14, axis=0).reshape(1, 14, z.shape[1]).copy()).to(device)
|
126 |
+
print(W.shape)
|
127 |
+
img = G.synthesis(W, noise_mode='const')
|
128 |
+
|
129 |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
130 |
return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
|
131 |
|
132 |
|
133 |
+
def get_concepts_vectors(concepts, annotations, df, samples=100, method='LR', C=0.1, latent_space='Z'):
|
134 |
"""
|
135 |
The get_concepts_vectors function takes in a list of concepts, a dictionary of annotations, and the dataframe containing all the images.
|
136 |
It returns two things:
|
|
|
150 |
performances = []
|
151 |
vectors = np.zeros((len(concepts), 512))
|
152 |
for i, conc in enumerate(concepts):
|
153 |
+
vec, _, imp_nodes, performance = get_separation_space(conc, annotations, df, samples=samples, method=method, C=C, latent_space=latent_space)
|
154 |
vectors[i,:] = vec
|
155 |
performances.append(performance)
|
156 |
important_nodes.append(set(imp_nodes))
|
|
|
166 |
nodes_in_common = set.intersection(*important_nodes)
|
167 |
return vectors, nodes_in_common, performances
|
168 |
|
169 |
+
|
170 |
+
def get_verification_score(concept, decision_boundary, model, annotations, samples=100, latent_space='Z'):
|
171 |
+
import open_clip
|
172 |
+
import os
|
173 |
+
import random
|
174 |
+
from tqdm import tqdm
|
175 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
176 |
+
|
177 |
+
|
178 |
+
model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k')
|
179 |
+
tokenizer = open_clip.get_tokenizer('ViT-L-14')
|
180 |
+
|
181 |
+
# Prepare the text queries
|
182 |
+
#@markdown _in the form pre_prompt {label}_:
|
183 |
+
pre_prompt = "Artwork, " #@param {type:"string"}
|
184 |
+
text_descriptions = [f"{pre_prompt}{label}" for label in [concept]]
|
185 |
+
text_tokens = tokenizer(text_descriptions)
|
186 |
+
|
187 |
+
|
188 |
+
listlen = len(annotations['fname'])
|
189 |
+
items = random.sample(range(listlen), samples)
|
190 |
+
changes = []
|
191 |
+
for iterator in tqdm(items):
|
192 |
+
chunk_imgs = []
|
193 |
+
chunk_ids = []
|
194 |
+
|
195 |
+
if latent_space == 'Z':
|
196 |
+
z = annotations['z_vectors'][iterator]
|
197 |
+
else:
|
198 |
+
z = annotations['w_vectors'][iterator]
|
199 |
+
images, lambdas = regenerate_images(model, z, decision_boundary, min_epsilon=0, max_epsilon=1, count=2, latent_space=latent_space)
|
200 |
+
for im,l in zip(images, lambdas):
|
201 |
+
|
202 |
+
chunk_imgs.append(preprocess(im.convert("RGB")))
|
203 |
+
chunk_ids.append(l)
|
204 |
+
|
205 |
+
image_input = torch.tensor(np.stack(chunk_imgs))
|
206 |
+
|
207 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
208 |
+
text_features = model_clip.encode_text(text_tokens).float()
|
209 |
+
image_features = model_clip.encode_image(image_input).float()
|
210 |
+
|
211 |
+
# Rescale features
|
212 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
213 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
214 |
+
|
215 |
+
# Analyze featues
|
216 |
+
text_probs = (100.0 * image_features.cpu().numpy() @ text_features.cpu().numpy().T)#.softmax(dim=-1)
|
217 |
+
|
218 |
+
change = max(text_probs[1][0].item() - text_probs[0][0].item(), 0)
|
219 |
+
changes.append(change)
|
220 |
+
|
221 |
+
return np.round(np.mean(np.array(changes)), 4)
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
|
data/annotated_files/{annotations_seeds0000-1000.pkl → seeds0000-50000.pkl}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cd1bd97b8ff508b1d4a7ef43323530368ace65b35d12d84a914913f541187298
|
3 |
+
size 314939226
|
data/annotated_files/sim_seeds0000-10000.csv
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:4e82d206b3aa231c00176a24c8de33a6299e92e65b23013a40538146b8d24ff8
|
3 |
-
size 5645518
|
|
|
|
|
|
|
|
data/annotated_files/{annotations_parallel_seeds0000-10000.pkl → sim_seeds0000-50000.csv}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c3faa3d75c2da1dbb2c5d90aeddee256e1f3324b24b902a54115d9b6aad0ae9d
|
3 |
+
size 21965577
|
data/model_files/pytorch_model.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:27d6840c1f9f11a0af97f6f1ff3809f7f3641d1e4ea7bc893ad15d9e4341caed
|
3 |
-
size 120944973
|
|
|
|
|
|
|
|
pages/1_Disentanglement.py
CHANGED
@@ -34,11 +34,11 @@ with st.expander("See more instruction", expanded=False):
|
|
34 |
st.write(instruction_text)
|
35 |
|
36 |
|
37 |
-
annotations_file = './data/annotated_files/seeds0000-
|
38 |
with open(annotations_file, 'rb') as f:
|
39 |
annotations = pickle.load(f)
|
40 |
|
41 |
-
ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-
|
42 |
concepts = './data/concepts.txt'
|
43 |
|
44 |
with open(concepts) as f:
|
@@ -48,6 +48,8 @@ if 'image_id' not in st.session_state:
|
|
48 |
st.session_state.image_id = 0
|
49 |
if 'concept_id' not in st.session_state:
|
50 |
st.session_state.concept_id = 'Abstract'
|
|
|
|
|
51 |
|
52 |
# def on_change_random_input():
|
53 |
# st.session_state.image_id = st.session_state.image_id
|
@@ -65,7 +67,12 @@ with input_col_1:
|
|
65 |
# concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
66 |
concept_id = st.selectbox('Concept:', tuple(labels))
|
67 |
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
69 |
# random_text = st.form_submit_button('Select a random concept')
|
70 |
|
71 |
# if random_text:
|
@@ -76,6 +83,8 @@ with input_col_1:
|
|
76 |
if choose_text_button:
|
77 |
concept_id = str(concept_id)
|
78 |
st.session_state.concept_id = concept_id
|
|
|
|
|
79 |
# st.write(image_id, st.session_state.image_id)
|
80 |
|
81 |
# ---------------------------- SET UP OUTPUT ------------------------------
|
@@ -101,10 +110,10 @@ smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgr
|
|
101 |
|
102 |
# ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
|
103 |
with output_col_1:
|
104 |
-
separation_vector, number_important_features, imp_nodes, performance = get_separation_space(concept_id, annotations, ann_df)
|
105 |
# st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
|
106 |
st.write('Concept vector', separation_vector)
|
107 |
-
header_col_1.write(f'Concept {concept_id} - Number of relevant nodes: {number_important_features} - Val classification performance: {performance}')# - Nodes {",".join(list(imp_nodes))}')
|
108 |
|
109 |
# ----------------------------- INPUT column 2 & 3 ----------------------------
|
110 |
with input_col_2:
|
@@ -141,8 +150,12 @@ with input_col_3:
|
|
141 |
with dnnlib.util.open_url('./data/model_files/network-snapshot-010600.pkl') as f:
|
142 |
model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
|
143 |
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
146 |
# input_image = original_image_dict['image']
|
147 |
# input_label = original_image_dict['label']
|
148 |
# input_id = original_image_dict['id']
|
@@ -152,7 +165,7 @@ with smoothgrad_col_3:
|
|
152 |
smooth_head_3.write(f'Base image')
|
153 |
|
154 |
|
155 |
-
images, lambdas = regenerate_images(model, original_image_vec, separation_vector, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon))
|
156 |
|
157 |
with smoothgrad_col_1:
|
158 |
st.image(images[0])
|
|
|
34 |
st.write(instruction_text)
|
35 |
|
36 |
|
37 |
+
annotations_file = './data/annotated_files/seeds0000-50000.pkl'
|
38 |
with open(annotations_file, 'rb') as f:
|
39 |
annotations = pickle.load(f)
|
40 |
|
41 |
+
ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-50000.csv')
|
42 |
concepts = './data/concepts.txt'
|
43 |
|
44 |
with open(concepts) as f:
|
|
|
48 |
st.session_state.image_id = 0
|
49 |
if 'concept_id' not in st.session_state:
|
50 |
st.session_state.concept_id = 'Abstract'
|
51 |
+
if 'space_id' not in st.session_state:
|
52 |
+
st.session_state.space_id = 'Z'
|
53 |
|
54 |
# def on_change_random_input():
|
55 |
# st.session_state.image_id = st.session_state.image_id
|
|
|
67 |
# concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
68 |
concept_id = st.selectbox('Concept:', tuple(labels))
|
69 |
|
70 |
+
st.write('**Choose a latent space to disentangle**')
|
71 |
+
# chosen_text_id_input = st.empty()
|
72 |
+
# concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
73 |
+
space_id = st.selectbox('Space:', tuple(['Z', 'W']))
|
74 |
+
|
75 |
+
choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
|
76 |
# random_text = st.form_submit_button('Select a random concept')
|
77 |
|
78 |
# if random_text:
|
|
|
83 |
if choose_text_button:
|
84 |
concept_id = str(concept_id)
|
85 |
st.session_state.concept_id = concept_id
|
86 |
+
space_id = str(space_id)
|
87 |
+
st.session_state.space_id = space_id
|
88 |
# st.write(image_id, st.session_state.image_id)
|
89 |
|
90 |
# ---------------------------- SET UP OUTPUT ------------------------------
|
|
|
110 |
|
111 |
# ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
|
112 |
with output_col_1:
|
113 |
+
separation_vector, number_important_features, imp_nodes, performance = get_separation_space(concept_id, annotations, ann_df, latent_space=st.session_state.space_id)
|
114 |
# st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
|
115 |
st.write('Concept vector', separation_vector)
|
116 |
+
header_col_1.write(f'Concept {concept_id} - Space {st.session_state.space_id} - Number of relevant nodes: {number_important_features} - Val classification performance: {performance}')# - Nodes {",".join(list(imp_nodes))}')
|
117 |
|
118 |
# ----------------------------- INPUT column 2 & 3 ----------------------------
|
119 |
with input_col_2:
|
|
|
150 |
with dnnlib.util.open_url('./data/model_files/network-snapshot-010600.pkl') as f:
|
151 |
model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
|
152 |
|
153 |
+
if st.session_state.space_id == 'Z':
|
154 |
+
original_image_vec = annotations['z_vectors'][st.session_state.image_id]
|
155 |
+
else:
|
156 |
+
original_image_vec = annotations['w_vectors'][st.session_state.image_id]
|
157 |
+
|
158 |
+
img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
|
159 |
# input_image = original_image_dict['image']
|
160 |
# input_label = original_image_dict['label']
|
161 |
# input_id = original_image_dict['id']
|
|
|
165 |
smooth_head_3.write(f'Base image')
|
166 |
|
167 |
|
168 |
+
images, lambdas = regenerate_images(model, original_image_vec, separation_vector, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon), latent_space=st.session_state.space_id)
|
169 |
|
170 |
with smoothgrad_col_1:
|
171 |
st.image(images[0])
|
pages/2_Concepts_comparison.py
CHANGED
@@ -39,7 +39,8 @@ if 'image_id' not in st.session_state:
|
|
39 |
st.session_state.image_id = 0
|
40 |
if 'concept_ids' not in st.session_state:
|
41 |
st.session_state.concept_ids = ['Abstract', 'Representational']
|
42 |
-
|
|
|
43 |
# def on_change_random_input():
|
44 |
# st.session_state.image_id = st.session_state.image_id
|
45 |
|
@@ -63,9 +64,17 @@ with input_col_1:
|
|
63 |
# concept_id = random.choice(labels)
|
64 |
# st.session_state.concept_id = concept_id
|
65 |
# chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
if choose_text_button:
|
68 |
st.session_state.concept_ids = list(concept_ids)
|
|
|
|
|
69 |
# st.write(image_id, st.session_state.image_id)
|
70 |
|
71 |
# ---------------------------- SET UP OUTPUT ------------------------------
|
@@ -91,10 +100,10 @@ smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgr
|
|
91 |
|
92 |
# ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
|
93 |
with output_col_1:
|
94 |
-
vectors, nodes_in_common, performances = get_concepts_vectors(concept_ids, annotations, ann_df)
|
95 |
# st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
|
96 |
#st.write('Concept vector', separation_vector)
|
97 |
-
header_col_1.write(f'Concepts {", ".join(concept_ids)} - Relevant nodes in common: {nodes_in_common} - Performance of the concept vectors: {performances}')# - Nodes {",".join(list(imp_nodes))}')
|
98 |
|
99 |
edges = []
|
100 |
for i in range(len(concept_ids)):
|
|
|
39 |
st.session_state.image_id = 0
|
40 |
if 'concept_ids' not in st.session_state:
|
41 |
st.session_state.concept_ids = ['Abstract', 'Representational']
|
42 |
+
if 'space_id' not in st.session_state:
|
43 |
+
st.session_state.space_id = 'Z'
|
44 |
# def on_change_random_input():
|
45 |
# st.session_state.image_id = st.session_state.image_id
|
46 |
|
|
|
64 |
# concept_id = random.choice(labels)
|
65 |
# st.session_state.concept_id = concept_id
|
66 |
# chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
67 |
+
st.write('**Choose a latent space to disentangle**')
|
68 |
+
# chosen_text_id_input = st.empty()
|
69 |
+
# concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
70 |
+
space_id = st.selectbox('Space:', tuple(['Z', 'W']))
|
71 |
+
|
72 |
+
choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
|
73 |
+
|
74 |
if choose_text_button:
|
75 |
st.session_state.concept_ids = list(concept_ids)
|
76 |
+
space_id = str(space_id)
|
77 |
+
st.session_state.space_id = space_id
|
78 |
# st.write(image_id, st.session_state.image_id)
|
79 |
|
80 |
# ---------------------------- SET UP OUTPUT ------------------------------
|
|
|
100 |
|
101 |
# ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
|
102 |
with output_col_1:
|
103 |
+
vectors, nodes_in_common, performances = get_concepts_vectors(concept_ids, annotations, ann_df, latent_space=space_id)
|
104 |
# st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
|
105 |
#st.write('Concept vector', separation_vector)
|
106 |
+
header_col_1.write(f'Concepts {", ".join(concept_ids)} - Latent space {space_id} - Relevant nodes in common: {nodes_in_common} - Performance of the concept vectors: {performances}')# - Nodes {",".join(list(imp_nodes))}')
|
107 |
|
108 |
edges = []
|
109 |
for i in range(len(concept_ids)):
|
view_predictions.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|