lauracabayol commited on
Commit
696a020
·
1 Parent(s): c212435

optimized version working at low z

Browse files
insight/.ipynb_checkpoints/archive-checkpoint.py CHANGED
@@ -13,7 +13,7 @@ rcParams["font.family"] = "STIXGeneral"
13
 
14
 
15
  class archive():
16
- def __init__(self, path, aperture=2, drop_stars=True, clean_photometry=True, convert_colors=True, extinction_corr=True, only_zspec=True, reliable_zspec=True):
17
 
18
  self.aperture = aperture
19
 
@@ -39,30 +39,28 @@ class archive():
39
  hdu_list = fits.open(os.path.join(path,filename_valid))
40
  cat_test = Table(hdu_list[1].data).to_pandas()
41
 
 
 
 
42
  gold_sample = pd.read_csv(os.path.join(path,filename_gold))
43
 
44
  #cat_test = self._match_gold_sample(cat_test,gold_sample)
45
 
46
  if drop_stars==True:
47
  cat = cat[cat.mu_class_L07==1]
 
48
 
49
  if clean_photometry==True:
50
  cat = self._clean_photometry(cat)
51
  cat_test = self._clean_photometry(cat_test)
52
 
53
- self._get_loss_weights(cat)
54
 
55
  cat = cat[cat.w_Q_f_S15>0]
56
-
57
- self._set_training_data(cat, only_zspec=only_zspec, reliable_zspec=reliable_zspec, extinction_corr=extinction_corr, convert_colors=convert_colors)
58
-
59
-
60
- self._set_testing_data(cat_test, only_zspec=only_zspec, reliable_zspec='Total', extinction_corr=extinction_corr, convert_colors=convert_colors)
61
 
62
  self._get_loss_weights(cat)
63
-
64
- #self.cat_test=cat_test
65
- #self.cat_train=cat
66
 
67
  def _extract_fluxes(self,catalogue):
