Spaces:
Runtime error
Runtime error
updated model
Browse files
backend/disentangle_concepts.py
CHANGED
@@ -1,22 +1,34 @@
|
|
1 |
import numpy as np
|
2 |
from sklearn.svm import SVC
|
|
|
3 |
from sklearn.model_selection import train_test_split
|
4 |
import torch
|
|
|
5 |
import PIL
|
6 |
|
7 |
-
def get_separation_space(type_bin, annotations, df):
|
8 |
abstracts = np.array([float(ann) for ann in df[type_bin]])
|
9 |
-
abstract_idxs = list(np.argsort(abstracts))[:
|
10 |
-
repr_idxs = list(np.argsort(abstracts))[-
|
11 |
X = np.array([annotations['z_vectors'][i] for i in abstract_idxs+repr_idxs])
|
12 |
-
X = X.reshape((
|
13 |
-
y = np.array([1]*
|
14 |
x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3, count=5):
|
22 |
device = torch.device('cpu')
|
@@ -55,4 +67,24 @@ def generate_original_image(z, model):
|
|
55 |
z = torch.from_numpy(z.copy()).to(device)
|
56 |
img = G(z, label, truncation_psi=0.7, noise_mode='random')
|
57 |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
58 |
-
return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
from sklearn.svm import SVC
|
3 |
+
from sklearn.linear_model import LogisticRegression
|
4 |
from sklearn.model_selection import train_test_split
|
5 |
import torch
|
6 |
+
from umap import UMAP
|
7 |
import PIL
|
8 |
|
9 |
+
def get_separation_space(type_bin, annotations, df, samples=100, method='LR', C=0.1):
|
10 |
abstracts = np.array([float(ann) for ann in df[type_bin]])
|
11 |
+
abstract_idxs = list(np.argsort(abstracts))[:samples]
|
12 |
+
repr_idxs = list(np.argsort(abstracts))[-samples:]
|
13 |
X = np.array([annotations['z_vectors'][i] for i in abstract_idxs+repr_idxs])
|
14 |
+
X = X.reshape((2*samples, 512))
|
15 |
+
y = np.array([1]*samples + [0]*samples)
|
16 |
x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
|
17 |
+
if method == 'SVM':
|
18 |
+
svc = SVC(gamma='auto', kernel='linear', random_state=0, C=C)
|
19 |
+
svc.fit(x_train, y_train)
|
20 |
+
print('Val performance SVM', svc.score(x_val, y_val))
|
21 |
+
imp_features = (np.abs(svc.coef_) > 0.2).sum()
|
22 |
+
imp_nodes = np.where(np.abs(svc.coef_) > 0.2)[1]
|
23 |
+
return svc.coef_, imp_features, imp_nodes
|
24 |
+
elif method == 'LR':
|
25 |
+
clf = LogisticRegression(random_state=0, C=C)
|
26 |
+
clf.fit(x_train, y_train)
|
27 |
+
print('Val performance logistic regression', clf.score(x_val, y_val))
|
28 |
+
imp_features = (np.abs(clf.coef_) > 0.2).sum()
|
29 |
+
imp_nodes = np.where(np.abs(clf.coef_) > 0.2)[1]
|
30 |
+
return clf.coef_, imp_features, imp_nodes
|
31 |
+
|
32 |
|
33 |
def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3, count=5):
|
34 |
device = torch.device('cpu')
|
|
|
67 |
z = torch.from_numpy(z.copy()).to(device)
|
68 |
img = G(z, label, truncation_psi=0.7, noise_mode='random')
|
69 |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
70 |
+
return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
|
71 |
+
|
72 |
+
|
73 |
+
def get_concepts_vectors(concepts, annotations, df, samples=100, method='LR', C=0.1):
|
74 |
+
important_nodes = []
|
75 |
+
vectors = np.zeros((len(concepts), 512))
|
76 |
+
for i, conc in enumerate(concepts):
|
77 |
+
vec, _, imp_nodes = get_separation_space(conc, annotations, df, samples=samples, method=method, C=C)
|
78 |
+
vectors[i,:] = vec
|
79 |
+
important_nodes.append(set(imp_nodes))
|
80 |
+
|
81 |
+
reducer = UMAP(n_neighbors=3, # default 15, The size of local neighborhood (in terms of number of neighboring sample points) used for manifold approximation.
|
82 |
+
n_components=3, # default 2, The dimension of the space to embed into.
|
83 |
+
min_dist=0.1, # default 0.1, The effective minimum distance between embedded points.
|
84 |
+
spread=2.0, # default 1.0, The effective scale of embedded points. In combination with ``min_dist`` this determines how clustered/clumped the embedded points are.
|
85 |
+
random_state=0, # default: None, If int, random_state is the seed used by the random number generator;
|
86 |
+
)
|
87 |
+
|
88 |
+
projection = reducer.fit_transform(vectors)
|
89 |
+
nodes_in_common = set.intersection(*important_nodes)
|
90 |
+
return vectors, projection, nodes_in_common
|
data/annotated_files/seeds0000-100000.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7b3a4fd155fa86df0953ad1cb660d50729189606de307fcee09fd893ba047228
|
3 |
+
size 420351795
|
data/annotated_files/sim_seeds0000-100000.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4e501641d051743b0f1eec385bf7cb2d769e3cb15f1fffc08dce6d38c1f2bbf8
|
3 |
+
size 14059984
|
data/model_files/network-snapshot-010600.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9a46e8aecd50191b82632b5de7bf3b9e219a59564c54994dd203f016b7a8270e
|
3 |
+
size 357344749
|
nx.html
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<html>
|
2 |
+
<head>
|
3 |
+
<meta charset="utf-8">
|
4 |
+
|
5 |
+
<script src="lib/bindings/utils.js"></script>
|
6 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.2/dist/dist/vis-network.min.css" integrity="sha512-WgxfT5LWjfszlPHXRmBWHkV2eceiWTOBvrKCNbdgDYTHrT2AeLCGbF4sZlZw3UMN3WtL0tGUoIAKsu8mllg/XA==" crossorigin="anonymous" referrerpolicy="no-referrer" />
|
7 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.2/dist/vis-network.min.js" integrity="sha512-LnvoEWDFrqGHlHmDD2101OrLcbsfkrzoSpvtSQtxK3RMnRV0eOkhhBN2dXHKRrUU8p2DGRTk35n4O8nWSVe1mQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
|
8 |
+
|
9 |
+
|
10 |
+
<center>
|
11 |
+
<h1></h1>
|
12 |
+
</center>
|
13 |
+
|
14 |
+
<!-- <link rel="stylesheet" href="../node_modules/vis/dist/vis.min.css" type="text/css" />
|
15 |
+
<script type="text/javascript" src="../node_modules/vis/dist/vis.js"> </script>-->
|
16 |
+
<link
|
17 |
+
href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.0-beta3/dist/css/bootstrap.min.css"
|
18 |
+
rel="stylesheet"
|
19 |
+
integrity="sha384-eOJMYsd53ii+scO/bJGFsiCZc+5NDVN2yr8+0RDqr0Ql0h+rP48ckxlpbzKgwra6"
|
20 |
+
crossorigin="anonymous"
|
21 |
+
/>
|
22 |
+
<script
|
23 |
+
src="https://cdn.jsdelivr.net/npm/bootstrap@5.0.0-beta3/dist/js/bootstrap.bundle.min.js"
|
24 |
+
integrity="sha384-JEW9xMcG8R+pH31jmWH6WWP0WintQrMb4s7ZOdauHnUtxwoG2vI5DkLtS3qm9Ekf"
|
25 |
+
crossorigin="anonymous"
|
26 |
+
></script>
|
27 |
+
|
28 |
+
|
29 |
+
<center>
|
30 |
+
<h1></h1>
|
31 |
+
</center>
|
32 |
+
<style type="text/css">
|
33 |
+
|
34 |
+
#mynetwork {
|
35 |
+
width: 100%;
|
36 |
+
height: 750px;
|
37 |
+
background-color: #ffffff;
|
38 |
+
border: 1px solid lightgray;
|
39 |
+
position: relative;
|
40 |
+
float: left;
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
</style>
|
49 |
+
</head>
|
50 |
+
|
51 |
+
|
52 |
+
<body>
|
53 |
+
<div class="card" style="width: 100%">
|
54 |
+
|
55 |
+
|
56 |
+
<div id="mynetwork" class="card-body"></div>
|
57 |
+
</div>
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
<script type="text/javascript">
|
63 |
+
|
64 |
+
// initialize global variables.
|
65 |
+
var edges;
|
66 |
+
var nodes;
|
67 |
+
var allNodes;
|
68 |
+
var allEdges;
|
69 |
+
var nodeColors;
|
70 |
+
var originalNodes;
|
71 |
+
var network;
|
72 |
+
var container;
|
73 |
+
var options, data;
|
74 |
+
var filter = {
|
75 |
+
item : '',
|
76 |
+
property : '',
|
77 |
+
value : []
|
78 |
+
};
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
// This method is responsible for drawing the graph, returns the drawn network
|
85 |
+
function drawGraph() {
|
86 |
+
var container = document.getElementById('mynetwork');
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
// parsing and collecting nodes and edges from the python
|
91 |
+
nodes = new vis.DataSet([{"color": "#97c2fc", "id": "Op Art", "label": "Op Art", "shape": "dot", "title": "Op Art"}, {"color": "#97c2fc", "id": "Minimalism", "label": "Minimalism", "shape": "dot", "title": "Minimalism"}, {"color": "#97c2fc", "id": "Surrealism", "label": "Surrealism", "shape": "dot", "title": "Surrealism"}, {"color": "#97c2fc", "id": "Baroque", "label": "Baroque", "shape": "dot", "title": "Baroque"}, {"color": "#97c2fc", "id": "Lithography", "label": "Lithography", "shape": "dot", "title": "Lithography"}, {"color": "#97c2fc", "id": "Woodcut", "label": "Woodcut", "shape": "dot", "title": "Woodcut"}, {"color": "#97c2fc", "id": "etching", "label": "etching", "shape": "dot", "title": "etching"}, {"color": "#97c2fc", "id": "Intaglio", "label": "Intaglio", "shape": "dot", "title": "Intaglio"}]);
|
92 |
+
edges = new vis.DataSet([{"from": "Op Art", "title": "Op Art to Minimalism similarity 0.432", "to": "Minimalism", "value": 0.432}, {"from": "Op Art", "title": "Op Art to Surrealism similarity -0.086", "to": "Surrealism", "value": -0.086}, {"from": "Op Art", "title": "Op Art to Baroque similarity -0.047", "to": "Baroque", "value": -0.047}, {"from": "Op Art", "title": "Op Art to Lithography similarity 0.054", "to": "Lithography", "value": 0.054}, {"from": "Op Art", "title": "Op Art to Woodcut similarity 0.125", "to": "Woodcut", "value": 0.125}, {"from": "Op Art", "title": "Op Art to etching similarity 0.117", "to": "etching", "value": 0.117}, {"from": "Op Art", "title": "Op Art to Intaglio similarity 0.094", "to": "Intaglio", "value": 0.094}, {"from": "Minimalism", "title": "Minimalism to Surrealism similarity -0.042", "to": "Surrealism", "value": -0.042}, {"from": "Minimalism", "title": "Minimalism to Baroque similarity -0.052", "to": "Baroque", "value": -0.052}, {"from": "Minimalism", "title": "Minimalism to Lithography similarity 0.046", "to": "Lithography", "value": 0.046}, {"from": "Minimalism", "title": "Minimalism to Woodcut similarity 0.069", "to": "Woodcut", "value": 0.069}, {"from": "Minimalism", "title": "Minimalism to etching similarity 0.1", "to": "etching", "value": 0.1}, {"from": "Minimalism", "title": "Minimalism to Intaglio similarity 0.03", "to": "Intaglio", "value": 0.03}, {"from": "Surrealism", "title": "Surrealism to Baroque similarity 0.067", "to": "Baroque", "value": 0.067}, {"from": "Surrealism", "title": "Surrealism to Lithography similarity -0.235", "to": "Lithography", "value": -0.235}, {"from": "Surrealism", "title": "Surrealism to Woodcut similarity -0.16", "to": "Woodcut", "value": -0.16}, {"from": "Surrealism", "title": "Surrealism to etching similarity -0.171", "to": "etching", "value": -0.171}, {"from": "Surrealism", "title": "Surrealism to Intaglio similarity -0.076", "to": "Intaglio", "value": -0.076}, {"from": "Baroque", "title": "Baroque to Lithography similarity -0.125", "to": "Lithography", "value": -0.125}, {"from": "Baroque", "title": "Baroque to Woodcut similarity -0.022", "to": "Woodcut", "value": -0.022}, {"from": "Baroque", "title": "Baroque to etching similarity -0.102", "to": "etching", "value": -0.102}, {"from": "Baroque", "title": "Baroque to Intaglio similarity -0.046", "to": "Intaglio", "value": -0.046}, {"from": "Lithography", "title": "Lithography to Woodcut similarity 0.258", "to": "Woodcut", "value": 0.258}, {"from": "Lithography", "title": "Lithography to etching similarity 0.268", "to": "etching", "value": 0.268}, {"from": "Lithography", "title": "Lithography to Intaglio similarity 0.123", "to": "Intaglio", "value": 0.123}, {"from": "Woodcut", "title": "Woodcut to etching similarity 0.21", "to": "etching", "value": 0.21}, {"from": "Woodcut", "title": "Woodcut to Intaglio similarity 0.209", "to": "Intaglio", "value": 0.209}, {"from": "etching", "title": "etching to Intaglio similarity 0.178", "to": "Intaglio", "value": 0.178}]);
|
93 |
+
|
94 |
+
nodeColors = {};
|
95 |
+
allNodes = nodes.get({ returnType: "Object" });
|
96 |
+
for (nodeId in allNodes) {
|
97 |
+
nodeColors[nodeId] = allNodes[nodeId].color;
|
98 |
+
}
|
99 |
+
allEdges = edges.get({ returnType: "Object" });
|
100 |
+
// adding nodes and edges to the graph
|
101 |
+
data = {nodes: nodes, edges: edges};
|
102 |
+
|
103 |
+
var options = {
|
104 |
+
"configure": {
|
105 |
+
"enabled": false
|
106 |
+
},
|
107 |
+
"edges": {
|
108 |
+
"color": {
|
109 |
+
"inherit": true
|
110 |
+
},
|
111 |
+
"smooth": {
|
112 |
+
"enabled": true,
|
113 |
+
"type": "dynamic"
|
114 |
+
}
|
115 |
+
},
|
116 |
+
"interaction": {
|
117 |
+
"dragNodes": true,
|
118 |
+
"hideEdgesOnDrag": false,
|
119 |
+
"hideNodesOnDrag": false
|
120 |
+
},
|
121 |
+
"physics": {
|
122 |
+
"enabled": true,
|
123 |
+
"stabilization": {
|
124 |
+
"enabled": true,
|
125 |
+
"fit": true,
|
126 |
+
"iterations": 1000,
|
127 |
+
"onlyDynamicEdges": false,
|
128 |
+
"updateInterval": 50
|
129 |
+
}
|
130 |
+
}
|
131 |
+
};
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
network = new vis.Network(container, data, options);
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
return network;
|
150 |
+
|
151 |
+
}
|
152 |
+
drawGraph();
|
153 |
+
</script>
|
154 |
+
</body>
|
155 |
+
</html>
|
pages/1_Disentanglement.py
CHANGED
@@ -9,6 +9,7 @@ from matplotlib.backends.backend_agg import RendererAgg
|
|
9 |
|
10 |
from backend.disentangle_concepts import *
|
11 |
import torch_utils
|
|
|
12 |
|
13 |
_lock = RendererAgg.lock
|
14 |
|
@@ -32,11 +33,11 @@ with st.expander("See more instruction", expanded=False):
|
|
32 |
st.write(instruction_text)
|
33 |
|
34 |
|
35 |
-
annotations_file = './data/annotated_files/
|
36 |
with open(annotations_file, 'rb') as f:
|
37 |
annotations = pickle.load(f)
|
38 |
|
39 |
-
ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-
|
40 |
concepts = './data/concepts.txt'
|
41 |
|
42 |
with open(concepts) as f:
|
@@ -117,7 +118,7 @@ with input_col_2:
|
|
117 |
random_id = st.form_submit_button('Generate a random image')
|
118 |
|
119 |
if random_id:
|
120 |
-
image_id = random.randint(0,
|
121 |
st.session_state.image_id = image_id
|
122 |
chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
123 |
|
@@ -135,7 +136,10 @@ with input_col_3:
|
|
135 |
|
136 |
# ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
|
137 |
|
138 |
-
model = torch.load('./data/model_files/pytorch_model.bin', map_location=torch.device('cpu'))
|
|
|
|
|
|
|
139 |
original_image_vec = annotations['z_vectors'][st.session_state.image_id]
|
140 |
img = generate_original_image(original_image_vec, model)
|
141 |
# input_image = original_image_dict['image']
|
|
|
9 |
|
10 |
from backend.disentangle_concepts import *
|
11 |
import torch_utils
|
12 |
+
import dnnlib
|
13 |
|
14 |
_lock = RendererAgg.lock
|
15 |
|
|
|
33 |
st.write(instruction_text)
|
34 |
|
35 |
|
36 |
+
annotations_file = './data/annotated_files/seeds0000-100000.pkl'
|
37 |
with open(annotations_file, 'rb') as f:
|
38 |
annotations = pickle.load(f)
|
39 |
|
40 |
+
ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-100000.csv')
|
41 |
concepts = './data/concepts.txt'
|
42 |
|
43 |
with open(concepts) as f:
|
|
|
118 |
random_id = st.form_submit_button('Generate a random image')
|
119 |
|
120 |
if random_id:
|
121 |
+
image_id = random.randint(0, 100000)
|
122 |
st.session_state.image_id = image_id
|
123 |
chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
124 |
|
|
|
136 |
|
137 |
# ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
|
138 |
|
139 |
+
#model = torch.load('./data/model_files/pytorch_model.bin', map_location=torch.device('cpu'))
|
140 |
+
with dnnlib.util.open_url('./data/model_files/network-snapshot-010600.pkl') as f:
|
141 |
+
model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
|
142 |
+
|
143 |
original_image_vec = annotations['z_vectors'][st.session_state.image_id]
|
144 |
img = generate_original_image(original_image_vec, model)
|
145 |
# input_image = original_image_dict['image']
|
view_predictions.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|