Spaces:
Runtime error
Runtime error
Laura Cabayol Garcia
commited on
Commit
·
21a7d1b
1
Parent(s):
2f9f22c
AT for TEMPS
Browse files- pyproject.toml +1 -1
- temps/constants.py +4 -0
- temps/plots.py +2 -2
- tests/test_temps.py +38 -0
pyproject.toml
CHANGED
@@ -31,7 +31,7 @@ dependencies = [
|
|
31 |
"jupytext",
|
32 |
"mkdocs",
|
33 |
"typing",
|
34 |
-
"dataclasses"
|
35 |
]
|
36 |
|
37 |
classifiers = [
|
|
|
31 |
"jupytext",
|
32 |
"mkdocs",
|
33 |
"typing",
|
34 |
+
"dataclasses"
|
35 |
]
|
36 |
|
37 |
classifiers = [
|
temps/constants.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
PROJ_ROOT = Path(__file__).resolve().parents[2]
|
4 |
+
MODULES_DIR = PROJ_ROOT / 'data/models'
|
temps/plots.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
import matplotlib.pyplot as plt
|
4 |
-
from temps.utils import nmad
|
5 |
from scipy import stats
|
6 |
from typing import List, Optional, Dict
|
7 |
|
@@ -277,7 +277,7 @@ def plot_crps(
|
|
277 |
crps_list_1: List[float],
|
278 |
crps_list_2: Optional[List[float]] = None,
|
279 |
crps_list_3: Optional[List[float]] = None,
|
280 |
-
|
281 |
sample: str = "specz",
|
282 |
save: bool = True,
|
283 |
) -> None:
|
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
import matplotlib.pyplot as plt
|
4 |
+
from temps.utils import nmad, sigma68
|
5 |
from scipy import stats
|
6 |
from typing import List, Optional, Dict
|
7 |
|
|
|
277 |
crps_list_1: List[float],
|
278 |
crps_list_2: Optional[List[float]] = None,
|
279 |
crps_list_3: Optional[List[float]] = None,
|
280 |
+
labels: Optional[List[str]] = None,
|
281 |
sample: str = "specz",
|
282 |
save: bool = True,
|
283 |
) -> None:
|
tests/test_temps.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from loguru import logger
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from temps.temps_arch import EncoderPhotometry, MeasureZ
|
6 |
+
from temps.temps import TempsModule
|
7 |
+
from temps.constants import MODULES_DIR
|
8 |
+
|
9 |
+
|
10 |
+
def test():
|
11 |
+
nn_features = EncoderPhotometry()
|
12 |
+
nn_features.load_state_dict(torch.load(MODULES_DIR / f'modelF_DA.pt',map_location=torch.device('cpu')))
|
13 |
+
nn_z = MeasureZ(num_gauss=6)
|
14 |
+
nn_z.load_state_dict(torch.load(MODULES_DIR / f'modelZ_DA.pt',map_location=torch.device('cpu')))
|
15 |
+
|
16 |
+
temps_module = TempsModule(nn_features, nn_z)
|
17 |
+
|
18 |
+
col = np.array([0.54804805, 1.81142339, 0.63354394, 0.7356338 , 1.3578122 ,
|
19 |
+
0.90108565])
|
20 |
+
ztrue = 0.444
|
21 |
+
|
22 |
+
z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(col),
|
23 |
+
return_pz=True,
|
24 |
+
return_flag=True)
|
25 |
+
|
26 |
+
zdiff = (z - ztrue).abs().mean()
|
27 |
+
|
28 |
+
logger.info(f'zdiff: {zdiff}')
|
29 |
+
logger.info("test passed")
|
30 |
+
|
31 |
+
assert zdiff < 0.01
|
32 |
+
|
33 |
+
test()
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
# %%
|