68
  columns_f = [f'FLUX_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
@@ -100,13 +98,9 @@ class archive():
100
  catalogue = catalogue[catalogue.z_spec_S15>0]
101
  return catalogue
102
 
103
- def _clean_zspec_sample(self,catalogue ,kind=None):
104
- if kind==None:
105
- return catalogue
106
- elif kind=='Total':
107
- return catalogue[catalogue['reliable_S15']>0]
108
- elif kind=='Partial':
109
- return catalogue[(catalogue['w_Q_f_S15']>0.5)]
110
 
111
  def _map_weight(self,Qz):
112
  for key, value in self.weight_dict.items():
@@ -134,11 +128,11 @@ class archive():
134
  return catalogue_valid
135
 
136
 
137
- def _set_training_data(self,catalogue, only_zspec=True, reliable_zspec=True, extinction_corr=True, convert_colors=True):
138
 
139
  if only_zspec:
140
  catalogue = self._take_only_zspec(catalogue, cat_flag='Calib')
141
- catalogue = self._clean_zspec_sample(catalogue, kind=reliable_zspec)
142
 
143
  self.cat_train=catalogue
144
  f, ferr = self._extract_fluxes(catalogue)
@@ -159,11 +153,11 @@ class archive():
159
  self.target_z_train = catalogue['z_spec_S15'].values
160
  self.target_qz_train = catalogue['w_Q_f_S15'].values
161
 
162
- def _set_testing_data(self,catalogue, only_zspec=True, reliable_zspec=True, extinction_corr=True, convert_colors=True):
163
 
164
  if only_zspec:
165
  catalogue = self._take_only_zspec(catalogue, cat_flag='Valid')
166
- catalogue = self._clean_zspec_sample(catalogue, kind=reliable_zspec)
167
 
168
  self.cat_test=catalogue
169
 
 
13
 
14
 
15
  class archive():
16
+ def __init__(self, path, aperture=2, drop_stars=True, clean_photometry=True, convert_colors=True, extinction_corr=True, only_zspec=True, Qz_cut=1):
17
 
18
  self.aperture = aperture
19
 
 
39
  hdu_list = fits.open(os.path.join(path,filename_valid))
40
  cat_test = Table(hdu_list[1].data).to_pandas()
41
 
42
+ self._get_loss_weights(cat)
43
+ self._get_loss_weights(cat_test)
44
+
45
  gold_sample = pd.read_csv(os.path.join(path,filename_gold))
46
 
47
  #cat_test = self._match_gold_sample(cat_test,gold_sample)
48
 
49
  if drop_stars==True:
50
  cat = cat[cat.mu_class_L07==1]
51
+ cat_test = cat_test[cat_test.mu_class_L07==1]
52
 
53
  if clean_photometry==True:
54
  cat = self._clean_photometry(cat)
55
  cat_test = self._clean_photometry(cat_test)
56
 
 
57
 
58
  cat = cat[cat.w_Q_f_S15>0]
59
+
60
+ self._set_training_data(cat, only_zspec=only_zspec, extinction_corr=extinction_corr, convert_colors=convert_colors,Qz_cut=Qz_cut)
61
+ self._set_testing_data(cat_test, only_zspec=only_zspec, extinction_corr=extinction_corr, convert_colors=convert_colors)
 
 
62
 
63
  self._get_loss_weights(cat)
 
 
 
64
 
65
  def _extract_fluxes(self,catalogue):
66
  columns_f = [f'FLUX_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
 
98
  catalogue = catalogue[catalogue.z_spec_S15>0]
99
  return catalogue
100
 
101
+ def _clean_zspec_sample(self,catalogue ,Qz_cut):
102
+ catalogue = catalogue[catalogue.w_Q_f_S15>=Qz_cut]
103
+ return catalogue
 
 
 
 
104
 
105
  def _map_weight(self,Qz):
106
  for key, value in self.weight_dict.items():
 
128
  return catalogue_valid
129
 
130
 
131
+ def _set_training_data(self,catalogue, only_zspec=True, extinction_corr=True, convert_colors=True,Qz_cut=1):
132
 
133
  if only_zspec:
134
  catalogue = self._take_only_zspec(catalogue, cat_flag='Calib')
135
+ catalogue = self._clean_zspec_sample(catalogue, Qz_cut=Qz_cut)
136
 
137
  self.cat_train=catalogue
138
  f, ferr = self._extract_fluxes(catalogue)
 
153
  self.target_z_train = catalogue['z_spec_S15'].values
154
  self.target_qz_train = catalogue['w_Q_f_S15'].values
155
 
156
+ def _set_testing_data(self,catalogue, only_zspec=True, extinction_corr=True, convert_colors=True):
157
 
158
  if only_zspec:
159
  catalogue = self._take_only_zspec(catalogue, cat_flag='Valid')
160
+ catalogue = self._clean_zspec_sample(catalogue, Qz_cut=1)
161
 
162
  self.cat_test=catalogue
163
 
insight/.ipynb_checkpoints/insight-checkpoint.py CHANGED
@@ -8,34 +8,37 @@ from astropy.io import fits
8
  import os
9
  from astropy.table import Table
10
  from scipy.spatial import KDTree
 
11
 
12
  class Insight_module():
13
  """ Define class"""
14
 
15
- def __init__(self, model):
16
  self.model=model
17
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
18
 
19
- def _get_dataloaders(self, input_data, target_data, target_weights, val_fraction=0.1):
20
  input_data = torch.Tensor(input_data)
21
  target_data = torch.Tensor(target_data)
22
- target_weights = torch.Tensor(target_weights)
23
 
24
- dataset = TensorDataset(input_data, target_data, target_weights)
25
 
26
  trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
27
- loader_train = DataLoader(trainig_dataset, batch_size=64, shuffle = True)
28
  loader_val = DataLoader(val_dataset, batch_size=64, shuffle = True)
29
 
30
  return loader_train, loader_val
31
 
 
32
 
33
- def _loss_function(self,mean, std, logmix, true, target_weights):
34
- log_prob = logmix - 0.5*(mean - true[:,None]).pow(2) / std.pow(2) - torch.log(std)
35
 
36
- log_prob = torch.logsumexp(log_prob, 1)
 
 
37
 
38
- #log_prob = log_prob * target_weights
 
39
  loss = -log_prob.mean()
40
 
41
  return loss
@@ -43,21 +46,25 @@ class Insight_module():
43
  def _to_numpy(self,x):
44
  return x.detach().cpu().numpy()
45
 
46
- def train(self,input_data, target_data, target_weights, nepochs=10, val_fraction=0.1, lr=1e-3 ):
47
  self.model = self.model.train()
48
- loader_train, loader_val = self._get_dataloaders(input_data, target_data, target_weights, val_fraction=0.1)
49
  optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
 
 
50
 
51
  self.model = self.model.to(self.device)
52
 
53
- loss_train, loss_validation = [],[]
 
 
54
 
55
  for epoch in range(nepochs):
56
- for input_data, target_data, target_weights in loader_train:
 
57
 
58
  input_data = input_data.to(self.device)
59
  target_data = target_data.to(self.device)
60
- target_weights = target_weights.to(self.device)
61
 
62
 
63
  optimizer.zero_grad()
@@ -69,32 +76,33 @@ class Insight_module():
69
 
70
  #print(mu,sig,target_data,torch.exp(logmix_coeff))
71
 
72
- loss = self._loss_function(mu, sig, logmix_coeff, target_data,target_weights)
 
73
 
74
  loss.backward()
75
- optimizer.step()
 
 
76
 
77
- loss_train.append(loss.item())
78
 
79
- for input_data, target_data, target_weights in loader_val:
80
 
81
 
82
  input_data = input_data.to(self.device)
83
  target_data = target_data.to(self.device)
84
- target_weights = target_weights.to(self.device)
85
 
86
 
87
  mu, logsig, logmix_coeff = self.model(input_data)
88
  logsig = torch.clamp(logsig,-6,2)
89
  sig = torch.exp(logsig)
90
 
91
- loss_val = self._loss_function(mu, sig, logmix_coeff, target_data, target_weights)
92
- loss_validation.append(loss_val.item())
93
 
94
- print(f'training_loss:{loss}',f'testing_loss:{loss_val}')
95
-
96
- self.loss_train=loss_train
97
- self.loss_validation=loss_validation
98
 
99
 
100
  def get_photoz(self,input_data, target_data):
 
8
  import os
9
  from astropy.table import Table
10
  from scipy.spatial import KDTree
11
+ from scipy.special import erf
12
 
13
  class Insight_module():
14
  """ Define class"""
15
 
16
+ def __init__(self, model, batch_size):
17
  self.model=model
18
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ self.batch_size=batch_size
20
 
21
+ def _get_dataloaders(self, input_data, target_data, val_fraction=0.1):
22
  input_data = torch.Tensor(input_data)
23
  target_data = torch.Tensor(target_data)
 
24
 
25
+ dataset = TensorDataset(input_data, target_data)
26
 
27
  trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
28
+ loader_train = DataLoader(trainig_dataset, batch_size=self.batch_size, shuffle = True)
29
  loader_val = DataLoader(val_dataset, batch_size=64, shuffle = True)
30
 
31
  return loader_train, loader_val
32
 
33
+
34
 
 
 
35
 
36
+ def _loss_function(self,mean, std, logmix, true):
37
+
38
+ logerf = torch.log(erf(true.cpu()[:,None]/(np.sqrt(2)*std.detach().cpu())+1))
39
 
40
+ log_prob = logmix - 0.5*(mean - true[:,None]).pow(2) / std.pow(2) - torch.log(std) #- logerf.to(self.device)
41
+ log_prob = torch.logsumexp(log_prob, 1)
42
  loss = -log_prob.mean()
43
 
44
  return loss
 
46
  def _to_numpy(self,x):
47
  return x.detach().cpu().numpy()
48
 
49
+ def train(self,input_data, target_data, nepochs=10, step_size = 100, val_fraction=0.1, lr=1e-3 ):
50
  self.model = self.model.train()
51
+ loader_train, loader_val = self._get_dataloaders(input_data, target_data, val_fraction=0.1)
52
  optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
53
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma =0.1)
54
+
55
 
56
  self.model = self.model.to(self.device)
57
 
58
+ self.loss_train, self.loss_validation = [],[]
59
+
60
+
61
 
62
  for epoch in range(nepochs):
63
+ for input_data, target_data in loader_train:
64
+ _loss_train, _loss_validation = [],[]
65
 
66
  input_data = input_data.to(self.device)
67
  target_data = target_data.to(self.device)
 
68
 
69
 
70
  optimizer.zero_grad()
 
76
 
77
  #print(mu,sig,target_data,torch.exp(logmix_coeff))
78
 
79
+ loss = self._loss_function(mu, sig, logmix_coeff, target_data)
80
+ _loss_train.append(loss.item())
81
 
82
  loss.backward()
83
+ optimizer.step()
84
+
85
+ scheduler.step()
86
 
87
+ self.loss_train.append(np.mean(_loss_train))
88
 
89
+ for input_data, target_data in loader_val:
90
 
91
 
92
  input_data = input_data.to(self.device)
93
  target_data = target_data.to(self.device)
 
94
 
95
 
96
  mu, logsig, logmix_coeff = self.model(input_data)
97
  logsig = torch.clamp(logsig,-6,2)
98
  sig = torch.exp(logsig)
99
 
100
+ loss_val = self._loss_function(mu, sig, logmix_coeff, target_data)
101
+ _loss_validation.append(loss_val.item())
102
 
103
+ self.loss_validation.append(np.mean(_loss_validation))
104
+
105
+ #print(f'training_loss:{loss}',f'testing_loss:{loss_val}')
 
106
 
107
 
108
  def get_photoz(self,input_data, target_data):
insight/.ipynb_checkpoints/insight_arch-checkpoint.py CHANGED
@@ -8,64 +8,37 @@ class Photoz_network(nn.Module):
8
  nn.Linear(6, 10),
9
  nn.Dropout(dropout_prob),
10
  nn.ReLU(),
11
- nn.Linear(10, 30),
12
  nn.Dropout(dropout_prob),
13
  nn.ReLU(),
14
- nn.Linear(30, 50),
15
  nn.Dropout(dropout_prob),
16
  nn.ReLU(),
17
- nn.Linear(50, 70),
18
  nn.Dropout(dropout_prob),
19
  nn.ReLU(),
20
- nn.Linear(70, 100)
21
  )
