lauracabayol commited on
Commit
af2bb4b
·
1 Parent(s): fc92339

latest version

Browse files
Files changed (1) hide show
  1. insight/insight.py +132 -22
insight/insight.py CHANGED
@@ -9,20 +9,29 @@ import os
9
  from astropy.table import Table
10
  from scipy.spatial import KDTree
11
  from scipy.special import erf
 
12
 
13
  class Insight_module():
14
  """ Define class"""
15
 
16
- def __init__(self, model, batch_size):
17
  self.model=model
18
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
  self.batch_size=batch_size
 
20
 
21
- def _get_dataloaders(self, input_data, target_data, val_fraction=0.1):
 
 
22
  input_data = torch.Tensor(input_data)
23
  target_data = torch.Tensor(target_data)
24
-
25
- dataset = TensorDataset(input_data, target_data)
 
 
 
 
 
26
 
27
  trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
28
  loader_train = DataLoader(trainig_dataset, batch_size=self.batch_size, shuffle = True)
@@ -34,17 +43,18 @@ class Insight_module():
34
 
35
 
36
  def _loss_function(self,mean, std, logmix, true):
37
-
38
- logerf = torch.log(erf(true.cpu()[:,None]/(np.sqrt(2)*std.detach().cpu())+1))
39
-
40
- log_prob = logmix - 0.5*(mean - true[:,None]).pow(2) / std.pow(2) - torch.log(std) #- logerf.to(self.device)
41
  log_prob = torch.logsumexp(log_prob, 1)
42
  loss = -log_prob.mean()
 
43
 
44
  return loss
45
 
46
  def _to_numpy(self,x):
47
  return x.detach().cpu().numpy()
 
 
48
 
49
  def train(self,input_data, target_data, nepochs=10, step_size = 100, val_fraction=0.1, lr=1e-3 ):
50
  self.model = self.model.train()
@@ -74,7 +84,6 @@ class Insight_module():
74
  sig = torch.exp(logsig)
75
 
76
 
77
- #print(mu,sig,target_data,torch.exp(logmix_coeff))
78
 
79
  loss = self._loss_function(mu, sig, logmix_coeff, target_data)
80
  _loss_train.append(loss.item())
@@ -102,32 +111,117 @@ class Insight_module():
102
 
103
  self.loss_validation.append(np.mean(_loss_validation))
104
 
105
- #print(f'training_loss:{loss}',f'testing_loss:{loss_val}')
106
 
107
-
108
- def get_photoz(self,input_data, target_data):
 
 
109
  self.model = self.model.eval()
110
  self.model = self.model.to(self.device)
111
 
112
  input_data = input_data.to(self.device)
113
  target_data = target_data.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  for ii in range(len(input_data)):
 
 
116
 
117
- mu, logsig, logmix_coeff = self.model(input_data)
118
- logsig = torch.clamp(logsig,-6,2)
119
- sig = torch.exp(logsig)
120
 
121
- mix_coeff = torch.exp(logmix_coeff)
122
 
123
- z = (mix_coeff * mu).sum(1)
124
- zerr = torch.sqrt( (mix_coeff * sig**2).sum(1) + (mix_coeff * (mu - target_data[:,None])**2).sum(1))
125
 
126
-
127
- return self._to_numpy(z),self._to_numpy(zerr)
128
 
129
 
130
- #return model
 
 
131
 
132
  def plot_photoz(self, df, nbins,xvariable,metric, type_bin='bin'):
133
  bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
@@ -170,5 +264,21 @@ class Insight_module():
170
  plt.show()
171
 
172
 
173
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
 
9
  from astropy.table import Table
10
  from scipy.spatial import KDTree
11
  from scipy.special import erf
12
+ from scipy.stats import norm
13
 
14
  class Insight_module():
15
  """ Define class"""
16
 
17
+ def __init__(self, model, batch_size=100,rejection_param=1):
18
  self.model=model
19
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
  self.batch_size=batch_size
21
+ self.rejection_parameter=rejection_param
22
 
23
+
24
+
25
+ def _get_dataloaders(self, input_data, target_data, additional_data=None, val_fraction=0.1):
26
  input_data = torch.Tensor(input_data)
27
  target_data = torch.Tensor(target_data)
28
+
29
+ if additional_data is None:
30
+ dataset = TensorDataset(input_data, target_data)
31
+ else:
32
+ additional_data = torch.Tensor(additional_data)
33
+ dataset = TensorDataset(input_data, target_data,additional_data)
34
+
35
 
36
  trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
37
  loader_train = DataLoader(trainig_dataset, batch_size=self.batch_size, shuffle = True)
 
43
 
44
 
45
  def _loss_function(self,mean, std, logmix, true):
46
+
47
+ log_prob = logmix - 0.5*(mean - true[:,None]).pow(2) / std.pow(2) - torch.log(std)
 
 
48
  log_prob = torch.logsumexp(log_prob, 1)
49
  loss = -log_prob.mean()
50
+
51
 
52
  return loss
53
 
54
  def _to_numpy(self,x):
55
  return x.detach().cpu().numpy()
56
+
57
+
58
 
59
  def train(self,input_data, target_data, nepochs=10, step_size = 100, val_fraction=0.1, lr=1e-3 ):
60
  self.model = self.model.train()
 
84
  sig = torch.exp(logsig)
85
 
86
 
 
87
 
88
  loss = self._loss_function(mu, sig, logmix_coeff, target_data)
