zhuwq0
commited on
Commit
•
0eb79a8
0
Parent(s):
init
Browse files- Dockerfile +25 -0
- LICENSE +21 -0
- docs/README.md +144 -0
- docs/data.mseed +0 -0
- docs/example_batch_prediction.ipynb +211 -0
- docs/example_fastapi.ipynb +0 -0
- docs/example_gradio.ipynb +0 -0
- docs/test_api.py +37 -0
- env.yml +17 -0
- mkdocs.yml +18 -0
- model/190703-214543/checkpoint +3 -0
- model/190703-214543/config.log +3 -0
- model/190703-214543/loss.log +3 -0
- model/190703-214543/model_95.ckpt.data-00000-of-00001 +3 -0
- model/190703-214543/model_95.ckpt.index +3 -0
- model/190703-214543/model_95.ckpt.meta +3 -0
- phasenet/__init__.py +1 -0
- phasenet/app.py +341 -0
- phasenet/data_reader.py +1010 -0
- phasenet/detect_peaks.py +207 -0
- phasenet/model.py +489 -0
- phasenet/postprocess.py +377 -0
- phasenet/predict.py +262 -0
- phasenet/slide_window.py +88 -0
- phasenet/test_app.py +47 -0
- phasenet/train.py +246 -0
- phasenet/util.py +238 -0
- phasenet/visulization.py +481 -0
- requirements.txt +7 -0
- setup.py +116 -0
Dockerfile
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM tensorflow/tensorflow
|
2 |
+
|
3 |
+
# Create the environment:
|
4 |
+
# COPY env.yml /app
|
5 |
+
# RUN conda env create --name cs329s --file=env.yml
|
6 |
+
# Make RUN commands use the new environment:
|
7 |
+
# SHELL ["conda", "run", "-n", "cs329s", "/bin/bash", "-c"]
|
8 |
+
|
9 |
+
RUN pip install tqdm obspy pandas
|
10 |
+
RUN pip install uvicorn fastapi
|
11 |
+
|
12 |
+
WORKDIR /opt
|
13 |
+
|
14 |
+
# Copy files
|
15 |
+
COPY phasenet /opt/phasenet
|
16 |
+
COPY model /opt/model
|
17 |
+
|
18 |
+
# Expose API port
|
19 |
+
EXPOSE 8000
|
20 |
+
|
21 |
+
ENV PYTHONUNBUFFERED=1
|
22 |
+
|
23 |
+
# Start API server
|
24 |
+
#ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "cs329s", "uvicorn", "--app-dir", "phasenet", "app:app", "--reload", "--port", "8000", "--host", "0.0.0.0"]
|
25 |
+
ENTRYPOINT ["uvicorn", "--app-dir", "phasenet", "app:app", "--reload", "--port", "7860", "--host", "0.0.0.0"]
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 Weiqiang Zhu
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
docs/README.md
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PhaseNet: A Deep-Neural-Network-Based Seismic Arrival Time Picking Method
|
2 |
+
|
3 |
+
[![](https://github.com/AI4EPS/PhaseNet/workflows/documentation/badge.svg)](https://ai4eps.github.io/PhaseNet)
|
4 |
+
|
5 |
+
## 1. Install [miniconda](https://docs.conda.io/en/latest/miniconda.html) and requirements
|
6 |
+
- Download PhaseNet repository
|
7 |
+
```bash
|
8 |
+
git clone https://github.com/wayneweiqiang/PhaseNet.git
|
9 |
+
cd PhaseNet
|
10 |
+
```
|
11 |
+
- Install to default environment
|
12 |
+
```bash
|
13 |
+
conda env update -f=env.yml -n base
|
14 |
+
```
|
15 |
+
- Install to "phasenet" virtual envirionment
|
16 |
+
```bash
|
17 |
+
conda env create -f env.yml
|
18 |
+
conda activate phasenet
|
19 |
+
```
|
20 |
+
|
21 |
+
## 2. Pre-trained model
|
22 |
+
Located in directory: **model/190703-214543**
|
23 |
+
|
24 |
+
## 3. Related papers
|
25 |
+
- Zhu, Weiqiang, and Gregory C. Beroza. "PhaseNet: A Deep-Neural-Network-Based Seismic Arrival Time Picking Method." arXiv preprint arXiv:1803.03211 (2018).
|
26 |
+
- Liu, Min, et al. "Rapid characterization of the July 2019 Ridgecrest, California, earthquake sequence from raw seismic data using machine‐learning phase picker." Geophysical Research Letters 47.4 (2020): e2019GL086189.
|
27 |
+
- Park, Yongsoo, et al. "Machine‐learning‐based analysis of the Guy‐Greenbrier, Arkansas earthquakes: A tale of two sequences." Geophysical Research Letters 47.6 (2020): e2020GL087032.
|
28 |
+
- Chai, Chengping, et al. "Using a deep neural network and transfer learning to bridge scales for seismic phase picking." Geophysical Research Letters 47.16 (2020): e2020GL088651.
|
29 |
+
- Tan, Yen Joe, et al. "Machine‐Learning‐Based High‐Resolution Earthquake Catalog Reveals How Complex Fault Structures Were Activated during the 2016–2017 Central Italy Sequence." The Seismic Record 1.1 (2021): 11-19.
|
30 |
+
|
31 |
+
## 4. Batch prediction
|
32 |
+
See examples in the [notebook](https://github.com/wayneweiqiang/PhaseNet/blob/master/docs/example_batch_prediction.ipynb): [example_batch_prediction.ipynb](example_batch_prediction.ipynb)
|
33 |
+
|
34 |
+
|
35 |
+
PhaseNet currently supports four data formats: mseed, sac, hdf5, and numpy. The test data can be downloaded here:
|
36 |
+
```
|
37 |
+
wget https://github.com/wayneweiqiang/PhaseNet/releases/download/test_data/test_data.zip
|
38 |
+
unzip test_data.zip
|
39 |
+
```
|
40 |
+
|
41 |
+
- For mseed format:
|
42 |
+
```
|
43 |
+
python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed.csv --data_dir=test_data/mseed --format=mseed --amplitude --response_xml=test_data/stations.xml --batch_size=1 --sampling_rate=100 --plot_figure
|
44 |
+
```
|
45 |
+
```
|
46 |
+
python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed2.csv --data_dir=test_data/mseed --format=mseed --amplitude --response_xml=test_data/stations.xml --batch_size=1 --sampling_rate=100 --plot_figure
|
47 |
+
```
|
48 |
+
|
49 |
+
- For sac format:
|
50 |
+
```
|
51 |
+
python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/sac.csv --data_dir=test_data/sac --format=sac --batch_size=1 --plot_figure
|
52 |
+
```
|
53 |
+
|
54 |
+
- For numpy format:
|
55 |
+
```
|
56 |
+
python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/npz.csv --data_dir=test_data/npz --format=numpy --plot_figure
|
57 |
+
```
|
58 |
+
|
59 |
+
- For hdf5 format:
|
60 |
+
```
|
61 |
+
python phasenet/predict.py --model=model/190703-214543 --hdf5_file=test_data/data.h5 --hdf5_group=data --format=hdf5 --plot_figure
|
62 |
+
```
|
63 |
+
|
64 |
+
- For a seismic array (used by [QuakeFlow](https://github.com/wayneweiqiang/QuakeFlow)):
|
65 |
+
```
|
66 |
+
python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed_array.csv --data_dir=test_data/mseed_array --stations=test_data/stations.json --format=mseed_array --amplitude
|
67 |
+
```
|
68 |
+
|
69 |
+
Notes:
|
70 |
+
|
71 |
+
1. The reason for using "--batch_size=1" is because the mseed or sac files usually are not the same length. If you want to use a larger batch size for a good prediction speed, you need to cut the data to the same length.
|
72 |
+
|
73 |
+
2. Remove the "--plot_figure" argument for large datasets, because plotting can be very slow.
|
74 |
+
|
75 |
+
Optional arguments:
|
76 |
+
```
|
77 |
+
usage: predict.py [-h] [--batch_size BATCH_SIZE] [--model_dir MODEL_DIR]
|
78 |
+
[--data_dir DATA_DIR] [--data_list DATA_LIST]
|
79 |
+
[--hdf5_file HDF5_FILE] [--hdf5_group HDF5_GROUP]
|
80 |
+
[--result_dir RESULT_DIR] [--result_fname RESULT_FNAME]
|
81 |
+
[--min_p_prob MIN_P_PROB] [--min_s_prob MIN_S_PROB]
|
82 |
+
[--mpd MPD] [--amplitude] [--format FORMAT]
|
83 |
+
[--s3_url S3_URL] [--stations STATIONS] [--plot_figure]
|
84 |
+
[--save_prob]
|
85 |
+
|
86 |
+
optional arguments:
|
87 |
+
-h, --help show this help message and exit
|
88 |
+
--batch_size BATCH_SIZE
|
89 |
+
batch size
|
90 |
+
--model_dir MODEL_DIR
|
91 |
+
Checkpoint directory (default: None)
|
92 |
+
--data_dir DATA_DIR Input file directory
|
93 |
+
--data_list DATA_LIST
|
94 |
+
Input csv file
|
95 |
+
--hdf5_file HDF5_FILE
|
96 |
+
Input hdf5 file
|
97 |
+
--hdf5_group HDF5_GROUP
|
98 |
+
data group name in hdf5 file
|
99 |
+
--result_dir RESULT_DIR
|
100 |
+
Output directory
|
101 |
+
--result_fname RESULT_FNAME
|
102 |
+
Output file
|
103 |
+
--min_p_prob MIN_P_PROB
|
104 |
+
Probability threshold for P pick
|
105 |
+
--min_s_prob MIN_S_PROB
|
106 |
+
Probability threshold for S pick
|
107 |
+
--mpd MPD Minimum peak distance
|
108 |
+
--amplitude if return amplitude value
|
109 |
+
--format FORMAT input format
|
110 |
+
--stations STATIONS seismic station info
|
111 |
+
--plot_figure If plot figure for test
|
112 |
+
--save_prob If save result for test
|
113 |
+
```
|
114 |
+
|
115 |
+
- The output picks are saved to "results/picks.csv" on default
|
116 |
+
|
117 |
+
|file_name |begin_time |station_id|phase_index|phase_time |phase_score|phase_amp |phase_type|
|
118 |
+
|-----------------|-----------------------|----------|-----------|-----------------------|-----------|----------------------|----------|
|
119 |
+
|2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.BOM..HH|14734 |2020-10-01T00:02:27.343|0.708 |2.4998866231208325e-14|P |
|
120 |
+
|2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.BOM..HH|15487 |2020-10-01T00:02:34.873|0.416 |2.4998866231208325e-14|S |
|
121 |
+
|2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.COA..HH|319 |2020-10-01T00:00:03.193|0.762 |3.708662269972206e-14 |P |
|
122 |
+
|
123 |
+
Notes:
|
124 |
+
1. The *phase_index* means which data point is the pick in the original sequence. So *phase_time* = *begin_time* + *phase_index* / *sampling rate*. The default *sampling_rate* is 100Hz
|
125 |
+
|
126 |
+
|
127 |
+
## 5. QuakeFlow example
|
128 |
+
A complete earthquake detection workflow can be found in the [QuakeFlow](https://wayneweiqiang.github.io/QuakeFlow/) project.
|
129 |
+
|
130 |
+
## 6. Interactive example
|
131 |
+
See details in the [notebook](https://github.com/wayneweiqiang/PhaseNet/blob/master/docs/example_gradio.ipynb): [example_interactive.ipynb](example_gradio.ipynb)
|
132 |
+
|
133 |
+
## 7. Training
|
134 |
+
- Download a small sample dataset:
|
135 |
+
```bash
|
136 |
+
wget https://github.com/wayneweiqiang/PhaseNet/releases/download/test_data/test_data.zip
|
137 |
+
unzip test_data.zip
|
138 |
+
```
|
139 |
+
- Start training from the pre-trained model
|
140 |
+
```
|
141 |
+
python phasenet/train.py --model_dir=model/190703-214543/ --train_dir=test_data/npz --train_list=test_data/npz.csv --plot_figure --epochs=10 --batch_size=10
|
142 |
+
```
|
143 |
+
- Check results in the **log** folder
|
144 |
+
|
docs/data.mseed
ADDED
Binary file (73.7 kB). View file
|
|
docs/example_batch_prediction.ipynb
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Batch Prediction\n",
|
8 |
+
"\n",
|
9 |
+
"## 1. Download demo data\n",
|
10 |
+
"\n",
|
11 |
+
"```\n",
|
12 |
+
"cd PhaseNet\n",
|
13 |
+
"wget https://github.com/wayneweiqiang/PhaseNet/releases/download/test_data/test_data.zip\n",
|
14 |
+
"unzip test_data.zip\n",
|
15 |
+
"```\n",
|
16 |
+
"\n",
|
17 |
+
"## 2. Run batch prediction \n",
|
18 |
+
"\n",
|
19 |
+
"PhaseNet currently supports four data formats: mseed, sac, hdf5, and numpy. \n",
|
20 |
+
"\n",
|
21 |
+
"- For mseed format:\n",
|
22 |
+
"```\n",
|
23 |
+
"python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed.csv --data_dir=test_data/mseed --format=mseed --plot_figure\n",
|
24 |
+
"```\n",
|
25 |
+
"\n",
|
26 |
+
"- For sac format:\n",
|
27 |
+
"```\n",
|
28 |
+
"python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/sac.csv --data_dir=test_data/sac --format=sac --plot_figure\n",
|
29 |
+
"```\n",
|
30 |
+
"\n",
|
31 |
+
"- For numpy format:\n",
|
32 |
+
"```\n",
|
33 |
+
"python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/npz.csv --data_dir=test_data/npz --format=numpy --plot_figure\n",
|
34 |
+
"```\n",
|
35 |
+
"\n",
|
36 |
+
"- For hdf5 format:\n",
|
37 |
+
"```\n",
|
38 |
+
"python phasenet/predict.py --model=model/190703-214543 --hdf5_file=test_data/data.h5 --hdf5_group=data --format=hdf5 --plot_figure\n",
|
39 |
+
"```\n",
|
40 |
+
"\n",
|
41 |
+
"- For a seismic array (used by [QuakeFlow](https://github.com/wayneweiqiang/QuakeFlow)):\n",
|
42 |
+
"```\n",
|
43 |
+
"python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed_array.csv --data_dir=test_data/mseed_array --stations=test_data/stations.json --format=mseed_array --amplitude\n",
|
44 |
+
"```\n",
|
45 |
+
"```\n",
|
46 |
+
"python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed2.csv --data_dir=test_data/mseed --stations=test_data/stations.json --format=mseed_array --amplitude\n",
|
47 |
+
"```\n",
|
48 |
+
"\n",
|
49 |
+
"Notes: \n",
|
50 |
+
"1. Remove the \"--plot_figure\" argument for large datasets, because plotting can be very slow.\n",
|
51 |
+
"\n",
|
52 |
+
"Optional arguments:\n",
|
53 |
+
"```\n",
|
54 |
+
"usage: predict.py [-h] [--batch_size BATCH_SIZE] [--model_dir MODEL_DIR]\n",
|
55 |
+
" [--data_dir DATA_DIR] [--data_list DATA_LIST]\n",
|
56 |
+
" [--hdf5_file HDF5_FILE] [--hdf5_group HDF5_GROUP]\n",
|
57 |
+
" [--result_dir RESULT_DIR] [--result_fname RESULT_FNAME]\n",
|
58 |
+
" [--min_p_prob MIN_P_PROB] [--min_s_prob MIN_S_PROB]\n",
|
59 |
+
" [--mpd MPD] [--amplitude] [--format FORMAT]\n",
|
60 |
+
" [--s3_url S3_URL] [--stations STATIONS] [--plot_figure]\n",
|
61 |
+
" [--save_prob]\n",
|
62 |
+
"\n",
|
63 |
+
"optional arguments:\n",
|
64 |
+
" -h, --help show this help message and exit\n",
|
65 |
+
" --batch_size BATCH_SIZE\n",
|
66 |
+
" batch size\n",
|
67 |
+
" --model_dir MODEL_DIR\n",
|
68 |
+
" Checkpoint directory (default: None)\n",
|
69 |
+
" --data_dir DATA_DIR Input file directory\n",
|
70 |
+
" --data_list DATA_LIST\n",
|
71 |
+
" Input csv file\n",
|
72 |
+
" --hdf5_file HDF5_FILE\n",
|
73 |
+
" Input hdf5 file\n",
|
74 |
+
" --hdf5_group HDF5_GROUP\n",
|
75 |
+
" data group name in hdf5 file\n",
|
76 |
+
" --result_dir RESULT_DIR\n",
|
77 |
+
" Output directory\n",
|
78 |
+
" --result_fname RESULT_FNAME\n",
|
79 |
+
" Output file\n",
|
80 |
+
" --min_p_prob MIN_P_PROB\n",
|
81 |
+
" Probability threshold for P pick\n",
|
82 |
+
" --min_s_prob MIN_S_PROB\n",
|
83 |
+
" Probability threshold for S pick\n",
|
84 |
+
" --mpd MPD Minimum peak distance\n",
|
85 |
+
" --amplitude if return amplitude value\n",
|
86 |
+
" --format FORMAT input format\n",
|
87 |
+
" --stations STATIONS seismic station info\n",
|
88 |
+
" --plot_figure If plot figure for test\n",
|
89 |
+
" --save_prob If save result for test\n",
|
90 |
+
"```\n",
|
91 |
+
"\n",
|
92 |
+
"## 3. Output picks\n",
|
93 |
+
"- The output picks are saved to \"results/picks.csv\" on default\n",
|
94 |
+
"\n",
|
95 |
+
"|file_name |begin_time |station_id|phase_index|phase_time |phase_score|phase_amp |phase_type|\n",
|
96 |
+
"|-----------------|-----------------------|----------|-----------|-----------------------|-----------|----------------------|----------|\n",
|
97 |
+
"|2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.BOM..HH|14734 |2020-10-01T00:02:27.343|0.708 |2.4998866231208325e-14|P |\n",
|
98 |
+
"|2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.BOM..HH|15487 |2020-10-01T00:02:34.873|0.416 |2.4998866231208325e-14|S |\n",
|
99 |
+
"|2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.COA..HH|319 |2020-10-01T00:00:03.193|0.762 |3.708662269972206e-14 |P |\n",
|
100 |
+
"\n",
|
101 |
+
"Notes:\n",
|
102 |
+
"1. The *phase_index* means which data point is the pick in the original sequence. So *phase_time* = *begin_time* + *phase_index* / *sampling rate*. The default *sampling_rate* is 100Hz \n"
|
103 |
+
]
|
104 |
+
},
|
105 |
+
{
|
106 |
+
"cell_type": "markdown",
|
107 |
+
"metadata": {},
|
108 |
+
"source": [
|
109 |
+
"## 3. Read P/S picks\n",
|
110 |
+
"\n",
|
111 |
+
"PhaseNet currently outputs two format: **CSV** and **JSON**"
|
112 |
+
]
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"cell_type": "code",
|
116 |
+
"execution_count": 1,
|
117 |
+
"metadata": {},
|
118 |
+
"outputs": [],
|
119 |
+
"source": [
|
120 |
+
"import pandas as pd\n",
|
121 |
+
"import json\n",
|
122 |
+
"import os\n",
|
123 |
+
"PROJECT_ROOT = os.path.realpath(os.path.join(os.path.abspath(''), \"..\"))"
|
124 |
+
]
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"cell_type": "code",
|
128 |
+
"execution_count": 2,
|
129 |
+
"metadata": {},
|
130 |
+
"outputs": [
|
131 |
+
{
|
132 |
+
"name": "stdout",
|
133 |
+
"output_type": "stream",
|
134 |
+
"text": [
|
135 |
+
"fname NC.MCV..EH.0361339.npz\n",
|
136 |
+
"t0 1970-01-01T00:00:00.000\n",
|
137 |
+
"p_idx [5999, 9015]\n",
|
138 |
+
"p_prob [0.987, 0.981]\n",
|
139 |
+
"s_idx [6181, 9205]\n",
|
140 |
+
"s_prob [0.553, 0.873]\n",
|
141 |
+
"Name: 1, dtype: object\n",
|
142 |
+
"fname NN.LHV..EH.0384064.npz\n",
|
143 |
+
"t0 1970-01-01T00:00:00.000\n",
|
144 |
+
"p_idx []\n",
|
145 |
+
"p_prob []\n",
|
146 |
+
"s_idx []\n",
|
147 |
+
"s_prob []\n",
|
148 |
+
"Name: 0, dtype: object\n"
|
149 |
+
]
|
150 |
+
}
|
151 |
+
],
|
152 |
+
"source": [
|
153 |
+
"picks_csv = pd.read_csv(os.path.join(PROJECT_ROOT, \"results/picks.csv\"), sep=\"\\t\")\n",
|
154 |
+
"picks_csv.loc[:, 'p_idx'] = picks_csv[\"p_idx\"].apply(lambda x: x.strip(\"[]\").split(\",\"))\n",
|
155 |
+
"picks_csv.loc[:, 'p_prob'] = picks_csv[\"p_prob\"].apply(lambda x: x.strip(\"[]\").split(\",\"))\n",
|
156 |
+
"picks_csv.loc[:, 's_idx'] = picks_csv[\"s_idx\"].apply(lambda x: x.strip(\"[]\").split(\",\"))\n",
|
157 |
+
"picks_csv.loc[:, 's_prob'] = picks_csv[\"s_prob\"].apply(lambda x: x.strip(\"[]\").split(\",\"))\n",
|
158 |
+
"print(picks_csv.iloc[1])\n",
|
159 |
+
"print(picks_csv.iloc[0])"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": 3,
|
165 |
+
"metadata": {},
|
166 |
+
"outputs": [
|
167 |
+
{
|
168 |
+
"name": "stdout",
|
169 |
+
"output_type": "stream",
|
170 |
+
"text": [
|
171 |
+
"{'id': 'NC.MCV..EH.0361339.npz', 'timestamp': '1970-01-01T00:01:30.150', 'prob': 0.9811667799949646, 'type': 'p'}\n",
|
172 |
+
"{'id': 'NC.MCV..EH.0361339.npz', 'timestamp': '1970-01-01T00:00:59.990', 'prob': 0.9872905611991882, 'type': 'p'}\n"
|
173 |
+
]
|
174 |
+
}
|
175 |
+
],
|
176 |
+
"source": [
|
177 |
+
"with open(os.path.join(PROJECT_ROOT, \"results/picks.json\")) as fp:\n",
|
178 |
+
" picks_json = json.load(fp) \n",
|
179 |
+
"print(picks_json[1])\n",
|
180 |
+
"print(picks_json[0])"
|
181 |
+
]
|
182 |
+
}
|
183 |
+
],
|
184 |
+
"metadata": {
|
185 |
+
"kernelspec": {
|
186 |
+
"display_name": "Python 3.10.4 64-bit",
|
187 |
+
"language": "python",
|
188 |
+
"name": "python3"
|
189 |
+
},
|
190 |
+
"language_info": {
|
191 |
+
"codemirror_mode": {
|
192 |
+
"name": "ipython",
|
193 |
+
"version": 3
|
194 |
+
},
|
195 |
+
"file_extension": ".py",
|
196 |
+
"mimetype": "text/x-python",
|
197 |
+
"name": "python",
|
198 |
+
"nbconvert_exporter": "python",
|
199 |
+
"pygments_lexer": "ipython3",
|
200 |
+
"version": "3.10.4"
|
201 |
+
},
|
202 |
+
"orig_nbformat": 4,
|
203 |
+
"vscode": {
|
204 |
+
"interpreter": {
|
205 |
+
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
206 |
+
}
|
207 |
+
}
|
208 |
+
},
|
209 |
+
"nbformat": 4,
|
210 |
+
"nbformat_minor": 2
|
211 |
+
}
|
docs/example_fastapi.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
docs/example_gradio.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
docs/test_api.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
from gradio_client import Client
|
3 |
+
import obspy
|
4 |
+
import numpy as np
|
5 |
+
import json
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
# %%
|
9 |
+
|
10 |
+
waveform = obspy.read()
|
11 |
+
array = np.array([x.data for x in waveform]).T
|
12 |
+
|
13 |
+
# pipeline = PreTrainedPipeline()
|
14 |
+
inputs = array.tolist()
|
15 |
+
inputs = json.dumps(inputs)
|
16 |
+
# picks = pipeline(inputs)
|
17 |
+
# print(picks)
|
18 |
+
|
19 |
+
# %%
|
20 |
+
client = Client("ai4eps/phasenet")
|
21 |
+
output, file = client.predict(["test_test.mseed"])
|
22 |
+
# %%
|
23 |
+
with open(output, "r") as f:
|
24 |
+
picks = json.load(f)["data"]
|
25 |
+
|
26 |
+
# %%
|
27 |
+
picks = pd.read_csv(file)
|
28 |
+
|
29 |
+
|
30 |
+
# %%
|
31 |
+
job = client.submit(["test_test.mseed", "test_test.mseed"], api_name="/predict") # This is not blocking
|
32 |
+
|
33 |
+
print(job.status())
|
34 |
+
|
35 |
+
# %%
|
36 |
+
output, file = job.result()
|
37 |
+
|
env.yml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: phasenet
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
- conda-forge
|
5 |
+
dependencies:
|
6 |
+
- python
|
7 |
+
- numpy
|
8 |
+
- scipy
|
9 |
+
- matplotlib
|
10 |
+
- pandas
|
11 |
+
- scikit-learn
|
12 |
+
- tqdm
|
13 |
+
- obspy
|
14 |
+
- uvicorn
|
15 |
+
- fastapi
|
16 |
+
- tensorflow
|
17 |
+
- keras
|
mkdocs.yml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
site_name: "PhaseNet"
|
2 |
+
site_description: 'PhaseNet: a deep-neural-network-based seismic arrival-time picking method'
|
3 |
+
site_author: 'Weiqiang Zhu'
|
4 |
+
docs_dir: docs/
|
5 |
+
repo_name: 'AI4EPS/PhaseNet'
|
6 |
+
repo_url: 'https://github.com/ai4eps/PhaseNet'
|
7 |
+
nav:
|
8 |
+
- Overview: README.md
|
9 |
+
- Interactive Example: example_gradio.ipynb
|
10 |
+
- Batch Prediction: example_batch_prediction.ipynb
|
11 |
+
theme:
|
12 |
+
name: 'material'
|
13 |
+
plugins:
|
14 |
+
- mkdocs-jupyter
|
15 |
+
extra:
|
16 |
+
analytics:
|
17 |
+
provider: google
|
18 |
+
property: G-RZQ9LRPL0S
|
model/190703-214543/checkpoint
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1606ccb25e1533fa0398c5dbce7f3a45ac77f90b78b99f81a044294ba38a2c0c
|
3 |
+
size 83
|
model/190703-214543/config.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ed9dfa705053a5025facc9952c7da6abef19ec5f672d9e50386bf3f2d80294f2
|
3 |
+
size 345
|
model/190703-214543/loss.log
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ccb6f19117497571e19bec5da6012ac7af91f1bd29e931ffd0b23c6b657bb401
|
3 |
+
size 8101
|
model/190703-214543/model_95.ckpt.data-00000-of-00001
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9ee2c15dd78fb15de45a55ad64a446f1a0ced152ba4ac5c506d82b9194da85b4
|
3 |
+
size 3226256
|
model/190703-214543/model_95.ckpt.index
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f96b553b76be4ebae9a455eaf8d83cfa8c0e110f06cfba958de2568e5b6b2780
|
3 |
+
size 7223
|
model/190703-214543/model_95.ckpt.meta
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7ebd154a5ba0721ba8bbb627ba61b556ee60660eb34bbcd1b1f50396b07cc4ed
|
3 |
+
size 2172055
|
phasenet/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = "0.1.0"
|
phasenet/app.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import defaultdict, namedtuple
|
3 |
+
from datetime import datetime, timedelta
|
4 |
+
from json import dumps
|
5 |
+
from typing import Any, AnyStr, Dict, List, NamedTuple, Union, Optional
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import requests
|
9 |
+
import tensorflow as tf
|
10 |
+
from fastapi import FastAPI, WebSocket
|
11 |
+
from kafka import KafkaProducer
|
12 |
+
from pydantic import BaseModel
|
13 |
+
from scipy.interpolate import interp1d
|
14 |
+
|
15 |
+
from model import ModelConfig, UNet
|
16 |
+
from postprocess import extract_picks
|
17 |
+
|
18 |
+
tf.compat.v1.disable_eager_execution()
|
19 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
20 |
+
PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
|
21 |
+
JSONObject = Dict[AnyStr, Any]
|
22 |
+
JSONArray = List[Any]
|
23 |
+
JSONStructure = Union[JSONArray, JSONObject]
|
24 |
+
|
25 |
+
app = FastAPI()
|
26 |
+
X_SHAPE = [3000, 1, 3]
|
27 |
+
SAMPLING_RATE = 100
|
28 |
+
|
29 |
+
# load model
|
30 |
+
model = UNet(mode="pred")
|
31 |
+
sess_config = tf.compat.v1.ConfigProto()
|
32 |
+
sess_config.gpu_options.allow_growth = True
|
33 |
+
|
34 |
+
sess = tf.compat.v1.Session(config=sess_config)
|
35 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
|
36 |
+
init = tf.compat.v1.global_variables_initializer()
|
37 |
+
sess.run(init)
|
38 |
+
latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190703-214543")
|
39 |
+
print(f"restoring model {latest_check_point}")
|
40 |
+
saver.restore(sess, latest_check_point)
|
41 |
+
|
42 |
+
# GAMMA API Endpoint
|
43 |
+
GAMMA_API_URL = "http://gamma-api:8001"
|
44 |
+
# GAMMA_API_URL = 'http://localhost:8001'
|
45 |
+
# GAMMA_API_URL = "http://gamma.quakeflow.com"
|
46 |
+
# GAMMA_API_URL = "http://127.0.0.1:8001"
|
47 |
+
|
48 |
+
# Kafak producer
|
49 |
+
use_kafka = False
|
50 |
+
|
51 |
+
try:
|
52 |
+
print("Connecting to k8s kafka")
|
53 |
+
BROKER_URL = "quakeflow-kafka-headless:9092"
|
54 |
+
# BROKER_URL = "34.83.137.139:9094"
|
55 |
+
producer = KafkaProducer(
|
56 |
+
bootstrap_servers=[BROKER_URL],
|
57 |
+
key_serializer=lambda x: dumps(x).encode("utf-8"),
|
58 |
+
value_serializer=lambda x: dumps(x).encode("utf-8"),
|
59 |
+
)
|
60 |
+
use_kafka = True
|
61 |
+
print("k8s kafka connection success!")
|
62 |
+
except BaseException:
|
63 |
+
print("k8s Kafka connection error")
|
64 |
+
try:
|
65 |
+
print("Connecting to local kafka")
|
66 |
+
producer = KafkaProducer(
|
67 |
+
bootstrap_servers=["localhost:9092"],
|
68 |
+
key_serializer=lambda x: dumps(x).encode("utf-8"),
|
69 |
+
value_serializer=lambda x: dumps(x).encode("utf-8"),
|
70 |
+
)
|
71 |
+
use_kafka = True
|
72 |
+
print("local kafka connection success!")
|
73 |
+
except BaseException:
|
74 |
+
print("local Kafka connection error")
|
75 |
+
print(f"Kafka status: {use_kafka}")
|
76 |
+
|
77 |
+
|
78 |
+
def normalize_batch(data, window=3000):
|
79 |
+
"""
|
80 |
+
data: nsta, nt, nch
|
81 |
+
"""
|
82 |
+
shift = window // 2
|
83 |
+
nsta, nt, nch = data.shape
|
84 |
+
|
85 |
+
# std in slide windows
|
86 |
+
data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0)), mode="reflect")
|
87 |
+
t = np.arange(0, nt, shift, dtype="int")
|
88 |
+
std = np.zeros([nsta, len(t) + 1, nch])
|
89 |
+
mean = np.zeros([nsta, len(t) + 1, nch])
|
90 |
+
for i in range(1, len(t)):
|
91 |
+
std[:, i, :] = np.std(data_pad[:, i * shift : i * shift + window, :], axis=1)
|
92 |
+
mean[:, i, :] = np.mean(data_pad[:, i * shift : i * shift + window, :], axis=1)
|
93 |
+
|
94 |
+
t = np.append(t, nt)
|
95 |
+
# std[:, -1, :] = np.std(data_pad[:, -window:, :], axis=1)
|
96 |
+
# mean[:, -1, :] = np.mean(data_pad[:, -window:, :], axis=1)
|
97 |
+
std[:, -1, :], mean[:, -1, :] = std[:, -2, :], mean[:, -2, :]
|
98 |
+
std[:, 0, :], mean[:, 0, :] = std[:, 1, :], mean[:, 1, :]
|
99 |
+
std[std == 0] = 1
|
100 |
+
|
101 |
+
# ## normalize data with interplated std
|
102 |
+
t_interp = np.arange(nt, dtype="int")
|
103 |
+
std_interp = interp1d(t, std, axis=1, kind="slinear")(t_interp)
|
104 |
+
mean_interp = interp1d(t, mean, axis=1, kind="slinear")(t_interp)
|
105 |
+
data = (data - mean_interp) / std_interp
|
106 |
+
|
107 |
+
return data
|
108 |
+
|
109 |
+
|
110 |
+
def preprocess(data):
|
111 |
+
raw = data.copy()
|
112 |
+
data = normalize_batch(data)
|
113 |
+
if len(data.shape) == 3:
|
114 |
+
data = data[:, :, np.newaxis, :]
|
115 |
+
raw = raw[:, :, np.newaxis, :]
|
116 |
+
return data, raw
|
117 |
+
|
118 |
+
|
119 |
+
def calc_timestamp(timestamp, sec):
|
120 |
+
timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec)
|
121 |
+
return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
122 |
+
|
123 |
+
|
124 |
+
def format_picks(picks, dt, amplitudes):
|
125 |
+
picks_ = []
|
126 |
+
for pick, amplitude in zip(picks, amplitudes):
|
127 |
+
for idxs, probs, amps in zip(pick.p_idx, pick.p_prob, amplitude.p_amp):
|
128 |
+
for idx, prob, amp in zip(idxs, probs, amps):
|
129 |
+
picks_.append(
|
130 |
+
{
|
131 |
+
"id": pick.fname,
|
132 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
133 |
+
"prob": prob,
|
134 |
+
"amp": amp,
|
135 |
+
"type": "p",
|
136 |
+
}
|
137 |
+
)
|
138 |
+
for idxs, probs, amps in zip(pick.s_idx, pick.s_prob, amplitude.s_amp):
|
139 |
+
for idx, prob, amp in zip(idxs, probs, amps):
|
140 |
+
picks_.append(
|
141 |
+
{
|
142 |
+
"id": pick.fname,
|
143 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
144 |
+
"prob": prob,
|
145 |
+
"amp": amp,
|
146 |
+
"type": "s",
|
147 |
+
}
|
148 |
+
)
|
149 |
+
return picks_
|
150 |
+
|
151 |
+
|
152 |
+
def format_data(data):
|
153 |
+
# chn2idx = {"ENZ": {"E":0, "N":1, "Z":2},
|
154 |
+
# "123": {"3":0, "2":1, "1":2},
|
155 |
+
# "12Z": {"1":0, "2":1, "Z":2}}
|
156 |
+
chn2idx = {"E": 0, "N": 1, "Z": 2, "3": 0, "2": 1, "1": 2}
|
157 |
+
Data = NamedTuple("data", [("id", list), ("timestamp", list), ("vec", list), ("dt", float)])
|
158 |
+
|
159 |
+
# Group by station
|
160 |
+
chn_ = defaultdict(list)
|
161 |
+
t0_ = defaultdict(list)
|
162 |
+
vv_ = defaultdict(list)
|
163 |
+
for i in range(len(data.id)):
|
164 |
+
key = data.id[i][:-1]
|
165 |
+
chn_[key].append(data.id[i][-1])
|
166 |
+
t0_[key].append(datetime.strptime(data.timestamp[i], "%Y-%m-%dT%H:%M:%S.%f").timestamp() * SAMPLING_RATE)
|
167 |
+
vv_[key].append(np.array(data.vec[i]))
|
168 |
+
|
169 |
+
# Merge to Data tuple
|
170 |
+
id_ = []
|
171 |
+
timestamp_ = []
|
172 |
+
vec_ = []
|
173 |
+
for k in chn_:
|
174 |
+
id_.append(k)
|
175 |
+
min_t0 = min(t0_[k])
|
176 |
+
timestamp_.append(datetime.fromtimestamp(min_t0 / SAMPLING_RATE).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3])
|
177 |
+
vec = np.zeros([X_SHAPE[0], X_SHAPE[-1]])
|
178 |
+
for i in range(len(chn_[k])):
|
179 |
+
# vec[int(t0_[k][i]-min_t0):len(vv_[k][i]), chn2idx[chn_[k][i]]] = vv_[k][i][int(t0_[k][i]-min_t0):X_SHAPE[0]] - np.mean(vv_[k][i])
|
180 |
+
shift = int(t0_[k][i] - min_t0)
|
181 |
+
vec[shift : len(vv_[k][i]) + shift, chn2idx[chn_[k][i]]] = vv_[k][i][: X_SHAPE[0] - shift] - np.mean(
|
182 |
+
vv_[k][i][: X_SHAPE[0] - shift]
|
183 |
+
)
|
184 |
+
vec_.append(vec.tolist())
|
185 |
+
|
186 |
+
return Data(id=id_, timestamp=timestamp_, vec=vec_, dt=1 / SAMPLING_RATE)
|
187 |
+
# return {"id": id_, "timestamp": timestamp_, "vec": vec_, "dt":1 / SAMPLING_RATE}
|
188 |
+
|
189 |
+
|
190 |
+
def get_prediction(data, return_preds=False):
|
191 |
+
vec = np.array(data.vec)
|
192 |
+
vec, vec_raw = preprocess(vec)
|
193 |
+
|
194 |
+
feed = {model.X: vec, model.drop_rate: 0, model.is_training: False}
|
195 |
+
preds = sess.run(model.preds, feed_dict=feed)
|
196 |
+
|
197 |
+
picks = extract_picks(preds, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
|
198 |
+
|
199 |
+
picks = [
|
200 |
+
{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]}
|
201 |
+
for pick in picks
|
202 |
+
]
|
203 |
+
|
204 |
+
if return_preds:
|
205 |
+
return picks, preds
|
206 |
+
|
207 |
+
return picks
|
208 |
+
|
209 |
+
|
210 |
+
class Data(BaseModel):
|
211 |
+
# id: Union[List[str], str]
|
212 |
+
# timestamp: Union[List[str], str]
|
213 |
+
# vec: Union[List[List[List[float]]], List[List[float]]]
|
214 |
+
id: List[str]
|
215 |
+
timestamp: List[Union[str, float, datetime]]
|
216 |
+
vec: Union[List[List[List[float]]], List[List[float]]]
|
217 |
+
|
218 |
+
dt: Optional[float] = 0.01
|
219 |
+
## gamma
|
220 |
+
stations: Optional[List[Dict[str, Union[float, str]]]] = None
|
221 |
+
config: Optional[Dict[str, Union[List[float], List[int], List[str], float, int, str]]] = None
|
222 |
+
|
223 |
+
|
224 |
+
# @app.on_event("startup")
|
225 |
+
# def set_default_executor():
|
226 |
+
# from concurrent.futures import ThreadPoolExecutor
|
227 |
+
# import asyncio
|
228 |
+
#
|
229 |
+
# loop = asyncio.get_running_loop()
|
230 |
+
# loop.set_default_executor(
|
231 |
+
# ThreadPoolExecutor(max_workers=2)
|
232 |
+
# )
|
233 |
+
|
234 |
+
|
235 |
+
@app.post("/predict")
|
236 |
+
def predict(data: Data):
|
237 |
+
picks = get_prediction(data)
|
238 |
+
|
239 |
+
return picks
|
240 |
+
|
241 |
+
|
242 |
+
@app.websocket("/ws")
|
243 |
+
async def websocket_endpoint(websocket: WebSocket):
|
244 |
+
await websocket.accept()
|
245 |
+
while True:
|
246 |
+
data = await websocket.receive_json()
|
247 |
+
# data = json.loads(data)
|
248 |
+
data = Data(**data)
|
249 |
+
picks = get_prediction(data)
|
250 |
+
await websocket.send_json(picks)
|
251 |
+
print("PhaseNet Updating...")
|
252 |
+
|
253 |
+
|
254 |
+
@app.post("/predict_prob")
|
255 |
+
def predict(data: Data):
|
256 |
+
picks, preds = get_prediction(data, True)
|
257 |
+
|
258 |
+
return picks, preds.tolist()
|
259 |
+
|
260 |
+
|
261 |
+
@app.post("/predict_phasenet2gamma")
|
262 |
+
def predict(data: Data):
|
263 |
+
picks = get_prediction(data)
|
264 |
+
|
265 |
+
# if use_kafka:
|
266 |
+
# print("Push picks to kafka...")
|
267 |
+
# for pick in picks:
|
268 |
+
# producer.send("phasenet_picks", key=pick["id"], value=pick)
|
269 |
+
try:
|
270 |
+
catalog = requests.post(
|
271 |
+
f"{GAMMA_API_URL}/predict", json={"picks": picks, "stations": data.stations, "config": data.config}
|
272 |
+
)
|
273 |
+
print(catalog.json()["catalog"])
|
274 |
+
return catalog.json()
|
275 |
+
except Exception as error:
|
276 |
+
print(error)
|
277 |
+
|
278 |
+
return {}
|
279 |
+
|
280 |
+
|
281 |
+
@app.post("/predict_phasenet2gamma2ui")
|
282 |
+
def predict(data: Data):
|
283 |
+
picks = get_prediction(data)
|
284 |
+
|
285 |
+
try:
|
286 |
+
catalog = requests.post(
|
287 |
+
f"{GAMMA_API_URL}/predict", json={"picks": picks, "stations": data.stations, "config": data.config}
|
288 |
+
)
|
289 |
+
print(catalog.json()["catalog"])
|
290 |
+
return catalog.json()
|
291 |
+
except Exception as error:
|
292 |
+
print(error)
|
293 |
+
|
294 |
+
if use_kafka:
|
295 |
+
print("Push picks to kafka...")
|
296 |
+
for pick in picks:
|
297 |
+
producer.send("phasenet_picks", key=pick["id"], value=pick)
|
298 |
+
print("Push waveform to kafka...")
|
299 |
+
for id, timestamp, vec in zip(data.id, data.timestamp, data.vec):
|
300 |
+
producer.send("waveform_phasenet", key=id, value={"timestamp": timestamp, "vec": vec, "dt": data.dt})
|
301 |
+
|
302 |
+
return {}
|
303 |
+
|
304 |
+
|
305 |
+
@app.post("/predict_stream_phasenet2gamma")
|
306 |
+
def predict(data: Data):
|
307 |
+
data = format_data(data)
|
308 |
+
# for i in range(len(data.id)):
|
309 |
+
# plt.clf()
|
310 |
+
# plt.subplot(311)
|
311 |
+
# plt.plot(np.array(data.vec)[i, :, 0])
|
312 |
+
# plt.subplot(312)
|
313 |
+
# plt.plot(np.array(data.vec)[i, :, 1])
|
314 |
+
# plt.subplot(313)
|
315 |
+
# plt.plot(np.array(data.vec)[i, :, 2])
|
316 |
+
# plt.savefig(f"{data.id[i]}.png")
|
317 |
+
|
318 |
+
picks = get_prediction(data)
|
319 |
+
|
320 |
+
return_value = {}
|
321 |
+
try:
|
322 |
+
catalog = requests.post(f"{GAMMA_API_URL}/predict_stream", json={"picks": picks})
|
323 |
+
print("GMMA:", catalog.json()["catalog"])
|
324 |
+
return_value = catalog.json()
|
325 |
+
except Exception as error:
|
326 |
+
print(error)
|
327 |
+
|
328 |
+
if use_kafka:
|
329 |
+
print("Push picks to kafka...")
|
330 |
+
for pick in picks:
|
331 |
+
producer.send("phasenet_picks", key=pick["id"], value=pick)
|
332 |
+
print("Push waveform to kafka...")
|
333 |
+
for id, timestamp, vec in zip(data.id, data.timestamp, data.vec):
|
334 |
+
producer.send("waveform_phasenet", key=id, value={"timestamp": timestamp, "vec": vec, "dt": data.dt})
|
335 |
+
|
336 |
+
return return_value
|
337 |
+
|
338 |
+
|
339 |
+
@app.get("/healthz")
|
340 |
+
def healthz():
|
341 |
+
return {"status": "ok"}
|
phasenet/data_reader.py
ADDED
@@ -0,0 +1,1010 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
tf.compat.v1.disable_eager_execution()
|
4 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
|
11 |
+
pd.options.mode.chained_assignment = None
|
12 |
+
import json
|
13 |
+
import random
|
14 |
+
from collections import defaultdict
|
15 |
+
|
16 |
+
# import s3fs
|
17 |
+
import h5py
|
18 |
+
import obspy
|
19 |
+
from scipy.interpolate import interp1d
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
|
23 |
+
def py_func_decorator(output_types=None, output_shapes=None, name=None):
|
24 |
+
def decorator(func):
|
25 |
+
def call(*args, **kwargs):
|
26 |
+
nonlocal output_shapes
|
27 |
+
# flat_output_types = nest.flatten(output_types)
|
28 |
+
flat_output_types = tf.nest.flatten(output_types)
|
29 |
+
# flat_values = tf.py_func(
|
30 |
+
flat_values = tf.numpy_function(func, inp=args, Tout=flat_output_types, name=name)
|
31 |
+
if output_shapes is not None:
|
32 |
+
for v, s in zip(flat_values, output_shapes):
|
33 |
+
v.set_shape(s)
|
34 |
+
# return nest.pack_sequence_as(output_types, flat_values)
|
35 |
+
return tf.nest.pack_sequence_as(output_types, flat_values)
|
36 |
+
|
37 |
+
return call
|
38 |
+
|
39 |
+
return decorator
|
40 |
+
|
41 |
+
|
42 |
+
def dataset_map(iterator, output_types, output_shapes=None, num_parallel_calls=None, name=None, shuffle=False):
|
43 |
+
dataset = tf.data.Dataset.range(len(iterator))
|
44 |
+
if shuffle:
|
45 |
+
dataset = dataset.shuffle(len(iterator), reshuffle_each_iteration=True)
|
46 |
+
|
47 |
+
@py_func_decorator(output_types, output_shapes, name=name)
|
48 |
+
def index_to_entry(idx):
|
49 |
+
return iterator[idx]
|
50 |
+
|
51 |
+
return dataset.map(index_to_entry, num_parallel_calls=num_parallel_calls)
|
52 |
+
|
53 |
+
|
54 |
+
def normalize(data, axis=(0,)):
|
55 |
+
"""data shape: (nt, nsta, nch)"""
|
56 |
+
data -= np.mean(data, axis=axis, keepdims=True)
|
57 |
+
std_data = np.std(data, axis=axis, keepdims=True)
|
58 |
+
std_data[std_data == 0] = 1
|
59 |
+
data /= std_data
|
60 |
+
# data /= (std_data + 1e-12)
|
61 |
+
return data
|
62 |
+
|
63 |
+
|
64 |
+
def normalize_long(data, axis=(0,), window=3000):
|
65 |
+
"""
|
66 |
+
data: nt, nch
|
67 |
+
"""
|
68 |
+
nt, nar, nch = data.shape
|
69 |
+
if window is None:
|
70 |
+
window = nt
|
71 |
+
shift = window // 2
|
72 |
+
|
73 |
+
dtype = data.dtype
|
74 |
+
## std in slide windows
|
75 |
+
data_pad = np.pad(data, ((window // 2, window // 2), (0, 0), (0, 0)), mode="reflect")
|
76 |
+
t = np.arange(0, nt, shift, dtype="int")
|
77 |
+
std = np.zeros([len(t) + 1, nar, nch])
|
78 |
+
mean = np.zeros([len(t) + 1, nar, nch])
|
79 |
+
for i in range(1, len(std)):
|
80 |
+
std[i, :] = np.std(data_pad[i * shift : i * shift + window, :, :], axis=axis)
|
81 |
+
mean[i, :] = np.mean(data_pad[i * shift : i * shift + window, :, :], axis=axis)
|
82 |
+
|
83 |
+
t = np.append(t, nt)
|
84 |
+
# std[-1, :] = np.std(data_pad[-window:, :], axis=0)
|
85 |
+
# mean[-1, :] = np.mean(data_pad[-window:, :], axis=0)
|
86 |
+
std[-1, ...], mean[-1, ...] = std[-2, ...], mean[-2, ...]
|
87 |
+
std[0, ...], mean[0, ...] = std[1, ...], mean[1, ...]
|
88 |
+
# std[std == 0] = 1.0
|
89 |
+
|
90 |
+
## normalize data with interplated std
|
91 |
+
t_interp = np.arange(nt, dtype="int")
|
92 |
+
std_interp = interp1d(t, std, axis=0, kind="slinear")(t_interp)
|
93 |
+
# std_interp = np.exp(interp1d(t, np.log(std), axis=0, kind="slinear")(t_interp))
|
94 |
+
mean_interp = interp1d(t, mean, axis=0, kind="slinear")(t_interp)
|
95 |
+
tmp = np.sum(std_interp, axis=(0, 1))
|
96 |
+
std_interp[std_interp == 0] = 1.0
|
97 |
+
data = (data - mean_interp) / std_interp
|
98 |
+
# data = (data - mean_interp)/(std_interp + 1e-12)
|
99 |
+
|
100 |
+
### dropout effect of < 3 channel
|
101 |
+
nonzero = np.count_nonzero(tmp)
|
102 |
+
if (nonzero < 3) and (nonzero > 0):
|
103 |
+
data *= 3.0 / nonzero
|
104 |
+
|
105 |
+
return data.astype(dtype)
|
106 |
+
|
107 |
+
|
108 |
+
def normalize_batch(data, window=3000):
|
109 |
+
"""
|
110 |
+
data: nsta, nt, nch
|
111 |
+
"""
|
112 |
+
nsta, nt, nar, nch = data.shape
|
113 |
+
if window is None:
|
114 |
+
window = nt
|
115 |
+
shift = window // 2
|
116 |
+
|
117 |
+
## std in slide windows
|
118 |
+
data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0), (0, 0)), mode="reflect")
|
119 |
+
t = np.arange(0, nt, shift, dtype="int")
|
120 |
+
std = np.zeros([nsta, len(t) + 1, nar, nch])
|
121 |
+
mean = np.zeros([nsta, len(t) + 1, nar, nch])
|
122 |
+
for i in range(1, len(t)):
|
123 |
+
std[:, i, :, :] = np.std(data_pad[:, i * shift : i * shift + window, :, :], axis=1)
|
124 |
+
mean[:, i, :, :] = np.mean(data_pad[:, i * shift : i * shift + window, :, :], axis=1)
|
125 |
+
|
126 |
+
t = np.append(t, nt)
|
127 |
+
# std[:, -1, :] = np.std(data_pad[:, -window:, :], axis=1)
|
128 |
+
# mean[:, -1, :] = np.mean(data_pad[:, -window:, :], axis=1)
|
129 |
+
std[:, -1, :, :], mean[:, -1, :, :] = std[:, -2, :, :], mean[:, -2, :, :]
|
130 |
+
std[:, 0, :, :], mean[:, 0, :, :] = std[:, 1, :, :], mean[:, 1, :, :]
|
131 |
+
# std[std == 0] = 1
|
132 |
+
|
133 |
+
# ## normalize data with interplated std
|
134 |
+
t_interp = np.arange(nt, dtype="int")
|
135 |
+
std_interp = interp1d(t, std, axis=1, kind="slinear")(t_interp)
|
136 |
+
# std_interp = np.exp(interp1d(t, np.log(std), axis=1, kind="slinear")(t_interp))
|
137 |
+
mean_interp = interp1d(t, mean, axis=1, kind="slinear")(t_interp)
|
138 |
+
tmp = np.sum(std_interp, axis=(1, 2))
|
139 |
+
std_interp[std_interp == 0] = 1.0
|
140 |
+
data = (data - mean_interp) / std_interp
|
141 |
+
# data = (data - mean_interp)/(std_interp + 1e-12)
|
142 |
+
|
143 |
+
### dropout effect of < 3 channel
|
144 |
+
nonzero = np.count_nonzero(tmp, axis=-1)
|
145 |
+
data[nonzero > 0, ...] *= 3.0 / nonzero[nonzero > 0][:, np.newaxis, np.newaxis, np.newaxis]
|
146 |
+
|
147 |
+
return data
|
148 |
+
|
149 |
+
|
150 |
+
class DataConfig:
|
151 |
+
seed = 123
|
152 |
+
use_seed = True
|
153 |
+
n_channel = 3
|
154 |
+
n_class = 3
|
155 |
+
sampling_rate = 100
|
156 |
+
dt = 1.0 / sampling_rate
|
157 |
+
X_shape = [3000, 1, n_channel]
|
158 |
+
Y_shape = [3000, 1, n_class]
|
159 |
+
min_event_gap = 3 * sampling_rate
|
160 |
+
label_shape = "gaussian"
|
161 |
+
label_width = 30
|
162 |
+
dtype = "float32"
|
163 |
+
|
164 |
+
def __init__(self, **kwargs):
|
165 |
+
for k, v in kwargs.items():
|
166 |
+
setattr(self, k, v)
|
167 |
+
|
168 |
+
|
169 |
+
class DataReader:
|
170 |
+
def __init__(
|
171 |
+
self, format="numpy", config=DataConfig(), response_xml=None, sampling_rate=100, highpass_filter=0, **kwargs
|
172 |
+
):
|
173 |
+
self.buffer = {}
|
174 |
+
self.n_channel = config.n_channel
|
175 |
+
self.n_class = config.n_class
|
176 |
+
self.X_shape = config.X_shape
|
177 |
+
self.Y_shape = config.Y_shape
|
178 |
+
self.dt = config.dt
|
179 |
+
self.dtype = config.dtype
|
180 |
+
self.label_shape = config.label_shape
|
181 |
+
self.label_width = config.label_width
|
182 |
+
self.config = config
|
183 |
+
self.format = format
|
184 |
+
# if "highpass_filter" in kwargs:
|
185 |
+
# self.highpass_filter = kwargs["highpass_filter"]
|
186 |
+
self.highpass_filter = highpass_filter
|
187 |
+
# self.response_xml = response_xml
|
188 |
+
if response_xml is not None:
|
189 |
+
self.response = obspy.read_inventory(response_xml)
|
190 |
+
else:
|
191 |
+
self.response = None
|
192 |
+
self.sampling_rate = sampling_rate
|
193 |
+
if format in ["numpy", "mseed", "sac"]:
|
194 |
+
self.data_dir = kwargs["data_dir"]
|
195 |
+
try:
|
196 |
+
csv = pd.read_csv(kwargs["data_list"], header=0, sep="[,|\s+]", engine="python")
|
197 |
+
except:
|
198 |
+
csv = pd.read_csv(kwargs["data_list"], header=0, sep="\t")
|
199 |
+
self.data_list = csv["fname"]
|
200 |
+
self.num_data = len(self.data_list)
|
201 |
+
elif format == "hdf5":
|
202 |
+
self.h5 = h5py.File(kwargs["hdf5_file"], "r", libver="latest", swmr=True)
|
203 |
+
self.h5_data = self.h5[kwargs["hdf5_group"]]
|
204 |
+
self.data_list = list(self.h5_data.keys())
|
205 |
+
self.num_data = len(self.data_list)
|
206 |
+
elif format == "s3":
|
207 |
+
self.s3fs = s3fs.S3FileSystem(
|
208 |
+
anon=kwargs["anon"],
|
209 |
+
key=kwargs["key"],
|
210 |
+
secret=kwargs["secret"],
|
211 |
+
client_kwargs={"endpoint_url": kwargs["s3_url"]},
|
212 |
+
use_ssl=kwargs["use_ssl"],
|
213 |
+
)
|
214 |
+
self.num_data = 0
|
215 |
+
else:
|
216 |
+
raise (f"{format} not support!")
|
217 |
+
|
218 |
+
def __len__(self):
|
219 |
+
return self.num_data
|
220 |
+
|
221 |
+
def read_numpy(self, fname):
|
222 |
+
# try:
|
223 |
+
if fname not in self.buffer:
|
224 |
+
npz = np.load(fname)
|
225 |
+
meta = {}
|
226 |
+
if len(npz["data"].shape) == 2:
|
227 |
+
meta["data"] = npz["data"][:, np.newaxis, :]
|
228 |
+
else:
|
229 |
+
meta["data"] = npz["data"]
|
230 |
+
if "p_idx" in npz.files:
|
231 |
+
if len(npz["p_idx"].shape) == 0:
|
232 |
+
meta["itp"] = [[npz["p_idx"]]]
|
233 |
+
else:
|
234 |
+
meta["itp"] = npz["p_idx"]
|
235 |
+
if "s_idx" in npz.files:
|
236 |
+
if len(npz["s_idx"].shape) == 0:
|
237 |
+
meta["its"] = [[npz["s_idx"]]]
|
238 |
+
else:
|
239 |
+
meta["its"] = npz["s_idx"]
|
240 |
+
if "itp" in npz.files:
|
241 |
+
if len(npz["itp"].shape) == 0:
|
242 |
+
meta["itp"] = [[npz["itp"]]]
|
243 |
+
else:
|
244 |
+
meta["itp"] = npz["itp"]
|
245 |
+
if "its" in npz.files:
|
246 |
+
if len(npz["its"].shape) == 0:
|
247 |
+
meta["its"] = [[npz["its"]]]
|
248 |
+
else:
|
249 |
+
meta["its"] = npz["its"]
|
250 |
+
if "station_id" in npz.files:
|
251 |
+
meta["station_id"] = npz["station_id"]
|
252 |
+
if "sta_id" in npz.files:
|
253 |
+
meta["station_id"] = npz["sta_id"]
|
254 |
+
if "t0" in npz.files:
|
255 |
+
meta["t0"] = npz["t0"]
|
256 |
+
self.buffer[fname] = meta
|
257 |
+
else:
|
258 |
+
meta = self.buffer[fname]
|
259 |
+
return meta
|
260 |
+
# except:
|
261 |
+
# logging.error("Failed reading {}".format(fname))
|
262 |
+
# return None
|
263 |
+
|
264 |
+
def read_hdf5(self, fname):
|
265 |
+
data = self.h5_data[fname][()]
|
266 |
+
attrs = self.h5_data[fname].attrs
|
267 |
+
meta = {}
|
268 |
+
if len(data.shape) == 2:
|
269 |
+
meta["data"] = data[:, np.newaxis, :]
|
270 |
+
else:
|
271 |
+
meta["data"] = data
|
272 |
+
if "p_idx" in attrs:
|
273 |
+
if len(attrs["p_idx"].shape) == 0:
|
274 |
+
meta["itp"] = [[attrs["p_idx"]]]
|
275 |
+
else:
|
276 |
+
meta["itp"] = attrs["p_idx"]
|
277 |
+
if "s_idx" in attrs:
|
278 |
+
if len(attrs["s_idx"].shape) == 0:
|
279 |
+
meta["its"] = [[attrs["s_idx"]]]
|
280 |
+
else:
|
281 |
+
meta["its"] = attrs["s_idx"]
|
282 |
+
if "itp" in attrs:
|
283 |
+
if len(attrs["itp"].shape) == 0:
|
284 |
+
meta["itp"] = [[attrs["itp"]]]
|
285 |
+
else:
|
286 |
+
meta["itp"] = attrs["itp"]
|
287 |
+
if "its" in attrs:
|
288 |
+
if len(attrs["its"].shape) == 0:
|
289 |
+
meta["its"] = [[attrs["its"]]]
|
290 |
+
else:
|
291 |
+
meta["its"] = attrs["its"]
|
292 |
+
if "t0" in attrs:
|
293 |
+
meta["t0"] = attrs["t0"]
|
294 |
+
return meta
|
295 |
+
|
296 |
+
def read_s3(self, format, fname, bucket, key, secret, s3_url, use_ssl):
|
297 |
+
with self.s3fs.open(bucket + "/" + fname, "rb") as fp:
|
298 |
+
if format == "numpy":
|
299 |
+
meta = self.read_numpy(fp)
|
300 |
+
elif format == "mseed":
|
301 |
+
meta = self.read_mseed(fp)
|
302 |
+
else:
|
303 |
+
raise (f"Format {format} not supported")
|
304 |
+
return meta
|
305 |
+
|
306 |
+
def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=100, return_single_station=True):
|
307 |
+
try:
|
308 |
+
stream = obspy.read(fname)
|
309 |
+
stream = stream.merge(fill_value="latest")
|
310 |
+
if response is not None:
|
311 |
+
# response = obspy.read_inventory(response_xml)
|
312 |
+
stream = stream.remove_sensitivity(response)
|
313 |
+
except Exception as e:
|
314 |
+
print(f"Error reading {fname}:\n{e}")
|
315 |
+
return {}
|
316 |
+
tmp_stream = obspy.Stream()
|
317 |
+
for trace in stream:
|
318 |
+
if len(trace.data) < 10:
|
319 |
+
continue
|
320 |
+
|
321 |
+
## interpolate to 100 Hz
|
322 |
+
if abs(trace.stats.sampling_rate - sampling_rate) > 0.1:
|
323 |
+
logging.warning(f"Resampling {trace.id} from {trace.stats.sampling_rate} to {sampling_rate} Hz")
|
324 |
+
try:
|
325 |
+
trace = trace.interpolate(sampling_rate, method="linear")
|
326 |
+
except Exception as e:
|
327 |
+
print(f"Error resampling {trace.id}:\n{e}")
|
328 |
+
|
329 |
+
trace = trace.detrend("demean")
|
330 |
+
|
331 |
+
## highpass filtering > 1Hz
|
332 |
+
if highpass_filter > 0.0:
|
333 |
+
trace = trace.filter("highpass", freq=highpass_filter)
|
334 |
+
|
335 |
+
tmp_stream.append(trace)
|
336 |
+
|
337 |
+
if len(tmp_stream) == 0:
|
338 |
+
return {}
|
339 |
+
stream = tmp_stream
|
340 |
+
|
341 |
+
begin_time = min([st.stats.starttime for st in stream])
|
342 |
+
end_time = max([st.stats.endtime for st in stream])
|
343 |
+
stream = stream.trim(begin_time, end_time, pad=True, fill_value=0)
|
344 |
+
|
345 |
+
comp = ["3", "2", "1", "E", "N", "U", "V", "Z"]
|
346 |
+
order = {key: i for i, key in enumerate(comp)}
|
347 |
+
comp2idx = {
|
348 |
+
"3": 0,
|
349 |
+
"2": 1,
|
350 |
+
"1": 2,
|
351 |
+
"E": 0,
|
352 |
+
"N": 1,
|
353 |
+
"Z": 2,
|
354 |
+
"U": 0,
|
355 |
+
"V": 1,
|
356 |
+
} ## only for cases less than 3 components
|
357 |
+
|
358 |
+
station_ids = defaultdict(list)
|
359 |
+
for tr in stream:
|
360 |
+
station_ids[tr.id[:-1]].append(tr.id[-1])
|
361 |
+
if tr.id[-1] not in comp:
|
362 |
+
print(f"Unknown component {tr.id[-1]}")
|
363 |
+
|
364 |
+
station_keys = sorted(list(station_ids.keys()))
|
365 |
+
|
366 |
+
nx = len(station_ids)
|
367 |
+
nt = len(stream[0].data)
|
368 |
+
data = np.zeros([3, nt, nx], dtype=np.float32)
|
369 |
+
for i, sta in enumerate(station_keys):
|
370 |
+
for j, c in enumerate(sorted(station_ids[sta], key=lambda x: order[x])):
|
371 |
+
if len(station_ids[sta]) != 3: ## less than 3 component
|
372 |
+
j = comp2idx[c]
|
373 |
+
|
374 |
+
if len(stream.select(id=sta + c)) == 0:
|
375 |
+
print(f"Empty trace: {sta+c} {begin_time}")
|
376 |
+
continue
|
377 |
+
|
378 |
+
trace = stream.select(id=sta + c)[0]
|
379 |
+
|
380 |
+
## accerleration to velocity
|
381 |
+
if sta[-1] == "N":
|
382 |
+
trace = trace.integrate().filter("highpass", freq=1.0)
|
383 |
+
|
384 |
+
tmp = trace.data.astype("float32")
|
385 |
+
data[j, : len(tmp), i] = tmp[:nt]
|
386 |
+
|
387 |
+
# if return_single_station and (len(station_keys) > 1):
|
388 |
+
# print(f"Warning: {fname} has multiple stations, returning only the first one {station_keys[0]}")
|
389 |
+
# data = data[:, :, 0:1]
|
390 |
+
# station_keys = station_keys[0:1]
|
391 |
+
|
392 |
+
meta = {
|
393 |
+
"data": data.transpose([1, 2, 0]),
|
394 |
+
"t0": begin_time.datetime.isoformat(timespec="milliseconds"),
|
395 |
+
"station_id": station_keys,
|
396 |
+
}
|
397 |
+
return meta
|
398 |
+
|
399 |
+
def read_sac(self, fname):
|
400 |
+
mseed = obspy.read(fname)
|
401 |
+
mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
|
402 |
+
mseed = mseed.merge(fill_value=0)
|
403 |
+
if self.highpass_filter > 0:
|
404 |
+
mseed = mseed.filter("highpass", freq=self.highpass_filter)
|
405 |
+
starttime = min([st.stats.starttime for st in mseed])
|
406 |
+
endtime = max([st.stats.endtime for st in mseed])
|
407 |
+
mseed = mseed.trim(starttime, endtime, pad=True, fill_value=0)
|
408 |
+
if abs(mseed[0].stats.sampling_rate - self.config.sampling_rate) > 1:
|
409 |
+
logging.warning(
|
410 |
+
f"Sampling rate mismatch in {fname.split('/')[-1]}: {mseed[0].stats.sampling_rate}Hz != {self.config.sampling_rate}Hz "
|
411 |
+
)
|
412 |
+
|
413 |
+
order = ["3", "2", "1", "E", "N", "Z"]
|
414 |
+
order = {key: i for i, key in enumerate(order)}
|
415 |
+
comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
|
416 |
+
|
417 |
+
t0 = starttime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
418 |
+
nt = len(mseed[0].data)
|
419 |
+
data = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
|
420 |
+
ids = [x.get_id() for x in mseed]
|
421 |
+
for j, id in enumerate(sorted(ids, key=lambda x: order[x[-1]])):
|
422 |
+
if len(ids) != 3:
|
423 |
+
if len(ids) > 3:
|
424 |
+
logging.warning(f"More than 3 channels {ids}!")
|
425 |
+
j = comp2idx[id[-1]]
|
426 |
+
data[:, j] = mseed.select(id=id)[0].data.astype(self.dtype)
|
427 |
+
|
428 |
+
data = data[:, np.newaxis, :]
|
429 |
+
meta = {"data": data, "t0": t0}
|
430 |
+
return meta
|
431 |
+
|
432 |
+
def read_mseed_array(self, fname, stations, amplitude=False, remove_resp=True):
|
433 |
+
data = []
|
434 |
+
station_id = []
|
435 |
+
t0 = []
|
436 |
+
raw_amp = []
|
437 |
+
|
438 |
+
try:
|
439 |
+
mseed = obspy.read(fname)
|
440 |
+
read_success = True
|
441 |
+
except Exception as e:
|
442 |
+
read_success = False
|
443 |
+
print(e)
|
444 |
+
|
445 |
+
if read_success:
|
446 |
+
try:
|
447 |
+
mseed = mseed.merge(fill_value=0)
|
448 |
+
except Exception as e:
|
449 |
+
print(e)
|
450 |
+
|
451 |
+
for i in range(len(mseed)):
|
452 |
+
if mseed[i].stats.sampling_rate != self.config.sampling_rate:
|
453 |
+
logging.warning(
|
454 |
+
f"Resampling {mseed[i].id} from {mseed[i].stats.sampling_rate} to {self.config.sampling_rate} Hz"
|
455 |
+
)
|
456 |
+
try:
|
457 |
+
mseed[i] = mseed[i].interpolate(self.config.sampling_rate, method="linear")
|
458 |
+
except Exception as e:
|
459 |
+
print(e)
|
460 |
+
mseed[i].data = mseed[i].data.astype(float) * 0.0 ## set to zero if resampling fails
|
461 |
+
|
462 |
+
if self.highpass_filter == 0:
|
463 |
+
try:
|
464 |
+
mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
|
465 |
+
except:
|
466 |
+
logging.error(f"Error: spline detrend failed at file {fname}")
|
467 |
+
mseed = mseed.detrend("demean")
|
468 |
+
else:
|
469 |
+
mseed = mseed.filter("highpass", freq=self.highpass_filter)
|
470 |
+
|
471 |
+
starttime = min([st.stats.starttime for st in mseed])
|
472 |
+
endtime = max([st.stats.endtime for st in mseed])
|
473 |
+
mseed = mseed.trim(starttime, endtime, pad=True, fill_value=0)
|
474 |
+
|
475 |
+
order = ["3", "2", "1", "E", "N", "Z"]
|
476 |
+
order = {key: i for i, key in enumerate(order)}
|
477 |
+
comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
|
478 |
+
|
479 |
+
nsta = len(stations)
|
480 |
+
nt = len(mseed[0].data)
|
481 |
+
# for i in range(nsta):
|
482 |
+
for sta in stations:
|
483 |
+
trace_data = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
|
484 |
+
if amplitude:
|
485 |
+
trace_amp = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
|
486 |
+
empty_station = True
|
487 |
+
# sta = stations.iloc[i]["station"]
|
488 |
+
# comp = stations.iloc[i]["component"].split(",")
|
489 |
+
comp = stations[sta]["component"]
|
490 |
+
if amplitude:
|
491 |
+
# resp = stations.iloc[i]["response"].split(",")
|
492 |
+
resp = stations[sta]["response"]
|
493 |
+
|
494 |
+
for j, c in enumerate(sorted(comp, key=lambda x: order[x[-1]])):
|
495 |
+
resp_j = resp[j]
|
496 |
+
if len(comp) != 3: ## less than 3 component
|
497 |
+
j = comp2idx[c]
|
498 |
+
|
499 |
+
if len(mseed.select(id=sta + c)) == 0:
|
500 |
+
print(f"Empty trace: {sta+c} {starttime}")
|
501 |
+
continue
|
502 |
+
else:
|
503 |
+
empty_station = False
|
504 |
+
|
505 |
+
tmp = mseed.select(id=sta + c)[0].data.astype(self.dtype)
|
506 |
+
trace_data[: len(tmp), j] = tmp[:nt]
|
507 |
+
if amplitude:
|
508 |
+
# if stations.iloc[i]["unit"] == "m/s**2":
|
509 |
+
if stations[sta]["unit"] == "m/s**2":
|
510 |
+
tmp = mseed.select(id=sta + c)[0]
|
511 |
+
tmp = tmp.integrate()
|
512 |
+
tmp = tmp.filter("highpass", freq=1.0)
|
513 |
+
tmp = tmp.data.astype(self.dtype)
|
514 |
+
trace_amp[: len(tmp), j] = tmp[:nt]
|
515 |
+
# elif stations.iloc[i]["unit"] == "m/s":
|
516 |
+
elif stations[sta]["unit"] == "m/s":
|
517 |
+
tmp = mseed.select(id=sta + c)[0].data.astype(self.dtype)
|
518 |
+
trace_amp[: len(tmp), j] = tmp[:nt]
|
519 |
+
else:
|
520 |
+
print(
|
521 |
+
f"Error in {stations.iloc[i]['station']}\n{stations.iloc[i]['unit']} should be m/s**2 or m/s!"
|
522 |
+
)
|
523 |
+
if amplitude and remove_resp:
|
524 |
+
# trace_amp[:, j] /= float(resp[j])
|
525 |
+
trace_amp[:, j] /= float(resp_j)
|
526 |
+
|
527 |
+
if not empty_station:
|
528 |
+
data.append(trace_data)
|
529 |
+
if amplitude:
|
530 |
+
raw_amp.append(trace_amp)
|
531 |
+
station_id.append([sta])
|
532 |
+
t0.append(starttime.datetime.isoformat(timespec="milliseconds"))
|
533 |
+
|
534 |
+
if len(data) > 0:
|
535 |
+
data = np.stack(data)
|
536 |
+
if len(data.shape) == 3:
|
537 |
+
data = data[:, :, np.newaxis, :]
|
538 |
+
if amplitude:
|
539 |
+
raw_amp = np.stack(raw_amp)
|
540 |
+
if len(raw_amp.shape) == 3:
|
541 |
+
raw_amp = raw_amp[:, :, np.newaxis, :]
|
542 |
+
else:
|
543 |
+
nt = 60 * 60 * self.config.sampling_rate # assume 1 hour data
|
544 |
+
data = np.zeros([1, nt, 1, self.config.n_channel], dtype=self.dtype)
|
545 |
+
if amplitude:
|
546 |
+
raw_amp = np.zeros([1, nt, 1, self.config.n_channel], dtype=self.dtype)
|
547 |
+
t0 = ["1970-01-01T00:00:00.000"]
|
548 |
+
station_id = ["None"]
|
549 |
+
|
550 |
+
if amplitude:
|
551 |
+
meta = {"data": data, "t0": t0, "station_id": station_id, "fname": fname.split("/")[-1], "raw_amp": raw_amp}
|
552 |
+
else:
|
553 |
+
meta = {"data": data, "t0": t0, "station_id": station_id, "fname": fname.split("/")[-1]}
|
554 |
+
return meta
|
555 |
+
|
556 |
+
def generate_label(self, data, phase_list, mask=None):
|
557 |
+
# target = np.zeros(self.Y_shape, dtype=self.dtype)
|
558 |
+
target = np.zeros_like(data)
|
559 |
+
|
560 |
+
if self.label_shape == "gaussian":
|
561 |
+
label_window = np.exp(
|
562 |
+
-((np.arange(-self.label_width // 2, self.label_width // 2 + 1)) ** 2)
|
563 |
+
/ (2 * (self.label_width / 5) ** 2)
|
564 |
+
)
|
565 |
+
elif self.label_shape == "triangle":
|
566 |
+
label_window = 1 - np.abs(
|
567 |
+
2 / self.label_width * (np.arange(-self.label_width // 2, self.label_width // 2 + 1))
|
568 |
+
)
|
569 |
+
else:
|
570 |
+
print(f"Label shape {self.label_shape} should be guassian or triangle")
|
571 |
+
raise
|
572 |
+
|
573 |
+
for i, phases in enumerate(phase_list):
|
574 |
+
for j, idx_list in enumerate(phases):
|
575 |
+
for idx in idx_list:
|
576 |
+
if np.isnan(idx):
|
577 |
+
continue
|
578 |
+
idx = int(idx)
|
579 |
+
if (idx - self.label_width // 2 >= 0) and (idx + self.label_width // 2 + 1 <= target.shape[0]):
|
580 |
+
target[idx - self.label_width // 2 : idx + self.label_width // 2 + 1, j, i + 1] = label_window
|
581 |
+
|
582 |
+
target[..., 0] = 1 - np.sum(target[..., 1:], axis=-1)
|
583 |
+
if mask is not None:
|
584 |
+
target[:, mask == 0, :] = 0
|
585 |
+
|
586 |
+
return target
|
587 |
+
|
588 |
+
def random_shift(self, sample, itp, its, itp_old=None, its_old=None, shift_range=None):
|
589 |
+
# anchor = np.round(1/2 * (min(itp[~np.isnan(itp.astype(float))]) + min(its[~np.isnan(its.astype(float))]))).astype(int)
|
590 |
+
flattern = lambda x: np.array([i for trace in x for i in trace], dtype=float)
|
591 |
+
shift_pick = lambda x, shift: [[i - shift for i in trace] for trace in x]
|
592 |
+
itp_flat = flattern(itp)
|
593 |
+
its_flat = flattern(its)
|
594 |
+
if (itp_old is None) and (its_old is None):
|
595 |
+
hi = np.round(np.median(itp_flat[~np.isnan(itp_flat)])).astype(int)
|
596 |
+
lo = -(sample.shape[0] - np.round(np.median(its_flat[~np.isnan(its_flat)])).astype(int))
|
597 |
+
if shift_range is None:
|
598 |
+
shift = np.random.randint(low=lo, high=hi + 1)
|
599 |
+
else:
|
600 |
+
shift = np.random.randint(low=max(lo, shift_range[0]), high=min(hi + 1, shift_range[1]))
|
601 |
+
else:
|
602 |
+
itp_old_flat = flattern(itp_old)
|
603 |
+
its_old_flat = flattern(its_old)
|
604 |
+
itp_ref = np.round(np.min(itp_flat[~np.isnan(itp_flat)])).astype(int)
|
605 |
+
its_ref = np.round(np.max(its_flat[~np.isnan(its_flat)])).astype(int)
|
606 |
+
itp_old_ref = np.round(np.min(itp_old_flat[~np.isnan(itp_old_flat)])).astype(int)
|
607 |
+
its_old_ref = np.round(np.max(its_old_flat[~np.isnan(its_old_flat)])).astype(int)
|
608 |
+
# min_event_gap = np.round(self.min_event_gap*(its_ref-itp_ref)).astype(int)
|
609 |
+
# min_event_gap_old = np.round(self.min_event_gap*(its_old_ref-itp_old_ref)).astype(int)
|
610 |
+
if shift_range is None:
|
611 |
+
hi = list(range(max(its_ref - itp_old_ref + self.min_event_gap, 0), itp_ref))
|
612 |
+
lo = list(range(-(sample.shape[0] - its_ref), -(max(its_old_ref - itp_ref + self.min_event_gap, 0))))
|
613 |
+
else:
|
614 |
+
lo_ = max(-(sample.shape[0] - its_ref), shift_range[0])
|
615 |
+
hi_ = min(itp_ref, shift_range[1])
|
616 |
+
hi = list(range(max(its_ref - itp_old_ref + self.min_event_gap, 0), hi_))
|
617 |
+
lo = list(range(lo_, -(max(its_old_ref - itp_ref + self.min_event_gap, 0))))
|
618 |
+
if len(hi + lo) > 0:
|
619 |
+
shift = np.random.choice(hi + lo)
|
620 |
+
else:
|
621 |
+
shift = 0
|
622 |
+
|
623 |
+
shifted_sample = np.zeros_like(sample)
|
624 |
+
if shift > 0:
|
625 |
+
shifted_sample[:-shift, ...] = sample[shift:, ...]
|
626 |
+
elif shift < 0:
|
627 |
+
shifted_sample[-shift:, ...] = sample[:shift, ...]
|
628 |
+
else:
|
629 |
+
shifted_sample[...] = sample[...]
|
630 |
+
|
631 |
+
return shifted_sample, shift_pick(itp, shift), shift_pick(its, shift), shift
|
632 |
+
|
633 |
+
def stack_events(self, sample_old, itp_old, its_old, shift_range=None, mask_old=None):
|
634 |
+
i = np.random.randint(self.num_data)
|
635 |
+
base_name = self.data_list[i]
|
636 |
+
if self.format == "numpy":
|
637 |
+
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
|
638 |
+
elif self.format == "hdf5":
|
639 |
+
meta = self.read_hdf5(base_name)
|
640 |
+
if meta == -1:
|
641 |
+
return sample_old, itp_old, its_old
|
642 |
+
|
643 |
+
sample = np.copy(meta["data"])
|
644 |
+
itp = meta["itp"]
|
645 |
+
its = meta["its"]
|
646 |
+
if mask_old is not None:
|
647 |
+
mask = np.copy(meta["mask"])
|
648 |
+
sample = normalize(sample)
|
649 |
+
sample, itp, its, shift = self.random_shift(sample, itp, its, itp_old, its_old, shift_range)
|
650 |
+
|
651 |
+
if shift != 0:
|
652 |
+
sample_old += sample
|
653 |
+
# itp_old = [np.hstack([i, j]) for i,j in zip(itp_old, itp)]
|
654 |
+
# its_old = [np.hstack([i, j]) for i,j in zip(its_old, its)]
|
655 |
+
itp_old = [i + j for i, j in zip(itp_old, itp)]
|
656 |
+
its_old = [i + j for i, j in zip(its_old, its)]
|
657 |
+
if mask_old is not None:
|
658 |
+
mask_old = mask_old * mask
|
659 |
+
|
660 |
+
return sample_old, itp_old, its_old, mask_old
|
661 |
+
|
662 |
+
def cut_window(self, sample, target, itp, its, select_range):
|
663 |
+
shift_pick = lambda x, shift: [[i - shift for i in trace] for trace in x]
|
664 |
+
sample = sample[select_range[0] : select_range[1]]
|
665 |
+
target = target[select_range[0] : select_range[1]]
|
666 |
+
return (sample, target, shift_pick(itp, select_range[0]), shift_pick(its, select_range[0]))
|
667 |
+
|
668 |
+
|
669 |
+
class DataReader_train(DataReader):
|
670 |
+
def __init__(self, format="numpy", config=DataConfig(), **kwargs):
|
671 |
+
super().__init__(format=format, config=config, **kwargs)
|
672 |
+
|
673 |
+
self.min_event_gap = config.min_event_gap
|
674 |
+
self.buffer_channels = {}
|
675 |
+
self.shift_range = [-2000 + self.label_width * 2, 1000 - self.label_width * 2]
|
676 |
+
self.select_range = [5000, 8000]
|
677 |
+
|
678 |
+
def __getitem__(self, i):
|
679 |
+
base_name = self.data_list[i]
|
680 |
+
if self.format == "numpy":
|
681 |
+
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
|
682 |
+
elif self.format == "hdf5":
|
683 |
+
meta = self.read_hdf5(base_name)
|
684 |
+
if meta == None:
|
685 |
+
return (np.zeros(self.X_shape, dtype=self.dtype), np.zeros(self.Y_shape, dtype=self.dtype), base_name)
|
686 |
+
|
687 |
+
sample = np.copy(meta["data"])
|
688 |
+
itp_list = meta["itp"]
|
689 |
+
its_list = meta["its"]
|
690 |
+
|
691 |
+
sample = normalize(sample)
|
692 |
+
if np.random.random() < 0.95:
|
693 |
+
sample, itp_list, its_list, _ = self.random_shift(sample, itp_list, its_list, shift_range=self.shift_range)
|
694 |
+
sample, itp_list, its_list, _ = self.stack_events(sample, itp_list, its_list, shift_range=self.shift_range)
|
695 |
+
target = self.generate_label(sample, [itp_list, its_list])
|
696 |
+
sample, target, itp_list, its_list = self.cut_window(sample, target, itp_list, its_list, self.select_range)
|
697 |
+
else:
|
698 |
+
## noise
|
699 |
+
assert self.X_shape[0] <= min(min(itp_list))
|
700 |
+
sample = sample[: self.X_shape[0], ...]
|
701 |
+
target = np.zeros(self.Y_shape).astype(self.dtype)
|
702 |
+
itp_list = [[]]
|
703 |
+
its_list = [[]]
|
704 |
+
|
705 |
+
sample = normalize(sample)
|
706 |
+
return (sample.astype(self.dtype), target.astype(self.dtype), base_name)
|
707 |
+
|
708 |
+
def dataset(self, batch_size, num_parallel_calls=2, shuffle=True, drop_remainder=True):
|
709 |
+
dataset = dataset_map(
|
710 |
+
self,
|
711 |
+
output_types=(self.dtype, self.dtype, "string"),
|
712 |
+
output_shapes=(self.X_shape, self.Y_shape, None),
|
713 |
+
num_parallel_calls=num_parallel_calls,
|
714 |
+
shuffle=shuffle,
|
715 |
+
)
|
716 |
+
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
|
717 |
+
return dataset
|
718 |
+
|
719 |
+
|
720 |
+
class DataReader_test(DataReader):
|
721 |
+
def __init__(self, format="numpy", config=DataConfig(), **kwargs):
|
722 |
+
super().__init__(format=format, config=config, **kwargs)
|
723 |
+
|
724 |
+
self.select_range = [5000, 8000]
|
725 |
+
|
726 |
+
def __getitem__(self, i):
|
727 |
+
base_name = self.data_list[i]
|
728 |
+
if self.format == "numpy":
|
729 |
+
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
|
730 |
+
elif self.format == "hdf5":
|
731 |
+
meta = self.read_hdf5(base_name)
|
732 |
+
if meta == -1:
|
733 |
+
return (np.zeros(self.Y_shape, dtype=self.dtype), np.zeros(self.X_shape, dtype=self.dtype), base_name)
|
734 |
+
|
735 |
+
sample = np.copy(meta["data"])
|
736 |
+
itp_list = meta["itp"]
|
737 |
+
its_list = meta["its"]
|
738 |
+
|
739 |
+
# sample, itp_list, its_list, _ = self.random_shift(sample, itp_list, its_list, shift_range=self.shift_range)
|
740 |
+
target = self.generate_label(sample, [itp_list, its_list])
|
741 |
+
sample, target, itp_list, its_list = self.cut_window(sample, target, itp_list, its_list, self.select_range)
|
742 |
+
|
743 |
+
sample = normalize(sample)
|
744 |
+
return (sample, target, base_name, itp_list, its_list)
|
745 |
+
|
746 |
+
def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainder=False):
|
747 |
+
dataset = dataset_map(
|
748 |
+
self,
|
749 |
+
output_types=(self.dtype, self.dtype, "string", "int64", "int64"),
|
750 |
+
output_shapes=(self.X_shape, self.Y_shape, None, None, None),
|
751 |
+
num_parallel_calls=num_parallel_calls,
|
752 |
+
shuffle=shuffle,
|
753 |
+
)
|
754 |
+
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
|
755 |
+
return dataset
|
756 |
+
|
757 |
+
|
758 |
+
class DataReader_pred(DataReader):
|
759 |
+
def __init__(self, format="numpy", amplitude=True, config=DataConfig(), **kwargs):
|
760 |
+
super().__init__(format=format, config=config, **kwargs)
|
761 |
+
|
762 |
+
self.amplitude = amplitude
|
763 |
+
|
764 |
+
def adjust_missingchannels(self, data):
|
765 |
+
tmp = np.max(np.abs(data), axis=0, keepdims=True)
|
766 |
+
assert tmp.shape[-1] == data.shape[-1]
|
767 |
+
if np.count_nonzero(tmp) > 0:
|
768 |
+
data *= data.shape[-1] / np.count_nonzero(tmp)
|
769 |
+
return data
|
770 |
+
|
771 |
+
def __getitem__(self, i):
|
772 |
+
base_name = self.data_list[i]
|
773 |
+
|
774 |
+
if self.format == "numpy":
|
775 |
+
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
|
776 |
+
elif (self.format == "mseed") or (self.format == "sac"):
|
777 |
+
meta = self.read_mseed(
|
778 |
+
os.path.join(self.data_dir, base_name),
|
779 |
+
response=self.response,
|
780 |
+
sampling_rate=self.sampling_rate,
|
781 |
+
highpass_filter=self.highpass_filter,
|
782 |
+
return_single_station=True,
|
783 |
+
)
|
784 |
+
elif self.format == "hdf5":
|
785 |
+
meta = self.read_hdf5(base_name)
|
786 |
+
else:
|
787 |
+
raise (f"{self.format} does not support!")
|
788 |
+
|
789 |
+
if "data" in meta:
|
790 |
+
raw_amp = meta["data"].copy()
|
791 |
+
sample = normalize_long(meta["data"])
|
792 |
+
else:
|
793 |
+
raw_amp = np.zeros([3000, 1, 3], dtype=np.float32)
|
794 |
+
sample = np.zeros([3000, 1, 3], dtype=np.float32)
|
795 |
+
|
796 |
+
if "t0" in meta:
|
797 |
+
t0 = meta["t0"]
|
798 |
+
else:
|
799 |
+
t0 = "1970-01-01T00:00:00.000"
|
800 |
+
|
801 |
+
if "station_id" in meta:
|
802 |
+
station_id = meta["station_id"]
|
803 |
+
else:
|
804 |
+
# station_id = base_name.split("/")[-1].rstrip("*")
|
805 |
+
station_id = os.path.basename(base_name).rstrip("*")
|
806 |
+
|
807 |
+
if np.isnan(sample).any() or np.isinf(sample).any():
|
808 |
+
logging.warning(f"Data error: Nan or Inf found in {base_name}")
|
809 |
+
sample[np.isnan(sample)] = 0
|
810 |
+
sample[np.isinf(sample)] = 0
|
811 |
+
|
812 |
+
# sample = self.adjust_missingchannels(sample)
|
813 |
+
|
814 |
+
if self.amplitude:
|
815 |
+
return (sample, raw_amp, base_name, t0, station_id)
|
816 |
+
else:
|
817 |
+
return (sample, base_name, t0, station_id)
|
818 |
+
|
819 |
+
def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainder=False):
|
820 |
+
if self.amplitude:
|
821 |
+
dataset = dataset_map(
|
822 |
+
self,
|
823 |
+
output_types=(self.dtype, self.dtype, "string", "string", "string"),
|
824 |
+
output_shapes=([None, None, 3], [None, None, 3], None, None, None),
|
825 |
+
num_parallel_calls=num_parallel_calls,
|
826 |
+
shuffle=shuffle,
|
827 |
+
)
|
828 |
+
else:
|
829 |
+
dataset = dataset_map(
|
830 |
+
self,
|
831 |
+
output_types=(self.dtype, "string", "string", "string"),
|
832 |
+
output_shapes=([None, None, 3], None, None, None),
|
833 |
+
num_parallel_calls=num_parallel_calls,
|
834 |
+
shuffle=shuffle,
|
835 |
+
)
|
836 |
+
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
|
837 |
+
return dataset
|
838 |
+
|
839 |
+
|
840 |
+
class DataReader_mseed_array(DataReader):
|
841 |
+
def __init__(self, stations, amplitude=True, remove_resp=True, config=DataConfig(), **kwargs):
|
842 |
+
super().__init__(format="mseed", config=config, **kwargs)
|
843 |
+
|
844 |
+
# self.stations = pd.read_json(stations)
|
845 |
+
with open(stations, "r") as f:
|
846 |
+
self.stations = json.load(f)
|
847 |
+
print(pd.DataFrame.from_dict(self.stations, orient="index").to_string())
|
848 |
+
|
849 |
+
self.amplitude = amplitude
|
850 |
+
self.remove_resp = remove_resp
|
851 |
+
self.X_shape = self.get_data_shape()
|
852 |
+
|
853 |
+
def get_data_shape(self):
|
854 |
+
fname = os.path.join(self.data_dir, self.data_list[0])
|
855 |
+
meta = self.read_mseed_array(fname, self.stations, self.amplitude, self.remove_resp)
|
856 |
+
return meta["data"].shape
|
857 |
+
|
858 |
+
def __getitem__(self, i):
|
859 |
+
fp = os.path.join(self.data_dir, self.data_list[i])
|
860 |
+
# try:
|
861 |
+
meta = self.read_mseed_array(fp, self.stations, self.amplitude, self.remove_resp)
|
862 |
+
# except Exception as e:
|
863 |
+
# logging.error(f"Failed reading {fp}: {e}")
|
864 |
+
# if self.amplitude:
|
865 |
+
# return (np.zeros(self.X_shape).astype(self.dtype), np.zeros(self.X_shape).astype(self.dtype),
|
866 |
+
# [self.stations.iloc[i]["station"] for i in range(len(self.stations))], ["0" for i in range(len(self.stations))])
|
867 |
+
# else:
|
868 |
+
# return (np.zeros(self.X_shape).astype(self.dtype), ["" for i in range(len(self.stations))],
|
869 |
+
# [self.stations.iloc[i]["station"] for i in range(len(self.stations))])
|
870 |
+
|
871 |
+
sample = np.zeros([len(meta["data"]), *self.X_shape[1:]], dtype=self.dtype)
|
872 |
+
sample[:, : meta["data"].shape[1], :, :] = normalize_batch(meta["data"])[:, : self.X_shape[1], :, :]
|
873 |
+
if np.isnan(sample).any() or np.isinf(sample).any():
|
874 |
+
logging.warning(f"Data error: Nan or Inf found in {fp}")
|
875 |
+
sample[np.isnan(sample)] = 0
|
876 |
+
sample[np.isinf(sample)] = 0
|
877 |
+
t0 = meta["t0"]
|
878 |
+
base_name = meta["fname"]
|
879 |
+
station_id = meta["station_id"]
|
880 |
+
# base_name = [self.stations.iloc[i]["station"]+"."+t0[i] for i in range(len(self.stations))]
|
881 |
+
# base_name = [self.stations.iloc[i]["station"] for i in range(len(self.stations))]
|
882 |
+
|
883 |
+
if self.amplitude:
|
884 |
+
raw_amp = np.zeros([len(meta["raw_amp"]), *self.X_shape[1:]], dtype=self.dtype)
|
885 |
+
raw_amp[:, : meta["raw_amp"].shape[1], :, :] = meta["raw_amp"][:, : self.X_shape[1], :, :]
|
886 |
+
if np.isnan(raw_amp).any() or np.isinf(raw_amp).any():
|
887 |
+
logging.warning(f"Data error: Nan or Inf found in {fp}")
|
888 |
+
raw_amp[np.isnan(raw_amp)] = 0
|
889 |
+
raw_amp[np.isinf(raw_amp)] = 0
|
890 |
+
return (sample, raw_amp, base_name, t0, station_id)
|
891 |
+
else:
|
892 |
+
return (sample, base_name, t0, station_id)
|
893 |
+
|
894 |
+
def dataset(self, num_parallel_calls=1, shuffle=False):
|
895 |
+
if self.amplitude:
|
896 |
+
dataset = dataset_map(
|
897 |
+
self,
|
898 |
+
output_types=(self.dtype, self.dtype, "string", "string", "string"),
|
899 |
+
output_shapes=([None, *self.X_shape[1:]], [None, *self.X_shape[1:]], None, None, None),
|
900 |
+
num_parallel_calls=num_parallel_calls,
|
901 |
+
)
|
902 |
+
else:
|
903 |
+
dataset = dataset_map(
|
904 |
+
self,
|
905 |
+
output_types=(self.dtype, "string", "string", "string"),
|
906 |
+
output_shapes=([None, *self.X_shape[1:]], None, None, None),
|
907 |
+
num_parallel_calls=num_parallel_calls,
|
908 |
+
)
|
909 |
+
dataset = dataset.prefetch(1)
|
910 |
+
# dataset = dataset.prefetch(len(self.stations)*2)
|
911 |
+
return dataset
|
912 |
+
|
913 |
+
|
914 |
+
###### test ########
|
915 |
+
|
916 |
+
|
917 |
+
def test_DataReader():
|
918 |
+
import os
|
919 |
+
import timeit
|
920 |
+
|
921 |
+
import matplotlib.pyplot as plt
|
922 |
+
|
923 |
+
if not os.path.exists("test_figures"):
|
924 |
+
os.mkdir("test_figures")
|
925 |
+
|
926 |
+
def plot_sample(sample, fname, label=None):
|
927 |
+
plt.clf()
|
928 |
+
plt.subplot(211)
|
929 |
+
plt.plot(sample[:, 0, -1])
|
930 |
+
if label is not None:
|
931 |
+
plt.subplot(212)
|
932 |
+
plt.plot(label[:, 0, 0])
|
933 |
+
plt.plot(label[:, 0, 1])
|
934 |
+
plt.plot(label[:, 0, 2])
|
935 |
+
plt.savefig(f"test_figures/{fname.decode()}.png")
|
936 |
+
|
937 |
+
def read(data_reader, batch=1):
|
938 |
+
start_time = timeit.default_timer()
|
939 |
+
if batch is None:
|
940 |
+
dataset = data_reader.dataset(shuffle=False)
|
941 |
+
else:
|
942 |
+
dataset = data_reader.dataset(1, shuffle=False)
|
943 |
+
sess = tf.compat.v1.Session()
|
944 |
+
|
945 |
+
print(len(data_reader))
|
946 |
+
print("-------", tf.data.Dataset.cardinality(dataset))
|
947 |
+
num = 0
|
948 |
+
x = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
|
949 |
+
while True:
|
950 |
+
num += 1
|
951 |
+
# print(num)
|
952 |
+
try:
|
953 |
+
out = sess.run(x)
|
954 |
+
if len(out) == 2:
|
955 |
+
sample, fname = out[0], out[1]
|
956 |
+
for i in range(len(sample)):
|
957 |
+
plot_sample(sample[i], fname[i])
|
958 |
+
else:
|
959 |
+
sample, label, fname = out[0], out[1], out[2]
|
960 |
+
for i in range(len(sample)):
|
961 |
+
plot_sample(sample[i], fname[i], label[i])
|
962 |
+
except tf.errors.OutOfRangeError:
|
963 |
+
break
|
964 |
+
print("End of dataset")
|
965 |
+
print("Tensorflow Dataset:\nexecution time = ", timeit.default_timer() - start_time)
|
966 |
+
|
967 |
+
data_reader = DataReader_train(data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
|
968 |
+
|
969 |
+
read(data_reader)
|
970 |
+
|
971 |
+
data_reader = DataReader_train(format="hdf5", hdf5="test_data/data.h5", group="data")
|
972 |
+
|
973 |
+
read(data_reader)
|
974 |
+
|
975 |
+
data_reader = DataReader_test(data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
|
976 |
+
|
977 |
+
read(data_reader)
|
978 |
+
|
979 |
+
data_reader = DataReader_test(format="hdf5", hdf5="test_data/data.h5", group="data")
|
980 |
+
|
981 |
+
read(data_reader)
|
982 |
+
|
983 |
+
data_reader = DataReader_pred(format="numpy", data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
|
984 |
+
|
985 |
+
read(data_reader)
|
986 |
+
|
987 |
+
data_reader = DataReader_pred(
|
988 |
+
format="mseed", data_list="test_data/mseed_station.csv", data_dir="test_data/waveforms/"
|
989 |
+
)
|
990 |
+
|
991 |
+
read(data_reader)
|
992 |
+
|
993 |
+
data_reader = DataReader_pred(
|
994 |
+
format="mseed", amplitude=True, data_list="test_data/mseed_station.csv", data_dir="test_data/waveforms/"
|
995 |
+
)
|
996 |
+
|
997 |
+
read(data_reader)
|
998 |
+
|
999 |
+
data_reader = DataReader_mseed_array(
|
1000 |
+
data_list="test_data/mseed.csv",
|
1001 |
+
data_dir="test_data/waveforms/",
|
1002 |
+
stations="test_data/stations.csv",
|
1003 |
+
remove_resp=False,
|
1004 |
+
)
|
1005 |
+
|
1006 |
+
read(data_reader, batch=None)
|
1007 |
+
|
1008 |
+
|
1009 |
+
if __name__ == "__main__":
|
1010 |
+
test_DataReader()
|
phasenet/detect_peaks.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Detect peaks in data based on their amplitude and other features."""
|
2 |
+
|
3 |
+
from __future__ import division, print_function
|
4 |
+
import warnings
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
__author__ = "Marcos Duarte, https://github.com/demotu"
|
8 |
+
__version__ = "1.0.6"
|
9 |
+
__license__ = "MIT"
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
def detect_peaks(x, mph=None, mpd=1, threshold=0, edge='rising',
|
14 |
+
kpsh=False, valley=False, show=False, ax=None, title=True):
|
15 |
+
|
16 |
+
"""Detect peaks in data based on their amplitude and other features.
|
17 |
+
|
18 |
+
Parameters
|
19 |
+
----------
|
20 |
+
x : 1D array_like
|
21 |
+
data.
|
22 |
+
mph : {None, number}, optional (default = None)
|
23 |
+
detect peaks that are greater than minimum peak height (if parameter
|
24 |
+
`valley` is False) or peaks that are smaller than maximum peak height
|
25 |
+
(if parameter `valley` is True).
|
26 |
+
mpd : positive integer, optional (default = 1)
|
27 |
+
detect peaks that are at least separated by minimum peak distance (in
|
28 |
+
number of data).
|
29 |
+
threshold : positive number, optional (default = 0)
|
30 |
+
detect peaks (valleys) that are greater (smaller) than `threshold`
|
31 |
+
in relation to their immediate neighbors.
|
32 |
+
edge : {None, 'rising', 'falling', 'both'}, optional (default = 'rising')
|
33 |
+
for a flat peak, keep only the rising edge ('rising'), only the
|
34 |
+
falling edge ('falling'), both edges ('both'), or don't detect a
|
35 |
+
flat peak (None).
|
36 |
+
kpsh : bool, optional (default = False)
|
37 |
+
keep peaks with same height even if they are closer than `mpd`.
|
38 |
+
valley : bool, optional (default = False)
|
39 |
+
if True (1), detect valleys (local minima) instead of peaks.
|
40 |
+
show : bool, optional (default = False)
|
41 |
+
if True (1), plot data in matplotlib figure.
|
42 |
+
ax : a matplotlib.axes.Axes instance, optional (default = None).
|
43 |
+
title : bool or string, optional (default = True)
|
44 |
+
if True, show standard title. If False or empty string, doesn't show
|
45 |
+
any title. If string, shows string as title.
|
46 |
+
|
47 |
+
Returns
|
48 |
+
-------
|
49 |
+
ind : 1D array_like
|
50 |
+
indeces of the peaks in `x`.
|
51 |
+
|
52 |
+
Notes
|
53 |
+
-----
|
54 |
+
The detection of valleys instead of peaks is performed internally by simply
|
55 |
+
negating the data: `ind_valleys = detect_peaks(-x)`
|
56 |
+
|
57 |
+
The function can handle NaN's
|
58 |
+
|
59 |
+
See this IPython Notebook [1]_.
|
60 |
+
|
61 |
+
References
|
62 |
+
----------
|
63 |
+
.. [1] http://nbviewer.ipython.org/github/demotu/BMC/blob/master/notebooks/DetectPeaks.ipynb
|
64 |
+
|
65 |
+
Examples
|
66 |
+
--------
|
67 |
+
>>> from detect_peaks import detect_peaks
|
68 |
+
>>> x = np.random.randn(100)
|
69 |
+
>>> x[60:81] = np.nan
|
70 |
+
>>> # detect all peaks and plot data
|
71 |
+
>>> ind = detect_peaks(x, show=True)
|
72 |
+
>>> print(ind)
|
73 |
+
|
74 |
+
>>> x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5
|
75 |
+
>>> # set minimum peak height = 0 and minimum peak distance = 20
|
76 |
+
>>> detect_peaks(x, mph=0, mpd=20, show=True)
|
77 |
+
|
78 |
+
>>> x = [0, 1, 0, 2, 0, 3, 0, 2, 0, 1, 0]
|
79 |
+
>>> # set minimum peak distance = 2
|
80 |
+
>>> detect_peaks(x, mpd=2, show=True)
|
81 |
+
|
82 |
+
>>> x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5
|
83 |
+
>>> # detection of valleys instead of peaks
|
84 |
+
>>> detect_peaks(x, mph=-1.2, mpd=20, valley=True, show=True)
|
85 |
+
|
86 |
+
>>> x = [0, 1, 1, 0, 1, 1, 0]
|
87 |
+
>>> # detect both edges
|
88 |
+
>>> detect_peaks(x, edge='both', show=True)
|
89 |
+
|
90 |
+
>>> x = [-2, 1, -2, 2, 1, 1, 3, 0]
|
91 |
+
>>> # set threshold = 2
|
92 |
+
>>> detect_peaks(x, threshold = 2, show=True)
|
93 |
+
|
94 |
+
>>> x = [-2, 1, -2, 2, 1, 1, 3, 0]
|
95 |
+
>>> fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(10, 4))
|
96 |
+
>>> detect_peaks(x, show=True, ax=axs[0], threshold=0.5, title=False)
|
97 |
+
>>> detect_peaks(x, show=True, ax=axs[1], threshold=1.5, title=False)
|
98 |
+
|
99 |
+
Version history
|
100 |
+
---------------
|
101 |
+
'1.0.6':
|
102 |
+
Fix issue of when specifying ax object only the first plot was shown
|
103 |
+
Add parameter to choose if a title is shown and input a title
|
104 |
+
'1.0.5':
|
105 |
+
The sign of `mph` is inverted if parameter `valley` is True
|
106 |
+
|
107 |
+
"""
|
108 |
+
|
109 |
+
x = np.atleast_1d(x).astype('float64')
|
110 |
+
if x.size < 3:
|
111 |
+
return np.array([], dtype=int)
|
112 |
+
if valley:
|
113 |
+
x = -x
|
114 |
+
if mph is not None:
|
115 |
+
mph = -mph
|
116 |
+
# find indices of all peaks
|
117 |
+
dx = x[1:] - x[:-1]
|
118 |
+
# handle NaN's
|
119 |
+
indnan = np.where(np.isnan(x))[0]
|
120 |
+
if indnan.size:
|
121 |
+
x[indnan] = np.inf
|
122 |
+
dx[np.where(np.isnan(dx))[0]] = np.inf
|
123 |
+
ine, ire, ife = np.array([[], [], []], dtype=int)
|
124 |
+
if not edge:
|
125 |
+
ine = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) > 0))[0]
|
126 |
+
else:
|
127 |
+
if edge.lower() in ['rising', 'both']:
|
128 |
+
ire = np.where((np.hstack((dx, 0)) <= 0) & (np.hstack((0, dx)) > 0))[0]
|
129 |
+
if edge.lower() in ['falling', 'both']:
|
130 |
+
ife = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) >= 0))[0]
|
131 |
+
ind = np.unique(np.hstack((ine, ire, ife)))
|
132 |
+
# handle NaN's
|
133 |
+
if ind.size and indnan.size:
|
134 |
+
# NaN's and values close to NaN's cannot be peaks
|
135 |
+
ind = ind[np.in1d(ind, np.unique(np.hstack((indnan, indnan-1, indnan+1))), invert=True)]
|
136 |
+
# first and last values of x cannot be peaks
|
137 |
+
if ind.size and ind[0] == 0:
|
138 |
+
ind = ind[1:]
|
139 |
+
if ind.size and ind[-1] == x.size-1:
|
140 |
+
ind = ind[:-1]
|
141 |
+
# remove peaks < minimum peak height
|
142 |
+
if ind.size and mph is not None:
|
143 |
+
ind = ind[x[ind] >= mph]
|
144 |
+
# remove peaks - neighbors < threshold
|
145 |
+
if ind.size and threshold > 0:
|
146 |
+
dx = np.min(np.vstack([x[ind]-x[ind-1], x[ind]-x[ind+1]]), axis=0)
|
147 |
+
ind = np.delete(ind, np.where(dx < threshold)[0])
|
148 |
+
# detect small peaks closer than minimum peak distance
|
149 |
+
if ind.size and mpd > 1:
|
150 |
+
ind = ind[np.argsort(x[ind])][::-1] # sort ind by peak height
|
151 |
+
idel = np.zeros(ind.size, dtype=bool)
|
152 |
+
for i in range(ind.size):
|
153 |
+
if not idel[i]:
|
154 |
+
# keep peaks with the same height if kpsh is True
|
155 |
+
idel = idel | (ind >= ind[i] - mpd) & (ind <= ind[i] + mpd) \
|
156 |
+
& (x[ind[i]] > x[ind] if kpsh else True)
|
157 |
+
idel[i] = 0 # Keep current peak
|
158 |
+
# remove the small peaks and sort back the indices by their occurrence
|
159 |
+
ind = np.sort(ind[~idel])
|
160 |
+
|
161 |
+
if show:
|
162 |
+
if indnan.size:
|
163 |
+
x[indnan] = np.nan
|
164 |
+
if valley:
|
165 |
+
x = -x
|
166 |
+
if mph is not None:
|
167 |
+
mph = -mph
|
168 |
+
_plot(x, mph, mpd, threshold, edge, valley, ax, ind, title)
|
169 |
+
|
170 |
+
return ind, x[ind]
|
171 |
+
|
172 |
+
|
173 |
+
def _plot(x, mph, mpd, threshold, edge, valley, ax, ind, title):
|
174 |
+
"""Plot results of the detect_peaks function, see its help."""
|
175 |
+
try:
|
176 |
+
import matplotlib.pyplot as plt
|
177 |
+
except ImportError:
|
178 |
+
print('matplotlib is not available.')
|
179 |
+
else:
|
180 |
+
if ax is None:
|
181 |
+
_, ax = plt.subplots(1, 1, figsize=(8, 4))
|
182 |
+
no_ax = True
|
183 |
+
else:
|
184 |
+
no_ax = False
|
185 |
+
|
186 |
+
ax.plot(x, 'b', lw=1)
|
187 |
+
if ind.size:
|
188 |
+
label = 'valley' if valley else 'peak'
|
189 |
+
label = label + 's' if ind.size > 1 else label
|
190 |
+
ax.plot(ind, x[ind], '+', mfc=None, mec='r', mew=2, ms=8,
|
191 |
+
label='%d %s' % (ind.size, label))
|
192 |
+
ax.legend(loc='best', framealpha=.5, numpoints=1)
|
193 |
+
ax.set_xlim(-.02*x.size, x.size*1.02-1)
|
194 |
+
ymin, ymax = x[np.isfinite(x)].min(), x[np.isfinite(x)].max()
|
195 |
+
yrange = ymax - ymin if ymax > ymin else 1
|
196 |
+
ax.set_ylim(ymin - 0.1*yrange, ymax + 0.1*yrange)
|
197 |
+
ax.set_xlabel('Data #', fontsize=14)
|
198 |
+
ax.set_ylabel('Amplitude', fontsize=14)
|
199 |
+
if title:
|
200 |
+
if not isinstance(title, str):
|
201 |
+
mode = 'Valley detection' if valley else 'Peak detection'
|
202 |
+
title = "%s (mph=%s, mpd=%d, threshold=%s, edge='%s')"% \
|
203 |
+
(mode, str(mph), mpd, str(threshold), edge)
|
204 |
+
ax.set_title(title)
|
205 |
+
# plt.grid()
|
206 |
+
if no_ax:
|
207 |
+
plt.show()
|
phasenet/model.py
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
tf.compat.v1.disable_eager_execution()
|
3 |
+
import numpy as np
|
4 |
+
import logging
|
5 |
+
import warnings
|
6 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
7 |
+
|
8 |
+
class ModelConfig:
|
9 |
+
|
10 |
+
batch_size = 20
|
11 |
+
depths = 5
|
12 |
+
filters_root = 8
|
13 |
+
kernel_size = [7, 1]
|
14 |
+
pool_size = [4, 1]
|
15 |
+
dilation_rate = [1, 1]
|
16 |
+
class_weights = [1.0, 1.0, 1.0]
|
17 |
+
loss_type = "cross_entropy"
|
18 |
+
weight_decay = 0.0
|
19 |
+
optimizer = "adam"
|
20 |
+
momentum = 0.9
|
21 |
+
learning_rate = 0.01
|
22 |
+
decay_step = 1e9
|
23 |
+
decay_rate = 0.9
|
24 |
+
drop_rate = 0.0
|
25 |
+
summary = True
|
26 |
+
|
27 |
+
X_shape = [3000, 1, 3]
|
28 |
+
n_channel = X_shape[-1]
|
29 |
+
Y_shape = [3000, 1, 3]
|
30 |
+
n_class = Y_shape[-1]
|
31 |
+
|
32 |
+
def __init__(self, **kwargs):
|
33 |
+
for k,v in kwargs.items():
|
34 |
+
setattr(self, k, v)
|
35 |
+
|
36 |
+
def update_args(self, args):
|
37 |
+
for k,v in vars(args).items():
|
38 |
+
setattr(self, k, v)
|
39 |
+
|
40 |
+
|
41 |
+
def crop_and_concat(net1, net2):
|
42 |
+
"""
|
43 |
+
the size(net1) <= size(net2)
|
44 |
+
"""
|
45 |
+
# net1_shape = net1.get_shape().as_list()
|
46 |
+
# net2_shape = net2.get_shape().as_list()
|
47 |
+
# # print(net1_shape)
|
48 |
+
# # print(net2_shape)
|
49 |
+
# # if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
|
50 |
+
# offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
|
51 |
+
# size = [-1, net1_shape[1], net1_shape[2], -1]
|
52 |
+
# net2_resize = tf.slice(net2, offsets, size)
|
53 |
+
# return tf.concat([net1, net2_resize], 3)
|
54 |
+
|
55 |
+
## dynamic shape
|
56 |
+
chn1 = net1.get_shape().as_list()[-1]
|
57 |
+
chn2 = net2.get_shape().as_list()[-1]
|
58 |
+
net1_shape = tf.shape(net1)
|
59 |
+
net2_shape = tf.shape(net2)
|
60 |
+
# print(net1_shape)
|
61 |
+
# print(net2_shape)
|
62 |
+
# if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
|
63 |
+
offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
|
64 |
+
size = [-1, net1_shape[1], net1_shape[2], -1]
|
65 |
+
net2_resize = tf.slice(net2, offsets, size)
|
66 |
+
|
67 |
+
out = tf.concat([net1, net2_resize], 3)
|
68 |
+
out.set_shape([None, None, None, chn1+chn2])
|
69 |
+
|
70 |
+
return out
|
71 |
+
|
72 |
+
# else:
|
73 |
+
# offsets = [0, (net1_shape[1] - net2_shape[1]) // 2, (net1_shape[2] - net2_shape[2]) // 2, 0]
|
74 |
+
# size = [-1, net2_shape[1], net2_shape[2], -1]
|
75 |
+
# net1_resize = tf.slice(net1, offsets, size)
|
76 |
+
# return tf.concat([net1_resize, net2], 3)
|
77 |
+
|
78 |
+
|
79 |
+
def crop_only(net1, net2):
|
80 |
+
"""
|
81 |
+
the size(net1) <= size(net2)
|
82 |
+
"""
|
83 |
+
net1_shape = net1.get_shape().as_list()
|
84 |
+
net2_shape = net2.get_shape().as_list()
|
85 |
+
# print(net1_shape)
|
86 |
+
# print(net2_shape)
|
87 |
+
# if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
|
88 |
+
offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
|
89 |
+
size = [-1, net1_shape[1], net1_shape[2], -1]
|
90 |
+
net2_resize = tf.slice(net2, offsets, size)
|
91 |
+
#return tf.concat([net1, net2_resize], 3)
|
92 |
+
return net2_resize
|
93 |
+
|
94 |
+
class UNet:
|
95 |
+
def __init__(self, config=ModelConfig(), input_batch=None, mode='train'):
|
96 |
+
self.depths = config.depths
|
97 |
+
self.filters_root = config.filters_root
|
98 |
+
self.kernel_size = config.kernel_size
|
99 |
+
self.dilation_rate = config.dilation_rate
|
100 |
+
self.pool_size = config.pool_size
|
101 |
+
self.X_shape = config.X_shape
|
102 |
+
self.Y_shape = config.Y_shape
|
103 |
+
self.n_channel = config.n_channel
|
104 |
+
self.n_class = config.n_class
|
105 |
+
self.class_weights = config.class_weights
|
106 |
+
self.batch_size = config.batch_size
|
107 |
+
self.loss_type = config.loss_type
|
108 |
+
self.weight_decay = config.weight_decay
|
109 |
+
self.optimizer = config.optimizer
|
110 |
+
self.learning_rate = config.learning_rate
|
111 |
+
self.decay_step = config.decay_step
|
112 |
+
self.decay_rate = config.decay_rate
|
113 |
+
self.momentum = config.momentum
|
114 |
+
self.global_step = tf.compat.v1.get_variable(name="global_step", initializer=0, dtype=tf.int32)
|
115 |
+
self.summary_train = []
|
116 |
+
self.summary_valid = []
|
117 |
+
|
118 |
+
self.build(input_batch, mode=mode)
|
119 |
+
|
120 |
+
def add_placeholders(self, input_batch=None, mode="train"):
|
121 |
+
if input_batch is None:
|
122 |
+
# self.X = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, self.X_shape[-3], self.X_shape[-2], self.X_shape[-1]], name='X')
|
123 |
+
# self.Y = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, self.Y_shape[-3], self.Y_shape[-2], self.n_class], name='y')
|
124 |
+
self.X = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, None, None, self.X_shape[-1]], name='X')
|
125 |
+
self.Y = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, None, None, self.n_class], name='y')
|
126 |
+
else:
|
127 |
+
self.X = input_batch[0]
|
128 |
+
if mode in ["train", "valid", "test"]:
|
129 |
+
self.Y = input_batch[1]
|
130 |
+
self.input_batch = input_batch
|
131 |
+
|
132 |
+
self.is_training = tf.compat.v1.placeholder(dtype=tf.bool, name="is_training")
|
133 |
+
# self.keep_prob = tf.compat.v1.placeholder(dtype=tf.float32, name="keep_prob")
|
134 |
+
self.drop_rate = tf.compat.v1.placeholder(dtype=tf.float32, name="drop_rate")
|
135 |
+
|
136 |
+
def add_prediction_op(self):
|
137 |
+
logging.info("Model: depths {depths}, filters {filters}, "
|
138 |
+
"filter size {kernel_size[0]}x{kernel_size[1]}, "
|
139 |
+
"pool size: {pool_size[0]}x{pool_size[1]}, "
|
140 |
+
"dilation rate: {dilation_rate[0]}x{dilation_rate[1]}".format(
|
141 |
+
depths=self.depths,
|
142 |
+
filters=self.filters_root,
|
143 |
+
kernel_size=self.kernel_size,
|
144 |
+
dilation_rate=self.dilation_rate,
|
145 |
+
pool_size=self.pool_size))
|
146 |
+
|
147 |
+
if self.weight_decay > 0:
|
148 |
+
weight_decay = tf.constant(self.weight_decay, dtype=tf.float32, name="weight_constant")
|
149 |
+
self.regularizer = tf.keras.regularizers.l2(l=0.5 * (weight_decay))
|
150 |
+
else:
|
151 |
+
self.regularizer = None
|
152 |
+
|
153 |
+
self.initializer = tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")
|
154 |
+
|
155 |
+
# down sample layers
|
156 |
+
convs = [None] * self.depths # store output of each depth
|
157 |
+
|
158 |
+
with tf.compat.v1.variable_scope("Input"):
|
159 |
+
net = self.X
|
160 |
+
net = tf.compat.v1.layers.conv2d(net,
|
161 |
+
filters=self.filters_root,
|
162 |
+
kernel_size=self.kernel_size,
|
163 |
+
activation=None,
|
164 |
+
padding='same',
|
165 |
+
dilation_rate=self.dilation_rate,
|
166 |
+
kernel_initializer=self.initializer,
|
167 |
+
kernel_regularizer=self.regularizer,
|
168 |
+
name="input_conv")
|
169 |
+
net = tf.compat.v1.layers.batch_normalization(net,
|
170 |
+
training=self.is_training,
|
171 |
+
name="input_bn")
|
172 |
+
net = tf.nn.relu(net,
|
173 |
+
name="input_relu")
|
174 |
+
# net = tf.nn.dropout(net, self.keep_prob)
|
175 |
+
net = tf.compat.v1.layers.dropout(net,
|
176 |
+
rate=self.drop_rate,
|
177 |
+
training=self.is_training,
|
178 |
+
name="input_dropout")
|
179 |
+
|
180 |
+
|
181 |
+
for depth in range(0, self.depths):
|
182 |
+
with tf.compat.v1.variable_scope("DownConv_%d" % depth):
|
183 |
+
filters = int(2**(depth) * self.filters_root)
|
184 |
+
|
185 |
+
net = tf.compat.v1.layers.conv2d(net,
|
186 |
+
filters=filters,
|
187 |
+
kernel_size=self.kernel_size,
|
188 |
+
activation=None,
|
189 |
+
use_bias=False,
|
190 |
+
padding='same',
|
191 |
+
dilation_rate=self.dilation_rate,
|
192 |
+
kernel_initializer=self.initializer,
|
193 |
+
kernel_regularizer=self.regularizer,
|
194 |
+
name="down_conv1_{}".format(depth + 1))
|
195 |
+
net = tf.compat.v1.layers.batch_normalization(net,
|
196 |
+
training=self.is_training,
|
197 |
+
name="down_bn1_{}".format(depth + 1))
|
198 |
+
net = tf.nn.relu(net,
|
199 |
+
name="down_relu1_{}".format(depth+1))
|
200 |
+
net = tf.compat.v1.layers.dropout(net,
|
201 |
+
rate=self.drop_rate,
|
202 |
+
training=self.is_training,
|
203 |
+
name="down_dropout1_{}".format(depth + 1))
|
204 |
+
|
205 |
+
convs[depth] = net
|
206 |
+
|
207 |
+
if depth < self.depths - 1:
|
208 |
+
net = tf.compat.v1.layers.conv2d(net,
|
209 |
+
filters=filters,
|
210 |
+
kernel_size=self.kernel_size,
|
211 |
+
strides=self.pool_size,
|
212 |
+
activation=None,
|
213 |
+
use_bias=False,
|
214 |
+
padding='same',
|
215 |
+
dilation_rate=self.dilation_rate,
|
216 |
+
kernel_initializer=self.initializer,
|
217 |
+
kernel_regularizer=self.regularizer,
|
218 |
+
name="down_conv3_{}".format(depth + 1))
|
219 |
+
net = tf.compat.v1.layers.batch_normalization(net,
|
220 |
+
training=self.is_training,
|
221 |
+
name="down_bn3_{}".format(depth + 1))
|
222 |
+
net = tf.nn.relu(net,
|
223 |
+
name="down_relu3_{}".format(depth+1))
|
224 |
+
net = tf.compat.v1.layers.dropout(net,
|
225 |
+
rate=self.drop_rate,
|
226 |
+
training=self.is_training,
|
227 |
+
name="down_dropout3_{}".format(depth + 1))
|
228 |
+
|
229 |
+
|
230 |
+
# up layers
|
231 |
+
for depth in range(self.depths - 2, -1, -1):
|
232 |
+
with tf.compat.v1.variable_scope("UpConv_%d" % depth):
|
233 |
+
filters = int(2**(depth) * self.filters_root)
|
234 |
+
net = tf.compat.v1.layers.conv2d_transpose(net,
|
235 |
+
filters=filters,
|
236 |
+
kernel_size=self.kernel_size,
|
237 |
+
strides=self.pool_size,
|
238 |
+
activation=None,
|
239 |
+
use_bias=False,
|
240 |
+
padding="same",
|
241 |
+
kernel_initializer=self.initializer,
|
242 |
+
kernel_regularizer=self.regularizer,
|
243 |
+
name="up_conv0_{}".format(depth+1))
|
244 |
+
net = tf.compat.v1.layers.batch_normalization(net,
|
245 |
+
training=self.is_training,
|
246 |
+
name="up_bn0_{}".format(depth + 1))
|
247 |
+
net = tf.nn.relu(net,
|
248 |
+
name="up_relu0_{}".format(depth+1))
|
249 |
+
net = tf.compat.v1.layers.dropout(net,
|
250 |
+
rate=self.drop_rate,
|
251 |
+
training=self.is_training,
|
252 |
+
name="up_dropout0_{}".format(depth + 1))
|
253 |
+
|
254 |
+
|
255 |
+
#skip connection
|
256 |
+
net = crop_and_concat(convs[depth], net)
|
257 |
+
#net = crop_only(convs[depth], net)
|
258 |
+
|
259 |
+
net = tf.compat.v1.layers.conv2d(net,
|
260 |
+
filters=filters,
|
261 |
+
kernel_size=self.kernel_size,
|
262 |
+
activation=None,
|
263 |
+
use_bias=False,
|
264 |
+
padding='same',
|
265 |
+
dilation_rate=self.dilation_rate,
|
266 |
+
kernel_initializer=self.initializer,
|
267 |
+
kernel_regularizer=self.regularizer,
|
268 |
+
name="up_conv1_{}".format(depth + 1))
|
269 |
+
net = tf.compat.v1.layers.batch_normalization(net,
|
270 |
+
training=self.is_training,
|
271 |
+
name="up_bn1_{}".format(depth + 1))
|
272 |
+
net = tf.nn.relu(net,
|
273 |
+
name="up_relu1_{}".format(depth + 1))
|
274 |
+
net = tf.compat.v1.layers.dropout(net,
|
275 |
+
rate=self.drop_rate,
|
276 |
+
training=self.is_training,
|
277 |
+
name="up_dropout1_{}".format(depth + 1))
|
278 |
+
|
279 |
+
|
280 |
+
# Output Map
|
281 |
+
with tf.compat.v1.variable_scope("Output"):
|
282 |
+
net = tf.compat.v1.layers.conv2d(net,
|
283 |
+
filters=self.n_class,
|
284 |
+
kernel_size=(1,1),
|
285 |
+
activation=None,
|
286 |
+
padding='same',
|
287 |
+
#dilation_rate=self.dilation_rate,
|
288 |
+
kernel_initializer=self.initializer,
|
289 |
+
kernel_regularizer=self.regularizer,
|
290 |
+
name="output_conv")
|
291 |
+
# net = tf.nn.relu(net,
|
292 |
+
# name="output_relu")
|
293 |
+
# net = tf.compat.v1.layers.dropout(net,
|
294 |
+
# rate=self.drop_rate,
|
295 |
+
# training=self.is_training,
|
296 |
+
# name="output_dropout")
|
297 |
+
# net = tf.compat.v1.layers.batch_normalization(net,
|
298 |
+
# training=self.is_training,
|
299 |
+
# name="output_bn")
|
300 |
+
output = net
|
301 |
+
|
302 |
+
with tf.compat.v1.variable_scope("representation"):
|
303 |
+
self.representation = convs[-1]
|
304 |
+
|
305 |
+
with tf.compat.v1.variable_scope("logits"):
|
306 |
+
self.logits = output
|
307 |
+
tmp = tf.compat.v1.summary.histogram("logits", self.logits)
|
308 |
+
self.summary_train.append(tmp)
|
309 |
+
|
310 |
+
with tf.compat.v1.variable_scope("preds"):
|
311 |
+
self.preds = tf.nn.softmax(output)
|
312 |
+
tmp = tf.compat.v1.summary.histogram("preds", self.preds)
|
313 |
+
self.summary_train.append(tmp)
|
314 |
+
|
315 |
+
def add_loss_op(self):
|
316 |
+
if self.loss_type == "cross_entropy":
|
317 |
+
with tf.compat.v1.variable_scope("cross_entropy"):
|
318 |
+
flat_logits = tf.reshape(self.logits, [-1, self.n_class], name="logits")
|
319 |
+
flat_labels = tf.reshape(self.Y, [-1, self.n_class], name="labels")
|
320 |
+
if (np.array(self.class_weights) != 1).any():
|
321 |
+
class_weights = tf.constant(np.array(self.class_weights, dtype=np.float32), name="class_weights")
|
322 |
+
weight_map = tf.multiply(flat_labels, class_weights)
|
323 |
+
weight_map = tf.reduce_sum(input_tensor=weight_map, axis=1)
|
324 |
+
loss_map = tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits,
|
325 |
+
labels=flat_labels)
|
326 |
+
|
327 |
+
weighted_loss = tf.multiply(loss_map, weight_map)
|
328 |
+
loss = tf.reduce_mean(input_tensor=weighted_loss)
|
329 |
+
else:
|
330 |
+
loss = tf.reduce_mean(input_tensor=tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits,
|
331 |
+
labels=flat_labels))
|
332 |
+
|
333 |
+
elif self.loss_type == "IOU":
|
334 |
+
with tf.compat.v1.variable_scope("IOU"):
|
335 |
+
eps = 1e-7
|
336 |
+
loss = 0
|
337 |
+
for i in range(1, self.n_class):
|
338 |
+
intersection = eps + tf.reduce_sum(input_tensor=self.preds[:,:,:,i] * self.Y[:,:,:,i], axis=[1,2])
|
339 |
+
union = eps + tf.reduce_sum(input_tensor=self.preds[:,:,:,i], axis=[1,2]) + tf.reduce_sum(input_tensor=self.Y[:,:,:,i], axis=[1,2])
|
340 |
+
loss += 1 - tf.reduce_mean(input_tensor=intersection / union)
|
341 |
+
elif self.loss_type == "mean_squared":
|
342 |
+
with tf.compat.v1.variable_scope("mean_squared"):
|
343 |
+
flat_logits = tf.reshape(self.logits, [-1, self.n_class], name="logits")
|
344 |
+
flat_labels = tf.reshape(self.Y, [-1, self.n_class], name="labels")
|
345 |
+
with tf.compat.v1.variable_scope("mean_squared"):
|
346 |
+
loss = tf.compat.v1.losses.mean_squared_error(labels=flat_labels, predictions=flat_logits)
|
347 |
+
else:
|
348 |
+
raise ValueError("Unknown loss function: " % self.loss_type)
|
349 |
+
|
350 |
+
tmp = tf.compat.v1.summary.scalar("train_loss", loss)
|
351 |
+
self.summary_train.append(tmp)
|
352 |
+
tmp = tf.compat.v1.summary.scalar("valid_loss", loss)
|
353 |
+
self.summary_valid.append(tmp)
|
354 |
+
|
355 |
+
if self.weight_decay > 0:
|
356 |
+
with tf.compat.v1.name_scope('weight_loss'):
|
357 |
+
tmp = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
|
358 |
+
weight_loss = tf.add_n(tmp, name="weight_loss")
|
359 |
+
self.loss = loss + weight_loss
|
360 |
+
else:
|
361 |
+
self.loss = loss
|
362 |
+
|
363 |
+
def add_training_op(self):
|
364 |
+
if self.optimizer == "momentum":
|
365 |
+
self.learning_rate_node = tf.compat.v1.train.exponential_decay(learning_rate=self.learning_rate,
|
366 |
+
global_step=self.global_step,
|
367 |
+
decay_steps=self.decay_step,
|
368 |
+
decay_rate=self.decay_rate,
|
369 |
+
staircase=True)
|
370 |
+
optimizer = tf.compat.v1.train.MomentumOptimizer(learning_rate=self.learning_rate_node,
|
371 |
+
momentum=self.momentum)
|
372 |
+
elif self.optimizer == "adam":
|
373 |
+
self.learning_rate_node = tf.compat.v1.train.exponential_decay(learning_rate=self.learning_rate,
|
374 |
+
global_step=self.global_step,
|
375 |
+
decay_steps=self.decay_step,
|
376 |
+
decay_rate=self.decay_rate,
|
377 |
+
staircase=True)
|
378 |
+
|
379 |
+
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.learning_rate_node)
|
380 |
+
update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
|
381 |
+
with tf.control_dependencies(update_ops):
|
382 |
+
self.train_op = optimizer.minimize(self.loss, global_step=self.global_step)
|
383 |
+
tmp = tf.compat.v1.summary.scalar("learning_rate", self.learning_rate_node)
|
384 |
+
self.summary_train.append(tmp)
|
385 |
+
|
386 |
+
def add_metrics_op(self):
|
387 |
+
with tf.compat.v1.variable_scope("metrics"):
|
388 |
+
|
389 |
+
Y= tf.argmax(input=self.Y, axis=-1)
|
390 |
+
confusion_matrix = tf.cast(tf.math.confusion_matrix(
|
391 |
+
labels=tf.reshape(Y, [-1]),
|
392 |
+
predictions=tf.reshape(self.preds, [-1]),
|
393 |
+
num_classes=self.n_class, name='confusion_matrix'),
|
394 |
+
dtype=tf.float32)
|
395 |
+
|
396 |
+
# with tf.variable_scope("P"):
|
397 |
+
c = tf.constant(1e-7, dtype=tf.float32)
|
398 |
+
precision_P = (confusion_matrix[1,1] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[:,1]) + c)
|
399 |
+
recall_P = (confusion_matrix[1,1] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[1,:]) + c)
|
400 |
+
f1_P = 2 * precision_P * recall_P / (precision_P + recall_P)
|
401 |
+
|
402 |
+
tmp1 = tf.compat.v1.summary.scalar("train_precision_p", precision_P)
|
403 |
+
tmp2 = tf.compat.v1.summary.scalar("train_recall_p", recall_P)
|
404 |
+
tmp3 = tf.compat.v1.summary.scalar("train_f1_p", f1_P)
|
405 |
+
self.summary_train.extend([tmp1, tmp2, tmp3])
|
406 |
+
|
407 |
+
tmp1 = tf.compat.v1.summary.scalar("valid_precision_p", precision_P)
|
408 |
+
tmp2 = tf.compat.v1.summary.scalar("valid_recall_p", recall_P)
|
409 |
+
tmp3 = tf.compat.v1.summary.scalar("valid_f1_p", f1_P)
|
410 |
+
self.summary_valid.extend([tmp1, tmp2, tmp3])
|
411 |
+
|
412 |
+
# with tf.variable_scope("S"):
|
413 |
+
precision_S = (confusion_matrix[2,2] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[:,2]) + c)
|
414 |
+
recall_S = (confusion_matrix[2,2] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[2,:]) + c)
|
415 |
+
f1_S = 2 * precision_S * recall_S / (precision_S + recall_S)
|
416 |
+
|
417 |
+
tmp1 = tf.compat.v1.summary.scalar("train_precision_s", precision_S)
|
418 |
+
tmp2 = tf.compat.v1.summary.scalar("train_recall_s", recall_S)
|
419 |
+
tmp3 = tf.compat.v1.summary.scalar("train_f1_s", f1_S)
|
420 |
+
self.summary_train.extend([tmp1, tmp2, tmp3])
|
421 |
+
|
422 |
+
tmp1 = tf.compat.v1.summary.scalar("valid_precision_s", precision_S)
|
423 |
+
tmp2 = tf.compat.v1.summary.scalar("valid_recall_s", recall_S)
|
424 |
+
tmp3 = tf.compat.v1.summary.scalar("valid_f1_s", f1_S)
|
425 |
+
self.summary_valid.extend([tmp1, tmp2, tmp3])
|
426 |
+
|
427 |
+
self.precision = [precision_P, precision_S]
|
428 |
+
self.recall = [recall_P, recall_S]
|
429 |
+
self.f1 = [f1_P, f1_S]
|
430 |
+
|
431 |
+
|
432 |
+
|
433 |
+
def train_on_batch(self, sess, inputs_batch, labels_batch, summary_writer, drop_rate=0.0):
|
434 |
+
feed = {self.X: inputs_batch,
|
435 |
+
self.Y: labels_batch,
|
436 |
+
self.drop_rate: drop_rate,
|
437 |
+
self.is_training: True}
|
438 |
+
|
439 |
+
_, step_summary, step, loss = sess.run([self.train_op,
|
440 |
+
self.summary_train,
|
441 |
+
self.global_step,
|
442 |
+
self.loss],
|
443 |
+
feed_dict=feed)
|
444 |
+
summary_writer.add_summary(step_summary, step)
|
445 |
+
return loss
|
446 |
+
|
447 |
+
def valid_on_batch(self, sess, inputs_batch, labels_batch, summary_writer):
|
448 |
+
feed = {self.X: inputs_batch,
|
449 |
+
self.Y: labels_batch,
|
450 |
+
self.drop_rate: 0,
|
451 |
+
self.is_training: False}
|
452 |
+
|
453 |
+
step_summary, step, loss, preds = sess.run([self.summary_valid,
|
454 |
+
self.global_step,
|
455 |
+
self.loss,
|
456 |
+
self.preds],
|
457 |
+
feed_dict=feed)
|
458 |
+
summary_writer.add_summary(step_summary, step)
|
459 |
+
return loss, preds
|
460 |
+
|
461 |
+
def test_on_batch(self, sess, summary_writer):
|
462 |
+
feed = {self.drop_rate: 0,
|
463 |
+
self.is_training: False}
|
464 |
+
step_summary, step, loss, preds, \
|
465 |
+
X_batch, Y_batch, fname_batch, \
|
466 |
+
itp_batch, its_batch = sess.run([self.summary_valid,
|
467 |
+
self.global_step,
|
468 |
+
self.loss,
|
469 |
+
self.preds,
|
470 |
+
self.X,
|
471 |
+
self.Y,
|
472 |
+
self.input_batch[2],
|
473 |
+
self.input_batch[3],
|
474 |
+
self.input_batch[4]],
|
475 |
+
feed_dict=feed)
|
476 |
+
summary_writer.add_summary(step_summary, step)
|
477 |
+
return loss, preds, X_batch, Y_batch, fname_batch, itp_batch, its_batch
|
478 |
+
|
479 |
+
|
480 |
+
def build(self, input_batch=None, mode='train'):
|
481 |
+
self.add_placeholders(input_batch, mode)
|
482 |
+
self.add_prediction_op()
|
483 |
+
if mode in ["train", "valid", "test"]:
|
484 |
+
self.add_loss_op()
|
485 |
+
self.add_training_op()
|
486 |
+
# self.add_metrics_op()
|
487 |
+
self.summary_train = tf.compat.v1.summary.merge(self.summary_train)
|
488 |
+
self.summary_valid = tf.compat.v1.summary.merge(self.summary_valid)
|
489 |
+
return 0
|
phasenet/postprocess.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from collections import namedtuple
|
5 |
+
from datetime import datetime, timedelta
|
6 |
+
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
from detect_peaks import detect_peaks
|
10 |
+
|
11 |
+
# def extract_picks(preds, fnames=None, station_ids=None, t0=None, config=None):
|
12 |
+
|
13 |
+
# if preds.shape[-1] == 4:
|
14 |
+
# record = namedtuple("phase", ["fname", "station_id", "t0", "p_idx", "p_prob", "s_idx", "s_prob", "ps_idx", "ps_prob"])
|
15 |
+
# else:
|
16 |
+
# record = namedtuple("phase", ["fname", "station_id", "t0", "p_idx", "p_prob", "s_idx", "s_prob"])
|
17 |
+
|
18 |
+
# picks = []
|
19 |
+
# for i, pred in enumerate(preds):
|
20 |
+
|
21 |
+
# if config is None:
|
22 |
+
# mph_p, mph_s, mpd = 0.3, 0.3, 50
|
23 |
+
# else:
|
24 |
+
# mph_p, mph_s, mpd = config.min_p_prob, config.min_s_prob, config.mpd
|
25 |
+
|
26 |
+
# if (fnames is None):
|
27 |
+
# fname = f"{i:04d}"
|
28 |
+
# else:
|
29 |
+
# if isinstance(fnames[i], str):
|
30 |
+
# fname = fnames[i]
|
31 |
+
# else:
|
32 |
+
# fname = fnames[i].decode()
|
33 |
+
|
34 |
+
# if (station_ids is None):
|
35 |
+
# station_id = f"{i:04d}"
|
36 |
+
# else:
|
37 |
+
# if isinstance(station_ids[i], str):
|
38 |
+
# station_id = station_ids[i]
|
39 |
+
# else:
|
40 |
+
# station_id = station_ids[i].decode()
|
41 |
+
|
42 |
+
# if (t0 is None):
|
43 |
+
# start_time = "1970-01-01T00:00:00.000"
|
44 |
+
# else:
|
45 |
+
# if isinstance(t0[i], str):
|
46 |
+
# start_time = t0[i]
|
47 |
+
# else:
|
48 |
+
# start_time = t0[i].decode()
|
49 |
+
|
50 |
+
# p_idx, p_prob, s_idx, s_prob = [], [], [], []
|
51 |
+
# for j in range(pred.shape[1]):
|
52 |
+
# p_idx_, p_prob_ = detect_peaks(pred[:,j,1], mph=mph_p, mpd=mpd, show=False)
|
53 |
+
# s_idx_, s_prob_ = detect_peaks(pred[:,j,2], mph=mph_s, mpd=mpd, show=False)
|
54 |
+
# p_idx.append(list(p_idx_))
|
55 |
+
# p_prob.append(list(p_prob_))
|
56 |
+
# s_idx.append(list(s_idx_))
|
57 |
+
# s_prob.append(list(s_prob_))
|
58 |
+
|
59 |
+
# if pred.shape[-1] == 4:
|
60 |
+
# ps_idx, ps_prob = detect_peaks(pred[:,0,3], mph=0.3, mpd=mpd, show=False)
|
61 |
+
# picks.append(record(fname, station_id, start_time, list(p_idx), list(p_prob), list(s_idx), list(s_prob), list(ps_idx), list(ps_prob)))
|
62 |
+
# else:
|
63 |
+
# picks.append(record(fname, station_id, start_time, list(p_idx), list(p_prob), list(s_idx), list(s_prob)))
|
64 |
+
|
65 |
+
# return picks
|
66 |
+
|
67 |
+
|
68 |
+
def extract_picks(
|
69 |
+
preds,
|
70 |
+
file_names=None,
|
71 |
+
begin_times=None,
|
72 |
+
station_ids=None,
|
73 |
+
dt=0.01,
|
74 |
+
phases=["P", "S"],
|
75 |
+
config=None,
|
76 |
+
waveforms=None,
|
77 |
+
use_amplitude=False,
|
78 |
+
):
|
79 |
+
"""Extract picks from prediction results.
|
80 |
+
Args:
|
81 |
+
preds ([type]): [Nb, Nt, Ns, Nc] "batch, time, station, channel"
|
82 |
+
file_names ([type], optional): [Nb]. Defaults to None.
|
83 |
+
station_ids ([type], optional): [Ns]. Defaults to None.
|
84 |
+
t0 ([type], optional): [Nb]. Defaults to None.
|
85 |
+
config ([type], optional): [description]. Defaults to None.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
picks [type]: {file_name, station_id, pick_time, pick_prob, pick_type}
|
89 |
+
"""
|
90 |
+
|
91 |
+
mph = {}
|
92 |
+
if config is None:
|
93 |
+
for x in phases:
|
94 |
+
mph[x] = 0.3
|
95 |
+
mpd = 50
|
96 |
+
pre_idx = int(1 / dt)
|
97 |
+
post_idx = int(4 / dt)
|
98 |
+
else:
|
99 |
+
mph["P"] = config.min_p_prob
|
100 |
+
mph["S"] = config.min_s_prob
|
101 |
+
mph["PS"] = 0.3
|
102 |
+
mpd = config.mpd
|
103 |
+
pre_idx = int(config.pre_sec / dt)
|
104 |
+
post_idx = int(config.post_sec / dt)
|
105 |
+
|
106 |
+
Nb, Nt, Ns, Nc = preds.shape
|
107 |
+
|
108 |
+
if file_names is None:
|
109 |
+
file_names = [f"{i:04d}" for i in range(Nb)]
|
110 |
+
elif not (isinstance(file_names, np.ndarray) or isinstance(file_names, list)):
|
111 |
+
if isinstance(file_names, bytes):
|
112 |
+
file_names = file_names.decode()
|
113 |
+
file_names = [file_names] * Nb
|
114 |
+
else:
|
115 |
+
file_names = [x.decode() if isinstance(x, bytes) else x for x in file_names]
|
116 |
+
|
117 |
+
if begin_times is None:
|
118 |
+
begin_times = ["1970-01-01T00:00:00.000+00:00"] * Nb
|
119 |
+
else:
|
120 |
+
begin_times = [x.decode() if isinstance(x, bytes) else x for x in begin_times]
|
121 |
+
|
122 |
+
picks = []
|
123 |
+
for i in range(Nb):
|
124 |
+
file_name = file_names[i]
|
125 |
+
begin_time = datetime.fromisoformat(begin_times[i])
|
126 |
+
|
127 |
+
for j in range(Ns):
|
128 |
+
if (station_ids is None) or (len(station_ids[i]) == 0):
|
129 |
+
station_id = f"{j:04d}"
|
130 |
+
else:
|
131 |
+
station_id = station_ids[i][j].decode() if isinstance(station_ids[i][j], bytes) else station_ids[i][j]
|
132 |
+
|
133 |
+
if (waveforms is not None) and use_amplitude:
|
134 |
+
amp = np.max(np.abs(waveforms[i, :, j, :]), axis=-1) ## amplitude over three channelspy
|
135 |
+
for k in range(Nc - 1): # 0-th channel noise
|
136 |
+
idxs, probs = detect_peaks(preds[i, :, j, k + 1], mph=mph[phases[k]], mpd=mpd, show=False)
|
137 |
+
for l, (phase_index, phase_prob) in enumerate(zip(idxs, probs)):
|
138 |
+
pick_time = begin_time + timedelta(seconds=phase_index * dt)
|
139 |
+
pick = {
|
140 |
+
"file_name": file_name,
|
141 |
+
"station_id": station_id,
|
142 |
+
"begin_time": begin_time.isoformat(timespec="milliseconds"),
|
143 |
+
"phase_index": int(phase_index),
|
144 |
+
"phase_time": pick_time.isoformat(timespec="milliseconds"),
|
145 |
+
"phase_score": round(phase_prob, 3),
|
146 |
+
"phase_type": phases[k],
|
147 |
+
"dt": dt,
|
148 |
+
}
|
149 |
+
|
150 |
+
## process waveform
|
151 |
+
if waveforms is not None:
|
152 |
+
tmp = np.zeros((pre_idx + post_idx, 3))
|
153 |
+
lo = phase_index - pre_idx
|
154 |
+
hi = phase_index + post_idx
|
155 |
+
insert_idx = 0
|
156 |
+
if lo < 0:
|
157 |
+
lo = 0
|
158 |
+
insert_idx = -lo
|
159 |
+
if hi > Nt:
|
160 |
+
hi = Nt
|
161 |
+
tmp[insert_idx : insert_idx + hi - lo, :] = waveforms[i, lo:hi, j, :]
|
162 |
+
if use_amplitude:
|
163 |
+
next_pick = idxs[l + 1] if l < len(idxs) - 1 else (phase_index + post_idx * 3)
|
164 |
+
pick["phase_amplitude"] = np.max(
|
165 |
+
amp[phase_index : min(phase_index + post_idx * 3, next_pick)]
|
166 |
+
).item() ## peak amplitude
|
167 |
+
|
168 |
+
picks.append(pick)
|
169 |
+
|
170 |
+
return picks
|
171 |
+
|
172 |
+
|
173 |
+
def extract_amplitude(data, picks, window_p=10, window_s=5, config=None):
|
174 |
+
record = namedtuple("amplitude", ["p_amp", "s_amp"])
|
175 |
+
dt = 0.01 if config is None else config.dt
|
176 |
+
window_p = int(window_p / dt)
|
177 |
+
window_s = int(window_s / dt)
|
178 |
+
amps = []
|
179 |
+
for i, (da, pi) in enumerate(zip(data, picks)):
|
180 |
+
p_amp, s_amp = [], []
|
181 |
+
for j in range(da.shape[1]):
|
182 |
+
amp = np.max(np.abs(da[:, j, :]), axis=-1)
|
183 |
+
# amp = np.median(np.abs(da[:,j,:]), axis=-1)
|
184 |
+
# amp = np.linalg.norm(da[:,j,:], axis=-1)
|
185 |
+
tmp = []
|
186 |
+
for k in range(len(pi.p_idx[j]) - 1):
|
187 |
+
tmp.append(np.max(amp[pi.p_idx[j][k] : min(pi.p_idx[j][k] + window_p, pi.p_idx[j][k + 1])]))
|
188 |
+
if len(pi.p_idx[j]) >= 1:
|
189 |
+
tmp.append(np.max(amp[pi.p_idx[j][-1] : pi.p_idx[j][-1] + window_p]))
|
190 |
+
p_amp.append(tmp)
|
191 |
+
tmp = []
|
192 |
+
for k in range(len(pi.s_idx[j]) - 1):
|
193 |
+
tmp.append(np.max(amp[pi.s_idx[j][k] : min(pi.s_idx[j][k] + window_s, pi.s_idx[j][k + 1])]))
|
194 |
+
if len(pi.s_idx[j]) >= 1:
|
195 |
+
tmp.append(np.max(amp[pi.s_idx[j][-1] : pi.s_idx[j][-1] + window_s]))
|
196 |
+
s_amp.append(tmp)
|
197 |
+
amps.append(record(p_amp, s_amp))
|
198 |
+
return amps
|
199 |
+
|
200 |
+
|
201 |
+
def save_picks(picks, output_dir, amps=None, fname=None):
|
202 |
+
if fname is None:
|
203 |
+
fname = "picks.csv"
|
204 |
+
|
205 |
+
int2s = lambda x: ",".join(["[" + ",".join(map(str, i)) + "]" for i in x])
|
206 |
+
flt2s = lambda x: ",".join(["[" + ",".join(map("{:0.3f}".format, i)) + "]" for i in x])
|
207 |
+
sci2s = lambda x: ",".join(["[" + ",".join(map("{:0.3e}".format, i)) + "]" for i in x])
|
208 |
+
if amps is None:
|
209 |
+
if hasattr(picks[0], "ps_idx"):
|
210 |
+
with open(os.path.join(output_dir, fname), "w") as fp:
|
211 |
+
fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\tps_idx\tps_prob\n")
|
212 |
+
for pick in picks:
|
213 |
+
fp.write(
|
214 |
+
f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\t{int2s(pick.ps_idx)}\t{flt2s(pick.ps_prob)}\n"
|
215 |
+
)
|
216 |
+
fp.close()
|
217 |
+
else:
|
218 |
+
with open(os.path.join(output_dir, fname), "w") as fp:
|
219 |
+
fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\n")
|
220 |
+
for pick in picks:
|
221 |
+
fp.write(
|
222 |
+
f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\n"
|
223 |
+
)
|
224 |
+
fp.close()
|
225 |
+
else:
|
226 |
+
with open(os.path.join(output_dir, fname), "w") as fp:
|
227 |
+
fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\tp_amp\ts_amp\n")
|
228 |
+
for pick, amp in zip(picks, amps):
|
229 |
+
fp.write(
|
230 |
+
f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\t{sci2s(amp.p_amp)}\t{sci2s(amp.s_amp)}\n"
|
231 |
+
)
|
232 |
+
fp.close()
|
233 |
+
|
234 |
+
return 0
|
235 |
+
|
236 |
+
|
237 |
+
def calc_timestamp(timestamp, sec):
|
238 |
+
timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec)
|
239 |
+
return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
240 |
+
|
241 |
+
|
242 |
+
def save_picks_json(picks, output_dir, dt=0.01, amps=None, fname=None):
|
243 |
+
if fname is None:
|
244 |
+
fname = "picks.json"
|
245 |
+
|
246 |
+
picks_ = []
|
247 |
+
if amps is None:
|
248 |
+
for pick in picks:
|
249 |
+
for idxs, probs in zip(pick.p_idx, pick.p_prob):
|
250 |
+
for idx, prob in zip(idxs, probs):
|
251 |
+
picks_.append(
|
252 |
+
{
|
253 |
+
"id": pick.station_id,
|
254 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
255 |
+
"prob": prob.astype(float),
|
256 |
+
"type": "p",
|
257 |
+
}
|
258 |
+
)
|
259 |
+
for idxs, probs in zip(pick.s_idx, pick.s_prob):
|
260 |
+
for idx, prob in zip(idxs, probs):
|
261 |
+
picks_.append(
|
262 |
+
{
|
263 |
+
"id": pick.station_id,
|
264 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
265 |
+
"prob": prob.astype(float),
|
266 |
+
"type": "s",
|
267 |
+
}
|
268 |
+
)
|
269 |
+
else:
|
270 |
+
for pick, amplitude in zip(picks, amps):
|
271 |
+
for idxs, probs, amps in zip(pick.p_idx, pick.p_prob, amplitude.p_amp):
|
272 |
+
for idx, prob, amp in zip(idxs, probs, amps):
|
273 |
+
picks_.append(
|
274 |
+
{
|
275 |
+
"id": pick.station_id,
|
276 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
277 |
+
"prob": prob.astype(float),
|
278 |
+
"amp": amp.astype(float),
|
279 |
+
"type": "p",
|
280 |
+
}
|
281 |
+
)
|
282 |
+
for idxs, probs, amps in zip(pick.s_idx, pick.s_prob, amplitude.s_amp):
|
283 |
+
for idx, prob, amp in zip(idxs, probs, amps):
|
284 |
+
picks_.append(
|
285 |
+
{
|
286 |
+
"id": pick.station_id,
|
287 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
288 |
+
"prob": prob.astype(float),
|
289 |
+
"amp": amp.astype(float),
|
290 |
+
"type": "s",
|
291 |
+
}
|
292 |
+
)
|
293 |
+
with open(os.path.join(output_dir, fname), "w") as fp:
|
294 |
+
json.dump(picks_, fp)
|
295 |
+
|
296 |
+
return 0
|
297 |
+
|
298 |
+
|
299 |
+
def convert_true_picks(fname, itp, its, itps=None):
|
300 |
+
true_picks = []
|
301 |
+
if itps is None:
|
302 |
+
record = namedtuple("phase", ["fname", "p_idx", "s_idx"])
|
303 |
+
for i in range(len(fname)):
|
304 |
+
true_picks.append(record(fname[i].decode(), itp[i], its[i]))
|
305 |
+
else:
|
306 |
+
record = namedtuple("phase", ["fname", "p_idx", "s_idx", "ps_idx"])
|
307 |
+
for i in range(len(fname)):
|
308 |
+
true_picks.append(record(fname[i].decode(), itp[i], its[i], itps[i]))
|
309 |
+
|
310 |
+
return true_picks
|
311 |
+
|
312 |
+
|
313 |
+
def calc_metrics(nTP, nP, nT):
|
314 |
+
"""
|
315 |
+
nTP: true positive
|
316 |
+
nP: number of positive picks
|
317 |
+
nT: number of true picks
|
318 |
+
"""
|
319 |
+
precision = nTP / nP
|
320 |
+
recall = nTP / nT
|
321 |
+
f1 = 2 * precision * recall / (precision + recall)
|
322 |
+
return [precision, recall, f1]
|
323 |
+
|
324 |
+
|
325 |
+
def calc_performance(picks, true_picks, tol=3.0, dt=1.0):
|
326 |
+
assert len(picks) == len(true_picks)
|
327 |
+
logging.info("Total records: {}".format(len(picks)))
|
328 |
+
|
329 |
+
count = lambda picks: sum([len(x) for x in picks])
|
330 |
+
metrics = {}
|
331 |
+
for phase in true_picks[0]._fields:
|
332 |
+
if phase == "fname":
|
333 |
+
continue
|
334 |
+
true_positive, positive, true = 0, 0, 0
|
335 |
+
residual = []
|
336 |
+
for i in range(len(true_picks)):
|
337 |
+
true += count(getattr(true_picks[i], phase))
|
338 |
+
positive += count(getattr(picks[i], phase))
|
339 |
+
# print(i, phase, getattr(picks[i], phase), getattr(true_picks[i], phase))
|
340 |
+
diff = dt * (
|
341 |
+
np.array(getattr(picks[i], phase))[:, np.newaxis, :]
|
342 |
+
- np.array(getattr(true_picks[i], phase))[:, :, np.newaxis]
|
343 |
+
)
|
344 |
+
residual.extend(list(diff[np.abs(diff) <= tol]))
|
345 |
+
true_positive += np.sum(np.abs(diff) <= tol)
|
346 |
+
metrics[phase] = calc_metrics(true_positive, positive, true)
|
347 |
+
|
348 |
+
logging.info(f"{phase}-phase:")
|
349 |
+
logging.info(f"True={true}, Positive={positive}, True Positive={true_positive}")
|
350 |
+
logging.info(f"Precision={metrics[phase][0]:.3f}, Recall={metrics[phase][1]:.3f}, F1={metrics[phase][2]:.3f}")
|
351 |
+
logging.info(f"Residual mean={np.mean(residual):.4f}, std={np.std(residual):.4f}")
|
352 |
+
|
353 |
+
return metrics
|
354 |
+
|
355 |
+
|
356 |
+
def save_prob_h5(probs, fnames, output_h5):
|
357 |
+
if fnames is None:
|
358 |
+
fnames = [f"{i:04d}" for i in range(len(probs))]
|
359 |
+
elif type(fnames[0]) is bytes:
|
360 |
+
fnames = [f.decode().rstrip(".npz") for f in fnames]
|
361 |
+
else:
|
362 |
+
fnames = [f.rstrip(".npz") for f in fnames]
|
363 |
+
for prob, fname in zip(probs, fnames):
|
364 |
+
output_h5.create_dataset(fname, data=prob, dtype="float32")
|
365 |
+
return 0
|
366 |
+
|
367 |
+
|
368 |
+
def save_prob(probs, fnames, prob_dir):
|
369 |
+
if fnames is None:
|
370 |
+
fnames = [f"{i:04d}" for i in range(len(probs))]
|
371 |
+
elif type(fnames[0]) is bytes:
|
372 |
+
fnames = [f.decode().rstrip(".npz") for f in fnames]
|
373 |
+
else:
|
374 |
+
fnames = [f.rstrip(".npz") for f in fnames]
|
375 |
+
for prob, fname in zip(probs, fnames):
|
376 |
+
np.savez(os.path.join(prob_dir, fname + ".npz"), prob=prob)
|
377 |
+
return 0
|
phasenet/predict.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import multiprocessing
|
4 |
+
import os
|
5 |
+
import pickle
|
6 |
+
import time
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
import h5py
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
import tensorflow as tf
|
13 |
+
from data_reader import DataReader_mseed_array, DataReader_pred
|
14 |
+
from postprocess import (
|
15 |
+
extract_amplitude,
|
16 |
+
extract_picks,
|
17 |
+
save_picks,
|
18 |
+
save_picks_json,
|
19 |
+
save_prob_h5,
|
20 |
+
)
|
21 |
+
from tqdm import tqdm
|
22 |
+
from visulization import plot_waveform
|
23 |
+
|
24 |
+
from model import ModelConfig, UNet
|
25 |
+
|
26 |
+
tf.compat.v1.disable_eager_execution()
|
27 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
28 |
+
|
29 |
+
|
30 |
+
def read_args():
|
31 |
+
parser = argparse.ArgumentParser()
|
32 |
+
parser.add_argument("--batch_size", default=20, type=int, help="batch size")
|
33 |
+
parser.add_argument("--model_dir", help="Checkpoint directory (default: None)")
|
34 |
+
parser.add_argument("--data_dir", default="", help="Input file directory")
|
35 |
+
parser.add_argument("--data_list", default="", help="Input csv file")
|
36 |
+
parser.add_argument("--hdf5_file", default="", help="Input hdf5 file")
|
37 |
+
parser.add_argument("--hdf5_group", default="data", help="data group name in hdf5 file")
|
38 |
+
parser.add_argument("--result_dir", default="results", help="Output directory")
|
39 |
+
parser.add_argument("--result_fname", default="picks", help="Output file")
|
40 |
+
parser.add_argument("--min_p_prob", default=0.3, type=float, help="Probability threshold for P pick")
|
41 |
+
parser.add_argument("--min_s_prob", default=0.3, type=float, help="Probability threshold for S pick")
|
42 |
+
parser.add_argument("--mpd", default=50, type=float, help="Minimum peak distance")
|
43 |
+
parser.add_argument("--amplitude", action="store_true", help="if return amplitude value")
|
44 |
+
parser.add_argument("--format", default="numpy", help="input format")
|
45 |
+
parser.add_argument("--s3_url", default="localhost:9000", help="s3 url")
|
46 |
+
parser.add_argument("--stations", default="", help="seismic station info")
|
47 |
+
parser.add_argument("--plot_figure", action="store_true", help="If plot figure for test")
|
48 |
+
parser.add_argument("--save_prob", action="store_true", help="If save result for test")
|
49 |
+
parser.add_argument("--pre_sec", default=1, type=float, help="Window length before pick")
|
50 |
+
parser.add_argument("--post_sec", default=4, type=float, help="Window length after pick")
|
51 |
+
|
52 |
+
parser.add_argument("--highpass_filter", default=0.0, type=float, help="Highpass filter")
|
53 |
+
parser.add_argument("--response_xml", default=None, type=str, help="response xml file")
|
54 |
+
parser.add_argument("--sampling_rate", default=100, type=float, help="sampling rate")
|
55 |
+
args = parser.parse_args()
|
56 |
+
|
57 |
+
return args
|
58 |
+
|
59 |
+
|
60 |
+
def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None):
|
61 |
+
current_time = time.strftime("%y%m%d-%H%M%S")
|
62 |
+
if log_dir is None:
|
63 |
+
log_dir = os.path.join(args.log_dir, "pred", current_time)
|
64 |
+
if not os.path.exists(log_dir):
|
65 |
+
os.makedirs(log_dir)
|
66 |
+
if (args.plot_figure == True) and (figure_dir is None):
|
67 |
+
figure_dir = os.path.join(log_dir, "figures")
|
68 |
+
if not os.path.exists(figure_dir):
|
69 |
+
os.makedirs(figure_dir)
|
70 |
+
if (args.save_prob == True) and (prob_dir is None):
|
71 |
+
prob_dir = os.path.join(log_dir, "probs")
|
72 |
+
if not os.path.exists(prob_dir):
|
73 |
+
os.makedirs(prob_dir)
|
74 |
+
if args.save_prob:
|
75 |
+
h5 = h5py.File(os.path.join(args.result_dir, "result.h5"), "w", libver="latest")
|
76 |
+
prob_h5 = h5.create_group("/prob")
|
77 |
+
logging.info("Pred log: %s" % log_dir)
|
78 |
+
logging.info("Dataset size: {}".format(data_reader.num_data))
|
79 |
+
|
80 |
+
with tf.compat.v1.name_scope("Input_Batch"):
|
81 |
+
if args.format == "mseed_array":
|
82 |
+
batch_size = 1
|
83 |
+
else:
|
84 |
+
batch_size = args.batch_size
|
85 |
+
dataset = data_reader.dataset(batch_size)
|
86 |
+
batch = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
|
87 |
+
|
88 |
+
config = ModelConfig(X_shape=data_reader.X_shape)
|
89 |
+
with open(os.path.join(log_dir, "config.log"), "w") as fp:
|
90 |
+
fp.write("\n".join("%s: %s" % item for item in vars(config).items()))
|
91 |
+
|
92 |
+
model = UNet(config=config, input_batch=batch, mode="pred")
|
93 |
+
# model = UNet(config=config, mode="pred")
|
94 |
+
sess_config = tf.compat.v1.ConfigProto()
|
95 |
+
sess_config.gpu_options.allow_growth = True
|
96 |
+
# sess_config.log_device_placement = False
|
97 |
+
|
98 |
+
with tf.compat.v1.Session(config=sess_config) as sess:
|
99 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5)
|
100 |
+
init = tf.compat.v1.global_variables_initializer()
|
101 |
+
sess.run(init)
|
102 |
+
|
103 |
+
latest_check_point = tf.train.latest_checkpoint(args.model_dir)
|
104 |
+
logging.info(f"restoring model {latest_check_point}")
|
105 |
+
saver.restore(sess, latest_check_point)
|
106 |
+
|
107 |
+
picks = []
|
108 |
+
amps = [] if args.amplitude else None
|
109 |
+
if args.plot_figure:
|
110 |
+
multiprocessing.set_start_method("spawn")
|
111 |
+
pool = multiprocessing.Pool(multiprocessing.cpu_count())
|
112 |
+
|
113 |
+
for _ in tqdm(range(0, data_reader.num_data, batch_size), desc="Pred"):
|
114 |
+
if args.amplitude:
|
115 |
+
pred_batch, X_batch, amp_batch, fname_batch, t0_batch, station_batch = sess.run(
|
116 |
+
[model.preds, batch[0], batch[1], batch[2], batch[3], batch[4]],
|
117 |
+
feed_dict={model.drop_rate: 0, model.is_training: False},
|
118 |
+
)
|
119 |
+
# X_batch, amp_batch, fname_batch, t0_batch = sess.run([batch[0], batch[1], batch[2], batch[3]])
|
120 |
+
else:
|
121 |
+
pred_batch, X_batch, fname_batch, t0_batch, station_batch = sess.run(
|
122 |
+
[model.preds, batch[0], batch[1], batch[2], batch[3]],
|
123 |
+
feed_dict={model.drop_rate: 0, model.is_training: False},
|
124 |
+
)
|
125 |
+
# X_batch, fname_batch, t0_batch = sess.run([model.preds, batch[0], batch[1], batch[2]])
|
126 |
+
# pred_batch = []
|
127 |
+
# for i in range(0, len(X_batch), 1):
|
128 |
+
# pred_batch.append(sess.run(model.preds, feed_dict={model.X: X_batch[i:i+1], model.drop_rate: 0, model.is_training: False}))
|
129 |
+
# pred_batch = np.vstack(pred_batch)
|
130 |
+
|
131 |
+
waveforms = None
|
132 |
+
if args.amplitude:
|
133 |
+
waveforms = amp_batch
|
134 |
+
|
135 |
+
picks_ = extract_picks(
|
136 |
+
preds=pred_batch,
|
137 |
+
file_names=fname_batch,
|
138 |
+
station_ids=station_batch,
|
139 |
+
begin_times=t0_batch,
|
140 |
+
config=args,
|
141 |
+
waveforms=waveforms,
|
142 |
+
use_amplitude=args.amplitude,
|
143 |
+
dt=1.0 / args.sampling_rate,
|
144 |
+
)
|
145 |
+
|
146 |
+
picks.extend(picks_)
|
147 |
+
|
148 |
+
## save pick per file
|
149 |
+
if len(fname_batch) == 1:
|
150 |
+
df = pd.DataFrame(picks_)
|
151 |
+
df = df[df["phase_index"] > 10]
|
152 |
+
if not os.path.exists(os.path.join(args.result_dir, "picks")):
|
153 |
+
os.makedirs(os.path.join(args.result_dir, "picks"))
|
154 |
+
df = df[
|
155 |
+
[
|
156 |
+
"station_id",
|
157 |
+
"begin_time",
|
158 |
+
"phase_index",
|
159 |
+
"phase_time",
|
160 |
+
"phase_score",
|
161 |
+
"phase_type",
|
162 |
+
"phase_amplitude",
|
163 |
+
"dt",
|
164 |
+
]
|
165 |
+
]
|
166 |
+
df.to_csv(
|
167 |
+
os.path.join(
|
168 |
+
args.result_dir, "picks", fname_batch[0].decode().split("/")[-1].rstrip(".mseed") + ".csv"
|
169 |
+
),
|
170 |
+
index=False,
|
171 |
+
)
|
172 |
+
|
173 |
+
if args.plot_figure:
|
174 |
+
if not (isinstance(fname_batch, np.ndarray) or isinstance(fname_batch, list)):
|
175 |
+
fname_batch = [fname_batch.decode().rstrip(".mseed") + "_" + x.decode() for x in station_batch]
|
176 |
+
else:
|
177 |
+
fname_batch = [x.decode() for x in fname_batch]
|
178 |
+
pool.starmap(
|
179 |
+
partial(
|
180 |
+
plot_waveform,
|
181 |
+
figure_dir=figure_dir,
|
182 |
+
),
|
183 |
+
# zip(X_batch, pred_batch, [x.decode() for x in fname_batch]),
|
184 |
+
zip(X_batch, pred_batch, fname_batch),
|
185 |
+
)
|
186 |
+
|
187 |
+
if args.save_prob:
|
188 |
+
# save_prob(pred_batch, fname_batch, prob_dir=prob_dir)
|
189 |
+
if not (isinstance(fname_batch, np.ndarray) or isinstance(fname_batch, list)):
|
190 |
+
fname_batch = [fname_batch.decode().rstrip(".mseed") + "_" + x.decode() for x in station_batch]
|
191 |
+
else:
|
192 |
+
fname_batch = [x.decode() for x in fname_batch]
|
193 |
+
save_prob_h5(pred_batch, fname_batch, prob_h5)
|
194 |
+
|
195 |
+
if len(picks) > 0:
|
196 |
+
# save_picks(picks, args.result_dir, amps=amps, fname=args.result_fname+".csv")
|
197 |
+
# save_picks_json(picks, args.result_dir, dt=data_reader.dt, amps=amps, fname=args.result_fname+".json")
|
198 |
+
df = pd.DataFrame(picks)
|
199 |
+
# df["fname"] = df["file_name"]
|
200 |
+
# df["id"] = df["station_id"]
|
201 |
+
# df["timestamp"] = df["phase_time"]
|
202 |
+
# df["prob"] = df["phase_prob"]
|
203 |
+
# df["type"] = df["phase_type"]
|
204 |
+
|
205 |
+
base_columns = [
|
206 |
+
"station_id",
|
207 |
+
"begin_time",
|
208 |
+
"phase_index",
|
209 |
+
"phase_time",
|
210 |
+
"phase_score",
|
211 |
+
"phase_type",
|
212 |
+
"file_name",
|
213 |
+
]
|
214 |
+
if args.amplitude:
|
215 |
+
base_columns.append("phase_amplitude")
|
216 |
+
base_columns.append("phase_amp")
|
217 |
+
df["phase_amp"] = df["phase_amplitude"]
|
218 |
+
|
219 |
+
df = df[base_columns]
|
220 |
+
df.to_csv(os.path.join(args.result_dir, args.result_fname + ".csv"), index=False)
|
221 |
+
|
222 |
+
print(
|
223 |
+
f"Done with {len(df[df['phase_type'] == 'P'])} P-picks and {len(df[df['phase_type'] == 'S'])} S-picks"
|
224 |
+
)
|
225 |
+
else:
|
226 |
+
print(f"Done with 0 P-picks and 0 S-picks")
|
227 |
+
return 0
|
228 |
+
|
229 |
+
|
230 |
+
def main(args):
|
231 |
+
logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)
|
232 |
+
|
233 |
+
with tf.compat.v1.name_scope("create_inputs"):
|
234 |
+
if args.format == "mseed_array":
|
235 |
+
data_reader = DataReader_mseed_array(
|
236 |
+
data_dir=args.data_dir,
|
237 |
+
data_list=args.data_list,
|
238 |
+
stations=args.stations,
|
239 |
+
amplitude=args.amplitude,
|
240 |
+
highpass_filter=args.highpass_filter,
|
241 |
+
)
|
242 |
+
else:
|
243 |
+
data_reader = DataReader_pred(
|
244 |
+
format=args.format,
|
245 |
+
data_dir=args.data_dir,
|
246 |
+
data_list=args.data_list,
|
247 |
+
hdf5_file=args.hdf5_file,
|
248 |
+
hdf5_group=args.hdf5_group,
|
249 |
+
amplitude=args.amplitude,
|
250 |
+
highpass_filter=args.highpass_filter,
|
251 |
+
response_xml=args.response_xml,
|
252 |
+
sampling_rate=args.sampling_rate,
|
253 |
+
)
|
254 |
+
|
255 |
+
pred_fn(args, data_reader, log_dir=args.result_dir)
|
256 |
+
|
257 |
+
return
|
258 |
+
|
259 |
+
|
260 |
+
if __name__ == "__main__":
|
261 |
+
args = read_args()
|
262 |
+
main(args)
|
phasenet/slide_window.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import defaultdict, namedtuple
|
3 |
+
from datetime import datetime, timedelta
|
4 |
+
from json import dumps
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import tensorflow as tf
|
8 |
+
|
9 |
+
from model import ModelConfig, UNet
|
10 |
+
from postprocess import extract_amplitude, extract_picks
|
11 |
+
import pandas as pd
|
12 |
+
import obspy
|
13 |
+
|
14 |
+
|
15 |
+
tf.compat.v1.disable_eager_execution()
|
16 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
17 |
+
PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
|
18 |
+
|
19 |
+
# load model
|
20 |
+
model = UNet(mode="pred")
|
21 |
+
sess_config = tf.compat.v1.ConfigProto()
|
22 |
+
sess_config.gpu_options.allow_growth = True
|
23 |
+
|
24 |
+
sess = tf.compat.v1.Session(config=sess_config)
|
25 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
|
26 |
+
init = tf.compat.v1.global_variables_initializer()
|
27 |
+
sess.run(init)
|
28 |
+
latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190703-214543")
|
29 |
+
print(f"restoring model {latest_check_point}")
|
30 |
+
saver.restore(sess, latest_check_point)
|
31 |
+
|
32 |
+
|
33 |
+
def calc_timestamp(timestamp, sec):
|
34 |
+
timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec)
|
35 |
+
return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
36 |
+
|
37 |
+
def format_picks(picks, dt):
|
38 |
+
picks_ = []
|
39 |
+
for pick in picks:
|
40 |
+
for idxs, probs in zip(pick.p_idx, pick.p_prob):
|
41 |
+
for idx, prob in zip(idxs, probs):
|
42 |
+
picks_.append(
|
43 |
+
{
|
44 |
+
"id": pick.fname,
|
45 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
46 |
+
"prob": prob,
|
47 |
+
"type": "p",
|
48 |
+
}
|
49 |
+
)
|
50 |
+
for idxs, probs in zip(pick.s_idx, pick.s_prob):
|
51 |
+
for idx, prob in zip(idxs, probs):
|
52 |
+
picks_.append(
|
53 |
+
{
|
54 |
+
"id": pick.fname,
|
55 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
56 |
+
"prob": prob,
|
57 |
+
"type": "s",
|
58 |
+
}
|
59 |
+
)
|
60 |
+
return picks_
|
61 |
+
|
62 |
+
|
63 |
+
stream = obspy.read()
|
64 |
+
stream = stream.sort() ## Assume it is NPZ sorted
|
65 |
+
assert(len(stream) == 3)
|
66 |
+
data = []
|
67 |
+
for trace in stream:
|
68 |
+
data.append(trace.data)
|
69 |
+
data = np.array(data).T
|
70 |
+
assert(data.shape[-1] == 3)
|
71 |
+
|
72 |
+
# data_id = stream[0].get_id()[:-1]
|
73 |
+
# timestamp = stream[0].stats.starttime.datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
74 |
+
|
75 |
+
data = np.stack([data for i in range(10)]) ## Assume 10 windows
|
76 |
+
data = data[:,:,np.newaxis,:] ## batch, nt, dummy_dim, channel
|
77 |
+
print(f"{data.shape = }")
|
78 |
+
data = (data - data.mean(axis=1, keepdims=True))/data.std(axis=1, keepdims=True)
|
79 |
+
|
80 |
+
feed = {model.X: data, model.drop_rate: 0, model.is_training: False}
|
81 |
+
preds = sess.run(model.preds, feed_dict=feed)
|
82 |
+
|
83 |
+
picks = extract_picks(preds, fnames=None, station_ids=None, t0=None)
|
84 |
+
picks = format_picks(picks, dt=0.01)
|
85 |
+
|
86 |
+
|
87 |
+
picks = pd.DataFrame(picks)
|
88 |
+
print(picks)
|
phasenet/test_app.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import obspy
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from datetime import datetime
|
6 |
+
|
7 |
+
### Start running the model first:
|
8 |
+
### FLASK_ENV=development FLASK_APP=app.py flask run
|
9 |
+
|
10 |
+
def read_data(mseed):
|
11 |
+
data = []
|
12 |
+
mseed = mseed.sort()
|
13 |
+
for c in ["E", "N", "Z"]:
|
14 |
+
data.append(mseed.select(channel="*"+c)[0].data)
|
15 |
+
return np.array(data).T
|
16 |
+
|
17 |
+
timestamp = lambda x: x.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
18 |
+
|
19 |
+
## prepare some test data
|
20 |
+
mseed = obspy.read()
|
21 |
+
data = []
|
22 |
+
for i in range(1):
|
23 |
+
data.append(read_data(mseed))
|
24 |
+
data = {
|
25 |
+
"id": ["test01"],
|
26 |
+
"timestamp": [timestamp(datetime.now())],
|
27 |
+
"vec": np.array(data).tolist(),
|
28 |
+
"dt": 0.01
|
29 |
+
}
|
30 |
+
|
31 |
+
## run prediction
|
32 |
+
print(data["id"])
|
33 |
+
resp = requests.get("http://localhost:8000/predict", json=data)
|
34 |
+
# picks = resp.json()["picks"]
|
35 |
+
print(resp.json())
|
36 |
+
|
37 |
+
|
38 |
+
## plot figure
|
39 |
+
plt.figure()
|
40 |
+
plt.plot(np.array(data["data"])[0,:,1])
|
41 |
+
ylim = plt.ylim()
|
42 |
+
plt.plot([picks[0][0][0], picks[0][0][0]], ylim, label="P-phase")
|
43 |
+
plt.text(picks[0][0][0], ylim[1]*0.9, f"{picks[0][1][0]:.2f}")
|
44 |
+
plt.plot([picks[0][2][0], picks[0][2][0]], ylim, label="S-phase")
|
45 |
+
plt.text(picks[0][2][0], ylim[1]*0.9, f"{picks[0][1][0]:.2f}")
|
46 |
+
plt.legend()
|
47 |
+
plt.savefig("test.png")
|
phasenet/train.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
tf.compat.v1.disable_eager_execution()
|
4 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
5 |
+
import argparse, os, time, logging
|
6 |
+
from tqdm import tqdm
|
7 |
+
import pandas as pd
|
8 |
+
import multiprocessing
|
9 |
+
from functools import partial
|
10 |
+
import pickle
|
11 |
+
from model import UNet, ModelConfig
|
12 |
+
from data_reader import DataReader_train, DataReader_test
|
13 |
+
from postprocess import extract_picks, save_picks, save_picks_json, extract_amplitude, convert_true_picks, calc_performance
|
14 |
+
from visulization import plot_waveform
|
15 |
+
from util import EMA, LMA
|
16 |
+
|
17 |
+
def read_args():
|
18 |
+
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument("--mode", default="train", help="train/train_valid/test/debug")
|
21 |
+
parser.add_argument("--epochs", default=100, type=int, help="number of epochs (default: 10)")
|
22 |
+
parser.add_argument("--batch_size", default=20, type=int, help="batch size")
|
23 |
+
parser.add_argument("--learning_rate", default=0.01, type=float, help="learning rate")
|
24 |
+
parser.add_argument("--drop_rate", default=0.0, type=float, help="dropout rate")
|
25 |
+
parser.add_argument("--decay_step", default=-1, type=int, help="decay step")
|
26 |
+
parser.add_argument("--decay_rate", default=0.9, type=float, help="decay rate")
|
27 |
+
parser.add_argument("--momentum", default=0.9, type=float, help="momentum")
|
28 |
+
parser.add_argument("--optimizer", default="adam", help="optimizer: adam, momentum")
|
29 |
+
parser.add_argument("--summary", default=True, type=bool, help="summary")
|
30 |
+
parser.add_argument("--class_weights", nargs="+", default=[1, 1, 1], type=float, help="class weights")
|
31 |
+
parser.add_argument("--model_dir", default=None, help="Checkpoint directory (default: None)")
|
32 |
+
parser.add_argument("--load_model", action="store_true", help="Load checkpoint")
|
33 |
+
parser.add_argument("--log_dir", default="log", help="Log directory (default: log)")
|
34 |
+
parser.add_argument("--num_plots", default=10, type=int, help="Plotting training results")
|
35 |
+
parser.add_argument("--min_p_prob", default=0.3, type=float, help="Probability threshold for P pick")
|
36 |
+
parser.add_argument("--min_s_prob", default=0.3, type=float, help="Probability threshold for S pick")
|
37 |
+
parser.add_argument("--format", default="numpy", help="Input data format")
|
38 |
+
parser.add_argument("--train_dir", default="./dataset/waveform_train/", help="Input file directory")
|
39 |
+
parser.add_argument("--train_list", default="./dataset/waveform.csv", help="Input csv file")
|
40 |
+
parser.add_argument("--valid_dir", default=None, help="Input file directory")
|
41 |
+
parser.add_argument("--valid_list", default=None, help="Input csv file")
|
42 |
+
parser.add_argument("--test_dir", default=None, help="Input file directory")
|
43 |
+
parser.add_argument("--test_list", default=None, help="Input csv file")
|
44 |
+
parser.add_argument("--result_dir", default="results", help="result directory")
|
45 |
+
parser.add_argument("--plot_figure", action="store_true", help="If plot figure for test")
|
46 |
+
parser.add_argument("--save_prob", action="store_true", help="If save result for test")
|
47 |
+
args = parser.parse_args()
|
48 |
+
|
49 |
+
return args
|
50 |
+
|
51 |
+
|
52 |
+
def train_fn(args, data_reader, data_reader_valid=None):
|
53 |
+
|
54 |
+
current_time = time.strftime("%y%m%d-%H%M%S")
|
55 |
+
log_dir = os.path.join(args.log_dir, current_time)
|
56 |
+
if not os.path.exists(log_dir):
|
57 |
+
os.makedirs(log_dir)
|
58 |
+
logging.info("Training log: {}".format(log_dir))
|
59 |
+
model_dir = os.path.join(log_dir, 'models')
|
60 |
+
os.makedirs(model_dir)
|
61 |
+
|
62 |
+
figure_dir = os.path.join(log_dir, 'figures')
|
63 |
+
if not os.path.exists(figure_dir):
|
64 |
+
os.makedirs(figure_dir)
|
65 |
+
|
66 |
+
config = ModelConfig(X_shape=data_reader.X_shape, Y_shape=data_reader.Y_shape)
|
67 |
+
if args.decay_step == -1:
|
68 |
+
args.decay_step = data_reader.num_data // args.batch_size
|
69 |
+
config.update_args(args)
|
70 |
+
with open(os.path.join(log_dir, 'config.log'), 'w') as fp:
|
71 |
+
fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
|
72 |
+
|
73 |
+
with tf.compat.v1.name_scope('Input_Batch'):
|
74 |
+
dataset = data_reader.dataset(args.batch_size, shuffle=True).repeat()
|
75 |
+
batch = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
|
76 |
+
if data_reader_valid is not None:
|
77 |
+
dataset_valid = data_reader_valid.dataset(args.batch_size, shuffle=False).repeat()
|
78 |
+
valid_batch = tf.compat.v1.data.make_one_shot_iterator(dataset_valid).get_next()
|
79 |
+
|
80 |
+
model = UNet(config, input_batch=batch)
|
81 |
+
sess_config = tf.compat.v1.ConfigProto()
|
82 |
+
sess_config.gpu_options.allow_growth = True
|
83 |
+
# sess_config.log_device_placement = False
|
84 |
+
|
85 |
+
with tf.compat.v1.Session(config=sess_config) as sess:
|
86 |
+
|
87 |
+
summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph)
|
88 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5)
|
89 |
+
init = tf.compat.v1.global_variables_initializer()
|
90 |
+
sess.run(init)
|
91 |
+
|
92 |
+
if args.model_dir is not None:
|
93 |
+
logging.info("restoring models...")
|
94 |
+
latest_check_point = tf.train.latest_checkpoint(args.model_dir)
|
95 |
+
saver.restore(sess, latest_check_point)
|
96 |
+
|
97 |
+
if args.plot_figure:
|
98 |
+
multiprocessing.set_start_method('spawn')
|
99 |
+
pool = multiprocessing.Pool(multiprocessing.cpu_count())
|
100 |
+
|
101 |
+
flog = open(os.path.join(log_dir, 'loss.log'), 'w')
|
102 |
+
train_loss = EMA(0.9)
|
103 |
+
best_valid_loss = np.inf
|
104 |
+
for epoch in range(args.epochs):
|
105 |
+
progressbar = tqdm(range(0, data_reader.num_data, args.batch_size), desc="{}: epoch {}".format(log_dir.split("/")[-1], epoch))
|
106 |
+
for _ in progressbar:
|
107 |
+
loss_batch, _, _ = sess.run([model.loss, model.train_op, model.global_step],
|
108 |
+
feed_dict={model.drop_rate: args.drop_rate, model.is_training: True})
|
109 |
+
train_loss(loss_batch)
|
110 |
+
progressbar.set_description("{}: epoch {}, loss={:.6f}, mean={:.6f}".format(log_dir.split("/")[-1], epoch, loss_batch, train_loss.value))
|
111 |
+
flog.write("epoch: {}, mean loss: {}\n".format(epoch, train_loss.value))
|
112 |
+
|
113 |
+
if data_reader_valid is not None:
|
114 |
+
valid_loss = LMA()
|
115 |
+
progressbar = tqdm(range(0, data_reader_valid.num_data, args.batch_size), desc="Valid:")
|
116 |
+
for _ in progressbar:
|
117 |
+
loss_batch, preds_batch, X_batch, Y_batch, fname_batch = sess.run([model.loss, model.preds, valid_batch[0], valid_batch[1], valid_batch[2]],
|
118 |
+
feed_dict={model.drop_rate: 0, model.is_training: False})
|
119 |
+
valid_loss(loss_batch)
|
120 |
+
progressbar.set_description("valid, loss={:.6f}, mean={:.6f}".format(loss_batch, valid_loss.value))
|
121 |
+
if valid_loss.value < best_valid_loss:
|
122 |
+
best_valid_loss = valid_loss.value
|
123 |
+
saver.save(sess, os.path.join(model_dir, "model_{}.ckpt".format(epoch)))
|
124 |
+
flog.write("Valid: mean loss: {}\n".format(valid_loss.value))
|
125 |
+
else:
|
126 |
+
loss_batch, preds_batch, X_batch, Y_batch, fname_batch = sess.run([model.loss, model.preds, batch[0], batch[1], batch[2]],
|
127 |
+
feed_dict={model.drop_rate: 0, model.is_training: False})
|
128 |
+
saver.save(sess, os.path.join(model_dir, "model_{}.ckpt".format(epoch)))
|
129 |
+
|
130 |
+
if args.plot_figure:
|
131 |
+
pool.starmap(
|
132 |
+
partial(
|
133 |
+
plot_waveform,
|
134 |
+
figure_dir=figure_dir,
|
135 |
+
),
|
136 |
+
zip(X_batch, preds_batch, [x.decode() for x in fname_batch], Y_batch),
|
137 |
+
)
|
138 |
+
# plot_waveform(X_batch, preds_batch, fname_batch, label=Y_batch, figure_dir=figure_dir)
|
139 |
+
flog.flush()
|
140 |
+
|
141 |
+
flog.close()
|
142 |
+
|
143 |
+
return 0
|
144 |
+
|
145 |
+
def test_fn(args, data_reader):
|
146 |
+
current_time = time.strftime("%y%m%d-%H%M%S")
|
147 |
+
logging.info("{} log: {}".format(args.mode, current_time))
|
148 |
+
if args.model_dir is None:
|
149 |
+
logging.error(f"model_dir = None!")
|
150 |
+
return -1
|
151 |
+
if not os.path.exists(args.result_dir):
|
152 |
+
os.makedirs(args.result_dir)
|
153 |
+
figure_dir=os.path.join(args.result_dir, "figures")
|
154 |
+
if not os.path.exists(figure_dir):
|
155 |
+
os.makedirs(figure_dir)
|
156 |
+
|
157 |
+
config = ModelConfig(X_shape=data_reader.X_shape, Y_shape=data_reader.Y_shape)
|
158 |
+
config.update_args(args)
|
159 |
+
with open(os.path.join(args.result_dir, 'config.log'), 'w') as fp:
|
160 |
+
fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
|
161 |
+
|
162 |
+
with tf.compat.v1.name_scope('Input_Batch'):
|
163 |
+
dataset = data_reader.dataset(args.batch_size, shuffle=False)
|
164 |
+
batch = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
|
165 |
+
|
166 |
+
model = UNet(config, input_batch=batch, mode='test')
|
167 |
+
sess_config = tf.compat.v1.ConfigProto()
|
168 |
+
sess_config.gpu_options.allow_growth = True
|
169 |
+
# sess_config.log_device_placement = False
|
170 |
+
|
171 |
+
with tf.compat.v1.Session(config=sess_config) as sess:
|
172 |
+
|
173 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
|
174 |
+
init = tf.compat.v1.global_variables_initializer()
|
175 |
+
sess.run(init)
|
176 |
+
|
177 |
+
logging.info("restoring models...")
|
178 |
+
latest_check_point = tf.train.latest_checkpoint(args.model_dir)
|
179 |
+
if latest_check_point is None:
|
180 |
+
logging.error(f"No models found in model_dir: {args.model_dir}")
|
181 |
+
return -1
|
182 |
+
saver.restore(sess, latest_check_point)
|
183 |
+
|
184 |
+
flog = open(os.path.join(args.result_dir, 'loss.log'), 'w')
|
185 |
+
test_loss = LMA()
|
186 |
+
progressbar = tqdm(range(0, data_reader.num_data, args.batch_size), desc=args.mode)
|
187 |
+
picks = []
|
188 |
+
true_picks = []
|
189 |
+
for _ in progressbar:
|
190 |
+
loss_batch, preds_batch, X_batch, Y_batch, fname_batch, itp_batch, its_batch \
|
191 |
+
= sess.run([model.loss, model.preds, batch[0], batch[1], batch[2], batch[3], batch[4]],
|
192 |
+
feed_dict={model.drop_rate: 0, model.is_training: False})
|
193 |
+
|
194 |
+
test_loss(loss_batch)
|
195 |
+
progressbar.set_description("{}, loss={:.6f}, mean loss={:6f}".format(args.mode, loss_batch, test_loss.value))
|
196 |
+
|
197 |
+
picks_ = extract_picks(preds_batch, fname_batch)
|
198 |
+
picks.extend(picks_)
|
199 |
+
true_picks.extend(convert_true_picks(fname_batch, itp_batch, its_batch))
|
200 |
+
if args.plot_figure:
|
201 |
+
plot_waveform(data_reader.config, X_batch, preds_batch, label=Y_batch, fname=fname_batch,
|
202 |
+
itp=itp_batch, its=its_batch, figure_dir=figure_dir)
|
203 |
+
|
204 |
+
save_picks(picks, args.result_dir)
|
205 |
+
metrics = calc_performance(picks, true_picks, tol=3.0, dt=data_reader.config.dt)
|
206 |
+
flog.write("mean loss: {}\n".format(test_loss))
|
207 |
+
flog.close()
|
208 |
+
|
209 |
+
return 0
|
210 |
+
|
211 |
+
def main(args):
|
212 |
+
|
213 |
+
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
|
214 |
+
coord = tf.train.Coordinator()
|
215 |
+
|
216 |
+
if (args.mode == "train") or (args.mode == "train_valid"):
|
217 |
+
with tf.compat.v1.name_scope('create_inputs'):
|
218 |
+
data_reader = DataReader_train(format=args.format,
|
219 |
+
data_dir=args.train_dir,
|
220 |
+
data_list=args.train_list)
|
221 |
+
if args.mode == "train_valid":
|
222 |
+
data_reader_valid = DataReader_train(format=args.format,
|
223 |
+
data_dir=args.valid_dir,
|
224 |
+
data_list=args.valid_list)
|
225 |
+
logging.info("Dataset size: train {}, valid {}".format(data_reader.num_data, data_reader_valid.num_data))
|
226 |
+
else:
|
227 |
+
data_reader_valid = None
|
228 |
+
logging.info("Dataset size: train {}".format(data_reader.num_data))
|
229 |
+
train_fn(args, data_reader, data_reader_valid)
|
230 |
+
|
231 |
+
elif args.mode == "test":
|
232 |
+
with tf.compat.v1.name_scope('create_inputs'):
|
233 |
+
data_reader = DataReader_test(format=args.format,
|
234 |
+
data_dir=args.test_dir,
|
235 |
+
data_list=args.test_list)
|
236 |
+
test_fn(args, data_reader)
|
237 |
+
|
238 |
+
else:
|
239 |
+
print("mode should be: train, train_valid, or test")
|
240 |
+
|
241 |
+
return
|
242 |
+
|
243 |
+
|
244 |
+
if __name__ == '__main__':
|
245 |
+
args = read_args()
|
246 |
+
main(args)
|
phasenet/util.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
import matplotlib
|
3 |
+
matplotlib.use('agg')
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
from data_reader import DataConfig
|
8 |
+
from detect_peaks import detect_peaks
|
9 |
+
import logging
|
10 |
+
|
11 |
+
class EMA(object):
|
12 |
+
def __init__(self, alpha):
|
13 |
+
self.alpha = alpha
|
14 |
+
self.x = 0.
|
15 |
+
self.count = 0
|
16 |
+
|
17 |
+
@property
|
18 |
+
def value(self):
|
19 |
+
return self.x
|
20 |
+
|
21 |
+
def __call__(self, x):
|
22 |
+
if self.count == 0:
|
23 |
+
self.x = x
|
24 |
+
else:
|
25 |
+
self.x = self.alpha * self.x + (1 - self.alpha) * x
|
26 |
+
self.count += 1
|
27 |
+
return self.x
|
28 |
+
|
29 |
+
class LMA(object):
|
30 |
+
def __init__(self):
|
31 |
+
self.x = 0.
|
32 |
+
self.count = 0
|
33 |
+
|
34 |
+
@property
|
35 |
+
def value(self):
|
36 |
+
return self.x
|
37 |
+
|
38 |
+
def __call__(self, x):
|
39 |
+
if self.count == 0:
|
40 |
+
self.x = x
|
41 |
+
else:
|
42 |
+
self.x += (x - self.x)/(self.count+1)
|
43 |
+
self.count += 1
|
44 |
+
return self.x
|
45 |
+
|
46 |
+
def detect_peaks_thread(i, pred, fname=None, result_dir=None, args=None):
|
47 |
+
if args is None:
|
48 |
+
itp, prob_p = detect_peaks(pred[i,:,0,1], mph=0.5, mpd=0.5/DataConfig().dt, show=False)
|
49 |
+
its, prob_s = detect_peaks(pred[i,:,0,2], mph=0.5, mpd=0.5/DataConfig().dt, show=False)
|
50 |
+
else:
|
51 |
+
itp, prob_p = detect_peaks(pred[i,:,0,1], mph=args.tp_prob, mpd=0.5/DataConfig().dt, show=False)
|
52 |
+
its, prob_s = detect_peaks(pred[i,:,0,2], mph=args.ts_prob, mpd=0.5/DataConfig().dt, show=False)
|
53 |
+
if (fname is not None) and (result_dir is not None):
|
54 |
+
# np.savez(os.path.join(result_dir, fname[i].decode().split('/')[-1]), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s)
|
55 |
+
try:
|
56 |
+
np.savez(os.path.join(result_dir, fname[i].decode()), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s)
|
57 |
+
except FileNotFoundError:
|
58 |
+
#if not os.path.exists(os.path.dirname(os.path.join(result_dir, fname[i].decode()))):
|
59 |
+
os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i].decode())), exist_ok=True)
|
60 |
+
np.savez(os.path.join(result_dir, fname[i].decode()), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s)
|
61 |
+
return [(itp, prob_p), (its, prob_s)]
|
62 |
+
|
63 |
+
def plot_result_thread(i, pred, X, Y=None, itp=None, its=None,
|
64 |
+
itp_pred=None, its_pred=None, fname=None, figure_dir=None):
|
65 |
+
dt = DataConfig().dt
|
66 |
+
t = np.arange(0, pred.shape[1]) * dt
|
67 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
68 |
+
text_loc = [0.05, 0.77]
|
69 |
+
|
70 |
+
plt.figure(i)
|
71 |
+
plt.clf()
|
72 |
+
# fig_size = plt.gcf().get_size_inches()
|
73 |
+
# plt.gcf().set_size_inches(fig_size*[1, 1.2])
|
74 |
+
plt.subplot(411)
|
75 |
+
plt.plot(t, X[i, :, 0, 0], 'k', label='E', linewidth=0.5)
|
76 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
77 |
+
tmp_min = np.min(X[i, :, 0, 0])
|
78 |
+
tmp_max = np.max(X[i, :, 0, 0])
|
79 |
+
if (itp is not None) and (its is not None):
|
80 |
+
for j in range(len(itp[i])):
|
81 |
+
if j == 0:
|
82 |
+
plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', label='P', linewidth=0.5)
|
83 |
+
else:
|
84 |
+
plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5)
|
85 |
+
for j in range(len(its[i])):
|
86 |
+
if j == 0:
|
87 |
+
plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', label='S', linewidth=0.5)
|
88 |
+
else:
|
89 |
+
plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5)
|
90 |
+
plt.ylabel('Amplitude')
|
91 |
+
plt.legend(loc='upper right', fontsize='small')
|
92 |
+
plt.gca().set_xticklabels([])
|
93 |
+
plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center',
|
94 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
95 |
+
plt.subplot(412)
|
96 |
+
plt.plot(t, X[i, :, 0, 1], 'k', label='N', linewidth=0.5)
|
97 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
98 |
+
tmp_min = np.min(X[i, :, 0, 1])
|
99 |
+
tmp_max = np.max(X[i, :, 0, 1])
|
100 |
+
if (itp is not None) and (its is not None):
|
101 |
+
for j in range(len(itp[i])):
|
102 |
+
plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5)
|
103 |
+
for j in range(len(its[i])):
|
104 |
+
plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5)
|
105 |
+
plt.ylabel('Amplitude')
|
106 |
+
plt.legend(loc='upper right', fontsize='small')
|
107 |
+
plt.gca().set_xticklabels([])
|
108 |
+
plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center',
|
109 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
110 |
+
plt.subplot(413)
|
111 |
+
plt.plot(t, X[i, :, 0, 2], 'k', label='Z', linewidth=0.5)
|
112 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
113 |
+
tmp_min = np.min(X[i, :, 0, 2])
|
114 |
+
tmp_max = np.max(X[i, :, 0, 2])
|
115 |
+
if (itp is not None) and (its is not None):
|
116 |
+
for j in range(len(itp[i])):
|
117 |
+
plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5)
|
118 |
+
for j in range(len(its[i])):
|
119 |
+
plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5)
|
120 |
+
plt.ylabel('Amplitude')
|
121 |
+
plt.legend(loc='upper right', fontsize='small')
|
122 |
+
plt.gca().set_xticklabels([])
|
123 |
+
plt.text(text_loc[0], text_loc[1], '(iii)', horizontalalignment='center',
|
124 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
125 |
+
plt.subplot(414)
|
126 |
+
if Y is not None:
|
127 |
+
plt.plot(t, Y[i, :, 0, 1], 'b', label='P', linewidth=0.5)
|
128 |
+
plt.plot(t, Y[i, :, 0, 2], 'r', label='S', linewidth=0.5)
|
129 |
+
plt.plot(t, pred[i, :, 0, 1], '--g', label='$\hat{P}$', linewidth=0.5)
|
130 |
+
plt.plot(t, pred[i, :, 0, 2], '-.m', label='$\hat{S}$', linewidth=0.5)
|
131 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
132 |
+
if (itp_pred is not None) and (its_pred is not None):
|
133 |
+
for j in range(len(itp_pred)):
|
134 |
+
plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--g', linewidth=0.5)
|
135 |
+
for j in range(len(its_pred)):
|
136 |
+
plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '-.m', linewidth=0.5)
|
137 |
+
plt.ylim([-0.05, 1.05])
|
138 |
+
plt.text(text_loc[0], text_loc[1], '(iv)', horizontalalignment='center',
|
139 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
140 |
+
plt.legend(loc='upper right', fontsize='small')
|
141 |
+
plt.xlabel('Time (s)')
|
142 |
+
plt.ylabel('Probability')
|
143 |
+
|
144 |
+
plt.tight_layout()
|
145 |
+
plt.gcf().align_labels()
|
146 |
+
|
147 |
+
try:
|
148 |
+
plt.savefig(os.path.join(figure_dir,
|
149 |
+
fname[i].decode().rstrip('.npz')+'.png'),
|
150 |
+
bbox_inches='tight')
|
151 |
+
except FileNotFoundError:
|
152 |
+
#if not os.path.exists(os.path.dirname(os.path.join(figure_dir, fname[i].decode()))):
|
153 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i].decode())), exist_ok=True)
|
154 |
+
plt.savefig(os.path.join(figure_dir,
|
155 |
+
fname[i].decode().rstrip('.npz')+'.png'),
|
156 |
+
bbox_inches='tight')
|
157 |
+
#plt.savefig(os.path.join(figure_dir,
|
158 |
+
# fname[i].decode().split('/')[-1].rstrip('.npz')+'.png'),
|
159 |
+
# bbox_inches='tight')
|
160 |
+
# plt.savefig(os.path.join(figure_dir,
|
161 |
+
# fname[i].decode().split('/')[-1].rstrip('.npz')+'.pdf'),
|
162 |
+
# bbox_inches='tight')
|
163 |
+
plt.close(i)
|
164 |
+
return 0
|
165 |
+
|
166 |
+
def postprocessing_thread(i, pred, X, Y=None, itp=None, its=None, fname=None, result_dir=None, figure_dir=None, args=None):
|
167 |
+
(itp_pred, prob_p), (its_pred, prob_s) = detect_peaks_thread(i, pred, fname, result_dir, args)
|
168 |
+
if (fname is not None) and (figure_dir is not None):
|
169 |
+
plot_result_thread(i, pred, X, Y, itp, its, itp_pred, its_pred, fname, figure_dir)
|
170 |
+
return [(itp_pred, prob_p), (its_pred, prob_s)]
|
171 |
+
|
172 |
+
|
173 |
+
def clean_queue(picks):
|
174 |
+
clean = []
|
175 |
+
for i in range(len(picks)):
|
176 |
+
tmp = []
|
177 |
+
for j in picks[i]:
|
178 |
+
if j != 0:
|
179 |
+
tmp.append(j)
|
180 |
+
clean.append(tmp)
|
181 |
+
return clean
|
182 |
+
|
183 |
+
def clean_queue_thread(picks):
|
184 |
+
tmp = []
|
185 |
+
for j in picks:
|
186 |
+
if j != 0:
|
187 |
+
tmp.append(j)
|
188 |
+
return tmp
|
189 |
+
|
190 |
+
|
191 |
+
def metrics(TP, nP, nT):
|
192 |
+
'''
|
193 |
+
TP: true positive
|
194 |
+
nP: number of positive picks
|
195 |
+
nT: number of true picks
|
196 |
+
'''
|
197 |
+
precision = TP / nP
|
198 |
+
recall = TP / nT
|
199 |
+
F1 = 2* precision * recall / (precision + recall)
|
200 |
+
return [precision, recall, F1]
|
201 |
+
|
202 |
+
def correct_picks(picks, true_p, true_s, tol):
|
203 |
+
dt = DataConfig().dt
|
204 |
+
if len(true_p) != len(true_s):
|
205 |
+
print("The length of true P and S pickers are not the same")
|
206 |
+
num = len(true_p)
|
207 |
+
TP_p = 0; TP_s = 0; nP_p = 0; nP_s = 0; nT_p = 0; nT_s = 0
|
208 |
+
diff_p = []; diff_s = []
|
209 |
+
for i in range(num):
|
210 |
+
nT_p += len(true_p[i])
|
211 |
+
nT_s += len(true_s[i])
|
212 |
+
nP_p += len(picks[i][0][0])
|
213 |
+
nP_s += len(picks[i][1][0])
|
214 |
+
|
215 |
+
if len(true_p[i]) > 1 or len(true_s[i]) > 1:
|
216 |
+
print(i, picks[i], true_p[i], true_s[i])
|
217 |
+
tmp_p = np.array(picks[i][0][0]) - np.array(true_p[i])[:,np.newaxis]
|
218 |
+
tmp_s = np.array(picks[i][1][0]) - np.array(true_s[i])[:,np.newaxis]
|
219 |
+
TP_p += np.sum(np.abs(tmp_p) < tol/dt)
|
220 |
+
TP_s += np.sum(np.abs(tmp_s) < tol/dt)
|
221 |
+
diff_p.append(tmp_p[np.abs(tmp_p) < 0.5/dt])
|
222 |
+
diff_s.append(tmp_s[np.abs(tmp_s) < 0.5/dt])
|
223 |
+
|
224 |
+
return [TP_p, TP_s, nP_p, nP_s, nT_p, nT_s, diff_p, diff_s]
|
225 |
+
|
226 |
+
def calculate_metrics(picks, itp, its, tol=0.1):
|
227 |
+
TP_p, TP_s, nP_p, nP_s, nT_p, nT_s, diff_p, diff_s = correct_picks(picks, itp, its, tol)
|
228 |
+
precision_p, recall_p, f1_p = metrics(TP_p, nP_p, nT_p)
|
229 |
+
precision_s, recall_s, f1_s = metrics(TP_s, nP_s, nT_s)
|
230 |
+
|
231 |
+
logging.info("Total records: {}".format(len(picks)))
|
232 |
+
logging.info("P-phase:")
|
233 |
+
logging.info("True={}, Predict={}, TruePositive={}".format(nT_p, nP_p, TP_p))
|
234 |
+
logging.info("Precision={:.3f}, Recall={:.3f}, F1={:.3f}".format(precision_p, recall_p, f1_p))
|
235 |
+
logging.info("S-phase:")
|
236 |
+
logging.info("True={}, Predict={}, TruePositive={}".format(nT_s, nP_s, TP_s))
|
237 |
+
logging.info("Precision={:.3f}, Recall={:.3f}, F1={:.3f}".format(precision_s, recall_s, f1_s))
|
238 |
+
return [precision_p, recall_p, f1_p], [precision_s, recall_s, f1_s]
|
phasenet/visulization.py
ADDED
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
matplotlib.use("agg")
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
|
7 |
+
|
8 |
+
def plot_residual(diff_p, diff_s, diff_ps, tol, dt):
|
9 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
10 |
+
text_loc = [0.07, 0.95]
|
11 |
+
plt.figure(figsize=(8,3))
|
12 |
+
plt.subplot(1,3,1)
|
13 |
+
plt.hist(diff_p, range=(-tol, tol), bins=int(2*tol/dt)+1, facecolor='b', edgecolor='black', linewidth=1)
|
14 |
+
plt.ylabel("Number of picks")
|
15 |
+
plt.xlabel("Residual (s)")
|
16 |
+
plt.text(text_loc[0], text_loc[1], "(i)", horizontalalignment='left', verticalalignment='top',
|
17 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
18 |
+
plt.title("P-phase")
|
19 |
+
plt.subplot(1,3,2)
|
20 |
+
plt.hist(diff_s, range=(-tol, tol), bins=int(2*tol/dt)+1, facecolor='b', edgecolor='black', linewidth=1)
|
21 |
+
plt.xlabel("Residual (s)")
|
22 |
+
plt.text(text_loc[0], text_loc[1], "(ii)", horizontalalignment='left', verticalalignment='top',
|
23 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
24 |
+
plt.title("S-phase")
|
25 |
+
plt.subplot(1,3,3)
|
26 |
+
plt.hist(diff_ps, range=(-tol, tol), bins=int(2*tol/dt)+1, facecolor='b', edgecolor='black', linewidth=1)
|
27 |
+
plt.xlabel("Residual (s)")
|
28 |
+
plt.text(text_loc[0], text_loc[1], "(iii)", horizontalalignment='left', verticalalignment='top',
|
29 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
30 |
+
plt.title("PS-phase")
|
31 |
+
plt.tight_layout()
|
32 |
+
plt.savefig("residuals.png", dpi=300)
|
33 |
+
plt.savefig("residuals.pdf")
|
34 |
+
|
35 |
+
|
36 |
+
# def plot_waveform(config, data, pred, label=None,
|
37 |
+
# itp=None, its=None, itps=None,
|
38 |
+
# itp_pred=None, its_pred=None, itps_pred=None,
|
39 |
+
# fname=None, figure_dir="./", epoch=0, max_fig=10):
|
40 |
+
|
41 |
+
# dt = config.dt if hasattr(config, "dt") else 1.0
|
42 |
+
# t = np.arange(0, pred.shape[1]) * dt
|
43 |
+
# box = dict(boxstyle='round', facecolor='white', alpha=1)
|
44 |
+
# text_loc = [0.05, 0.77]
|
45 |
+
# if fname is None:
|
46 |
+
# fname = [f"{epoch:03d}_{i:02d}" for i in range(len(data))]
|
47 |
+
# else:
|
48 |
+
# fname = [fname[i].decode().rstrip(".npz") for i in range(len(fname))]
|
49 |
+
|
50 |
+
# for i in range(min(len(data), max_fig)):
|
51 |
+
# plt.figure(i)
|
52 |
+
|
53 |
+
# plt.subplot(411)
|
54 |
+
# plt.plot(t, data[i, :, 0, 0], 'k', label='E', linewidth=0.5)
|
55 |
+
# plt.autoscale(enable=True, axis='x', tight=True)
|
56 |
+
# tmp_min = np.min(data[i, :, 0, 0])
|
57 |
+
# tmp_max = np.max(data[i, :, 0, 0])
|
58 |
+
# if (itp is not None) and (its is not None):
|
59 |
+
# for j in range(len(itp[i])):
|
60 |
+
# lb = "P" if j==0 else ""
|
61 |
+
# plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
62 |
+
# for j in range(len(its[i])):
|
63 |
+
# lb = "S" if j==0 else ""
|
64 |
+
# plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
65 |
+
# if (itps is not None):
|
66 |
+
# for j in range(len(itps[i])):
|
67 |
+
# lb = "PS" if j==0 else ""
|
68 |
+
# plt.plot([itps[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
69 |
+
# plt.ylabel('Amplitude')
|
70 |
+
# plt.legend(loc='upper right', fontsize='small')
|
71 |
+
# plt.gca().set_xticklabels([])
|
72 |
+
# plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center',
|
73 |
+
# transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
74 |
+
|
75 |
+
# plt.subplot(412)
|
76 |
+
# plt.plot(t, data[i, :, 0, 1], 'k', label='N', linewidth=0.5)
|
77 |
+
# plt.autoscale(enable=True, axis='x', tight=True)
|
78 |
+
# tmp_min = np.min(data[i, :, 0, 1])
|
79 |
+
# tmp_max = np.max(data[i, :, 0, 1])
|
80 |
+
# if (itp is not None) and (its is not None):
|
81 |
+
# for j in range(len(itp[i])):
|
82 |
+
# lb = "P" if j==0 else ""
|
83 |
+
# plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
84 |
+
# for j in range(len(its[i])):
|
85 |
+
# lb = "S" if j==0 else ""
|
86 |
+
# plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
87 |
+
# if (itps is not None):
|
88 |
+
# for j in range(len(itps[i])):
|
89 |
+
# lb = "PS" if j==0 else ""
|
90 |
+
# plt.plot([itps[i][j]*dt, itps[i][j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
91 |
+
# plt.ylabel('Amplitude')
|
92 |
+
# plt.legend(loc='upper right', fontsize='small')
|
93 |
+
# plt.gca().set_xticklabels([])
|
94 |
+
# plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center',
|
95 |
+
# transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
96 |
+
|
97 |
+
# plt.subplot(413)
|
98 |
+
# plt.plot(t, data[i, :, 0, 2], 'k', label='Z', linewidth=0.5)
|
99 |
+
# plt.autoscale(enable=True, axis='x', tight=True)
|
100 |
+
# tmp_min = np.min(data[i, :, 0, 2])
|
101 |
+
# tmp_max = np.max(data[i, :, 0, 2])
|
102 |
+
# if (itp is not None) and (its is not None):
|
103 |
+
# for j in range(len(itp[i])):
|
104 |
+
# lb = "P" if j==0 else ""
|
105 |
+
# plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
106 |
+
# for j in range(len(its[i])):
|
107 |
+
# lb = "S" if j==0 else ""
|
108 |
+
# plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
109 |
+
# if (itps is not None):
|
110 |
+
# for j in range(len(itps[i])):
|
111 |
+
# lb = "PS" if j==0 else ""
|
112 |
+
# plt.plot([itps[i][j]*dt, itps[i][j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
113 |
+
# plt.ylabel('Amplitude')
|
114 |
+
# plt.legend(loc='upper right', fontsize='small')
|
115 |
+
# plt.gca().set_xticklabels([])
|
116 |
+
# plt.text(text_loc[0], text_loc[1], '(iii)', horizontalalignment='center',
|
117 |
+
# transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
118 |
+
|
119 |
+
# plt.subplot(414)
|
120 |
+
# if label is not None:
|
121 |
+
# plt.plot(t, label[i, :, 0, 1], 'C0', label='P', linewidth=1)
|
122 |
+
# plt.plot(t, label[i, :, 0, 2], 'C1', label='S', linewidth=1)
|
123 |
+
# if label.shape[-1] == 4:
|
124 |
+
# plt.plot(t, label[i, :, 0, 3], 'C2', label='PS', linewidth=1)
|
125 |
+
# plt.plot(t, pred[i, :, 0, 1], '--C0', label='$\hat{P}$', linewidth=1)
|
126 |
+
# plt.plot(t, pred[i, :, 0, 2], '--C1', label='$\hat{S}$', linewidth=1)
|
127 |
+
# if pred.shape[-1] == 4:
|
128 |
+
# plt.plot(t, pred[i, :, 0, 3], '--C2', label='$\hat{PS}$', linewidth=1)
|
129 |
+
# plt.autoscale(enable=True, axis='x', tight=True)
|
130 |
+
# if (itp_pred is not None) and (its_pred is not None) :
|
131 |
+
# for j in range(len(itp_pred)):
|
132 |
+
# plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--C0', linewidth=1)
|
133 |
+
# for j in range(len(its_pred)):
|
134 |
+
# plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '--C1', linewidth=1)
|
135 |
+
# if (itps_pred is not None):
|
136 |
+
# for j in range(len(itps_pred)):
|
137 |
+
# plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C2', linewidth=1)
|
138 |
+
# plt.ylim([-0.05, 1.05])
|
139 |
+
# plt.text(text_loc[0], text_loc[1], '(iv)', horizontalalignment='center',
|
140 |
+
# transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
141 |
+
# plt.legend(loc='upper right', fontsize='small', ncol=2)
|
142 |
+
# plt.xlabel('Time (s)')
|
143 |
+
# plt.ylabel('Probability')
|
144 |
+
# plt.tight_layout()
|
145 |
+
# plt.gcf().align_labels()
|
146 |
+
|
147 |
+
# try:
|
148 |
+
# plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
|
149 |
+
# except FileNotFoundError:
|
150 |
+
# os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
|
151 |
+
# plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
|
152 |
+
|
153 |
+
# plt.close(i)
|
154 |
+
# return 0
|
155 |
+
|
156 |
+
|
157 |
+
def plot_waveform(data, pred, fname, label=None,
|
158 |
+
itp=None, its=None, itps=None,
|
159 |
+
itp_pred=None, its_pred=None, itps_pred=None,
|
160 |
+
figure_dir="./", dt=0.01):
|
161 |
+
|
162 |
+
t = np.arange(0, pred.shape[0]) * dt
|
163 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
164 |
+
text_loc = [0.05, 0.77]
|
165 |
+
|
166 |
+
plt.figure()
|
167 |
+
|
168 |
+
plt.subplot(411)
|
169 |
+
plt.plot(t, data[:, 0, 0], 'k', label='E', linewidth=0.5)
|
170 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
171 |
+
tmp_min = np.min(data[:, 0, 0])
|
172 |
+
tmp_max = np.max(data[:, 0, 0])
|
173 |
+
if (itp is not None) and (its is not None):
|
174 |
+
for j in range(len(itp)):
|
175 |
+
lb = "P" if j==0 else ""
|
176 |
+
plt.plot([itp[j]*dt, itp[j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
177 |
+
for j in range(len(its[i])):
|
178 |
+
lb = "S" if j==0 else ""
|
179 |
+
plt.plot([its[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
180 |
+
if (itps is not None):
|
181 |
+
for j in range(len(itps)):
|
182 |
+
lb = "PS" if j==0 else ""
|
183 |
+
plt.plot([itps[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
184 |
+
plt.ylabel('Amplitude')
|
185 |
+
plt.legend(loc='upper right', fontsize='small')
|
186 |
+
plt.gca().set_xticklabels([])
|
187 |
+
plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center',
|
188 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
189 |
+
|
190 |
+
plt.subplot(412)
|
191 |
+
plt.plot(t, data[:, 0, 1], 'k', label='N', linewidth=0.5)
|
192 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
193 |
+
tmp_min = np.min(data[:, 0, 1])
|
194 |
+
tmp_max = np.max(data[:, 0, 1])
|
195 |
+
if (itp is not None) and (its is not None):
|
196 |
+
for j in range(len(itp)):
|
197 |
+
lb = "P" if j==0 else ""
|
198 |
+
plt.plot([itp[j]*dt, itp[j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
199 |
+
for j in range(len(its)):
|
200 |
+
lb = "S" if j==0 else ""
|
201 |
+
plt.plot([its[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
202 |
+
if (itps is not None):
|
203 |
+
for j in range(len(itps)):
|
204 |
+
lb = "PS" if j==0 else ""
|
205 |
+
plt.plot([itps[j]*dt, itps[j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
206 |
+
plt.ylabel('Amplitude')
|
207 |
+
plt.legend(loc='upper right', fontsize='small')
|
208 |
+
plt.gca().set_xticklabels([])
|
209 |
+
plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center',
|
210 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
211 |
+
|
212 |
+
plt.subplot(413)
|
213 |
+
plt.plot(t, data[:, 0, 2], 'k', label='Z', linewidth=0.5)
|
214 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
215 |
+
tmp_min = np.min(data[:, 0, 2])
|
216 |
+
tmp_max = np.max(data[:, 0, 2])
|
217 |
+
if (itp is not None) and (its is not None):
|
218 |
+
for j in range(len(itp)):
|
219 |
+
lb = "P" if j==0 else ""
|
220 |
+
plt.plot([itp[j]*dt, itp[j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
221 |
+
for j in range(len(its)):
|
222 |
+
lb = "S" if j==0 else ""
|
223 |
+
plt.plot([its[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
224 |
+
if (itps is not None):
|
225 |
+
for j in range(len(itps)):
|
226 |
+
lb = "PS" if j==0 else ""
|
227 |
+
plt.plot([itps[j]*dt, itps[j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
228 |
+
plt.ylabel('Amplitude')
|
229 |
+
plt.legend(loc='upper right', fontsize='small')
|
230 |
+
plt.gca().set_xticklabels([])
|
231 |
+
plt.text(text_loc[0], text_loc[1], '(iii)', horizontalalignment='center',
|
232 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
233 |
+
|
234 |
+
plt.subplot(414)
|
235 |
+
if label is not None:
|
236 |
+
plt.plot(t, label[:, 0, 1], 'C0', label='P', linewidth=1)
|
237 |
+
plt.plot(t, label[:, 0, 2], 'C1', label='S', linewidth=1)
|
238 |
+
if label.shape[-1] == 4:
|
239 |
+
plt.plot(t, label[:, 0, 3], 'C2', label='PS', linewidth=1)
|
240 |
+
plt.plot(t, pred[:, 0, 1], '--C0', label='$\hat{P}$', linewidth=1)
|
241 |
+
plt.plot(t, pred[:, 0, 2], '--C1', label='$\hat{S}$', linewidth=1)
|
242 |
+
if pred.shape[-1] == 4:
|
243 |
+
plt.plot(t, pred[:, 0, 3], '--C2', label='$\hat{PS}$', linewidth=1)
|
244 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
245 |
+
if (itp_pred is not None) and (its_pred is not None) :
|
246 |
+
for j in range(len(itp_pred)):
|
247 |
+
plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--C0', linewidth=1)
|
248 |
+
for j in range(len(its_pred)):
|
249 |
+
plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '--C1', linewidth=1)
|
250 |
+
if (itps_pred is not None):
|
251 |
+
for j in range(len(itps_pred)):
|
252 |
+
plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C2', linewidth=1)
|
253 |
+
plt.ylim([-0.05, 1.05])
|
254 |
+
plt.text(text_loc[0], text_loc[1], '(iv)', horizontalalignment='center',
|
255 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
256 |
+
plt.legend(loc='upper right', fontsize='small', ncol=2)
|
257 |
+
plt.xlabel('Time (s)')
|
258 |
+
plt.ylabel('Probability')
|
259 |
+
plt.tight_layout()
|
260 |
+
plt.gcf().align_labels()
|
261 |
+
|
262 |
+
try:
|
263 |
+
plt.savefig(os.path.join(figure_dir, fname+'.png'), bbox_inches='tight')
|
264 |
+
except FileNotFoundError:
|
265 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname)), exist_ok=True)
|
266 |
+
plt.savefig(os.path.join(figure_dir, fname+'.png'), bbox_inches='tight')
|
267 |
+
|
268 |
+
plt.close()
|
269 |
+
return 0
|
270 |
+
|
271 |
+
|
272 |
+
def plot_array(config, data, pred, label=None,
|
273 |
+
itp=None, its=None, itps=None,
|
274 |
+
itp_pred=None, its_pred=None, itps_pred=None,
|
275 |
+
fname=None, figure_dir="./", epoch=0):
|
276 |
+
|
277 |
+
dt = config.dt if hasattr(config, "dt") else 1.0
|
278 |
+
t = np.arange(0, pred.shape[1]) * dt
|
279 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
280 |
+
text_loc = [0.05, 0.95]
|
281 |
+
if fname is None:
|
282 |
+
fname = [f"{epoch:03d}_{i:03d}" for i in range(len(data))]
|
283 |
+
else:
|
284 |
+
fname = [fname[i].decode().rstrip(".npz") for i in range(len(fname))]
|
285 |
+
|
286 |
+
for i in range(len(data)):
|
287 |
+
plt.figure(i, figsize=(10, 5))
|
288 |
+
plt.clf()
|
289 |
+
|
290 |
+
plt.subplot(121)
|
291 |
+
for j in range(data.shape[-2]):
|
292 |
+
plt.plot(t, data[i, :, j, 0]/10 + j, 'k', label='E', linewidth=0.5)
|
293 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
294 |
+
tmp_min = np.min(data[i, :, 0, 0])
|
295 |
+
tmp_max = np.max(data[i, :, 0, 0])
|
296 |
+
plt.xlabel('Time (s)')
|
297 |
+
plt.ylabel('Amplitude')
|
298 |
+
# plt.legend(loc='upper right', fontsize='small')
|
299 |
+
# plt.gca().set_xticklabels([])
|
300 |
+
plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center', verticalalignment="top",
|
301 |
+
transform=plt.gca().transAxes, fontsize="large", fontweight="normal", bbox=box)
|
302 |
+
|
303 |
+
plt.subplot(122)
|
304 |
+
for j in range(pred.shape[-2]):
|
305 |
+
if label is not None:
|
306 |
+
plt.plot(t, label[i, :, j, 1]+j, 'C2', label='P', linewidth=0.5)
|
307 |
+
plt.plot(t, label[i, :, j, 2]+j, 'C3', label='S', linewidth=0.5)
|
308 |
+
# plt.plot(t, label[i, :, j, 0]+j, 'C4', label='N', linewidth=0.5)
|
309 |
+
plt.plot(t, pred[i, :, j, 1]+j, 'C0', label='$\hat{P}$', linewidth=1)
|
310 |
+
plt.plot(t, pred[i, :, j, 2]+j, 'C1', label='$\hat{S}$', linewidth=1)
|
311 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
312 |
+
if (itp_pred is not None) and (its_pred is not None) and (itps_pred is not None):
|
313 |
+
for j in range(len(itp_pred)):
|
314 |
+
plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--C0', linewidth=1)
|
315 |
+
for j in range(len(its_pred)):
|
316 |
+
plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '--C1', linewidth=1)
|
317 |
+
for j in range(len(itps_pred)):
|
318 |
+
plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C2', linewidth=1)
|
319 |
+
# plt.ylim([-0.05, 1.05])
|
320 |
+
plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center', verticalalignment="top",
|
321 |
+
transform=plt.gca().transAxes, fontsize="large", fontweight="normal", bbox=box)
|
322 |
+
# plt.legend(loc='upper right', fontsize='small', ncol=2)
|
323 |
+
plt.xlabel('Time (s)')
|
324 |
+
plt.ylabel('Probability')
|
325 |
+
plt.tight_layout()
|
326 |
+
plt.gcf().align_labels()
|
327 |
+
|
328 |
+
try:
|
329 |
+
plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
|
330 |
+
except FileNotFoundError:
|
331 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
|
332 |
+
plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
|
333 |
+
|
334 |
+
plt.close(i)
|
335 |
+
return 0
|
336 |
+
|
337 |
+
|
338 |
+
def plot_spectrogram(config, data, pred, label=None,
|
339 |
+
itp=None, its=None, itps=None,
|
340 |
+
itp_pred=None, its_pred=None, itps_pred=None,
|
341 |
+
time=None, freq=None,
|
342 |
+
fname=None, figure_dir="./", epoch=0):
|
343 |
+
|
344 |
+
# dt = config.dt
|
345 |
+
# df = config.df
|
346 |
+
# t = np.arange(0, data.shape[1]) * dt
|
347 |
+
# f = np.arange(0, data.shape[2]) * df
|
348 |
+
t, f = time, freq
|
349 |
+
dt = t[1] - t[0]
|
350 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
351 |
+
text_loc = [0.05, 0.75]
|
352 |
+
if fname is None:
|
353 |
+
fname = [f"{i:03d}" for i in range(len(data))]
|
354 |
+
elif type(fname[0]) is bytes:
|
355 |
+
fname = [f.decode() for f in fname]
|
356 |
+
|
357 |
+
numbers = ["(i)", "(ii)", "(iii)", "(iv)"]
|
358 |
+
for i in range(len(data)):
|
359 |
+
fig = plt.figure(i)
|
360 |
+
# gs = fig.add_gridspec(4, 1)
|
361 |
+
|
362 |
+
for j in range(3):
|
363 |
+
# fig.add_subplot(gs[j, 0])
|
364 |
+
plt.subplot(4,1,j+1)
|
365 |
+
plt.pcolormesh(t, f, np.abs(data[i, :, :, j]+1j*data[i, :, :, j+3]).T, vmax=2*np.std(data[i, :, :, j]+1j*data[i, :, :, j+3]), cmap="jet", shading='auto')
|
366 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
367 |
+
plt.gca().set_xticklabels([])
|
368 |
+
if j == 1:
|
369 |
+
plt.ylabel('Frequency (Hz)')
|
370 |
+
plt.text(text_loc[0], text_loc[1], numbers[j], horizontalalignment='center',
|
371 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
372 |
+
|
373 |
+
# fig.add_subplot(gs[-1, 0])
|
374 |
+
plt.subplot(4,1,4)
|
375 |
+
if label is not None:
|
376 |
+
plt.plot(t, label[i, :, 0, 1], '--C0', linewidth=1)
|
377 |
+
plt.plot(t, label[i, :, 0, 2], '--C3', linewidth=1)
|
378 |
+
plt.plot(t, label[i, :, 0, 3], '--C1', linewidth=1)
|
379 |
+
plt.plot(t, pred[i, :, 0, 1], 'C0', label='P', linewidth=1)
|
380 |
+
plt.plot(t, pred[i, :, 0, 2], 'C3', label='S', linewidth=1)
|
381 |
+
plt.plot(t, pred[i, :, 0, 3], 'C1', label='PS', linewidth=1)
|
382 |
+
plt.plot(t, t*0, 'k', linewidth=1)
|
383 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
384 |
+
if (itp_pred is not None) and (its_pred is not None) and (itps_pred is not None):
|
385 |
+
for j in range(len(itp_pred)):
|
386 |
+
plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], ':C3', linewidth=1)
|
387 |
+
for j in range(len(its_pred)):
|
388 |
+
plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '-.C6', linewidth=1)
|
389 |
+
for j in range(len(itps_pred)):
|
390 |
+
plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C8', linewidth=1)
|
391 |
+
plt.ylim([-0.05, 1.05])
|
392 |
+
plt.text(text_loc[0], text_loc[1], numbers[-1], horizontalalignment='center',
|
393 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
394 |
+
plt.legend(loc='upper right', fontsize='small', ncol=1)
|
395 |
+
plt.xlabel('Time (s)')
|
396 |
+
plt.ylabel('Probability')
|
397 |
+
# plt.tight_layout()
|
398 |
+
plt.gcf().align_labels()
|
399 |
+
|
400 |
+
try:
|
401 |
+
plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
|
402 |
+
except FileNotFoundError:
|
403 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
|
404 |
+
plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
|
405 |
+
|
406 |
+
plt.close(i)
|
407 |
+
return 0
|
408 |
+
|
409 |
+
|
410 |
+
def plot_spectrogram_waveform(config, spectrogram, waveform, pred, label=None,
|
411 |
+
itp=None, its=None, itps=None, picks=None,
|
412 |
+
time=None, freq=None,
|
413 |
+
fname=None, figure_dir="./", epoch=0):
|
414 |
+
|
415 |
+
# dt = config.dt
|
416 |
+
# df = config.df
|
417 |
+
# t = np.arange(0, spectrogram.shape[1]) * dt
|
418 |
+
# f = np.arange(0, spectrogram.shape[2]) * df
|
419 |
+
t, f = time, freq
|
420 |
+
dt = t[1] - t[0]
|
421 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
422 |
+
text_loc = [0.02, 0.90]
|
423 |
+
if fname is None:
|
424 |
+
fname = [f"{i:03d}" for i in range(len(spectrogram))]
|
425 |
+
elif type(fname[0]) is bytes:
|
426 |
+
fname = [f.decode() for f in fname]
|
427 |
+
|
428 |
+
numbers = ["(i)", "(ii)", "(iii)", "(iv)", "(v)", "(vi)", "(vii)"]
|
429 |
+
for i in range(len(spectrogram)):
|
430 |
+
fig = plt.figure(i, figsize=(6.4, 10))
|
431 |
+
# gs = fig.add_gridspec(4, 1)
|
432 |
+
|
433 |
+
for j in range(3):
|
434 |
+
# fig.add_subplot(gs[j, 0])
|
435 |
+
plt.subplot(7,1,j*2+1)
|
436 |
+
plt.plot(waveform[i,:,j], 'k', linewidth=0.5)
|
437 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
438 |
+
plt.gca().set_xticklabels([])
|
439 |
+
plt.ylabel('')
|
440 |
+
plt.text(text_loc[0], text_loc[1], numbers[j*2], horizontalalignment='left', verticalalignment='top',
|
441 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
442 |
+
|
443 |
+
for j in range(3):
|
444 |
+
# fig.add_subplot(gs[j, 0])
|
445 |
+
plt.subplot(7,1,j*2+2)
|
446 |
+
plt.pcolormesh(t, f, np.abs(spectrogram[i, :, :, j]+1j*spectrogram[i, :, :, j+3]).T, vmax=2*np.std(spectrogram[i, :, :, j]+1j*spectrogram[i, :, :, j+3]), cmap="jet", shading='auto')
|
447 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
448 |
+
plt.gca().set_xticklabels([])
|
449 |
+
if j == 1:
|
450 |
+
plt.ylabel('Frequency (Hz) or Amplitude')
|
451 |
+
plt.text(text_loc[0], text_loc[1], numbers[j*2+1], horizontalalignment='left', verticalalignment='top',
|
452 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
453 |
+
|
454 |
+
# fig.add_subplot(gs[-1, 0])
|
455 |
+
plt.subplot(7,1,7)
|
456 |
+
if label is not None:
|
457 |
+
plt.plot(t, label[i, :, 0, 1], '--C0', linewidth=1)
|
458 |
+
plt.plot(t, label[i, :, 0, 2], '--C3', linewidth=1)
|
459 |
+
plt.plot(t, label[i, :, 0, 3], '--C1', linewidth=1)
|
460 |
+
plt.plot(t, pred[i, :, 0, 1], 'C0', label='P', linewidth=1)
|
461 |
+
plt.plot(t, pred[i, :, 0, 2], 'C3', label='S', linewidth=1)
|
462 |
+
plt.plot(t, pred[i, :, 0, 3], 'C1', label='PS', linewidth=1)
|
463 |
+
plt.plot(t, t*0, 'k', linewidth=1)
|
464 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
465 |
+
plt.ylim([-0.05, 1.05])
|
466 |
+
plt.text(text_loc[0], text_loc[1], numbers[-1], horizontalalignment='left', verticalalignment='top',
|
467 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
468 |
+
plt.legend(loc='upper right', fontsize='small', ncol=1)
|
469 |
+
plt.xlabel('Time (s)')
|
470 |
+
plt.ylabel('Probability')
|
471 |
+
# plt.tight_layout()
|
472 |
+
plt.gcf().align_labels()
|
473 |
+
|
474 |
+
try:
|
475 |
+
plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
|
476 |
+
except FileNotFoundError:
|
477 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
|
478 |
+
plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
|
479 |
+
|
480 |
+
plt.close(i)
|
481 |
+
return 0
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tensorflow
|
2 |
+
matplotlib
|
3 |
+
pandas
|
4 |
+
tqdm
|
5 |
+
scipy
|
6 |
+
obspy
|
7 |
+
|
setup.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import sys
|
5 |
+
from shutil import rmtree
|
6 |
+
from typing import Tuple, List
|
7 |
+
|
8 |
+
from setuptools import Command, find_packages, setup
|
9 |
+
|
10 |
+
# Package meta-data.
|
11 |
+
name = "PhaseNet"
|
12 |
+
description = "PhaseNet"
|
13 |
+
url = ""
|
14 |
+
email = "wayne.weiqiang@gmail.com"
|
15 |
+
author = "Weiqiang Zhu"
|
16 |
+
requires_python = ">=3.6.0"
|
17 |
+
current_dir = os.path.abspath(os.path.dirname(__file__))
|
18 |
+
|
19 |
+
|
20 |
+
def get_version():
|
21 |
+
version_file = os.path.join(current_dir, "phasenet", "__init__.py")
|
22 |
+
with io.open(version_file, encoding="utf-8") as f:
|
23 |
+
return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', f.read(), re.M).group(1)
|
24 |
+
|
25 |
+
|
26 |
+
# What packages are required for this module to be executed?
|
27 |
+
try:
|
28 |
+
with open(os.path.join(current_dir, "requirements.txt"), encoding="utf-8") as f:
|
29 |
+
required = f.read().split("\n")
|
30 |
+
except FileNotFoundError:
|
31 |
+
required = []
|
32 |
+
|
33 |
+
# What packages are optional?
|
34 |
+
extras = {"test": ["pytest"]}
|
35 |
+
|
36 |
+
version = get_version()
|
37 |
+
|
38 |
+
about = {"__version__": version}
|
39 |
+
|
40 |
+
|
41 |
+
def get_test_requirements():
|
42 |
+
requirements = ["pytest"]
|
43 |
+
if sys.version_info < (3, 3):
|
44 |
+
requirements.append("mock")
|
45 |
+
return requirements
|
46 |
+
|
47 |
+
|
48 |
+
def get_long_description():
|
49 |
+
# base_dir = os.path.abspath(os.path.dirname(__file__))
|
50 |
+
# with io.open(os.path.join(base_dir, "README.md"), encoding="utf-8") as f:
|
51 |
+
# return f.read()
|
52 |
+
return ""
|
53 |
+
|
54 |
+
|
55 |
+
class UploadCommand(Command):
|
56 |
+
"""Support setup.py upload."""
|
57 |
+
|
58 |
+
description = "Build and publish the package."
|
59 |
+
user_options: List[Tuple] = []
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def status(s):
|
63 |
+
"""Print things in bold."""
|
64 |
+
print(s)
|
65 |
+
|
66 |
+
def initialize_options(self):
|
67 |
+
pass
|
68 |
+
|
69 |
+
def finalize_options(self):
|
70 |
+
pass
|
71 |
+
|
72 |
+
def run(self):
|
73 |
+
try:
|
74 |
+
self.status("Removing previous builds...")
|
75 |
+
rmtree(os.path.join(current_dir, "dist"))
|
76 |
+
except OSError:
|
77 |
+
pass
|
78 |
+
|
79 |
+
self.status("Building Source and Wheel (universal) distribution...")
|
80 |
+
os.system(f"{sys.executable} setup.py sdist bdist_wheel --universal")
|
81 |
+
|
82 |
+
self.status("Uploading the package to PyPI via Twine...")
|
83 |
+
os.system("twine upload dist/*")
|
84 |
+
|
85 |
+
self.status("Pushing git tags...")
|
86 |
+
os.system("git tag v{}".format(about["__version__"]))
|
87 |
+
os.system("git push --tags")
|
88 |
+
|
89 |
+
sys.exit()
|
90 |
+
|
91 |
+
|
92 |
+
setup(
|
93 |
+
name=name,
|
94 |
+
version=version,
|
95 |
+
description=description,
|
96 |
+
long_description=get_long_description(),
|
97 |
+
long_description_content_type="text/markdown",
|
98 |
+
author="Weiqiang Zhu",
|
99 |
+
author_email = "wayne.weiqiang@gmail.com",
|
100 |
+
license="GPL-3.0",
|
101 |
+
url=url,
|
102 |
+
packages=find_packages(exclude=["tests", "docs", "dataset", "model", "log"]),
|
103 |
+
install_requires=required,
|
104 |
+
extras_require=extras,
|
105 |
+
classifiers=[
|
106 |
+
"License :: OSI Approved :: BSD License",
|
107 |
+
"Intended Audience :: Developers",
|
108 |
+
"Intended Audience :: Science/Research",
|
109 |
+
"Operating System :: OS Independent",
|
110 |
+
"Programming Language :: Python",
|
111 |
+
"Programming Language :: Python :: 3",
|
112 |
+
"Topic :: Software Development :: Libraries",
|
113 |
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
114 |
+
],
|
115 |
+
cmdclass={"upload": UploadCommand},
|
116 |
+
)
|