22
 
23
  self.measure_mu = nn.Sequential(
24
- nn.Linear(100, 80),
25
  nn.Dropout(dropout_prob),
26
  nn.ReLU(),
27
- nn.Linear(80, 70),
28
- nn.Dropout(dropout_prob),
29
- nn.ReLU(),
30
- nn.Linear(70, 60),
31
- nn.Dropout(dropout_prob),
32
- nn.ReLU(),
33
- nn.Linear(60, 50),
34
- nn.Dropout(dropout_prob),
35
- nn.ReLU(),
36
- nn.Linear(50, num_gauss)
37
  )
38
 
39
  self.measure_coeffs = nn.Sequential(
40
- nn.Linear(100, 80),
41
- nn.Dropout(dropout_prob),
42
- nn.ReLU(),
43
- nn.Linear(80, 70),
44
- nn.Dropout(dropout_prob),
45
- nn.ReLU(),
46
- nn.Linear(70, 60),
47
  nn.Dropout(dropout_prob),
48
  nn.ReLU(),
49
- nn.Linear(60, 50),
50
- nn.Dropout(dropout_prob),
51
- nn.ReLU(),
52
- nn.Linear(50, num_gauss)
53
  )
54
 
55
  self.measure_sigma = nn.Sequential(
56
- nn.Linear(100, 80),
57
- nn.Dropout(dropout_prob),
58
- nn.ReLU(),
59
- nn.Linear(80, 70),
60
- nn.Dropout(dropout_prob),
61
- nn.ReLU(),
62
- nn.Linear(70, 60),
63
- nn.Dropout(dropout_prob),
64
- nn.ReLU(),
65
- nn.Linear(60, 50),
66
  nn.Dropout(dropout_prob),
67
  nn.ReLU(),
68
- nn.Linear(50, num_gauss)
69
  )
70
 
71
  def forward(self, x):
 
8
  nn.Linear(6, 10),
9
  nn.Dropout(dropout_prob),
10
  nn.ReLU(),
11
+ nn.Linear(10, 20),
12
  nn.Dropout(dropout_prob),
13
  nn.ReLU(),
14
+ nn.Linear(20, 50),
15
  nn.Dropout(dropout_prob),
16
  nn.ReLU(),
17
+ nn.Linear(50, 20),
18
  nn.Dropout(dropout_prob),
19
  nn.ReLU(),
20
+ nn.Linear(20, 10)
21
  )
22
 
23
  self.measure_mu = nn.Sequential(
24
+ nn.Linear(10, 20),
25
  nn.Dropout(dropout_prob),
26
  nn.ReLU(),
27
+ nn.Linear(20, num_gauss)
 
 
 
 
 
 
 
 
 
28
  )
29
 
30
  self.measure_coeffs = nn.Sequential(
31
+ nn.Linear(10, 20),
 
 
 
 
 
 
32
  nn.Dropout(dropout_prob),
33
  nn.ReLU(),
34
+ nn.Linear(20, num_gauss)
 
 
 
35
  )
