Spaces:
Runtime error
Runtime error
Commit
·
af2bb4b
1
Parent(s):
fc92339
latest version
Browse files- insight/insight.py +132 -22
insight/insight.py
CHANGED
@@ -9,20 +9,29 @@ import os
|
|
9 |
from astropy.table import Table
|
10 |
from scipy.spatial import KDTree
|
11 |
from scipy.special import erf
|
|
|
12 |
|
13 |
class Insight_module():
|
14 |
""" Define class"""
|
15 |
|
16 |
-
def __init__(self, model, batch_size):
|
17 |
self.model=model
|
18 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
19 |
self.batch_size=batch_size
|
|
|
20 |
|
21 |
-
|
|
|
|
|
22 |
input_data = torch.Tensor(input_data)
|
23 |
target_data = torch.Tensor(target_data)
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
|
28 |
loader_train = DataLoader(trainig_dataset, batch_size=self.batch_size, shuffle = True)
|
@@ -34,17 +43,18 @@ class Insight_module():
|
|
34 |
|
35 |
|
36 |
def _loss_function(self,mean, std, logmix, true):
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
log_prob = logmix - 0.5*(mean - true[:,None]).pow(2) / std.pow(2) - torch.log(std) #- logerf.to(self.device)
|
41 |
log_prob = torch.logsumexp(log_prob, 1)
|
42 |
loss = -log_prob.mean()
|
|
|
43 |
|
44 |
return loss
|
45 |
|
46 |
def _to_numpy(self,x):
|
47 |
return x.detach().cpu().numpy()
|
|
|
|
|
48 |
|
49 |
def train(self,input_data, target_data, nepochs=10, step_size = 100, val_fraction=0.1, lr=1e-3 ):
|
50 |
self.model = self.model.train()
|
@@ -74,7 +84,6 @@ class Insight_module():
|
|
74 |
sig = torch.exp(logsig)
|
75 |
|
76 |
|
77 |
-
#print(mu,sig,target_data,torch.exp(logmix_coeff))
|
78 |
|
79 |
loss = self._loss_function(mu, sig, logmix_coeff, target_data)
|
80 |
_loss_train.append(loss.item())
|
@@ -102,32 +111,117 @@ class Insight_module():
|
|
102 |
|
103 |
self.loss_validation.append(np.mean(_loss_validation))
|
104 |
|
105 |
-
|
106 |
|
107 |
-
|
108 |
-
|
|
|
|
|
109 |
self.model = self.model.eval()
|
110 |
self.model = self.model.to(self.device)
|
111 |
|
112 |
input_data = input_data.to(self.device)
|
113 |
target_data = target_data.to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
for ii in range(len(input_data)):
|
|
|
|
|
116 |
|
117 |
-
|
118 |
-
logsig = torch.clamp(logsig,-6,2)
|
119 |
-
sig = torch.exp(logsig)
|
120 |
|
121 |
-
mix_coeff = torch.exp(logmix_coeff)
|
122 |
|
123 |
-
|
124 |
-
zerr = torch.sqrt( (mix_coeff * sig**2).sum(1) + (mix_coeff * (mu - target_data[:,None])**2).sum(1))
|
125 |
|
126 |
-
|
127 |
-
return self._to_numpy(z),self._to_numpy(zerr)
|
128 |
|
129 |
|
130 |
-
|
|
|
|
|
131 |
|
132 |
def plot_photoz(self, df, nbins,xvariable,metric, type_bin='bin'):
|
133 |
bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
|
@@ -170,5 +264,21 @@ class Insight_module():
|
|
170 |
plt.show()
|
171 |
|
172 |
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
|
|
9 |
from astropy.table import Table
|
10 |
from scipy.spatial import KDTree
|
11 |
from scipy.special import erf
|
12 |
+
from scipy.stats import norm
|
13 |
|
14 |
class Insight_module():
|
15 |
""" Define class"""
|
16 |
|
17 |
+
def __init__(self, model, batch_size=100,rejection_param=1):
|
18 |
self.model=model
|
19 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
20 |
self.batch_size=batch_size
|
21 |
+
self.rejection_parameter=rejection_param
|
22 |
|
23 |
+
|
24 |
+
|
25 |
+
def _get_dataloaders(self, input_data, target_data, additional_data=None, val_fraction=0.1):
|
26 |
input_data = torch.Tensor(input_data)
|
27 |
target_data = torch.Tensor(target_data)
|
28 |
+
|
29 |
+
if additional_data is None:
|
30 |
+
dataset = TensorDataset(input_data, target_data)
|
31 |
+
else:
|
32 |
+
additional_data = torch.Tensor(additional_data)
|
33 |
+
dataset = TensorDataset(input_data, target_data,additional_data)
|
34 |
+
|
35 |
|
36 |
trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
|
37 |
loader_train = DataLoader(trainig_dataset, batch_size=self.batch_size, shuffle = True)
|
|
|
43 |
|
44 |
|
45 |
def _loss_function(self,mean, std, logmix, true):
|
46 |
+
|
47 |
+
log_prob = logmix - 0.5*(mean - true[:,None]).pow(2) / std.pow(2) - torch.log(std)
|
|
|
|
|
48 |
log_prob = torch.logsumexp(log_prob, 1)
|
49 |
loss = -log_prob.mean()
|
50 |
+
|
51 |
|
52 |
return loss
|
53 |
|
54 |
def _to_numpy(self,x):
|
55 |
return x.detach().cpu().numpy()
|
56 |
+
|
57 |
+
|
58 |
|
59 |
def train(self,input_data, target_data, nepochs=10, step_size = 100, val_fraction=0.1, lr=1e-3 ):
|
60 |
self.model = self.model.train()
|
|
|
84 |
sig = torch.exp(logsig)
|
85 |
|
86 |
|
|
|
87 |
|
88 |
loss = self._loss_function(mu, sig, logmix_coeff, target_data)
|
89 |
_loss_train.append(loss.item())
|
|
|
111 |
|
112 |
self.loss_validation.append(np.mean(_loss_validation))
|
113 |
|
114 |
+
print(f'training_loss:{loss}',f'testing_loss:{loss_val}')
|
115 |
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
def get_pz(self,input_data, target_data, return_pz=False):
|
120 |
self.model = self.model.eval()
|
121 |
self.model = self.model.to(self.device)
|
122 |
|
123 |
input_data = input_data.to(self.device)
|
124 |
target_data = target_data.to(self.device)
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
mu, logsig, logmix_coeff = self.model(input_data)
|
129 |
+
logsig = torch.clamp(logsig,-6,2)
|
130 |
+
sig = torch.exp(logsig)
|
131 |
+
|
132 |
+
mix_coeff = torch.exp(logmix_coeff)
|
133 |
+
|
134 |
+
z = (mix_coeff * mu).sum(1)
|
135 |
+
zerr = torch.sqrt( (mix_coeff * sig**2).sum(1) + (mix_coeff * (mu - target_data[:,None])**2).sum(1))
|
136 |
+
|
137 |
+
mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
|
138 |
+
|
139 |
+
|
140 |
+
if return_pz==True:
|
141 |
+
x = np.linspace(0, 4, 1000)
|
142 |
+
pdf_mixture = np.zeros(shape=(len(target_data), len(x)))
|
143 |
+
for ii in range(len(input_data)):
|
144 |
+
for i in range(6):
|
145 |
+
pdf_mixture[ii] += mix_coeff[ii,i] * norm.pdf(x, mu[ii,i], sig[ii,i])
|
146 |
+
|
147 |
+
return self._to_numpy(z),self._to_numpy(zerr), pdf_mixture
|
148 |
+
|
149 |
+
else:
|
150 |
+
return self._to_numpy(z),self._to_numpy(zerr)
|
151 |
+
|
152 |
+
def pit(self, input_data, target_data):
|
153 |
+
|
154 |
+
pit_list = []
|
155 |
+
|
156 |
+
self.model = self.model.eval()
|
157 |
+
self.model = self.model.to(self.device)
|
158 |
+
|
159 |
+
input_data = input_data.to(self.device)
|
160 |
+
|
161 |
+
|
162 |
+
mu, logsig, logmix_coeff = self.model(input_data)
|
163 |
+
logsig = torch.clamp(logsig,-6,2)
|
164 |
+
sig = torch.exp(logsig)
|
165 |
|
166 |
+
mix_coeff = torch.exp(logmix_coeff)
|
167 |
+
|
168 |
+
mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
|
169 |
+
|
170 |
+
for ii in range(len(input_data)):
|
171 |
+
pit = (mix_coeff[ii] * norm.cdf(target_data[ii]*np.ones(mu[ii].shape),mu[ii], sig[ii])).sum()
|
172 |
+
pit_list.append(pit)
|
173 |
+
|
174 |
+
|
175 |
+
return pit_list
|
176 |
+
|
177 |
+
def crps(self, input_data, target_data):
|
178 |
+
|
179 |
+
def measure_crps(cdf, t):
|
180 |
+
zgrid = np.linspace(0,4,1000)
|
181 |
+
Deltaz = zgrid[None,:] - t[:,None]
|
182 |
+
DeltaZ_heaviside = np.where(Deltaz < 0,0,1)
|
183 |
+
integral = (cdf-DeltaZ_heaviside)**2
|
184 |
+
crps_value = integral.sum(1) / 1000
|
185 |
+
|
186 |
+
return crps_value
|
187 |
+
|
188 |
+
|
189 |
+
crps_list = []
|
190 |
+
|
191 |
+
self.model = self.model.eval()
|
192 |
+
self.model = self.model.to(self.device)
|
193 |
+
|
194 |
+
input_data = input_data.to(self.device)
|
195 |
+
|
196 |
+
|
197 |
+
mu, logsig, logmix_coeff = self.model(input_data)
|
198 |
+
logsig = torch.clamp(logsig,-6,2)
|
199 |
+
sig = torch.exp(logsig)
|
200 |
+
|
201 |
+
mix_coeff = torch.exp(logmix_coeff)
|
202 |
+
|
203 |
+
|
204 |
+
mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
|
205 |
+
|
206 |
+
z = (mix_coeff * mu).sum(1)
|
207 |
+
|
208 |
+
x = np.linspace(0, 4, 1000)
|
209 |
+
pdf_mixture = np.zeros(shape=(len(target_data), len(x)))
|
210 |
for ii in range(len(input_data)):
|
211 |
+
for i in range(6):
|
212 |
+
pdf_mixture[ii] += mix_coeff[ii,i] * norm.pdf(x, mu[ii,i], sig[ii,i])
|
213 |
|
214 |
+
pdf_mixture = pdf_mixture / pdf_mixture.sum(1)[:,None]
|
|
|
|
|
215 |
|
|
|
216 |
|
217 |
+
cdf_mixture = np.cumsum(pdf_mixture,1)
|
|
|
218 |
|
219 |
+
crps_value = measure_crps(cdf_mixture, target_data)
|
|
|
220 |
|
221 |
|
222 |
+
|
223 |
+
return crps_value
|
224 |
+
|
225 |
|
226 |
def plot_photoz(self, df, nbins,xvariable,metric, type_bin='bin'):
|
227 |
bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
|
|
|
264 |
plt.show()
|
265 |
|
266 |
|
267 |
+
def plot_pz(self, m, pz, specz):
|
268 |
+
# Create a figure and axis
|
269 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
270 |
+
|
271 |
+
# Plot the PDF with a label
|
272 |
+
ax.plot(np.linspace(0, 4, 1000), pz[m], label='PDF', color='navy')
|
273 |
+
|
274 |
+
# Add a vertical line for 'specz_test'
|
275 |
+
ax.axvline(specz[m], color='black', linestyle='--', label=r'$z_{\rm s}$')
|
276 |
+
|
277 |
+
# Add labels and a legend
|
278 |
+
ax.set_xlabel(r'$z$', fontsize = 18)
|
279 |
+
ax.set_ylabel('Probability Density', fontsize=16)
|
280 |
+
ax.legend(fontsize = 18)
|
281 |
+
|
282 |
+
# Display the plot
|
283 |
+
plt.show()
|
284 |
|