Spaces:
Runtime error
Runtime error
Commit
·
a57776c
1
Parent(s):
e50129b
changes to enable training with any catalog
Browse files- notebooks/NMAD.py +10 -5
- pyproject.toml +1 -0
- temps/archive.py +57 -115
- temps/temps.py +3 -1
notebooks/NMAD.py
CHANGED
@@ -33,6 +33,7 @@ import os
|
|
33 |
from astropy.io import fits
|
34 |
from astropy.table import Table
|
35 |
import torch
|
|
|
36 |
|
37 |
# %%
|
38 |
#matplotlib settings
|
@@ -63,6 +64,8 @@ eval_methods=True
|
|
63 |
#define here the directory containing the photometric catalogues
|
64 |
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
|
65 |
modules_dir = Path('../data/models/')
|
|
|
|
|
66 |
|
67 |
# %%
|
68 |
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
@@ -83,9 +86,11 @@ VISmag = cat['MAG_VIS']
|
|
83 |
zsflag = cat['reliable_S15']
|
84 |
|
85 |
# %%
|
86 |
-
photoz_archive = Archive(
|
87 |
-
|
88 |
-
|
|
|
|
|
89 |
|
90 |
# %% [markdown]
|
91 |
# ### EVALUATE USING TRAINED MODELS
|
@@ -97,9 +102,9 @@ if eval_methods:
|
|
97 |
for il, lab in enumerate(['z','L15','DA']):
|
98 |
|
99 |
nn_features = EncoderPhotometry()
|
100 |
-
nn_features.load_state_dict(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
|
101 |
nn_z = MeasureZ(num_gauss=6)
|
102 |
-
nn_z.load_state_dict(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
|
103 |
|
104 |
temps_module = TempsModule(nn_features, nn_z)
|
105 |
|
|
|
33 |
from astropy.io import fits
|
34 |
from astropy.table import Table
|
35 |
import torch
|
36 |
+
from pathlib import Path
|
37 |
|
38 |
# %%
|
39 |
#matplotlib settings
|
|
|
64 |
#define here the directory containing the photometric catalogues
|
65 |
parent_dir = Path('/data/astro/scratch/lcabayol/insight/data/Euclid_EXT_MER_PHZ_DC2_v1.5')
|
66 |
modules_dir = Path('../data/models/')
|
67 |
+
filename_calib = 'euclid_cosmos_DC2_S1_v2.1_calib_clean.fits'
|
68 |
+
filename_valid = 'euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
69 |
|
70 |
# %%
|
71 |
filename_valid='euclid_cosmos_DC2_S1_v2.1_valid_matched.fits'
|
|
|
86 |
zsflag = cat['reliable_S15']
|
87 |
|
88 |
# %%
|
89 |
+
photoz_archive = Archive(path_calib = parent_dir/filename_calib,
|
90 |
+
path_valid = parent_dir/filename_valid,
|
91 |
+
only_zspec=False)
|
92 |
+
f = photoz_archive._extract_fluxes(catalogue= cat)
|
93 |
+
col = photoz_archive._to_colors(f)
|
94 |
|
95 |
# %% [markdown]
|
96 |
# ### EVALUATE USING TRAINED MODELS
|
|
|
102 |
for il, lab in enumerate(['z','L15','DA']):
|
103 |
|
104 |
nn_features = EncoderPhotometry()
|
105 |
+
nn_features.load_state_dict(torch.load(modules_dir / f'modelF_{lab}.pt',map_location=torch.device('cpu')))
|
106 |
nn_z = MeasureZ(num_gauss=6)
|
107 |
+
nn_z.load_state_dict(torch.load(modules_dir / f'modelZ_{lab}.pt',map_location=torch.device('cpu')))
|
108 |
|
109 |
temps_module = TempsModule(nn_features, nn_z)
|
110 |
|
pyproject.toml
CHANGED
@@ -28,6 +28,7 @@ dependencies = [
|
|
28 |
"pathlib",
|
29 |
"astropy",
|
30 |
"gradio",
|
|
|
31 |
]
|
32 |
|
33 |
classifiers = [
|
|
|
28 |
"pathlib",
|
29 |
"astropy",
|
30 |
"gradio",
|
31 |
+
"jupytext"
|
32 |
]
|
33 |
|
34 |
classifiers = [
|
temps/archive.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
from astropy.io import fits
|
|
|
4 |
from scipy.spatial import KDTree
|
5 |
from matplotlib import pyplot as plt
|
6 |
from matplotlib import rcParams
|
@@ -12,37 +13,47 @@ rcParams["mathtext.fontset"] = "stix"
|
|
12 |
rcParams["font.family"] = "STIXGeneral"
|
13 |
|
14 |
class Archive:
|
15 |
-
def __init__(self,
|
16 |
-
|
|
|
17 |
drop_stars=True,
|
18 |
clean_photometry=True,
|
19 |
convert_colors=True,
|
20 |
extinction_corr=True,
|
21 |
only_zspec=True,
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
|
26 |
logger.info("Starting archive")
|
27 |
-
self.aperture = aperture
|
28 |
-
self.all_apertures = all_apertures
|
29 |
self.flags_kept = flags_kept
|
|
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
with fits.open(path_valid) as hdu_list:
|
45 |
-
cat_test = Table(hdu_list[1].data).to_pandas()
|
46 |
|
47 |
# Store the catalogs for later use
|
48 |
self.cat = cat
|
@@ -85,57 +96,18 @@ class Archive:
|
|
85 |
|
86 |
|
87 |
def _extract_fluxes(self,catalogue):
|
88 |
-
|
89 |
-
|
90 |
-
columns_ferr = [f'FLUXERR_{x}_{a}' for a in [1,2,3] for x in ['G','R','I','Z','Y','J','H'] ]
|
91 |
-
else:
|
92 |
-
columns_f = [f'FLUX_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
|
93 |
-
columns_ferr = [f'FLUXERR_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
|
94 |
-
|
95 |
-
f = catalogue[columns_f].values
|
96 |
-
ferr = catalogue[columns_ferr].values
|
97 |
-
return f, ferr
|
98 |
-
|
99 |
-
def _extract_magnitudes(self,catalogue):
|
100 |
-
if self.all_apertures:
|
101 |
-
columns_m = [f'MAG_{x}_{a}' for a in [1,2,3] for x in ['G','R','I','Z','Y','J','H']]
|
102 |
-
columns_merr = [f'MAGERR_{x}_{a}' for a in [1,2,3] for x in ['G','R','I','Z','Y','J','H'] ]
|
103 |
-
else:
|
104 |
-
columns_m = [f'MAG_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
|
105 |
-
columns_merr = [f'MAGERR_{x}_{self.aperture}' for x in ['G','R','I','Z','Y','J','H']]
|
106 |
|
107 |
-
|
108 |
-
merr = catalogue[columns_merr].values
|
109 |
-
return m, merr
|
110 |
-
|
111 |
-
def _to_colors(self, flux, fluxerr):
|
112 |
""" Convert fluxes to colors"""
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
for a in range(3):
|
117 |
-
lim1 = 7*a
|
118 |
-
lim2 = 7*(a+1)
|
119 |
-
c = flux[:,lim1:(lim2-1)] / flux[:,(lim1+1):lim2]
|
120 |
-
cerr = np.sqrt((fluxerr[:,lim1:(lim2-1)]/ flux[:,(lim1+1):lim2])**2 + (flux[:,lim1:(lim2-1)] / flux[:,(lim1+1):lim2]**2)**2 * fluxerr[:,(lim1+1):lim2]**2)
|
121 |
-
|
122 |
-
if a==0:
|
123 |
-
color = c
|
124 |
-
color_err = cerr
|
125 |
-
else:
|
126 |
-
color = np.concatenate((color,c),axis=1)
|
127 |
-
color_err = np.concatenate((color_err,cerr),axis=1)
|
128 |
-
|
129 |
-
else:
|
130 |
-
color = flux[:,:-1] / flux[:,1:]
|
131 |
-
|
132 |
-
color_err = np.sqrt((fluxerr[:,:-1]/ flux[:,1:])**2 + (flux[:,:-1] / flux[:,1:]**2)**2 * fluxerr[:,1:]**2)
|
133 |
-
return color,color_err
|
134 |
|
135 |
def _set_combiend_target(self, catalogue):
|
136 |
-
catalogue['target_z'] = catalogue.apply(lambda row: row['
|
137 |
-
if row['
|
138 |
-
else row['
|
139 |
|
140 |
return catalogue
|
141 |
|
@@ -148,13 +120,7 @@ class Archive:
|
|
148 |
|
149 |
def _correct_extinction(self,catalogue, f, return_ext_corr=False):
|
150 |
"""Corrects for extinction"""
|
151 |
-
|
152 |
-
if self.all_apertures:
|
153 |
-
ext_correction = catalogue[ext_correction_cols].values
|
154 |
-
ext_correction = np.concatenate((ext_correction,ext_correction,ext_correction),axis=1)
|
155 |
-
else:
|
156 |
-
ext_correction = catalogue[ext_correction_cols].values
|
157 |
-
|
158 |
f = f * ext_correction
|
159 |
if return_ext_corr:
|
160 |
return f, ext_correction
|
@@ -164,14 +130,14 @@ class Archive:
|
|
164 |
def _select_only_zspec(self,catalogue,cat_flag=None):
|
165 |
"""Selects only galaxies with spectroscopic redshift"""
|
166 |
if cat_flag=='Calib':
|
167 |
-
catalogue = catalogue[catalogue.
|
168 |
elif cat_flag=='Valid':
|
169 |
-
catalogue = catalogue[catalogue.
|
170 |
return catalogue
|
171 |
|
172 |
def _exclude_only_zspec(self,catalogue):
|
173 |
"""Selects only galaxies without spectroscopic redshift"""
|
174 |
-
catalogue = catalogue[(catalogue.
|
175 |
return catalogue
|
176 |
|
177 |
def _select_L15_sample(self,catalogue):
|
@@ -187,7 +153,7 @@ class Archive:
|
|
187 |
if cat_flag=='Calib':
|
188 |
catalogue = catalogue[catalogue.target_z>0]
|
189 |
elif cat_flag=='Valid':
|
190 |
-
catalogue = catalogue[catalogue.
|
191 |
return catalogue
|
192 |
|
193 |
def _clean_zspec_sample(self,catalogue ,flags_kept=[3,3.1,3.4,3.5,4]):
|
@@ -222,7 +188,7 @@ class Archive:
|
|
222 |
def _set_training_data(self,catalogue, catalogue_da, only_zspec=True, extinction_corr=True, convert_colors=True):
|
223 |
|
224 |
cat_da = self._exclude_only_zspec(catalogue_da)
|
225 |
-
target_z_train_DA = cat_da['
|
226 |
|
227 |
|
228 |
if only_zspec:
|
@@ -235,11 +201,10 @@ class Archive:
|
|
235 |
|
236 |
|
237 |
self.cat_train=catalogue
|
238 |
-
f
|
239 |
-
|
240 |
-
f_DA, ferr_DA = self._extract_fluxes(cat_da)
|
241 |
idx = np.random.randint(0, len(f_DA), len(f))
|
242 |
-
f_DA
|
243 |
target_z_train_DA = target_z_train_DA[idx]
|
244 |
self.target_z_train_DA = target_z_train_DA
|
245 |
|
@@ -250,21 +215,17 @@ class Archive:
|
|
250 |
|
251 |
if convert_colors==True:
|
252 |
logger.info("Converting to colors")
|
253 |
-
col
|
254 |
-
col_DA
|
255 |
|
256 |
self.phot_train = col
|
257 |
-
self.photerr_train = colerr
|
258 |
self.phot_train_DA = col_DA
|
259 |
-
self.photerr_train_DA = colerr_DA
|
260 |
else:
|
261 |
self.phot_train = f
|
262 |
-
self.photerr_train = ferr
|
263 |
self.phot_train_DA = f_DA
|
264 |
-
self.photerr_train_DA = ferr_DA
|
265 |
|
266 |
if only_zspec==True:
|
267 |
-
self.target_z_train = catalogue['
|
268 |
else:
|
269 |
self.target_z_train = catalogue['target_z'].values
|
270 |
|
@@ -275,7 +236,7 @@ class Archive:
|
|
275 |
if target=='specz':
|
276 |
catalogue = self._select_only_zspec(catalogue, cat_flag='Valid')
|
277 |
catalogue = self._clean_zspec_sample(catalogue)
|
278 |
-
self.target_z_test = catalogue['
|
279 |
|
280 |
elif target=='L15':
|
281 |
catalogue = self._select_L15_sample(catalogue)
|
@@ -284,45 +245,26 @@ class Archive:
|
|
284 |
|
285 |
self.cat_test=catalogue
|
286 |
|
287 |
-
f
|
288 |
|
289 |
if extinction_corr==True:
|
290 |
f = self._correct_extinction(catalogue,f)
|
291 |
|
292 |
if convert_colors==True:
|
293 |
-
col
|
294 |
self.phot_test = col
|
295 |
-
self.photerr_test = colerr
|
296 |
else:
|
297 |
self.phot_test = f
|
298 |
-
self.photerr_test = ferr
|
299 |
|
300 |
|
301 |
self.VIS_mag_test = catalogue['MAG_VIS'].values
|
302 |
|
303 |
|
304 |
def get_training_data(self):
|
305 |
-
return self.phot_train, self.
|
306 |
|
307 |
def get_testing_data(self):
|
308 |
-
return self.phot_test, self.
|
309 |
|
310 |
def get_VIS_mag(self, catalogue):
|
311 |
return catalogue[['MAG_VIS']].values
|
312 |
-
|
313 |
-
def plot_zdistribution(self, plot_test=False, bins=50):
|
314 |
-
_,_,specz = photoz_archive.get_training_data()
|
315 |
-
plt.hist(specz, bins = bins, hisstype='step', color='navy', label=r'Training sample')
|
316 |
-
|
317 |
-
if plot_test:
|
318 |
-
_,_,specz_test = photoz_archive.get_training_data()
|
319 |
-
plt.hist(specz, bins = bins, hisstype='step', color='goldenrod', label=r'Test sample',ls='--')
|
320 |
-
|
321 |
-
|
322 |
-
plt.xticks(fontsize=12)
|
323 |
-
plt.yticks(fontsize=12)
|
324 |
-
|
325 |
-
plt.xlabel(r'Redshift', fontsize=14)
|
326 |
-
plt.ylabel('Counts', fontsize=14)
|
327 |
-
|
328 |
-
plt.show()
|
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
from astropy.io import fits
|
4 |
+
from astropy.table import Table
|
5 |
from scipy.spatial import KDTree
|
6 |
from matplotlib import pyplot as plt
|
7 |
from matplotlib import rcParams
|
|
|
13 |
rcParams["font.family"] = "STIXGeneral"
|
14 |
|
15 |
class Archive:
|
16 |
+
def __init__(self,
|
17 |
+
path_calib,
|
18 |
+
path_valid=None,
|
19 |
drop_stars=True,
|
20 |
clean_photometry=True,
|
21 |
convert_colors=True,
|
22 |
extinction_corr=True,
|
23 |
only_zspec=True,
|
24 |
+
columns_photometry = ['FLUX_G_2','FLUX_R_2','FLUX_I_2','FLUX_Z_2','FLUX_Y_2','FLUX_J_2','FLUX_H_2'],
|
25 |
+
columns_ebv = ['EB_V_corr_FLUX_G','EB_V_corr_FLUX_R','EB_V_corr_FLUX_I','EB_V_corr_FLUX_Z','EB_V_corr_FLUX_Y','EB_V_corr_FLUX_J','EB_V_corr_FLUX_H'],
|
26 |
+
photoz_name="photo_z_L15",
|
27 |
+
specz_name="z_spec_S15",
|
28 |
+
target_test='specz',
|
29 |
+
flags_kept=[3, 3.1, 3.4, 3.5, 4]):
|
30 |
|
|
|
31 |
logger.info("Starting archive")
|
|
|
|
|
32 |
self.flags_kept = flags_kept
|
33 |
+
self.columns_photometry=columns_photometry
|
34 |
+
self.columns_ebv=columns_ebv
|
35 |
+
|
36 |
|
37 |
+
if path_calib.suffix == ".fits":
|
38 |
+
with fits.open(path_calib) as hdu_list:
|
39 |
+
cat = Table(hdu_list[1].data).to_pandas()
|
40 |
+
if path_valid != None:
|
41 |
+
with fits.open(path_valid) as hdu_list:
|
42 |
+
cat_test = Table(hdu_list[1].data).to_pandas()
|
43 |
+
|
44 |
+
elif path_calib.suffix == ".csv":
|
45 |
+
cat = pd.read_csv(path_calib)
|
46 |
+
if path_valid != None:
|
47 |
+
cat_test = pd.read_csv(path_valid)
|
48 |
+
else:
|
49 |
+
raise ValueError("Unsupported file format. Please provide a .fits or .csv file.")
|
50 |
+
|
51 |
+
cat = cat.rename(columns ={f"{specz_name}":"specz",
|
52 |
+
f"{photoz_name}":"photo_z"})
|
53 |
+
cat_test = cat_test.rename(columns ={f"{specz_name}":"specz",
|
54 |
+
f"{photoz_name}":"photo_z"})
|
55 |
|
56 |
+
cat = cat[(cat['specz'] > 0) | (cat['photo_z'] > 0)]
|
|
|
|
|
57 |
|
58 |
# Store the catalogs for later use
|
59 |
self.cat = cat
|
|
|
96 |
|
97 |
|
98 |
def _extract_fluxes(self,catalogue):
|
99 |
+
f = catalogue[self.columns_photometry].values
|
100 |
+
return f
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
+
def _to_colors(self, flux):
|
|
|
|
|
|
|
|
|
103 |
""" Convert fluxes to colors"""
|
104 |
+
color = flux[:,:-1] / flux[:,1:]
|
105 |
+
return color
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
def _set_combiend_target(self, catalogue):
|
108 |
+
catalogue['target_z'] = catalogue.apply(lambda row: row['specz']
|
109 |
+
if row['specz'] > 0
|
110 |
+
else row['photo_z'], axis=1)
|
111 |
|
112 |
return catalogue
|
113 |
|
|
|
120 |
|
121 |
def _correct_extinction(self,catalogue, f, return_ext_corr=False):
|
122 |
"""Corrects for extinction"""
|
123 |
+
ext_correction = catalogue[self.columns_ebv].values
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
f = f * ext_correction
|
125 |
if return_ext_corr:
|
126 |
return f, ext_correction
|
|
|
130 |
def _select_only_zspec(self,catalogue,cat_flag=None):
|
131 |
"""Selects only galaxies with spectroscopic redshift"""
|
132 |
if cat_flag=='Calib':
|
133 |
+
catalogue = catalogue[catalogue.specz>0]
|
134 |
elif cat_flag=='Valid':
|
135 |
+
catalogue = catalogue[catalogue.specz>0]
|
136 |
return catalogue
|
137 |
|
138 |
def _exclude_only_zspec(self,catalogue):
|
139 |
"""Selects only galaxies without spectroscopic redshift"""
|
140 |
+
catalogue = catalogue[(catalogue.specz<0)&(catalogue.photo_z>0)&(catalogue.photo_z<4)]
|
141 |
return catalogue
|
142 |
|
143 |
def _select_L15_sample(self,catalogue):
|
|
|
153 |
if cat_flag=='Calib':
|
154 |
catalogue = catalogue[catalogue.target_z>0]
|
155 |
elif cat_flag=='Valid':
|
156 |
+
catalogue = catalogue[catalogue.specz>0]
|
157 |
return catalogue
|
158 |
|
159 |
def _clean_zspec_sample(self,catalogue ,flags_kept=[3,3.1,3.4,3.5,4]):
|
|
|
188 |
def _set_training_data(self,catalogue, catalogue_da, only_zspec=True, extinction_corr=True, convert_colors=True):
|
189 |
|
190 |
cat_da = self._exclude_only_zspec(catalogue_da)
|
191 |
+
target_z_train_DA = cat_da['photo_z'].values
|
192 |
|
193 |
|
194 |
if only_zspec:
|
|
|
201 |
|
202 |
|
203 |
self.cat_train=catalogue
|
204 |
+
f = self._extract_fluxes(catalogue)
|
205 |
+
f_DA = self._extract_fluxes(cat_da)
|
|
|
206 |
idx = np.random.randint(0, len(f_DA), len(f))
|
207 |
+
f_DA = f_DA[idx]
|
208 |
target_z_train_DA = target_z_train_DA[idx]
|
209 |
self.target_z_train_DA = target_z_train_DA
|
210 |
|
|
|
215 |
|
216 |
if convert_colors==True:
|
217 |
logger.info("Converting to colors")
|
218 |
+
col = self._to_colors(f)
|
219 |
+
col_DA = self._to_colors(f_DA)
|
220 |
|
221 |
self.phot_train = col
|
|
|
222 |
self.phot_train_DA = col_DA
|
|
|
223 |
else:
|
224 |
self.phot_train = f
|
|
|
225 |
self.phot_train_DA = f_DA
|
|
|
226 |
|
227 |
if only_zspec==True:
|
228 |
+
self.target_z_train = catalogue['specz'].values
|
229 |
else:
|
230 |
self.target_z_train = catalogue['target_z'].values
|
231 |
|
|
|
236 |
if target=='specz':
|
237 |
catalogue = self._select_only_zspec(catalogue, cat_flag='Valid')
|
238 |
catalogue = self._clean_zspec_sample(catalogue)
|
239 |
+
self.target_z_test = catalogue['specz'].values
|
240 |
|
241 |
elif target=='L15':
|
242 |
catalogue = self._select_L15_sample(catalogue)
|
|
|
245 |
|
246 |
self.cat_test=catalogue
|
247 |
|
248 |
+
f = self._extract_fluxes(catalogue)
|
249 |
|
250 |
if extinction_corr==True:
|
251 |
f = self._correct_extinction(catalogue,f)
|
252 |
|
253 |
if convert_colors==True:
|
254 |
+
col = self._to_colors(f)
|
255 |
self.phot_test = col
|
|
|
256 |
else:
|
257 |
self.phot_test = f
|
|
|
258 |
|
259 |
|
260 |
self.VIS_mag_test = catalogue['MAG_VIS'].values
|
261 |
|
262 |
|
263 |
def get_training_data(self):
|
264 |
+
return self.phot_train, self.target_z_train, self.VIS_mag_train, self.phot_train_DA, self.target_z_train_DA
|
265 |
|
266 |
def get_testing_data(self):
|
267 |
+
return self.phot_test, self.target_z_test, self.VIS_mag_test
|
268 |
|
269 |
def get_VIS_mag(self, catalogue):
|
270 |
return catalogue[['MAG_VIS']].values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
temps/temps.py
CHANGED
@@ -5,6 +5,8 @@ from torch.utils.data import DataLoader, TensorDataset
|
|
5 |
from torch.optim import lr_scheduler
|
6 |
from loguru import logger
|
7 |
import pandas as pd
|
|
|
|
|
8 |
from tqdm import tqdm # Import tqdm for progress bars
|
9 |
|
10 |
from temps.utils import maximum_mean_discrepancy
|
@@ -47,7 +49,7 @@ class TempsModule:
|
|
47 |
dataset = TensorDataset(input_data, input_data_da, target_data)
|
48 |
train_dataset, val_dataset = torch.utils.data.random_split(
|
49 |
dataset,
|
50 |
-
[int(len(dataset) * (1 - val_fraction)), int(len(dataset) * val_fraction)],
|
51 |
)
|
52 |
loader_train = DataLoader(
|
53 |
train_dataset, batch_size=self.batch_size, shuffle=True
|
|
|
5 |
from torch.optim import lr_scheduler
|
6 |
from loguru import logger
|
7 |
import pandas as pd
|
8 |
+
from scipy.stats import norm
|
9 |
+
|
10 |
from tqdm import tqdm # Import tqdm for progress bars
|
11 |
|
12 |
from temps.utils import maximum_mean_discrepancy
|
|
|
49 |
dataset = TensorDataset(input_data, input_data_da, target_data)
|
50 |
train_dataset, val_dataset = torch.utils.data.random_split(
|
51 |
dataset,
|
52 |
+
[int(len(dataset) * (1 - val_fraction)), int(len(dataset) * val_fraction)+1],
|
53 |
)
|
54 |
loader_train = DataLoader(
|
55 |
train_dataset, batch_size=self.batch_size, shuffle=True
|