36
 
37
  self.measure_sigma = nn.Sequential(
38
+ nn.Linear(10, 20),
 
 
 
 
 
 
 
 
 
39
  nn.Dropout(dropout_prob),
40
  nn.ReLU(),
41
+ nn.Linear(20, num_gauss)
42
  )
43
 
44
  def forward(self, x):
insight/.ipynb_checkpoints/utils-checkpoint.py CHANGED
@@ -9,9 +9,9 @@ def nmad(data):
9
  def sigma68(data): return 0.5*(pd.Series(data).quantile(q = 0.84) - pd.Series(data).quantile(q = 0.16))
10
 
11
 
12
- def plot_photoz(df, nbins,xvariable,metric, type_bin='bin'):
13
  bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
14
- ydata,xlab = [],[]
15
 
16
 
17
  for k in range(len(bin_edges)-1):
@@ -21,26 +21,37 @@ def plot_photoz(df, nbins,xvariable,metric, type_bin='bin'):
21
  mean_mag = (edge_max + edge_min) / 2
22
 
23
  if type_bin=='bin':
24
- df_plot = df_test[(df_test.imag > edge_min) & (df_test.imag < edge_max)]
25
  elif type_bin=='cum':
26
- df_plot = df_test[(df_test.imag < edge_max)]
27
  else:
28
  raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
29
 
30
 
31
- xlab.append(mean_mag)
32
  if metric=='sig68':
33
  ydata.append(sigma68(df_plot.zwerr))
 
34
  elif metric=='bias':
35
- ydata.append(np.mean(df_plot.zwerr))
 
36
  elif metric=='nmad':
37
  ydata.append(nmad(df_plot.zwerr))
 
38
  elif metric=='outliers':
39
- ydata.append(len(df_plot[np.abs(df_plot.zwerr)>0.15])/len(df_plot))
40
-
41
- plt.plot(xlab,ydata, ls = '-', marker = '.', color = 'navy',lw = 1, label = '')
42
- plt.ylabel(f'{metric}$[\Delta z]$', fontsize = 18)
43
- plt.xlabel(f'{xvariable}', fontsize = 16)
 
 
 
 
 
 
 
 
44
 
45
  plt.xticks(fontsize = 14)
46
  plt.yticks(fontsize = 14)
 
9
  def sigma68(data): return 0.5*(pd.Series(data).quantile(q = 0.84) - pd.Series(data).quantile(q = 0.16))
10
 
11
 
12
+ def plot_photoz_estimates(df, nbins,xvariable,metric, type_bin='bin'):
13
  bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
14
+ ydata,xdata = [],[]
15
 
16
 
17
  for k in range(len(bin_edges)-1):
 
21
  mean_mag = (edge_max + edge_min) / 2
22
 
23
  if type_bin=='bin':
24
+ df_plot = df_test[(df_test[xvariable] > edge_min) & (df_test[xvariable] < edge_max)]
25
  elif type_bin=='cum':
26
+ df_plot = df_test[(df_test[xvariable] < edge_max)]
27
  else:
28
  raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
29
 
30
 
31
+ xdata.append(mean_mag)
32
  if metric=='sig68':
33
  ydata.append(sigma68(df_plot.zwerr))
34
+ ylab=r'$\sigma_{\rm NMAD} [\Delta z]$'
35
  elif metric=='bias':
36
+ ydata.append(np.median(df_plot.zwerr))
37
+ ylab=r'Median $[\Delta z]$'
38
  elif metric=='nmad':
39
  ydata.append(nmad(df_plot.zwerr))
40
+ ylab=r'$\sigma_{\rm NMAD} [\Delta z]$'
41
  elif metric=='outliers':
42
+ ydata.append(len(df_plot[np.abs(df_plot.zwerr)>0.15])/len(df_plot) *100)
43
+ ylab=r'$\eta$ [%]'
44
+
45
+ if xvariable=='VISmag':
46
+ xlab='VIS'
47
+ elif xvariable=='zs':
48
+ xlab=r'$z_{\rm spec}$'
49
+ elif xvariable=='z':
50
+ xlab=r'$z$'
51
+
52
+ plt.plot(xdata,ydata, ls = '-', marker = '.', color = 'navy',lw = 1, label = '')
53
+ plt.ylabel(f'{ylab}', fontsize = 18)
54
+ plt.xlabel(f'{xlab}', fontsize = 16)
55
 
56
  plt.xticks(fontsize = 14)
57
  plt.yticks(fontsize = 14)
insight/__pycache__/archive.cpython-310.pyc CHANGED
Binary files a/insight/__pycache__/archive.cpython-310.pyc and b/insight/__pycache__/archive.cpython-310.pyc differ
 
insight/__pycache__/archive.cpython-39.pyc CHANGED
Binary files a/insight/__pycache__/archive.cpython-39.pyc and b/insight/__pycache__/archive.cpython-39.pyc differ
 
insight/__pycache__/insight.cpython-310.pyc CHANGED
Binary files a/insight/__pycache__/insight.cpython-310.pyc and b/insight/__pycache__/insight.cpython-310.pyc differ
 
insight/__pycache__/insight.cpython-39.pyc CHANGED
Binary files a/insight/__pycache__/insight.cpython-39.pyc and b/insight/__pycache__/insight.cpython-39.pyc differ
 
insight/__pycache__/insight_arch.cpython-310.pyc CHANGED
Binary files a/insight/__pycache__/insight_arch.cpython-310.pyc and b/insight/__pycache__/insight_arch.cpython-310.pyc differ
 
insight/__pycache__/insight_arch.cpython-39.pyc CHANGED
Binary files a/insight/__pycache__/insight_arch.cpython-39.pyc and b/insight/__pycache__/insight_arch.cpython-39.pyc differ
 
