Laura Cabayol Garcia commited on
Commit
21a7d1b
·
1 Parent(s): 2f9f22c

AT for TEMPS

Browse files
Files changed (4) hide show
  1. pyproject.toml +1 -1
  2. temps/constants.py +4 -0
  3. temps/plots.py +2 -2
  4. tests/test_temps.py +38 -0
pyproject.toml CHANGED
@@ -31,7 +31,7 @@ dependencies = [
31
  "jupytext",
32
  "mkdocs",
33
  "typing",
34
- "dataclasses",
35
  ]
36
 
37
  classifiers = [
 
31
  "jupytext",
32
  "mkdocs",
33
  "typing",
34
+ "dataclasses"
35
  ]
36
 
37
  classifiers = [
temps/constants.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ PROJ_ROOT = Path(__file__).resolve().parents[2]
4
+ MODULES_DIR = PROJ_ROOT / 'data/models'
temps/plots.py CHANGED
@@ -1,7 +1,7 @@
1
  import numpy as np
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
- from temps.utils import nmad
5
  from scipy import stats
6
  from typing import List, Optional, Dict
7
 
@@ -277,7 +277,7 @@ def plot_crps(
277
  crps_list_1: List[float],
278
  crps_list_2: Optional[List[float]] = None,
279
  crps_list_3: Optional[List[float]] = None,
280
- label: Optional[List[str]] = None,
281
  sample: str = "specz",
282
  save: bool = True,
283
  ) -> None:
 
1
  import numpy as np
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
+ from temps.utils import nmad, sigma68
5
  from scipy import stats
6
  from typing import List, Optional, Dict
7
 
 
277
  crps_list_1: List[float],
278
  crps_list_2: Optional[List[float]] = None,
279
  crps_list_3: Optional[List[float]] = None,
280
+ labels: Optional[List[str]] = None,
281
  sample: str = "specz",
282
  save: bool = True,
283
  ) -> None:
tests/test_temps.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from loguru import logger
3
+ import torch
4
+
5
+ from temps.temps_arch import EncoderPhotometry, MeasureZ
6
+ from temps.temps import TempsModule
7
+ from temps.constants import MODULES_DIR
8
+
9
+
10
+ def test():
11
+ nn_features = EncoderPhotometry()
12
+ nn_features.load_state_dict(torch.load(MODULES_DIR / f'modelF_DA.pt',map_location=torch.device('cpu')))
13
+ nn_z = MeasureZ(num_gauss=6)
14
+ nn_z.load_state_dict(torch.load(MODULES_DIR / f'modelZ_DA.pt',map_location=torch.device('cpu')))
15
+
16
+ temps_module = TempsModule(nn_features, nn_z)
17
+
18
+ col = np.array([0.54804805, 1.81142339, 0.63354394, 0.7356338 , 1.3578122 ,
19
+ 0.90108565])
20
+ ztrue = 0.444
21
+
22
+ z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col),
23
+ return_pz=True,
24
+ return_flag=True)
25
+
26
+ zdiff = (z - ztrue).abs().mean()
27
+
28
+ logger.info(f'zdiff: {zdiff}')
29
+ logger.info("test passed")
30
+
31
+ assert zdiff < 0.01
32
+
33
+ test()
34
+
35
+
36
+
37
+
38
+ # %%