89
  _loss_train.append(loss.item())
 
111
 
112
  self.loss_validation.append(np.mean(_loss_validation))
113
 
114
+ print(f'training_loss:{loss}',f'testing_loss:{loss_val}')
115
 
116
+
117
+
118
+
119
+ def get_pz(self,input_data, target_data, return_pz=False):
120
  self.model = self.model.eval()
121
  self.model = self.model.to(self.device)
122
 
123
  input_data = input_data.to(self.device)
124
  target_data = target_data.to(self.device)
125
+
126
+
127
+
128
+ mu, logsig, logmix_coeff = self.model(input_data)
129
+ logsig = torch.clamp(logsig,-6,2)
130
+ sig = torch.exp(logsig)
131
+
132
+ mix_coeff = torch.exp(logmix_coeff)
133
+
134
+ z = (mix_coeff * mu).sum(1)
135
+ zerr = torch.sqrt( (mix_coeff * sig**2).sum(1) + (mix_coeff * (mu - target_data[:,None])**2).sum(1))
136
+
137
+ mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
138
+
139
+
140
+ if return_pz==True:
141
+ x = np.linspace(0, 4, 1000)
142
+ pdf_mixture = np.zeros(shape=(len(target_data), len(x)))
143
+ for ii in range(len(input_data)):
144
+ for i in range(6):
145
+ pdf_mixture[ii] += mix_coeff[ii,i] * norm.pdf(x, mu[ii,i], sig[ii,i])
146
+
147
+ return self._to_numpy(z),self._to_numpy(zerr), pdf_mixture
148
+
149
+ else:
150
+ return self._to_numpy(z),self._to_numpy(zerr)
151
+
152
+ def pit(self, input_data, target_data):
153
+
154
+ pit_list = []
155
+
156
+ self.model = self.model.eval()
157
+ self.model = self.model.to(self.device)
158
+
159
+ input_data = input_data.to(self.device)
160
+
161
+
162
+ mu, logsig, logmix_coeff = self.model(input_data)
163
+ logsig = torch.clamp(logsig,-6,2)
164
+ sig = torch.exp(logsig)
165
 
166
+ mix_coeff = torch.exp(logmix_coeff)
167
+
168
+ mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
169
+
170
+ for ii in range(len(input_data)):
171
+ pit = (mix_coeff[ii] * norm.cdf(target_data[ii]*np.ones(mu[ii].shape),mu[ii], sig[ii])).sum()
172
+ pit_list.append(pit)
173
+
174
+
175
+ return pit_list
176
+
177
+ def crps(self, input_data, target_data):
178
+
179
+ def measure_crps(cdf, t):
180
+ zgrid = np.linspace(0,4,1000)
181
+ Deltaz = zgrid[None,:] - t[:,None]
182
+ DeltaZ_heaviside = np.where(Deltaz < 0,0,1)
183
+ integral = (cdf-DeltaZ_heaviside)**2
184
+ crps_value = integral.sum(1) / 1000
185
+
186
+ return crps_value
187
+
188
+
189
+ crps_list = []
190
+
191
+ self.model = self.model.eval()
192
+ self.model = self.model.to(self.device)
193
+
194
+ input_data = input_data.to(self.device)
195
+
196
+
197
+ mu, logsig, logmix_coeff = self.model(input_data)
198
+ logsig = torch.clamp(logsig,-6,2)
199
+ sig = torch.exp(logsig)
200
+
201
+ mix_coeff = torch.exp(logmix_coeff)
202
+
203
+
204
+ mu, mix_coeff, sig = mu.detach().cpu().numpy(), mix_coeff.detach().cpu().numpy(), sig.detach().cpu().numpy()
205
+
206
+ z = (mix_coeff * mu).sum(1)
207
+
208
+ x = np.linspace(0, 4, 1000)
209
+ pdf_mixture = np.zeros(shape=(len(target_data), len(x)))
210
  for ii in range(len(input_data)):
211
+ for i in range(6):
212
+ pdf_mixture[ii] += mix_coeff[ii,i] * norm.pdf(x, mu[ii,i], sig[ii,i])
213
 
214
+ pdf_mixture = pdf_mixture / pdf_mixture.sum(1)[:,None]
 
 
215
 
 
216
 
217
+ cdf_mixture = np.cumsum(pdf_mixture,1)
 
218
 
219
+ crps_value = measure_crps(cdf_mixture, target_data)
 
220
 
221
 
222
+
223
+ return crps_value
224
+
225
 
226
  def plot_photoz(self, df, nbins,xvariable,metric, type_bin='bin'):
227
  bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
 
264
  plt.show()
265
 
266
 
267
+ def plot_pz(self, m, pz, specz):
268
+ # Create a figure and axis
269
+ fig, ax = plt.subplots(figsize=(8, 6))
270
+
271
+ # Plot the PDF with a label
272
+ ax.plot(np.linspace(0, 4, 1000), pz[m], label='PDF', color='navy')
273
+
274
+ # Add a vertical line for 'specz_test'
275
+ ax.axvline(specz[m], color='black', linestyle='--', label=r'$z_{\rm s}$')
276
+
277
+ # Add labels and a legend
278
+ ax.set_xlabel(r'$z$', fontsize = 18)
279
+ ax.set_ylabel('Probability Density', fontsize=16)
280
+ ax.legend(fontsize = 18)
281
+
282
+ # Display the plot
283
+ plt.show()
284