insight/__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/insight/__pycache__/utils.cpython-310.pyc and b/insight/__pycache__/utils.cpython-310.pyc differ
 
insight/__pycache__/utils.cpython-39.pyc CHANGED
Binary files a/insight/__pycache__/utils.cpython-39.pyc and b/insight/__pycache__/utils.cpython-39.pyc differ
 
insight/archive.py CHANGED
@@ -13,7 +13,7 @@ rcParams["font.family"] = "STIXGeneral"
13
 
14
 
15
  class archive():
16
- def __init__(self, path, aperture=2, drop_stars=True, clean_photometry=True, convert_colors=True, extinction_corr=True, only_zspec=True, reliable_zspec=True):
17
 
18
  self.aperture = aperture
19
 
@@ -39,30 +39,28 @@ class archive():
39
  hdu_list = fits.open(os.path.join(path,filename_valid))
40
  cat_test = Table(hdu_list[1].data).to_pandas()
41
 
 
 
 
42
  gold_sample = pd.read_csv(os.path.join(path,filename_gold))
43
 
44
  #cat_test = self._match_gold_sample(cat_test,gold_sample)
45
 
46
  if drop_stars==True:
47
  cat = cat[cat.mu_class_L07==1]
 
48
 
49
  if clean_photometry==True:
50
  cat = self._clean_photometry(cat)
51
  cat_test = self._clean_photometry(cat_test)
52
 
53
- self._get_loss_weights(cat)
54
 
55
  cat = cat[cat.w_Q_f_S15>0]
56
-
57
- self._set_training_data(cat, only_zspec=only_zspec, reliable_zspec=reliable_zspec, extinction_corr=extinction_corr, convert_colors=convert_colors)
58
-
59
-
60
- self._set_testing_data(cat_test, only_zspec=only_zspec, reliable_zspec='Total', extinction_corr=extinction_corr, convert_colors=convert_colors)
61
 
62
  self._get_loss_weights(cat)
63
-
64
- #self.cat_test=cat_test
65
- #self.cat_train=cat
66
 
67
  def _extract_fluxes(self,catalogue):
