Spaces:
Runtime error
Runtime error
added cv2 and show sample image
Browse files- backend/disentangle_concepts.py +10 -0
- pages/1_Disentanglement.py +8 -1
- requirements.txt +2 -1
backend/disentangle_concepts.py
CHANGED
@@ -45,3 +45,13 @@ def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3
|
|
45 |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
46 |
images.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))
|
47 |
return images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
46 |
images.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))
|
47 |
return images
|
48 |
+
|
49 |
+
def generate_original_image(z, model):
|
50 |
+
device = torch.device('cpu')
|
51 |
+
G = model.to(device) # type: ignore
|
52 |
+
# Labels.
|
53 |
+
label = torch.zeros([1, G.c_dim], device=device)
|
54 |
+
z = torch.from_numpy(z.copy()).to(device)
|
55 |
+
img = G(z, label, truncation_psi=0.7, noise_mode='random')
|
56 |
+
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
57 |
+
return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
|
pages/1_Disentanglement.py
CHANGED
@@ -128,11 +128,18 @@ with input_col_3:
|
|
128 |
epsilon_button = st.form_submit_button('Choose the defined epsilon')
|
129 |
|
130 |
# ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
|
131 |
-
|
|
|
|
|
|
|
132 |
# input_image = original_image_dict['image']
|
133 |
# input_label = original_image_dict['label']
|
134 |
# input_id = original_image_dict['id']
|
135 |
|
|
|
|
|
|
|
|
|
136 |
|
137 |
|
138 |
# if pred_class_id == (input_id-1) and (epsilon_button or find_epsilon or setting_button):
|
|
|
128 |
epsilon_button = st.form_submit_button('Choose the defined epsilon')
|
129 |
|
130 |
# ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
|
131 |
+
|
132 |
+
model = torch.load('./data/model_files/pytorch_model.bin')
|
133 |
+
original_image_vec = annotations['z_vectors'][st.session_state.image_id]
|
134 |
+
img = generate_original_image(original_image_vec, model)
|
135 |
# input_image = original_image_dict['image']
|
136 |
# input_label = original_image_dict['label']
|
137 |
# input_id = original_image_dict['id']
|
138 |
|
139 |
+
with smoothgrad_col_3:
|
140 |
+
st.image(img)
|
141 |
+
header_col_1.write(f'Base image')
|
142 |
+
|
143 |
|
144 |
|
145 |
# if pred_class_id == (input_id-1) and (epsilon_button or find_epsilon or setting_button):
|
requirements.txt
CHANGED
@@ -12,4 +12,5 @@ torchvision==0.11.2
|
|
12 |
tqdm==4.64.1
|
13 |
transformers==4.25.1
|
14 |
scikit-learn
|
15 |
-
altair==4.0
|
|
|
|
12 |
tqdm==4.64.1
|
13 |
transformers==4.25.1
|
14 |
scikit-learn
|
15 |
+
altair==4.0
|
16 |
+
opencv-python
|