68
  columns_f = [f'FLUX_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
@@ -100,13 +98,9 @@ class archive():
100
  catalogue = catalogue[catalogue.z_spec_S15>0]
101
  return catalogue
102
 
103
- def _clean_zspec_sample(self,catalogue ,kind=None):
104
- if kind==None:
105
- return catalogue
106
- elif kind=='Total':
107
- return catalogue[catalogue['reliable_S15']>0]
108
- elif kind=='Partial':
109
- return catalogue[(catalogue['w_Q_f_S15']>0.5)]
110
 
111
  def _map_weight(self,Qz):
112
  for key, value in self.weight_dict.items():
@@ -134,11 +128,11 @@ class archive():
134
  return catalogue_valid
135
 
136
 
137
- def _set_training_data(self,catalogue, only_zspec=True, reliable_zspec=True, extinction_corr=True, convert_colors=True):
138
 
139
  if only_zspec:
140
  catalogue = self._take_only_zspec(catalogue, cat_flag='Calib')
141
- catalogue = self._clean_zspec_sample(catalogue, kind=reliable_zspec)
142
 
143
  self.cat_train=catalogue
144
  f, ferr = self._extract_fluxes(catalogue)
@@ -159,11 +153,11 @@ class archive():
159
  self.target_z_train = catalogue['z_spec_S15'].values
160
  self.target_qz_train = catalogue['w_Q_f_S15'].values
161
 
162
- def _set_testing_data(self,catalogue, only_zspec=True, reliable_zspec=True, extinction_corr=True, convert_colors=True):
163
 
164
  if only_zspec:
165
  catalogue = self._take_only_zspec(catalogue, cat_flag='Valid')
166
- catalogue = self._clean_zspec_sample(catalogue, kind=reliable_zspec)
167
 
168
  self.cat_test=catalogue
169
 
 
13
 
14
 
15
  class archive():
16
+ def __init__(self, path, aperture=2, drop_stars=True, clean_photometry=True, convert_colors=True, extinction_corr=True, only_zspec=True, Qz_cut=1):
17
 
18
  self.aperture = aperture
19
 
 
39
  hdu_list = fits.open(os.path.join(path,filename_valid))
40
  cat_test = Table(hdu_list[1].data).to_pandas()
41
 
42
+ self._get_loss_weights(cat)
43
+ self._get_loss_weights(cat_test)
44
+
45
  gold_sample = pd.read_csv(os.path.join(path,filename_gold))
46
 
47
  #cat_test = self._match_gold_sample(cat_test,gold_sample)
48
 
49
  if drop_stars==True:
50
  cat = cat[cat.mu_class_L07==1]
51
+ cat_test = cat_test[cat_test.mu_class_L07==1]
52
 
53
  if clean_photometry==True:
54
  cat = self._clean_photometry(cat)
55
  cat_test = self._clean_photometry(cat_test)
56
 
 
57
 
58
  cat = cat[cat.w_Q_f_S15>0]
59
+
60
+ self._set_training_data(cat, only_zspec=only_zspec, extinction_corr=extinction_corr, convert_colors=convert_colors,Qz_cut=Qz_cut)
61
+ self._set_testing_data(cat_test, only_zspec=only_zspec, extinction_corr=extinction_corr, convert_colors=convert_colors)
 
 
62
 
63
  self._get_loss_weights(cat)
 
 
 
64
 
65
  def _extract_fluxes(self,catalogue):
66
  columns_f = [f'FLUX_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
 
98
  catalogue = catalogue[catalogue.z_spec_S15>0]
99
  return catalogue
100
 
101
+ def _clean_zspec_sample(self,catalogue ,Qz_cut):
102
+ catalogue = catalogue[catalogue.w_Q_f_S15>=Qz_cut]
103
+ return catalogue
 
 
 
 
104
 
105
  def _map_weight(self,Qz):
106
  for key, value in self.weight_dict.items():
 
128
  return catalogue_valid
129
 
130
 
131
+ def _set_training_data(self,catalogue, only_zspec=True, extinction_corr=True, convert_colors=True,Qz_cut=1):
132
 
133
  if only_zspec:
134
  catalogue = self._take_only_zspec(catalogue, cat_flag='Calib')
135
+ catalogue = self._clean_zspec_sample(catalogue, Qz_cut=Qz_cut)
136
 
137
  self.cat_train=catalogue
138
  f, ferr = self._extract_fluxes(catalogue)
 
153
  self.target_z_train = catalogue['z_spec_S15'].values
154
  self.target_qz_train = catalogue['w_Q_f_S15'].values
155
 
156
+ def _set_testing_data(self,catalogue, only_zspec=True, extinction_corr=True, convert_colors=True):
157
 
158
  if only_zspec:
159
  catalogue = self._take_only_zspec(catalogue, cat_flag='Valid')
160
+ catalogue = self._clean_zspec_sample(catalogue, Qz_cut=1)
161
 
162
  self.cat_test=catalogue
163
 
insight/insight.py CHANGED
@@ -8,34 +8,37 @@ from astropy.io import fits
8
  import os
9
  from astropy.table import Table
10
  from scipy.spatial import KDTree
 
11
 
12
  class Insight_module():
13
  """ Define class"""
14
 
15
- def __init__(self, model):
16
  self.model=model
17
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
18
 
19
- def _get_dataloaders(self, input_data, target_data, target_weights, val_fraction=0.1):
20
  input_data = torch.Tensor(input_data)
21
  target_data = torch.Tensor(target_data)
22
- target_weights = torch.Tensor(target_weights)
23
 
24
- dataset = TensorDataset(input_data, target_data, target_weights)
25
 
26
  trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
27
- loader_train = DataLoader(trainig_dataset, batch_size=64, shuffle = True)
28
  loader_val = DataLoader(val_dataset, batch_size=64, shuffle = True)
29
 
30
  return loader_train, loader_val
31
 
 
32
 
33
- def _loss_function(self,mean, std, logmix, true, target_weights):
34
- log_prob = logmix - 0.5*(mean - true[:,None]).pow(2) / std.pow(2) - torch.log(std)
35
 
36
- log_prob = torch.logsumexp(log_prob, 1)
 
 
37
 
38
- #log_prob = log_prob * target_weights
 
39
  loss = -log_prob.mean()
40
 
41
  return loss
@@ -43,21 +46,25 @@ class Insight_module():
43
  def _to_numpy(self,x):
44
  return x.detach().cpu().numpy()
45
 
46
- def train(self,input_data, target_data, target_weights, nepochs=10, val_fraction=0.1, lr=1e-3 ):
47
  self.model = self.model.train()
48
- loader_train, loader_val = self._get_dataloaders(input_data, target_data, target_weights, val_fraction=0.1)
49
  optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
 
 
50
 
51
  self.model = self.model.to(self.device)
52
 
53
- loss_train, loss_validation = [],[]
 
 
54
 
55
  for epoch in range(nepochs):
56
- for input_data, target_data, target_weights in loader_train:
 
57
 
58
  input_data = input_data.to(self.device)
59
  target_data = target_data.to(self.device)
60
- target_weights = target_weights.to(self.device)
61
 
62
 
63
  optimizer.zero_grad()
@@ -69,32 +76,33 @@ class Insight_module():
69
 
70
  #print(mu,sig,target_data,torch.exp(logmix_coeff))
71
 
72
- loss = self._loss_function(mu, sig, logmix_coeff, target_data,target_weights)
 
73
 
74
  loss.backward()
75
- optimizer.step()
 
 
76
 
77
- loss_train.append(loss.item())
78
 
79
- for input_data, target_data, target_weights in loader_val:
80
 
81
 
82
  input_data = input_data.to(self.device)
83
  target_data = target_data.to(self.device)
84
- target_weights = target_weights.to(self.device)
85
 
86
 
87
  mu, logsig, logmix_coeff = self.model(input_data)
88
  logsig = torch.clamp(logsig,-6,2)
89
  sig = torch.exp(logsig)
90
 
91
- loss_val = self._loss_function(mu, sig, logmix_coeff, target_data, target_weights)
92
- loss_validation.append(loss_val.item())
93
 
94
- print(f'training_loss:{loss}',f'testing_loss:{loss_val}')
95
-
96
- self.loss_train=loss_train
97
- self.loss_validation=loss_validation
98
 
99
 
100
  def get_photoz(self,input_data, target_data):
 
8
  import os
9
  from astropy.table import Table
10
  from scipy.spatial import KDTree
11
+ from scipy.special import erf
12
 
13
  class Insight_module():
14
  """ Define class"""
15
 
16
+ def __init__(self, model, batch_size):
17
  self.model=model
18
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ self.batch_size=batch_size
20
 
21
+ def _get_dataloaders(self, input_data, target_data, val_fraction=0.1):
22
  input_data = torch.Tensor(input_data)
23
  target_data = torch.Tensor(target_data)
 
24
 
25
+ dataset = TensorDataset(input_data, target_data)
26
 
27
  trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
28
+ loader_train = DataLoader(trainig_dataset, batch_size=self.batch_size, shuffle = True)
29
  loader_val = DataLoader(val_dataset, batch_size=64, shuffle = True)
30
 
31
  return loader_train, loader_val
32
 
33
+
34
 
 
 
35
 
36
+ def _loss_function(self,mean, std, logmix, true):
37
+
38
+ logerf = torch.log(erf(true.cpu()[:,None]/(np.sqrt(2)*std.detach().cpu())+1))
39
 
40
+ log_prob = logmix - 0.5*(mean - true[:,None]).pow(2) / std.pow(2) - torch.log(std) #- logerf.to(self.device)
41
+ log_prob = torch.logsumexp(log_prob, 1)
42
  loss = -log_prob.mean()
43
 
44
  return loss
 
46
  def _to_numpy(self,x):
47
  return x.detach().cpu().numpy()
48
 
49
+ def train(self,input_data, target_data, nepochs=10, step_size = 100, val_fraction=0.1, lr=1e-3 ):
50
  self.model = self.model.train()
51
+ loader_train, loader_val = self._get_dataloaders(input_data, target_data, val_fraction=0.1)
52
  optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
53
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma =0.1)
54
+
55
 
56
  self.model = self.model.to(self.device)
57
 
58
+ self.loss_train, self.loss_validation = [],[]
59
+
60
+
61
 
62
  for epoch in range(nepochs):
63
+ for input_data, target_data in loader_train:
64
+ _loss_train, _loss_validation = [],[]
65
 
66
  input_data = input_data.to(self.device)
67
  target_data = target_data.to(self.device)
 
68
 
69
 
70
  optimizer.zero_grad()
 
76
 
77
  #print(mu,sig,target_data,torch.exp(logmix_coeff))
78
 
79
+ loss = self._loss_function(mu, sig, logmix_coeff, target_data)
80
+ _loss_train.append(loss.item())
81
 
82
  loss.backward()
83
+ optimizer.step()
84
+
85
+ scheduler.step()
86
 
87
+ self.loss_train.append(np.mean(_loss_train))
88
 
89
+ for input_data, target_data in loader_val:
90
 
91
 
92
  input_data = input_data.to(self.device)
93
  target_data = target_data.to(self.device)
 
94
 
95
 
96
  mu, logsig, logmix_coeff = self.model(input_data)
97
  logsig = torch.clamp(logsig,-6,2)
98
  sig = torch.exp(logsig)
99
 
100
+ loss_val = self._loss_function(mu, sig, logmix_coeff, target_data)
101
+ _loss_validation.append(loss_val.item())
102
 
103
+ self.loss_validation.append(np.mean(_loss_validation))
104
+
105
+ #print(f'training_loss:{loss}',f'testing_loss:{loss_val}')
 
106
 
107
 
108
  def get_photoz(self,input_data, target_data):
insight/insight_arch.py CHANGED
@@ -8,64 +8,37 @@ class Photoz_network(nn.Module):
8
  nn.Linear(6, 10),
9
  nn.Dropout(dropout_prob),
10
  nn.ReLU(),
11
- nn.Linear(10, 30),
12
  nn.Dropout(dropout_prob),
13
  nn.ReLU(),
14
- nn.Linear(30, 50),
15
  nn.Dropout(dropout_prob),
16
  nn.ReLU(),
17
- nn.Linear(50, 70),
18
  nn.Dropout(dropout_prob),
19
  nn.ReLU(),
20
- nn.Linear(70, 100)
21
  )
22
 
23
  self.measure_mu = nn.Sequential(
24
- nn.Linear(100, 80),
25
  nn.Dropout(dropout_prob),
26
  nn.ReLU(),
27
- nn.Linear(80, 70),
28
- nn.Dropout(dropout_prob),
29
- nn.ReLU(),
30
- nn.Linear(70, 60),
31
- nn.Dropout(dropout_prob),
32
- nn.ReLU(),
33
- nn.Linear(60, 50),
34
- nn.Dropout(dropout_prob),
35
- nn.ReLU(),
36
- nn.Linear(50, num_gauss)
37
  )
38
 
39
  self.measure_coeffs = nn.Sequential(
40
- nn.Linear(100, 80),
41
- nn.Dropout(dropout_prob),
42
- nn.ReLU(),
43
- nn.Linear(80, 70),
44
- nn.Dropout(dropout_prob),
45
- nn.ReLU(),
46
- nn.Linear(70, 60),
47
  nn.Dropout(dropout_prob),
48
  nn.ReLU(),
49
- nn.Linear(60, 50),
50
- nn.Dropout(dropout_prob),
51
- nn.ReLU(),
52
- nn.Linear(50, num_gauss)
53
  )
54
 
55
  self.measure_sigma = nn.Sequential(
56
- nn.Linear(100, 80),
57
- nn.Dropout(dropout_prob),
58
- nn.ReLU(),
59
- nn.Linear(80, 70),
60
- nn.Dropout(dropout_prob),
61
- nn.ReLU(),
62
- nn.Linear(70, 60),
63
- nn.Dropout(dropout_prob),
64
- nn.ReLU(),
65
- nn.Linear(60, 50),
66
  nn.Dropout(dropout_prob),
67
  nn.ReLU(),
68
- nn.Linear(50, num_gauss)
69
  )
70
 
71
  def forward(self, x):
 
8
  nn.Linear(6, 10),
9
  nn.Dropout(dropout_prob),
10
  nn.ReLU(),
11
+ nn.Linear(10, 20),
12
  nn.Dropout(dropout_prob),
13
  nn.ReLU(),
14
+ nn.Linear(20, 50),
15
  nn.Dropout(dropout_prob),
16
  nn.ReLU(),
17
+ nn.Linear(50, 20),
18
  nn.Dropout(dropout_prob),
19
  nn.ReLU(),
20
+ nn.Linear(20, 10)
21
  )
22
 
23
  self.measure_mu = nn.Sequential(
24
+ nn.Linear(10, 20),
25
  nn.Dropout(dropout_prob),
26
  nn.ReLU(),
27
+ nn.Linear(20, num_gauss)
 
 
 
 
 
 
 
 
 
28
  )
29
 
30
  self.measure_coeffs = nn.Sequential(
31
+ nn.Linear(10, 20),
 
 
 
 
 
 
32
  nn.Dropout(dropout_prob),
33
  nn.ReLU(),
34
+ nn.Linear(20, num_gauss)
 
 
 
35
  )
36
 
37
  self.measure_sigma = nn.Sequential(
38
+ nn.Linear(10, 20),
 
 
 
 
 
 
 
 
 
39
  nn.Dropout(dropout_prob),
40
  nn.ReLU(),
41
+ nn.Linear(20, num_gauss)
42
  )
43
 
44
  def forward(self, x):
insight/utils.py CHANGED
@@ -8,10 +8,9 @@ def nmad(data):
8
 
9
  def sigma68(data): return 0.5*(pd.Series(data).quantile(q = 0.84) - pd.Series(data).quantile(q = 0.16))
10
 
11
-
12
- def plot_photoz(df, nbins,xvariable,metric, type_bin='bin'):
13
  bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
14
- ydata,xlab = [],[]
15
 
16
 
17
  for k in range(len(bin_edges)-1):
@@ -21,26 +20,37 @@ def plot_photoz(df, nbins,xvariable,metric, type_bin='bin'):
21
  mean_mag = (edge_max + edge_min) / 2
22
 
23
  if type_bin=='bin':
24
- df_plot = df_test[(df_test.imag > edge_min) & (df_test.imag < edge_max)]
25
  elif type_bin=='cum':
26
- df_plot = df_test[(df_test.imag < edge_max)]
27
  else:
28
  raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
29
 
30
 
31
- xlab.append(mean_mag)
32
  if metric=='sig68':
33
  ydata.append(sigma68(df_plot.zwerr))
 
34
  elif metric=='bias':
35
- ydata.append(np.mean(df_plot.zwerr))
 
36
  elif metric=='nmad':
37
  ydata.append(nmad(df_plot.zwerr))
 
38
  elif metric=='outliers':
39
- ydata.append(len(df_plot[np.abs(df_plot.zwerr)>0.15])/len(df_plot))
 
 
 
 
 
 
 
 
40
 
41
- plt.plot(xlab,ydata, ls = '-', marker = '.', color = 'navy',lw = 1, label = '')
42
- plt.ylabel(f'{metric}$[\Delta z]$', fontsize = 18)
43
- plt.xlabel(f'{xvariable}', fontsize = 16)
44
 
45
  plt.xticks(fontsize = 14)
46
  plt.yticks(fontsize = 14)
@@ -48,4 +58,23 @@ def plot_photoz(df, nbins,xvariable,metric, type_bin='bin'):
48
  plt.grid(False)
49
 
50
  plt.show()
51
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def sigma68(data): return 0.5*(pd.Series(data).quantile(q = 0.84) - pd.Series(data).quantile(q = 0.16))
10
 
11
+ def plot_photoz_estimates(df, nbins,xvariable,metric, type_bin='bin'):
 
12
  bin_edges = stats.mstats.mquantiles(df[xvariable].values, np.linspace(0.1,1,nbins))
13
+ ydata,xdata = [],[]
14
 
15
 
16
  for k in range(len(bin_edges)-1):
 
20
  mean_mag = (edge_max + edge_min) / 2
21
 
22
  if type_bin=='bin':
23
+ df_plot = df[(df[xvariable] > edge_min) & (df[xvariable] < edge_max)]
24
  elif type_bin=='cum':
25
+ df_plot = df[(df[xvariable] < edge_max)]
26
  else:
27
  raise ValueError("Only type_bin=='bin' for binned and 'cum' for cumulative are supported")
28
 
29
 
30
+ xdata.append(mean_mag)
31
  if metric=='sig68':
32
  ydata.append(sigma68(df_plot.zwerr))
33
+ ylab=r'$\sigma_{\rm NMAD} [\Delta z]$'
34
  elif metric=='bias':
35
+ ydata.append(np.median(df_plot.zwerr))
36
+ ylab=r'Median $[\Delta z]$'
37
  elif metric=='nmad':
38
  ydata.append(nmad(df_plot.zwerr))
39
+ ylab=r'$\sigma_{\rm NMAD} [\Delta z]$'
40
  elif metric=='outliers':
41
+ ydata.append(len(df_plot[np.abs(df_plot.zwerr)>0.15])/len(df_plot) *100)
42
+ ylab=r'$\eta$ [%]'
43
+
44
+ if xvariable=='VISmag':
45
+ xlab='VIS'
46
+ elif xvariable=='zs':
47
+ xlab=r'$z_{\rm spec}$'
48
+ elif xvariable=='z':
49
+ xlab=r'$z$'
50
 
51
+ plt.plot(xdata,ydata, ls = '-', marker = '.', color = 'navy',lw = 1, label = '')
52
+ plt.ylabel(f'{ylab}', fontsize = 18)
53
+ plt.xlabel(f'{xlab}', fontsize = 16)
54
 
55
  plt.xticks(fontsize = 14)
56
  plt.yticks(fontsize = 14)
 
58
  plt.grid(False)
59
 
60
  plt.show()
61
+
62
+ return
63
+
64
+
65
+ def plot_nz(df, bins=np.arange(0,5,0.2)):
66
+ kwargs=dict( bins=bins,alpha=0.5)
67
+ plt.hist(df.zs.values, color='grey', ls='-' ,**kwargs)
68
+ counts, _, =np.histogram(df.z.values, bins=bins)
69
+
70
+ plt.plot((bins[:-1]+bins[1:])*0.5,counts, color ='purple')
71
+
72
+ #plt.legend(fontsize=14)
73
+ plt.xlabel(r'Redshift', fontsize=14)
74
+ plt.ylabel(r'Counts', fontsize=14)
75
+ plt.yscale('log')
76
+
77
+ plt.show()
78
+
79
+ return
80
+