zhuwq0 commited on
Commit
0eb79a8
0 Parent(s):
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM tensorflow/tensorflow
2
+
3
+ # Create the environment:
4
+ # COPY env.yml /app
5
+ # RUN conda env create --name cs329s --file=env.yml
6
+ # Make RUN commands use the new environment:
7
+ # SHELL ["conda", "run", "-n", "cs329s", "/bin/bash", "-c"]
8
+
9
+ RUN pip install tqdm obspy pandas
10
+ RUN pip install uvicorn fastapi
11
+
12
+ WORKDIR /opt
13
+
14
+ # Copy files
15
+ COPY phasenet /opt/phasenet
16
+ COPY model /opt/model
17
+
18
+ # Expose API port
19
+ EXPOSE 8000
20
+
21
+ ENV PYTHONUNBUFFERED=1
22
+
23
+ # Start API server
24
+ #ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "cs329s", "uvicorn", "--app-dir", "phasenet", "app:app", "--reload", "--port", "8000", "--host", "0.0.0.0"]
25
+ ENTRYPOINT ["uvicorn", "--app-dir", "phasenet", "app:app", "--reload", "--port", "7860", "--host", "0.0.0.0"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Weiqiang Zhu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
docs/README.md ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PhaseNet: A Deep-Neural-Network-Based Seismic Arrival Time Picking Method
2
+
3
+ [![](https://github.com/AI4EPS/PhaseNet/workflows/documentation/badge.svg)](https://ai4eps.github.io/PhaseNet)
4
+
5
+ ## 1. Install [miniconda](https://docs.conda.io/en/latest/miniconda.html) and requirements
6
+ - Download PhaseNet repository
7
+ ```bash
8
+ git clone https://github.com/wayneweiqiang/PhaseNet.git
9
+ cd PhaseNet
10
+ ```
11
+ - Install to default environment
12
+ ```bash
13
+ conda env update -f=env.yml -n base
14
+ ```
15
+ - Install to "phasenet" virtual envirionment
16
+ ```bash
17
+ conda env create -f env.yml
18
+ conda activate phasenet
19
+ ```
20
+
21
+ ## 2. Pre-trained model
22
+ Located in directory: **model/190703-214543**
23
+
24
+ ## 3. Related papers
25
+ - Zhu, Weiqiang, and Gregory C. Beroza. "PhaseNet: A Deep-Neural-Network-Based Seismic Arrival Time Picking Method." arXiv preprint arXiv:1803.03211 (2018).
26
+ - Liu, Min, et al. "Rapid characterization of the July 2019 Ridgecrest, California, earthquake sequence from raw seismic data using machine‐learning phase picker." Geophysical Research Letters 47.4 (2020): e2019GL086189.
27
+ - Park, Yongsoo, et al. "Machine‐learning‐based analysis of the Guy‐Greenbrier, Arkansas earthquakes: A tale of two sequences." Geophysical Research Letters 47.6 (2020): e2020GL087032.
28
+ - Chai, Chengping, et al. "Using a deep neural network and transfer learning to bridge scales for seismic phase picking." Geophysical Research Letters 47.16 (2020): e2020GL088651.
29
+ - Tan, Yen Joe, et al. "Machine‐Learning‐Based High‐Resolution Earthquake Catalog Reveals How Complex Fault Structures Were Activated during the 2016–2017 Central Italy Sequence." The Seismic Record 1.1 (2021): 11-19.
30
+
31
+ ## 4. Batch prediction
32
+ See examples in the [notebook](https://github.com/wayneweiqiang/PhaseNet/blob/master/docs/example_batch_prediction.ipynb): [example_batch_prediction.ipynb](example_batch_prediction.ipynb)
33
+
34
+
35
+ PhaseNet currently supports four data formats: mseed, sac, hdf5, and numpy. The test data can be downloaded here:
36
+ ```
37
+ wget https://github.com/wayneweiqiang/PhaseNet/releases/download/test_data/test_data.zip
38
+ unzip test_data.zip
39
+ ```
40
+
41
+ - For mseed format:
42
+ ```
43
+ python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed.csv --data_dir=test_data/mseed --format=mseed --amplitude --response_xml=test_data/stations.xml --batch_size=1 --sampling_rate=100 --plot_figure
44
+ ```
45
+ ```
46
+ python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed2.csv --data_dir=test_data/mseed --format=mseed --amplitude --response_xml=test_data/stations.xml --batch_size=1 --sampling_rate=100 --plot_figure
47
+ ```
48
+
49
+ - For sac format:
50
+ ```
51
+ python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/sac.csv --data_dir=test_data/sac --format=sac --batch_size=1 --plot_figure
52
+ ```
53
+
54
+ - For numpy format:
55
+ ```
56
+ python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/npz.csv --data_dir=test_data/npz --format=numpy --plot_figure
57
+ ```
58
+
59
+ - For hdf5 format:
60
+ ```
61
+ python phasenet/predict.py --model=model/190703-214543 --hdf5_file=test_data/data.h5 --hdf5_group=data --format=hdf5 --plot_figure
62
+ ```
63
+
64
+ - For a seismic array (used by [QuakeFlow](https://github.com/wayneweiqiang/QuakeFlow)):
65
+ ```
66
+ python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed_array.csv --data_dir=test_data/mseed_array --stations=test_data/stations.json --format=mseed_array --amplitude
67
+ ```
68
+
69
+ Notes:
70
+
71
+ 1. The reason for using "--batch_size=1" is because the mseed or sac files usually are not the same length. If you want to use a larger batch size for a good prediction speed, you need to cut the data to the same length.
72
+
73
+ 2. Remove the "--plot_figure" argument for large datasets, because plotting can be very slow.
74
+
75
+ Optional arguments:
76
+ ```
77
+ usage: predict.py [-h] [--batch_size BATCH_SIZE] [--model_dir MODEL_DIR]
78
+ [--data_dir DATA_DIR] [--data_list DATA_LIST]
79
+ [--hdf5_file HDF5_FILE] [--hdf5_group HDF5_GROUP]
80
+ [--result_dir RESULT_DIR] [--result_fname RESULT_FNAME]
81
+ [--min_p_prob MIN_P_PROB] [--min_s_prob MIN_S_PROB]
82
+ [--mpd MPD] [--amplitude] [--format FORMAT]
83
+ [--s3_url S3_URL] [--stations STATIONS] [--plot_figure]
84
+ [--save_prob]
85
+
86
+ optional arguments:
87
+ -h, --help show this help message and exit
88
+ --batch_size BATCH_SIZE
89
+ batch size
90
+ --model_dir MODEL_DIR
91
+ Checkpoint directory (default: None)
92
+ --data_dir DATA_DIR Input file directory
93
+ --data_list DATA_LIST
94
+ Input csv file
95
+ --hdf5_file HDF5_FILE
96
+ Input hdf5 file
97
+ --hdf5_group HDF5_GROUP
98
+ data group name in hdf5 file
99
+ --result_dir RESULT_DIR
100
+ Output directory
101
+ --result_fname RESULT_FNAME
102
+ Output file
103
+ --min_p_prob MIN_P_PROB
104
+ Probability threshold for P pick
105
+ --min_s_prob MIN_S_PROB
106
+ Probability threshold for S pick
107
+ --mpd MPD Minimum peak distance
108
+ --amplitude if return amplitude value
109
+ --format FORMAT input format
110
+ --stations STATIONS seismic station info
111
+ --plot_figure If plot figure for test
112
+ --save_prob If save result for test
113
+ ```
114
+
115
+ - The output picks are saved to "results/picks.csv" on default
116
+
117
+ |file_name |begin_time |station_id|phase_index|phase_time |phase_score|phase_amp |phase_type|
118
+ |-----------------|-----------------------|----------|-----------|-----------------------|-----------|----------------------|----------|
119
+ |2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.BOM..HH|14734 |2020-10-01T00:02:27.343|0.708 |2.4998866231208325e-14|P |
120
+ |2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.BOM..HH|15487 |2020-10-01T00:02:34.873|0.416 |2.4998866231208325e-14|S |
121
+ |2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.COA..HH|319 |2020-10-01T00:00:03.193|0.762 |3.708662269972206e-14 |P |
122
+
123
+ Notes:
124
+ 1. The *phase_index* means which data point is the pick in the original sequence. So *phase_time* = *begin_time* + *phase_index* / *sampling rate*. The default *sampling_rate* is 100Hz
125
+
126
+
127
+ ## 5. QuakeFlow example
128
+ A complete earthquake detection workflow can be found in the [QuakeFlow](https://wayneweiqiang.github.io/QuakeFlow/) project.
129
+
130
+ ## 6. Interactive example
131
+ See details in the [notebook](https://github.com/wayneweiqiang/PhaseNet/blob/master/docs/example_gradio.ipynb): [example_interactive.ipynb](example_gradio.ipynb)
132
+
133
+ ## 7. Training
134
+ - Download a small sample dataset:
135
+ ```bash
136
+ wget https://github.com/wayneweiqiang/PhaseNet/releases/download/test_data/test_data.zip
137
+ unzip test_data.zip
138
+ ```
139
+ - Start training from the pre-trained model
140
+ ```
141
+ python phasenet/train.py --model_dir=model/190703-214543/ --train_dir=test_data/npz --train_list=test_data/npz.csv --plot_figure --epochs=10 --batch_size=10
142
+ ```
143
+ - Check results in the **log** folder
144
+
docs/data.mseed ADDED
Binary file (73.7 kB). View file
 
docs/example_batch_prediction.ipynb ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Batch Prediction\n",
8
+ "\n",
9
+ "## 1. Download demo data\n",
10
+ "\n",
11
+ "```\n",
12
+ "cd PhaseNet\n",
13
+ "wget https://github.com/wayneweiqiang/PhaseNet/releases/download/test_data/test_data.zip\n",
14
+ "unzip test_data.zip\n",
15
+ "```\n",
16
+ "\n",
17
+ "## 2. Run batch prediction \n",
18
+ "\n",
19
+ "PhaseNet currently supports four data formats: mseed, sac, hdf5, and numpy. \n",
20
+ "\n",
21
+ "- For mseed format:\n",
22
+ "```\n",
23
+ "python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed.csv --data_dir=test_data/mseed --format=mseed --plot_figure\n",
24
+ "```\n",
25
+ "\n",
26
+ "- For sac format:\n",
27
+ "```\n",
28
+ "python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/sac.csv --data_dir=test_data/sac --format=sac --plot_figure\n",
29
+ "```\n",
30
+ "\n",
31
+ "- For numpy format:\n",
32
+ "```\n",
33
+ "python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/npz.csv --data_dir=test_data/npz --format=numpy --plot_figure\n",
34
+ "```\n",
35
+ "\n",
36
+ "- For hdf5 format:\n",
37
+ "```\n",
38
+ "python phasenet/predict.py --model=model/190703-214543 --hdf5_file=test_data/data.h5 --hdf5_group=data --format=hdf5 --plot_figure\n",
39
+ "```\n",
40
+ "\n",
41
+ "- For a seismic array (used by [QuakeFlow](https://github.com/wayneweiqiang/QuakeFlow)):\n",
42
+ "```\n",
43
+ "python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed_array.csv --data_dir=test_data/mseed_array --stations=test_data/stations.json --format=mseed_array --amplitude\n",
44
+ "```\n",
45
+ "```\n",
46
+ "python phasenet/predict.py --model=model/190703-214543 --data_list=test_data/mseed2.csv --data_dir=test_data/mseed --stations=test_data/stations.json --format=mseed_array --amplitude\n",
47
+ "```\n",
48
+ "\n",
49
+ "Notes: \n",
50
+ "1. Remove the \"--plot_figure\" argument for large datasets, because plotting can be very slow.\n",
51
+ "\n",
52
+ "Optional arguments:\n",
53
+ "```\n",
54
+ "usage: predict.py [-h] [--batch_size BATCH_SIZE] [--model_dir MODEL_DIR]\n",
55
+ " [--data_dir DATA_DIR] [--data_list DATA_LIST]\n",
56
+ " [--hdf5_file HDF5_FILE] [--hdf5_group HDF5_GROUP]\n",
57
+ " [--result_dir RESULT_DIR] [--result_fname RESULT_FNAME]\n",
58
+ " [--min_p_prob MIN_P_PROB] [--min_s_prob MIN_S_PROB]\n",
59
+ " [--mpd MPD] [--amplitude] [--format FORMAT]\n",
60
+ " [--s3_url S3_URL] [--stations STATIONS] [--plot_figure]\n",
61
+ " [--save_prob]\n",
62
+ "\n",
63
+ "optional arguments:\n",
64
+ " -h, --help show this help message and exit\n",
65
+ " --batch_size BATCH_SIZE\n",
66
+ " batch size\n",
67
+ " --model_dir MODEL_DIR\n",
68
+ " Checkpoint directory (default: None)\n",
69
+ " --data_dir DATA_DIR Input file directory\n",
70
+ " --data_list DATA_LIST\n",
71
+ " Input csv file\n",
72
+ " --hdf5_file HDF5_FILE\n",
73
+ " Input hdf5 file\n",
74
+ " --hdf5_group HDF5_GROUP\n",
75
+ " data group name in hdf5 file\n",
76
+ " --result_dir RESULT_DIR\n",
77
+ " Output directory\n",
78
+ " --result_fname RESULT_FNAME\n",
79
+ " Output file\n",
80
+ " --min_p_prob MIN_P_PROB\n",
81
+ " Probability threshold for P pick\n",
82
+ " --min_s_prob MIN_S_PROB\n",
83
+ " Probability threshold for S pick\n",
84
+ " --mpd MPD Minimum peak distance\n",
85
+ " --amplitude if return amplitude value\n",
86
+ " --format FORMAT input format\n",
87
+ " --stations STATIONS seismic station info\n",
88
+ " --plot_figure If plot figure for test\n",
89
+ " --save_prob If save result for test\n",
90
+ "```\n",
91
+ "\n",
92
+ "## 3. Output picks\n",
93
+ "- The output picks are saved to \"results/picks.csv\" on default\n",
94
+ "\n",
95
+ "|file_name |begin_time |station_id|phase_index|phase_time |phase_score|phase_amp |phase_type|\n",
96
+ "|-----------------|-----------------------|----------|-----------|-----------------------|-----------|----------------------|----------|\n",
97
+ "|2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.BOM..HH|14734 |2020-10-01T00:02:27.343|0.708 |2.4998866231208325e-14|P |\n",
98
+ "|2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.BOM..HH|15487 |2020-10-01T00:02:34.873|0.416 |2.4998866231208325e-14|S |\n",
99
+ "|2020-10-01T00:00*|2020-10-01T00:00:00.003|CI.COA..HH|319 |2020-10-01T00:00:03.193|0.762 |3.708662269972206e-14 |P |\n",
100
+ "\n",
101
+ "Notes:\n",
102
+ "1. The *phase_index* means which data point is the pick in the original sequence. So *phase_time* = *begin_time* + *phase_index* / *sampling rate*. The default *sampling_rate* is 100Hz \n"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "markdown",
107
+ "metadata": {},
108
+ "source": [
109
+ "## 3. Read P/S picks\n",
110
+ "\n",
111
+ "PhaseNet currently outputs two format: **CSV** and **JSON**"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": 1,
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "import pandas as pd\n",
121
+ "import json\n",
122
+ "import os\n",
123
+ "PROJECT_ROOT = os.path.realpath(os.path.join(os.path.abspath(''), \"..\"))"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": 2,
129
+ "metadata": {},
130
+ "outputs": [
131
+ {
132
+ "name": "stdout",
133
+ "output_type": "stream",
134
+ "text": [
135
+ "fname NC.MCV..EH.0361339.npz\n",
136
+ "t0 1970-01-01T00:00:00.000\n",
137
+ "p_idx [5999, 9015]\n",
138
+ "p_prob [0.987, 0.981]\n",
139
+ "s_idx [6181, 9205]\n",
140
+ "s_prob [0.553, 0.873]\n",
141
+ "Name: 1, dtype: object\n",
142
+ "fname NN.LHV..EH.0384064.npz\n",
143
+ "t0 1970-01-01T00:00:00.000\n",
144
+ "p_idx []\n",
145
+ "p_prob []\n",
146
+ "s_idx []\n",
147
+ "s_prob []\n",
148
+ "Name: 0, dtype: object\n"
149
+ ]
150
+ }
151
+ ],
152
+ "source": [
153
+ "picks_csv = pd.read_csv(os.path.join(PROJECT_ROOT, \"results/picks.csv\"), sep=\"\\t\")\n",
154
+ "picks_csv.loc[:, 'p_idx'] = picks_csv[\"p_idx\"].apply(lambda x: x.strip(\"[]\").split(\",\"))\n",
155
+ "picks_csv.loc[:, 'p_prob'] = picks_csv[\"p_prob\"].apply(lambda x: x.strip(\"[]\").split(\",\"))\n",
156
+ "picks_csv.loc[:, 's_idx'] = picks_csv[\"s_idx\"].apply(lambda x: x.strip(\"[]\").split(\",\"))\n",
157
+ "picks_csv.loc[:, 's_prob'] = picks_csv[\"s_prob\"].apply(lambda x: x.strip(\"[]\").split(\",\"))\n",
158
+ "print(picks_csv.iloc[1])\n",
159
+ "print(picks_csv.iloc[0])"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": 3,
165
+ "metadata": {},
166
+ "outputs": [
167
+ {
168
+ "name": "stdout",
169
+ "output_type": "stream",
170
+ "text": [
171
+ "{'id': 'NC.MCV..EH.0361339.npz', 'timestamp': '1970-01-01T00:01:30.150', 'prob': 0.9811667799949646, 'type': 'p'}\n",
172
+ "{'id': 'NC.MCV..EH.0361339.npz', 'timestamp': '1970-01-01T00:00:59.990', 'prob': 0.9872905611991882, 'type': 'p'}\n"
173
+ ]
174
+ }
175
+ ],
176
+ "source": [
177
+ "with open(os.path.join(PROJECT_ROOT, \"results/picks.json\")) as fp:\n",
178
+ " picks_json = json.load(fp) \n",
179
+ "print(picks_json[1])\n",
180
+ "print(picks_json[0])"
181
+ ]
182
+ }
183
+ ],
184
+ "metadata": {
185
+ "kernelspec": {
186
+ "display_name": "Python 3.10.4 64-bit",
187
+ "language": "python",
188
+ "name": "python3"
189
+ },
190
+ "language_info": {
191
+ "codemirror_mode": {
192
+ "name": "ipython",
193
+ "version": 3
194
+ },
195
+ "file_extension": ".py",
196
+ "mimetype": "text/x-python",
197
+ "name": "python",
198
+ "nbconvert_exporter": "python",
199
+ "pygments_lexer": "ipython3",
200
+ "version": "3.10.4"
201
+ },
202
+ "orig_nbformat": 4,
203
+ "vscode": {
204
+ "interpreter": {
205
+ "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
206
+ }
207
+ }
208
+ },
209
+ "nbformat": 4,
210
+ "nbformat_minor": 2
211
+ }
docs/example_fastapi.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
docs/example_gradio.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
docs/test_api.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ from gradio_client import Client
3
+ import obspy
4
+ import numpy as np
5
+ import json
6
+ import pandas as pd
7
+
8
+ # %%
9
+
10
+ waveform = obspy.read()
11
+ array = np.array([x.data for x in waveform]).T
12
+
13
+ # pipeline = PreTrainedPipeline()
14
+ inputs = array.tolist()
15
+ inputs = json.dumps(inputs)
16
+ # picks = pipeline(inputs)
17
+ # print(picks)
18
+
19
+ # %%
20
+ client = Client("ai4eps/phasenet")
21
+ output, file = client.predict(["test_test.mseed"])
22
+ # %%
23
+ with open(output, "r") as f:
24
+ picks = json.load(f)["data"]
25
+
26
+ # %%
27
+ picks = pd.read_csv(file)
28
+
29
+
30
+ # %%
31
+ job = client.submit(["test_test.mseed", "test_test.mseed"], api_name="/predict") # This is not blocking
32
+
33
+ print(job.status())
34
+
35
+ # %%
36
+ output, file = job.result()
37
+
env.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: phasenet
2
+ channels:
3
+ - defaults
4
+ - conda-forge
5
+ dependencies:
6
+ - python
7
+ - numpy
8
+ - scipy
9
+ - matplotlib
10
+ - pandas
11
+ - scikit-learn
12
+ - tqdm
13
+ - obspy
14
+ - uvicorn
15
+ - fastapi
16
+ - tensorflow
17
+ - keras
mkdocs.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ site_name: "PhaseNet"
2
+ site_description: 'PhaseNet: a deep-neural-network-based seismic arrival-time picking method'
3
+ site_author: 'Weiqiang Zhu'
4
+ docs_dir: docs/
5
+ repo_name: 'AI4EPS/PhaseNet'
6
+ repo_url: 'https://github.com/ai4eps/PhaseNet'
7
+ nav:
8
+ - Overview: README.md
9
+ - Interactive Example: example_gradio.ipynb
10
+ - Batch Prediction: example_batch_prediction.ipynb
11
+ theme:
12
+ name: 'material'
13
+ plugins:
14
+ - mkdocs-jupyter
15
+ extra:
16
+ analytics:
17
+ provider: google
18
+ property: G-RZQ9LRPL0S
model/190703-214543/checkpoint ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1606ccb25e1533fa0398c5dbce7f3a45ac77f90b78b99f81a044294ba38a2c0c
3
+ size 83
model/190703-214543/config.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed9dfa705053a5025facc9952c7da6abef19ec5f672d9e50386bf3f2d80294f2
3
+ size 345
model/190703-214543/loss.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccb6f19117497571e19bec5da6012ac7af91f1bd29e931ffd0b23c6b657bb401
3
+ size 8101
model/190703-214543/model_95.ckpt.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ee2c15dd78fb15de45a55ad64a446f1a0ced152ba4ac5c506d82b9194da85b4
3
+ size 3226256
model/190703-214543/model_95.ckpt.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f96b553b76be4ebae9a455eaf8d83cfa8c0e110f06cfba958de2568e5b6b2780
3
+ size 7223
model/190703-214543/model_95.ckpt.meta ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ebd154a5ba0721ba8bbb627ba61b556ee60660eb34bbcd1b1f50396b07cc4ed
3
+ size 2172055
phasenet/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.0"
phasenet/app.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import defaultdict, namedtuple
3
+ from datetime import datetime, timedelta
4
+ from json import dumps
5
+ from typing import Any, AnyStr, Dict, List, NamedTuple, Union, Optional
6
+
7
+ import numpy as np
8
+ import requests
9
+ import tensorflow as tf
10
+ from fastapi import FastAPI, WebSocket
11
+ from kafka import KafkaProducer
12
+ from pydantic import BaseModel
13
+ from scipy.interpolate import interp1d
14
+
15
+ from model import ModelConfig, UNet
16
+ from postprocess import extract_picks
17
+
18
+ tf.compat.v1.disable_eager_execution()
19
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
20
+ PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
21
+ JSONObject = Dict[AnyStr, Any]
22
+ JSONArray = List[Any]
23
+ JSONStructure = Union[JSONArray, JSONObject]
24
+
25
+ app = FastAPI()
26
+ X_SHAPE = [3000, 1, 3]
27
+ SAMPLING_RATE = 100
28
+
29
+ # load model
30
+ model = UNet(mode="pred")
31
+ sess_config = tf.compat.v1.ConfigProto()
32
+ sess_config.gpu_options.allow_growth = True
33
+
34
+ sess = tf.compat.v1.Session(config=sess_config)
35
+ saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
36
+ init = tf.compat.v1.global_variables_initializer()
37
+ sess.run(init)
38
+ latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190703-214543")
39
+ print(f"restoring model {latest_check_point}")
40
+ saver.restore(sess, latest_check_point)
41
+
42
+ # GAMMA API Endpoint
43
+ GAMMA_API_URL = "http://gamma-api:8001"
44
+ # GAMMA_API_URL = 'http://localhost:8001'
45
+ # GAMMA_API_URL = "http://gamma.quakeflow.com"
46
+ # GAMMA_API_URL = "http://127.0.0.1:8001"
47
+
48
+ # Kafak producer
49
+ use_kafka = False
50
+
51
+ try:
52
+ print("Connecting to k8s kafka")
53
+ BROKER_URL = "quakeflow-kafka-headless:9092"
54
+ # BROKER_URL = "34.83.137.139:9094"
55
+ producer = KafkaProducer(
56
+ bootstrap_servers=[BROKER_URL],
57
+ key_serializer=lambda x: dumps(x).encode("utf-8"),
58
+ value_serializer=lambda x: dumps(x).encode("utf-8"),
59
+ )
60
+ use_kafka = True
61
+ print("k8s kafka connection success!")
62
+ except BaseException:
63
+ print("k8s Kafka connection error")
64
+ try:
65
+ print("Connecting to local kafka")
66
+ producer = KafkaProducer(
67
+ bootstrap_servers=["localhost:9092"],
68
+ key_serializer=lambda x: dumps(x).encode("utf-8"),
69
+ value_serializer=lambda x: dumps(x).encode("utf-8"),
70
+ )
71
+ use_kafka = True
72
+ print("local kafka connection success!")
73
+ except BaseException:
74
+ print("local Kafka connection error")
75
+ print(f"Kafka status: {use_kafka}")
76
+
77
+
78
+ def normalize_batch(data, window=3000):
79
+ """
80
+ data: nsta, nt, nch
81
+ """
82
+ shift = window // 2
83
+ nsta, nt, nch = data.shape
84
+
85
+ # std in slide windows
86
+ data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0)), mode="reflect")
87
+ t = np.arange(0, nt, shift, dtype="int")
88
+ std = np.zeros([nsta, len(t) + 1, nch])
89
+ mean = np.zeros([nsta, len(t) + 1, nch])
90
+ for i in range(1, len(t)):
91
+ std[:, i, :] = np.std(data_pad[:, i * shift : i * shift + window, :], axis=1)
92
+ mean[:, i, :] = np.mean(data_pad[:, i * shift : i * shift + window, :], axis=1)
93
+
94
+ t = np.append(t, nt)
95
+ # std[:, -1, :] = np.std(data_pad[:, -window:, :], axis=1)
96
+ # mean[:, -1, :] = np.mean(data_pad[:, -window:, :], axis=1)
97
+ std[:, -1, :], mean[:, -1, :] = std[:, -2, :], mean[:, -2, :]
98
+ std[:, 0, :], mean[:, 0, :] = std[:, 1, :], mean[:, 1, :]
99
+ std[std == 0] = 1
100
+
101
+ # ## normalize data with interplated std
102
+ t_interp = np.arange(nt, dtype="int")
103
+ std_interp = interp1d(t, std, axis=1, kind="slinear")(t_interp)
104
+ mean_interp = interp1d(t, mean, axis=1, kind="slinear")(t_interp)
105
+ data = (data - mean_interp) / std_interp
106
+
107
+ return data
108
+
109
+
110
+ def preprocess(data):
111
+ raw = data.copy()
112
+ data = normalize_batch(data)
113
+ if len(data.shape) == 3:
114
+ data = data[:, :, np.newaxis, :]
115
+ raw = raw[:, :, np.newaxis, :]
116
+ return data, raw
117
+
118
+
119
+ def calc_timestamp(timestamp, sec):
120
+ timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec)
121
+ return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
122
+
123
+
124
+ def format_picks(picks, dt, amplitudes):
125
+ picks_ = []
126
+ for pick, amplitude in zip(picks, amplitudes):
127
+ for idxs, probs, amps in zip(pick.p_idx, pick.p_prob, amplitude.p_amp):
128
+ for idx, prob, amp in zip(idxs, probs, amps):
129
+ picks_.append(
130
+ {
131
+ "id": pick.fname,
132
+ "timestamp": calc_timestamp(pick.t0, float(idx) * dt),
133
+ "prob": prob,
134
+ "amp": amp,
135
+ "type": "p",
136
+ }
137
+ )
138
+ for idxs, probs, amps in zip(pick.s_idx, pick.s_prob, amplitude.s_amp):
139
+ for idx, prob, amp in zip(idxs, probs, amps):
140
+ picks_.append(
141
+ {
142
+ "id": pick.fname,
143
+ "timestamp": calc_timestamp(pick.t0, float(idx) * dt),
144
+ "prob": prob,
145
+ "amp": amp,
146
+ "type": "s",
147
+ }
148
+ )
149
+ return picks_
150
+
151
+
152
+ def format_data(data):
153
+ # chn2idx = {"ENZ": {"E":0, "N":1, "Z":2},
154
+ # "123": {"3":0, "2":1, "1":2},
155
+ # "12Z": {"1":0, "2":1, "Z":2}}
156
+ chn2idx = {"E": 0, "N": 1, "Z": 2, "3": 0, "2": 1, "1": 2}
157
+ Data = NamedTuple("data", [("id", list), ("timestamp", list), ("vec", list), ("dt", float)])
158
+
159
+ # Group by station
160
+ chn_ = defaultdict(list)
161
+ t0_ = defaultdict(list)
162
+ vv_ = defaultdict(list)
163
+ for i in range(len(data.id)):
164
+ key = data.id[i][:-1]
165
+ chn_[key].append(data.id[i][-1])
166
+ t0_[key].append(datetime.strptime(data.timestamp[i], "%Y-%m-%dT%H:%M:%S.%f").timestamp() * SAMPLING_RATE)
167
+ vv_[key].append(np.array(data.vec[i]))
168
+
169
+ # Merge to Data tuple
170
+ id_ = []
171
+ timestamp_ = []
172
+ vec_ = []
173
+ for k in chn_:
174
+ id_.append(k)
175
+ min_t0 = min(t0_[k])
176
+ timestamp_.append(datetime.fromtimestamp(min_t0 / SAMPLING_RATE).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3])
177
+ vec = np.zeros([X_SHAPE[0], X_SHAPE[-1]])
178
+ for i in range(len(chn_[k])):
179
+ # vec[int(t0_[k][i]-min_t0):len(vv_[k][i]), chn2idx[chn_[k][i]]] = vv_[k][i][int(t0_[k][i]-min_t0):X_SHAPE[0]] - np.mean(vv_[k][i])
180
+ shift = int(t0_[k][i] - min_t0)
181
+ vec[shift : len(vv_[k][i]) + shift, chn2idx[chn_[k][i]]] = vv_[k][i][: X_SHAPE[0] - shift] - np.mean(
182
+ vv_[k][i][: X_SHAPE[0] - shift]
183
+ )
184
+ vec_.append(vec.tolist())
185
+
186
+ return Data(id=id_, timestamp=timestamp_, vec=vec_, dt=1 / SAMPLING_RATE)
187
+ # return {"id": id_, "timestamp": timestamp_, "vec": vec_, "dt":1 / SAMPLING_RATE}
188
+
189
+
190
+ def get_prediction(data, return_preds=False):
191
+ vec = np.array(data.vec)
192
+ vec, vec_raw = preprocess(vec)
193
+
194
+ feed = {model.X: vec, model.drop_rate: 0, model.is_training: False}
195
+ preds = sess.run(model.preds, feed_dict=feed)
196
+
197
+ picks = extract_picks(preds, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
198
+
199
+ picks = [
200
+ {k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]}
201
+ for pick in picks
202
+ ]
203
+
204
+ if return_preds:
205
+ return picks, preds
206
+
207
+ return picks
208
+
209
+
210
+ class Data(BaseModel):
211
+ # id: Union[List[str], str]
212
+ # timestamp: Union[List[str], str]
213
+ # vec: Union[List[List[List[float]]], List[List[float]]]
214
+ id: List[str]
215
+ timestamp: List[Union[str, float, datetime]]
216
+ vec: Union[List[List[List[float]]], List[List[float]]]
217
+
218
+ dt: Optional[float] = 0.01
219
+ ## gamma
220
+ stations: Optional[List[Dict[str, Union[float, str]]]] = None
221
+ config: Optional[Dict[str, Union[List[float], List[int], List[str], float, int, str]]] = None
222
+
223
+
224
+ # @app.on_event("startup")
225
+ # def set_default_executor():
226
+ # from concurrent.futures import ThreadPoolExecutor
227
+ # import asyncio
228
+ #
229
+ # loop = asyncio.get_running_loop()
230
+ # loop.set_default_executor(
231
+ # ThreadPoolExecutor(max_workers=2)
232
+ # )
233
+
234
+
235
+ @app.post("/predict")
236
+ def predict(data: Data):
237
+ picks = get_prediction(data)
238
+
239
+ return picks
240
+
241
+
242
+ @app.websocket("/ws")
243
+ async def websocket_endpoint(websocket: WebSocket):
244
+ await websocket.accept()
245
+ while True:
246
+ data = await websocket.receive_json()
247
+ # data = json.loads(data)
248
+ data = Data(**data)
249
+ picks = get_prediction(data)
250
+ await websocket.send_json(picks)
251
+ print("PhaseNet Updating...")
252
+
253
+
254
+ @app.post("/predict_prob")
255
+ def predict(data: Data):
256
+ picks, preds = get_prediction(data, True)
257
+
258
+ return picks, preds.tolist()
259
+
260
+
261
+ @app.post("/predict_phasenet2gamma")
262
+ def predict(data: Data):
263
+ picks = get_prediction(data)
264
+
265
+ # if use_kafka:
266
+ # print("Push picks to kafka...")
267
+ # for pick in picks:
268
+ # producer.send("phasenet_picks", key=pick["id"], value=pick)
269
+ try:
270
+ catalog = requests.post(
271
+ f"{GAMMA_API_URL}/predict", json={"picks": picks, "stations": data.stations, "config": data.config}
272
+ )
273
+ print(catalog.json()["catalog"])
274
+ return catalog.json()
275
+ except Exception as error:
276
+ print(error)
277
+
278
+ return {}
279
+
280
+
281
+ @app.post("/predict_phasenet2gamma2ui")
282
+ def predict(data: Data):
283
+ picks = get_prediction(data)
284
+
285
+ try:
286
+ catalog = requests.post(
287
+ f"{GAMMA_API_URL}/predict", json={"picks": picks, "stations": data.stations, "config": data.config}
288
+ )
289
+ print(catalog.json()["catalog"])
290
+ return catalog.json()
291
+ except Exception as error:
292
+ print(error)
293
+
294
+ if use_kafka:
295
+ print("Push picks to kafka...")
296
+ for pick in picks:
297
+ producer.send("phasenet_picks", key=pick["id"], value=pick)
298
+ print("Push waveform to kafka...")
299
+ for id, timestamp, vec in zip(data.id, data.timestamp, data.vec):
300
+ producer.send("waveform_phasenet", key=id, value={"timestamp": timestamp, "vec": vec, "dt": data.dt})
301
+
302
+ return {}
303
+
304
+
305
+ @app.post("/predict_stream_phasenet2gamma")
306
+ def predict(data: Data):
307
+ data = format_data(data)
308
+ # for i in range(len(data.id)):
309
+ # plt.clf()
310
+ # plt.subplot(311)
311
+ # plt.plot(np.array(data.vec)[i, :, 0])
312
+ # plt.subplot(312)
313
+ # plt.plot(np.array(data.vec)[i, :, 1])
314
+ # plt.subplot(313)
315
+ # plt.plot(np.array(data.vec)[i, :, 2])
316
+ # plt.savefig(f"{data.id[i]}.png")
317
+
318
+ picks = get_prediction(data)
319
+
320
+ return_value = {}
321
+ try:
322
+ catalog = requests.post(f"{GAMMA_API_URL}/predict_stream", json={"picks": picks})
323
+ print("GMMA:", catalog.json()["catalog"])
324
+ return_value = catalog.json()
325
+ except Exception as error:
326
+ print(error)
327
+
328
+ if use_kafka:
329
+ print("Push picks to kafka...")
330
+ for pick in picks:
331
+ producer.send("phasenet_picks", key=pick["id"], value=pick)
332
+ print("Push waveform to kafka...")
333
+ for id, timestamp, vec in zip(data.id, data.timestamp, data.vec):
334
+ producer.send("waveform_phasenet", key=id, value={"timestamp": timestamp, "vec": vec, "dt": data.dt})
335
+
336
+ return return_value
337
+
338
+
339
+ @app.get("/healthz")
340
+ def healthz():
341
+ return {"status": "ok"}
phasenet/data_reader.py ADDED
@@ -0,0 +1,1010 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ tf.compat.v1.disable_eager_execution()
4
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
5
+ import logging
6
+ import os
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+ pd.options.mode.chained_assignment = None
12
+ import json
13
+ import random
14
+ from collections import defaultdict
15
+
16
+ # import s3fs
17
+ import h5py
18
+ import obspy
19
+ from scipy.interpolate import interp1d
20
+ from tqdm import tqdm
21
+
22
+
23
+ def py_func_decorator(output_types=None, output_shapes=None, name=None):
24
+ def decorator(func):
25
+ def call(*args, **kwargs):
26
+ nonlocal output_shapes
27
+ # flat_output_types = nest.flatten(output_types)
28
+ flat_output_types = tf.nest.flatten(output_types)
29
+ # flat_values = tf.py_func(
30
+ flat_values = tf.numpy_function(func, inp=args, Tout=flat_output_types, name=name)
31
+ if output_shapes is not None:
32
+ for v, s in zip(flat_values, output_shapes):
33
+ v.set_shape(s)
34
+ # return nest.pack_sequence_as(output_types, flat_values)
35
+ return tf.nest.pack_sequence_as(output_types, flat_values)
36
+
37
+ return call
38
+
39
+ return decorator
40
+
41
+
42
+ def dataset_map(iterator, output_types, output_shapes=None, num_parallel_calls=None, name=None, shuffle=False):
43
+ dataset = tf.data.Dataset.range(len(iterator))
44
+ if shuffle:
45
+ dataset = dataset.shuffle(len(iterator), reshuffle_each_iteration=True)
46
+
47
+ @py_func_decorator(output_types, output_shapes, name=name)
48
+ def index_to_entry(idx):
49
+ return iterator[idx]
50
+
51
+ return dataset.map(index_to_entry, num_parallel_calls=num_parallel_calls)
52
+
53
+
54
+ def normalize(data, axis=(0,)):
55
+ """data shape: (nt, nsta, nch)"""
56
+ data -= np.mean(data, axis=axis, keepdims=True)
57
+ std_data = np.std(data, axis=axis, keepdims=True)
58
+ std_data[std_data == 0] = 1
59
+ data /= std_data
60
+ # data /= (std_data + 1e-12)
61
+ return data
62
+
63
+
64
+ def normalize_long(data, axis=(0,), window=3000):
65
+ """
66
+ data: nt, nch
67
+ """
68
+ nt, nar, nch = data.shape
69
+ if window is None:
70
+ window = nt
71
+ shift = window // 2
72
+
73
+ dtype = data.dtype
74
+ ## std in slide windows
75
+ data_pad = np.pad(data, ((window // 2, window // 2), (0, 0), (0, 0)), mode="reflect")
76
+ t = np.arange(0, nt, shift, dtype="int")
77
+ std = np.zeros([len(t) + 1, nar, nch])
78
+ mean = np.zeros([len(t) + 1, nar, nch])
79
+ for i in range(1, len(std)):
80
+ std[i, :] = np.std(data_pad[i * shift : i * shift + window, :, :], axis=axis)
81
+ mean[i, :] = np.mean(data_pad[i * shift : i * shift + window, :, :], axis=axis)
82
+
83
+ t = np.append(t, nt)
84
+ # std[-1, :] = np.std(data_pad[-window:, :], axis=0)
85
+ # mean[-1, :] = np.mean(data_pad[-window:, :], axis=0)
86
+ std[-1, ...], mean[-1, ...] = std[-2, ...], mean[-2, ...]
87
+ std[0, ...], mean[0, ...] = std[1, ...], mean[1, ...]
88
+ # std[std == 0] = 1.0
89
+
90
+ ## normalize data with interplated std
91
+ t_interp = np.arange(nt, dtype="int")
92
+ std_interp = interp1d(t, std, axis=0, kind="slinear")(t_interp)
93
+ # std_interp = np.exp(interp1d(t, np.log(std), axis=0, kind="slinear")(t_interp))
94
+ mean_interp = interp1d(t, mean, axis=0, kind="slinear")(t_interp)
95
+ tmp = np.sum(std_interp, axis=(0, 1))
96
+ std_interp[std_interp == 0] = 1.0
97
+ data = (data - mean_interp) / std_interp
98
+ # data = (data - mean_interp)/(std_interp + 1e-12)
99
+
100
+ ### dropout effect of < 3 channel
101
+ nonzero = np.count_nonzero(tmp)
102
+ if (nonzero < 3) and (nonzero > 0):
103
+ data *= 3.0 / nonzero
104
+
105
+ return data.astype(dtype)
106
+
107
+
108
+ def normalize_batch(data, window=3000):
109
+ """
110
+ data: nsta, nt, nch
111
+ """
112
+ nsta, nt, nar, nch = data.shape
113
+ if window is None:
114
+ window = nt
115
+ shift = window // 2
116
+
117
+ ## std in slide windows
118
+ data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0), (0, 0)), mode="reflect")
119
+ t = np.arange(0, nt, shift, dtype="int")
120
+ std = np.zeros([nsta, len(t) + 1, nar, nch])
121
+ mean = np.zeros([nsta, len(t) + 1, nar, nch])
122
+ for i in range(1, len(t)):
123
+ std[:, i, :, :] = np.std(data_pad[:, i * shift : i * shift + window, :, :], axis=1)
124
+ mean[:, i, :, :] = np.mean(data_pad[:, i * shift : i * shift + window, :, :], axis=1)
125
+
126
+ t = np.append(t, nt)
127
+ # std[:, -1, :] = np.std(data_pad[:, -window:, :], axis=1)
128
+ # mean[:, -1, :] = np.mean(data_pad[:, -window:, :], axis=1)
129
+ std[:, -1, :, :], mean[:, -1, :, :] = std[:, -2, :, :], mean[:, -2, :, :]
130
+ std[:, 0, :, :], mean[:, 0, :, :] = std[:, 1, :, :], mean[:, 1, :, :]
131
+ # std[std == 0] = 1
132
+
133
+ # ## normalize data with interplated std
134
+ t_interp = np.arange(nt, dtype="int")
135
+ std_interp = interp1d(t, std, axis=1, kind="slinear")(t_interp)
136
+ # std_interp = np.exp(interp1d(t, np.log(std), axis=1, kind="slinear")(t_interp))
137
+ mean_interp = interp1d(t, mean, axis=1, kind="slinear")(t_interp)
138
+ tmp = np.sum(std_interp, axis=(1, 2))
139
+ std_interp[std_interp == 0] = 1.0
140
+ data = (data - mean_interp) / std_interp
141
+ # data = (data - mean_interp)/(std_interp + 1e-12)
142
+
143
+ ### dropout effect of < 3 channel
144
+ nonzero = np.count_nonzero(tmp, axis=-1)
145
+ data[nonzero > 0, ...] *= 3.0 / nonzero[nonzero > 0][:, np.newaxis, np.newaxis, np.newaxis]
146
+
147
+ return data
148
+
149
+
150
+ class DataConfig:
151
+ seed = 123
152
+ use_seed = True
153
+ n_channel = 3
154
+ n_class = 3
155
+ sampling_rate = 100
156
+ dt = 1.0 / sampling_rate
157
+ X_shape = [3000, 1, n_channel]
158
+ Y_shape = [3000, 1, n_class]
159
+ min_event_gap = 3 * sampling_rate
160
+ label_shape = "gaussian"
161
+ label_width = 30
162
+ dtype = "float32"
163
+
164
+ def __init__(self, **kwargs):
165
+ for k, v in kwargs.items():
166
+ setattr(self, k, v)
167
+
168
+
169
+ class DataReader:
170
+ def __init__(
171
+ self, format="numpy", config=DataConfig(), response_xml=None, sampling_rate=100, highpass_filter=0, **kwargs
172
+ ):
173
+ self.buffer = {}
174
+ self.n_channel = config.n_channel
175
+ self.n_class = config.n_class
176
+ self.X_shape = config.X_shape
177
+ self.Y_shape = config.Y_shape
178
+ self.dt = config.dt
179
+ self.dtype = config.dtype
180
+ self.label_shape = config.label_shape
181
+ self.label_width = config.label_width
182
+ self.config = config
183
+ self.format = format
184
+ # if "highpass_filter" in kwargs:
185
+ # self.highpass_filter = kwargs["highpass_filter"]
186
+ self.highpass_filter = highpass_filter
187
+ # self.response_xml = response_xml
188
+ if response_xml is not None:
189
+ self.response = obspy.read_inventory(response_xml)
190
+ else:
191
+ self.response = None
192
+ self.sampling_rate = sampling_rate
193
+ if format in ["numpy", "mseed", "sac"]:
194
+ self.data_dir = kwargs["data_dir"]
195
+ try:
196
+ csv = pd.read_csv(kwargs["data_list"], header=0, sep="[,|\s+]", engine="python")
197
+ except:
198
+ csv = pd.read_csv(kwargs["data_list"], header=0, sep="\t")
199
+ self.data_list = csv["fname"]
200
+ self.num_data = len(self.data_list)
201
+ elif format == "hdf5":
202
+ self.h5 = h5py.File(kwargs["hdf5_file"], "r", libver="latest", swmr=True)
203
+ self.h5_data = self.h5[kwargs["hdf5_group"]]
204
+ self.data_list = list(self.h5_data.keys())
205
+ self.num_data = len(self.data_list)
206
+ elif format == "s3":
207
+ self.s3fs = s3fs.S3FileSystem(
208
+ anon=kwargs["anon"],
209
+ key=kwargs["key"],
210
+ secret=kwargs["secret"],
211
+ client_kwargs={"endpoint_url": kwargs["s3_url"]},
212
+ use_ssl=kwargs["use_ssl"],
213
+ )
214
+ self.num_data = 0
215
+ else:
216
+ raise (f"{format} not support!")
217
+
218
+ def __len__(self):
219
+ return self.num_data
220
+
221
+ def read_numpy(self, fname):
222
+ # try:
223
+ if fname not in self.buffer:
224
+ npz = np.load(fname)
225
+ meta = {}
226
+ if len(npz["data"].shape) == 2:
227
+ meta["data"] = npz["data"][:, np.newaxis, :]
228
+ else:
229
+ meta["data"] = npz["data"]
230
+ if "p_idx" in npz.files:
231
+ if len(npz["p_idx"].shape) == 0:
232
+ meta["itp"] = [[npz["p_idx"]]]
233
+ else:
234
+ meta["itp"] = npz["p_idx"]
235
+ if "s_idx" in npz.files:
236
+ if len(npz["s_idx"].shape) == 0:
237
+ meta["its"] = [[npz["s_idx"]]]
238
+ else:
239
+ meta["its"] = npz["s_idx"]
240
+ if "itp" in npz.files:
241
+ if len(npz["itp"].shape) == 0:
242
+ meta["itp"] = [[npz["itp"]]]
243
+ else:
244
+ meta["itp"] = npz["itp"]
245
+ if "its" in npz.files:
246
+ if len(npz["its"].shape) == 0:
247
+ meta["its"] = [[npz["its"]]]
248
+ else:
249
+ meta["its"] = npz["its"]
250
+ if "station_id" in npz.files:
251
+ meta["station_id"] = npz["station_id"]
252
+ if "sta_id" in npz.files:
253
+ meta["station_id"] = npz["sta_id"]
254
+ if "t0" in npz.files:
255
+ meta["t0"] = npz["t0"]
256
+ self.buffer[fname] = meta
257
+ else:
258
+ meta = self.buffer[fname]
259
+ return meta
260
+ # except:
261
+ # logging.error("Failed reading {}".format(fname))
262
+ # return None
263
+
264
+ def read_hdf5(self, fname):
265
+ data = self.h5_data[fname][()]
266
+ attrs = self.h5_data[fname].attrs
267
+ meta = {}
268
+ if len(data.shape) == 2:
269
+ meta["data"] = data[:, np.newaxis, :]
270
+ else:
271
+ meta["data"] = data
272
+ if "p_idx" in attrs:
273
+ if len(attrs["p_idx"].shape) == 0:
274
+ meta["itp"] = [[attrs["p_idx"]]]
275
+ else:
276
+ meta["itp"] = attrs["p_idx"]
277
+ if "s_idx" in attrs:
278
+ if len(attrs["s_idx"].shape) == 0:
279
+ meta["its"] = [[attrs["s_idx"]]]
280
+ else:
281
+ meta["its"] = attrs["s_idx"]
282
+ if "itp" in attrs:
283
+ if len(attrs["itp"].shape) == 0:
284
+ meta["itp"] = [[attrs["itp"]]]
285
+ else:
286
+ meta["itp"] = attrs["itp"]
287
+ if "its" in attrs:
288
+ if len(attrs["its"].shape) == 0:
289
+ meta["its"] = [[attrs["its"]]]
290
+ else:
291
+ meta["its"] = attrs["its"]
292
+ if "t0" in attrs:
293
+ meta["t0"] = attrs["t0"]
294
+ return meta
295
+
296
+ def read_s3(self, format, fname, bucket, key, secret, s3_url, use_ssl):
297
+ with self.s3fs.open(bucket + "/" + fname, "rb") as fp:
298
+ if format == "numpy":
299
+ meta = self.read_numpy(fp)
300
+ elif format == "mseed":
301
+ meta = self.read_mseed(fp)
302
+ else:
303
+ raise (f"Format {format} not supported")
304
+ return meta
305
+
306
+ def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=100, return_single_station=True):
307
+ try:
308
+ stream = obspy.read(fname)
309
+ stream = stream.merge(fill_value="latest")
310
+ if response is not None:
311
+ # response = obspy.read_inventory(response_xml)
312
+ stream = stream.remove_sensitivity(response)
313
+ except Exception as e:
314
+ print(f"Error reading {fname}:\n{e}")
315
+ return {}
316
+ tmp_stream = obspy.Stream()
317
+ for trace in stream:
318
+ if len(trace.data) < 10:
319
+ continue
320
+
321
+ ## interpolate to 100 Hz
322
+ if abs(trace.stats.sampling_rate - sampling_rate) > 0.1:
323
+ logging.warning(f"Resampling {trace.id} from {trace.stats.sampling_rate} to {sampling_rate} Hz")
324
+ try:
325
+ trace = trace.interpolate(sampling_rate, method="linear")
326
+ except Exception as e:
327
+ print(f"Error resampling {trace.id}:\n{e}")
328
+
329
+ trace = trace.detrend("demean")
330
+
331
+ ## highpass filtering > 1Hz
332
+ if highpass_filter > 0.0:
333
+ trace = trace.filter("highpass", freq=highpass_filter)
334
+
335
+ tmp_stream.append(trace)
336
+
337
+ if len(tmp_stream) == 0:
338
+ return {}
339
+ stream = tmp_stream
340
+
341
+ begin_time = min([st.stats.starttime for st in stream])
342
+ end_time = max([st.stats.endtime for st in stream])
343
+ stream = stream.trim(begin_time, end_time, pad=True, fill_value=0)
344
+
345
+ comp = ["3", "2", "1", "E", "N", "U", "V", "Z"]
346
+ order = {key: i for i, key in enumerate(comp)}
347
+ comp2idx = {
348
+ "3": 0,
349
+ "2": 1,
350
+ "1": 2,
351
+ "E": 0,
352
+ "N": 1,
353
+ "Z": 2,
354
+ "U": 0,
355
+ "V": 1,
356
+ } ## only for cases less than 3 components
357
+
358
+ station_ids = defaultdict(list)
359
+ for tr in stream:
360
+ station_ids[tr.id[:-1]].append(tr.id[-1])
361
+ if tr.id[-1] not in comp:
362
+ print(f"Unknown component {tr.id[-1]}")
363
+
364
+ station_keys = sorted(list(station_ids.keys()))
365
+
366
+ nx = len(station_ids)
367
+ nt = len(stream[0].data)
368
+ data = np.zeros([3, nt, nx], dtype=np.float32)
369
+ for i, sta in enumerate(station_keys):
370
+ for j, c in enumerate(sorted(station_ids[sta], key=lambda x: order[x])):
371
+ if len(station_ids[sta]) != 3: ## less than 3 component
372
+ j = comp2idx[c]
373
+
374
+ if len(stream.select(id=sta + c)) == 0:
375
+ print(f"Empty trace: {sta+c} {begin_time}")
376
+ continue
377
+
378
+ trace = stream.select(id=sta + c)[0]
379
+
380
+ ## accerleration to velocity
381
+ if sta[-1] == "N":
382
+ trace = trace.integrate().filter("highpass", freq=1.0)
383
+
384
+ tmp = trace.data.astype("float32")
385
+ data[j, : len(tmp), i] = tmp[:nt]
386
+
387
+ # if return_single_station and (len(station_keys) > 1):
388
+ # print(f"Warning: {fname} has multiple stations, returning only the first one {station_keys[0]}")
389
+ # data = data[:, :, 0:1]
390
+ # station_keys = station_keys[0:1]
391
+
392
+ meta = {
393
+ "data": data.transpose([1, 2, 0]),
394
+ "t0": begin_time.datetime.isoformat(timespec="milliseconds"),
395
+ "station_id": station_keys,
396
+ }
397
+ return meta
398
+
399
+ def read_sac(self, fname):
400
+ mseed = obspy.read(fname)
401
+ mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
402
+ mseed = mseed.merge(fill_value=0)
403
+ if self.highpass_filter > 0:
404
+ mseed = mseed.filter("highpass", freq=self.highpass_filter)
405
+ starttime = min([st.stats.starttime for st in mseed])
406
+ endtime = max([st.stats.endtime for st in mseed])
407
+ mseed = mseed.trim(starttime, endtime, pad=True, fill_value=0)
408
+ if abs(mseed[0].stats.sampling_rate - self.config.sampling_rate) > 1:
409
+ logging.warning(
410
+ f"Sampling rate mismatch in {fname.split('/')[-1]}: {mseed[0].stats.sampling_rate}Hz != {self.config.sampling_rate}Hz "
411
+ )
412
+
413
+ order = ["3", "2", "1", "E", "N", "Z"]
414
+ order = {key: i for i, key in enumerate(order)}
415
+ comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
416
+
417
+ t0 = starttime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
418
+ nt = len(mseed[0].data)
419
+ data = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
420
+ ids = [x.get_id() for x in mseed]
421
+ for j, id in enumerate(sorted(ids, key=lambda x: order[x[-1]])):
422
+ if len(ids) != 3:
423
+ if len(ids) > 3:
424
+ logging.warning(f"More than 3 channels {ids}!")
425
+ j = comp2idx[id[-1]]
426
+ data[:, j] = mseed.select(id=id)[0].data.astype(self.dtype)
427
+
428
+ data = data[:, np.newaxis, :]
429
+ meta = {"data": data, "t0": t0}
430
+ return meta
431
+
432
+ def read_mseed_array(self, fname, stations, amplitude=False, remove_resp=True):
433
+ data = []
434
+ station_id = []
435
+ t0 = []
436
+ raw_amp = []
437
+
438
+ try:
439
+ mseed = obspy.read(fname)
440
+ read_success = True
441
+ except Exception as e:
442
+ read_success = False
443
+ print(e)
444
+
445
+ if read_success:
446
+ try:
447
+ mseed = mseed.merge(fill_value=0)
448
+ except Exception as e:
449
+ print(e)
450
+
451
+ for i in range(len(mseed)):
452
+ if mseed[i].stats.sampling_rate != self.config.sampling_rate:
453
+ logging.warning(
454
+ f"Resampling {mseed[i].id} from {mseed[i].stats.sampling_rate} to {self.config.sampling_rate} Hz"
455
+ )
456
+ try:
457
+ mseed[i] = mseed[i].interpolate(self.config.sampling_rate, method="linear")
458
+ except Exception as e:
459
+ print(e)
460
+ mseed[i].data = mseed[i].data.astype(float) * 0.0 ## set to zero if resampling fails
461
+
462
+ if self.highpass_filter == 0:
463
+ try:
464
+ mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
465
+ except:
466
+ logging.error(f"Error: spline detrend failed at file {fname}")
467
+ mseed = mseed.detrend("demean")
468
+ else:
469
+ mseed = mseed.filter("highpass", freq=self.highpass_filter)
470
+
471
+ starttime = min([st.stats.starttime for st in mseed])
472
+ endtime = max([st.stats.endtime for st in mseed])
473
+ mseed = mseed.trim(starttime, endtime, pad=True, fill_value=0)
474
+
475
+ order = ["3", "2", "1", "E", "N", "Z"]
476
+ order = {key: i for i, key in enumerate(order)}
477
+ comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
478
+
479
+ nsta = len(stations)
480
+ nt = len(mseed[0].data)
481
+ # for i in range(nsta):
482
+ for sta in stations:
483
+ trace_data = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
484
+ if amplitude:
485
+ trace_amp = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
486
+ empty_station = True
487
+ # sta = stations.iloc[i]["station"]
488
+ # comp = stations.iloc[i]["component"].split(",")
489
+ comp = stations[sta]["component"]
490
+ if amplitude:
491
+ # resp = stations.iloc[i]["response"].split(",")
492
+ resp = stations[sta]["response"]
493
+
494
+ for j, c in enumerate(sorted(comp, key=lambda x: order[x[-1]])):
495
+ resp_j = resp[j]
496
+ if len(comp) != 3: ## less than 3 component
497
+ j = comp2idx[c]
498
+
499
+ if len(mseed.select(id=sta + c)) == 0:
500
+ print(f"Empty trace: {sta+c} {starttime}")
501
+ continue
502
+ else:
503
+ empty_station = False
504
+
505
+ tmp = mseed.select(id=sta + c)[0].data.astype(self.dtype)
506
+ trace_data[: len(tmp), j] = tmp[:nt]
507
+ if amplitude:
508
+ # if stations.iloc[i]["unit"] == "m/s**2":
509
+ if stations[sta]["unit"] == "m/s**2":
510
+ tmp = mseed.select(id=sta + c)[0]
511
+ tmp = tmp.integrate()
512
+ tmp = tmp.filter("highpass", freq=1.0)
513
+ tmp = tmp.data.astype(self.dtype)
514
+ trace_amp[: len(tmp), j] = tmp[:nt]
515
+ # elif stations.iloc[i]["unit"] == "m/s":
516
+ elif stations[sta]["unit"] == "m/s":
517
+ tmp = mseed.select(id=sta + c)[0].data.astype(self.dtype)
518
+ trace_amp[: len(tmp), j] = tmp[:nt]
519
+ else:
520
+ print(
521
+ f"Error in {stations.iloc[i]['station']}\n{stations.iloc[i]['unit']} should be m/s**2 or m/s!"
522
+ )
523
+ if amplitude and remove_resp:
524
+ # trace_amp[:, j] /= float(resp[j])
525
+ trace_amp[:, j] /= float(resp_j)
526
+
527
+ if not empty_station:
528
+ data.append(trace_data)
529
+ if amplitude:
530
+ raw_amp.append(trace_amp)
531
+ station_id.append([sta])
532
+ t0.append(starttime.datetime.isoformat(timespec="milliseconds"))
533
+
534
+ if len(data) > 0:
535
+ data = np.stack(data)
536
+ if len(data.shape) == 3:
537
+ data = data[:, :, np.newaxis, :]
538
+ if amplitude:
539
+ raw_amp = np.stack(raw_amp)
540
+ if len(raw_amp.shape) == 3:
541
+ raw_amp = raw_amp[:, :, np.newaxis, :]
542
+ else:
543
+ nt = 60 * 60 * self.config.sampling_rate # assume 1 hour data
544
+ data = np.zeros([1, nt, 1, self.config.n_channel], dtype=self.dtype)
545
+ if amplitude:
546
+ raw_amp = np.zeros([1, nt, 1, self.config.n_channel], dtype=self.dtype)
547
+ t0 = ["1970-01-01T00:00:00.000"]
548
+ station_id = ["None"]
549
+
550
+ if amplitude:
551
+ meta = {"data": data, "t0": t0, "station_id": station_id, "fname": fname.split("/")[-1], "raw_amp": raw_amp}
552
+ else:
553
+ meta = {"data": data, "t0": t0, "station_id": station_id, "fname": fname.split("/")[-1]}
554
+ return meta
555
+
556
+ def generate_label(self, data, phase_list, mask=None):
557
+ # target = np.zeros(self.Y_shape, dtype=self.dtype)
558
+ target = np.zeros_like(data)
559
+
560
+ if self.label_shape == "gaussian":
561
+ label_window = np.exp(
562
+ -((np.arange(-self.label_width // 2, self.label_width // 2 + 1)) ** 2)
563
+ / (2 * (self.label_width / 5) ** 2)
564
+ )
565
+ elif self.label_shape == "triangle":
566
+ label_window = 1 - np.abs(
567
+ 2 / self.label_width * (np.arange(-self.label_width // 2, self.label_width // 2 + 1))
568
+ )
569
+ else:
570
+ print(f"Label shape {self.label_shape} should be guassian or triangle")
571
+ raise
572
+
573
+ for i, phases in enumerate(phase_list):
574
+ for j, idx_list in enumerate(phases):
575
+ for idx in idx_list:
576
+ if np.isnan(idx):
577
+ continue
578
+ idx = int(idx)
579
+ if (idx - self.label_width // 2 >= 0) and (idx + self.label_width // 2 + 1 <= target.shape[0]):
580
+ target[idx - self.label_width // 2 : idx + self.label_width // 2 + 1, j, i + 1] = label_window
581
+
582
+ target[..., 0] = 1 - np.sum(target[..., 1:], axis=-1)
583
+ if mask is not None:
584
+ target[:, mask == 0, :] = 0
585
+
586
+ return target
587
+
588
+ def random_shift(self, sample, itp, its, itp_old=None, its_old=None, shift_range=None):
589
+ # anchor = np.round(1/2 * (min(itp[~np.isnan(itp.astype(float))]) + min(its[~np.isnan(its.astype(float))]))).astype(int)
590
+ flattern = lambda x: np.array([i for trace in x for i in trace], dtype=float)
591
+ shift_pick = lambda x, shift: [[i - shift for i in trace] for trace in x]
592
+ itp_flat = flattern(itp)
593
+ its_flat = flattern(its)
594
+ if (itp_old is None) and (its_old is None):
595
+ hi = np.round(np.median(itp_flat[~np.isnan(itp_flat)])).astype(int)
596
+ lo = -(sample.shape[0] - np.round(np.median(its_flat[~np.isnan(its_flat)])).astype(int))
597
+ if shift_range is None:
598
+ shift = np.random.randint(low=lo, high=hi + 1)
599
+ else:
600
+ shift = np.random.randint(low=max(lo, shift_range[0]), high=min(hi + 1, shift_range[1]))
601
+ else:
602
+ itp_old_flat = flattern(itp_old)
603
+ its_old_flat = flattern(its_old)
604
+ itp_ref = np.round(np.min(itp_flat[~np.isnan(itp_flat)])).astype(int)
605
+ its_ref = np.round(np.max(its_flat[~np.isnan(its_flat)])).astype(int)
606
+ itp_old_ref = np.round(np.min(itp_old_flat[~np.isnan(itp_old_flat)])).astype(int)
607
+ its_old_ref = np.round(np.max(its_old_flat[~np.isnan(its_old_flat)])).astype(int)
608
+ # min_event_gap = np.round(self.min_event_gap*(its_ref-itp_ref)).astype(int)
609
+ # min_event_gap_old = np.round(self.min_event_gap*(its_old_ref-itp_old_ref)).astype(int)
610
+ if shift_range is None:
611
+ hi = list(range(max(its_ref - itp_old_ref + self.min_event_gap, 0), itp_ref))
612
+ lo = list(range(-(sample.shape[0] - its_ref), -(max(its_old_ref - itp_ref + self.min_event_gap, 0))))
613
+ else:
614
+ lo_ = max(-(sample.shape[0] - its_ref), shift_range[0])
615
+ hi_ = min(itp_ref, shift_range[1])
616
+ hi = list(range(max(its_ref - itp_old_ref + self.min_event_gap, 0), hi_))
617
+ lo = list(range(lo_, -(max(its_old_ref - itp_ref + self.min_event_gap, 0))))
618
+ if len(hi + lo) > 0:
619
+ shift = np.random.choice(hi + lo)
620
+ else:
621
+ shift = 0
622
+
623
+ shifted_sample = np.zeros_like(sample)
624
+ if shift > 0:
625
+ shifted_sample[:-shift, ...] = sample[shift:, ...]
626
+ elif shift < 0:
627
+ shifted_sample[-shift:, ...] = sample[:shift, ...]
628
+ else:
629
+ shifted_sample[...] = sample[...]
630
+
631
+ return shifted_sample, shift_pick(itp, shift), shift_pick(its, shift), shift
632
+
633
+ def stack_events(self, sample_old, itp_old, its_old, shift_range=None, mask_old=None):
634
+ i = np.random.randint(self.num_data)
635
+ base_name = self.data_list[i]
636
+ if self.format == "numpy":
637
+ meta = self.read_numpy(os.path.join(self.data_dir, base_name))
638
+ elif self.format == "hdf5":
639
+ meta = self.read_hdf5(base_name)
640
+ if meta == -1:
641
+ return sample_old, itp_old, its_old
642
+
643
+ sample = np.copy(meta["data"])
644
+ itp = meta["itp"]
645
+ its = meta["its"]
646
+ if mask_old is not None:
647
+ mask = np.copy(meta["mask"])
648
+ sample = normalize(sample)
649
+ sample, itp, its, shift = self.random_shift(sample, itp, its, itp_old, its_old, shift_range)
650
+
651
+ if shift != 0:
652
+ sample_old += sample
653
+ # itp_old = [np.hstack([i, j]) for i,j in zip(itp_old, itp)]
654
+ # its_old = [np.hstack([i, j]) for i,j in zip(its_old, its)]
655
+ itp_old = [i + j for i, j in zip(itp_old, itp)]
656
+ its_old = [i + j for i, j in zip(its_old, its)]
657
+ if mask_old is not None:
658
+ mask_old = mask_old * mask
659
+
660
+ return sample_old, itp_old, its_old, mask_old
661
+
662
+ def cut_window(self, sample, target, itp, its, select_range):
663
+ shift_pick = lambda x, shift: [[i - shift for i in trace] for trace in x]
664
+ sample = sample[select_range[0] : select_range[1]]
665
+ target = target[select_range[0] : select_range[1]]
666
+ return (sample, target, shift_pick(itp, select_range[0]), shift_pick(its, select_range[0]))
667
+
668
+
669
+ class DataReader_train(DataReader):
670
+ def __init__(self, format="numpy", config=DataConfig(), **kwargs):
671
+ super().__init__(format=format, config=config, **kwargs)
672
+
673
+ self.min_event_gap = config.min_event_gap
674
+ self.buffer_channels = {}
675
+ self.shift_range = [-2000 + self.label_width * 2, 1000 - self.label_width * 2]
676
+ self.select_range = [5000, 8000]
677
+
678
+ def __getitem__(self, i):
679
+ base_name = self.data_list[i]
680
+ if self.format == "numpy":
681
+ meta = self.read_numpy(os.path.join(self.data_dir, base_name))
682
+ elif self.format == "hdf5":
683
+ meta = self.read_hdf5(base_name)
684
+ if meta == None:
685
+ return (np.zeros(self.X_shape, dtype=self.dtype), np.zeros(self.Y_shape, dtype=self.dtype), base_name)
686
+
687
+ sample = np.copy(meta["data"])
688
+ itp_list = meta["itp"]
689
+ its_list = meta["its"]
690
+
691
+ sample = normalize(sample)
692
+ if np.random.random() < 0.95:
693
+ sample, itp_list, its_list, _ = self.random_shift(sample, itp_list, its_list, shift_range=self.shift_range)
694
+ sample, itp_list, its_list, _ = self.stack_events(sample, itp_list, its_list, shift_range=self.shift_range)
695
+ target = self.generate_label(sample, [itp_list, its_list])
696
+ sample, target, itp_list, its_list = self.cut_window(sample, target, itp_list, its_list, self.select_range)
697
+ else:
698
+ ## noise
699
+ assert self.X_shape[0] <= min(min(itp_list))
700
+ sample = sample[: self.X_shape[0], ...]
701
+ target = np.zeros(self.Y_shape).astype(self.dtype)
702
+ itp_list = [[]]
703
+ its_list = [[]]
704
+
705
+ sample = normalize(sample)
706
+ return (sample.astype(self.dtype), target.astype(self.dtype), base_name)
707
+
708
+ def dataset(self, batch_size, num_parallel_calls=2, shuffle=True, drop_remainder=True):
709
+ dataset = dataset_map(
710
+ self,
711
+ output_types=(self.dtype, self.dtype, "string"),
712
+ output_shapes=(self.X_shape, self.Y_shape, None),
713
+ num_parallel_calls=num_parallel_calls,
714
+ shuffle=shuffle,
715
+ )
716
+ dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
717
+ return dataset
718
+
719
+
720
+ class DataReader_test(DataReader):
721
+ def __init__(self, format="numpy", config=DataConfig(), **kwargs):
722
+ super().__init__(format=format, config=config, **kwargs)
723
+
724
+ self.select_range = [5000, 8000]
725
+
726
+ def __getitem__(self, i):
727
+ base_name = self.data_list[i]
728
+ if self.format == "numpy":
729
+ meta = self.read_numpy(os.path.join(self.data_dir, base_name))
730
+ elif self.format == "hdf5":
731
+ meta = self.read_hdf5(base_name)
732
+ if meta == -1:
733
+ return (np.zeros(self.Y_shape, dtype=self.dtype), np.zeros(self.X_shape, dtype=self.dtype), base_name)
734
+
735
+ sample = np.copy(meta["data"])
736
+ itp_list = meta["itp"]
737
+ its_list = meta["its"]
738
+
739
+ # sample, itp_list, its_list, _ = self.random_shift(sample, itp_list, its_list, shift_range=self.shift_range)
740
+ target = self.generate_label(sample, [itp_list, its_list])
741
+ sample, target, itp_list, its_list = self.cut_window(sample, target, itp_list, its_list, self.select_range)
742
+
743
+ sample = normalize(sample)
744
+ return (sample, target, base_name, itp_list, its_list)
745
+
746
+ def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainder=False):
747
+ dataset = dataset_map(
748
+ self,
749
+ output_types=(self.dtype, self.dtype, "string", "int64", "int64"),
750
+ output_shapes=(self.X_shape, self.Y_shape, None, None, None),
751
+ num_parallel_calls=num_parallel_calls,
752
+ shuffle=shuffle,
753
+ )
754
+ dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
755
+ return dataset
756
+
757
+
758
+ class DataReader_pred(DataReader):
759
+ def __init__(self, format="numpy", amplitude=True, config=DataConfig(), **kwargs):
760
+ super().__init__(format=format, config=config, **kwargs)
761
+
762
+ self.amplitude = amplitude
763
+
764
+ def adjust_missingchannels(self, data):
765
+ tmp = np.max(np.abs(data), axis=0, keepdims=True)
766
+ assert tmp.shape[-1] == data.shape[-1]
767
+ if np.count_nonzero(tmp) > 0:
768
+ data *= data.shape[-1] / np.count_nonzero(tmp)
769
+ return data
770
+
771
+ def __getitem__(self, i):
772
+ base_name = self.data_list[i]
773
+
774
+ if self.format == "numpy":
775
+ meta = self.read_numpy(os.path.join(self.data_dir, base_name))
776
+ elif (self.format == "mseed") or (self.format == "sac"):
777
+ meta = self.read_mseed(
778
+ os.path.join(self.data_dir, base_name),
779
+ response=self.response,
780
+ sampling_rate=self.sampling_rate,
781
+ highpass_filter=self.highpass_filter,
782
+ return_single_station=True,
783
+ )
784
+ elif self.format == "hdf5":
785
+ meta = self.read_hdf5(base_name)
786
+ else:
787
+ raise (f"{self.format} does not support!")
788
+
789
+ if "data" in meta:
790
+ raw_amp = meta["data"].copy()
791
+ sample = normalize_long(meta["data"])
792
+ else:
793
+ raw_amp = np.zeros([3000, 1, 3], dtype=np.float32)
794
+ sample = np.zeros([3000, 1, 3], dtype=np.float32)
795
+
796
+ if "t0" in meta:
797
+ t0 = meta["t0"]
798
+ else:
799
+ t0 = "1970-01-01T00:00:00.000"
800
+
801
+ if "station_id" in meta:
802
+ station_id = meta["station_id"]
803
+ else:
804
+ # station_id = base_name.split("/")[-1].rstrip("*")
805
+ station_id = os.path.basename(base_name).rstrip("*")
806
+
807
+ if np.isnan(sample).any() or np.isinf(sample).any():
808
+ logging.warning(f"Data error: Nan or Inf found in {base_name}")
809
+ sample[np.isnan(sample)] = 0
810
+ sample[np.isinf(sample)] = 0
811
+
812
+ # sample = self.adjust_missingchannels(sample)
813
+
814
+ if self.amplitude:
815
+ return (sample, raw_amp, base_name, t0, station_id)
816
+ else:
817
+ return (sample, base_name, t0, station_id)
818
+
819
+ def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainder=False):
820
+ if self.amplitude:
821
+ dataset = dataset_map(
822
+ self,
823
+ output_types=(self.dtype, self.dtype, "string", "string", "string"),
824
+ output_shapes=([None, None, 3], [None, None, 3], None, None, None),
825
+ num_parallel_calls=num_parallel_calls,
826
+ shuffle=shuffle,
827
+ )
828
+ else:
829
+ dataset = dataset_map(
830
+ self,
831
+ output_types=(self.dtype, "string", "string", "string"),
832
+ output_shapes=([None, None, 3], None, None, None),
833
+ num_parallel_calls=num_parallel_calls,
834
+ shuffle=shuffle,
835
+ )
836
+ dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
837
+ return dataset
838
+
839
+
840
+ class DataReader_mseed_array(DataReader):
841
+ def __init__(self, stations, amplitude=True, remove_resp=True, config=DataConfig(), **kwargs):
842
+ super().__init__(format="mseed", config=config, **kwargs)
843
+
844
+ # self.stations = pd.read_json(stations)
845
+ with open(stations, "r") as f:
846
+ self.stations = json.load(f)
847
+ print(pd.DataFrame.from_dict(self.stations, orient="index").to_string())
848
+
849
+ self.amplitude = amplitude
850
+ self.remove_resp = remove_resp
851
+ self.X_shape = self.get_data_shape()
852
+
853
+ def get_data_shape(self):
854
+ fname = os.path.join(self.data_dir, self.data_list[0])
855
+ meta = self.read_mseed_array(fname, self.stations, self.amplitude, self.remove_resp)
856
+ return meta["data"].shape
857
+
858
+ def __getitem__(self, i):
859
+ fp = os.path.join(self.data_dir, self.data_list[i])
860
+ # try:
861
+ meta = self.read_mseed_array(fp, self.stations, self.amplitude, self.remove_resp)
862
+ # except Exception as e:
863
+ # logging.error(f"Failed reading {fp}: {e}")
864
+ # if self.amplitude:
865
+ # return (np.zeros(self.X_shape).astype(self.dtype), np.zeros(self.X_shape).astype(self.dtype),
866
+ # [self.stations.iloc[i]["station"] for i in range(len(self.stations))], ["0" for i in range(len(self.stations))])
867
+ # else:
868
+ # return (np.zeros(self.X_shape).astype(self.dtype), ["" for i in range(len(self.stations))],
869
+ # [self.stations.iloc[i]["station"] for i in range(len(self.stations))])
870
+
871
+ sample = np.zeros([len(meta["data"]), *self.X_shape[1:]], dtype=self.dtype)
872
+ sample[:, : meta["data"].shape[1], :, :] = normalize_batch(meta["data"])[:, : self.X_shape[1], :, :]
873
+ if np.isnan(sample).any() or np.isinf(sample).any():
874
+ logging.warning(f"Data error: Nan or Inf found in {fp}")
875
+ sample[np.isnan(sample)] = 0
876
+ sample[np.isinf(sample)] = 0
877
+ t0 = meta["t0"]
878
+ base_name = meta["fname"]
879
+ station_id = meta["station_id"]
880
+ # base_name = [self.stations.iloc[i]["station"]+"."+t0[i] for i in range(len(self.stations))]
881
+ # base_name = [self.stations.iloc[i]["station"] for i in range(len(self.stations))]
882
+
883
+ if self.amplitude:
884
+ raw_amp = np.zeros([len(meta["raw_amp"]), *self.X_shape[1:]], dtype=self.dtype)
885
+ raw_amp[:, : meta["raw_amp"].shape[1], :, :] = meta["raw_amp"][:, : self.X_shape[1], :, :]
886
+ if np.isnan(raw_amp).any() or np.isinf(raw_amp).any():
887
+ logging.warning(f"Data error: Nan or Inf found in {fp}")
888
+ raw_amp[np.isnan(raw_amp)] = 0
889
+ raw_amp[np.isinf(raw_amp)] = 0
890
+ return (sample, raw_amp, base_name, t0, station_id)
891
+ else:
892
+ return (sample, base_name, t0, station_id)
893
+
894
+ def dataset(self, num_parallel_calls=1, shuffle=False):
895
+ if self.amplitude:
896
+ dataset = dataset_map(
897
+ self,
898
+ output_types=(self.dtype, self.dtype, "string", "string", "string"),
899
+ output_shapes=([None, *self.X_shape[1:]], [None, *self.X_shape[1:]], None, None, None),
900
+ num_parallel_calls=num_parallel_calls,
901
+ )
902
+ else:
903
+ dataset = dataset_map(
904
+ self,
905
+ output_types=(self.dtype, "string", "string", "string"),
906
+ output_shapes=([None, *self.X_shape[1:]], None, None, None),
907
+ num_parallel_calls=num_parallel_calls,
908
+ )
909
+ dataset = dataset.prefetch(1)
910
+ # dataset = dataset.prefetch(len(self.stations)*2)
911
+ return dataset
912
+
913
+
914
+ ###### test ########
915
+
916
+
917
+ def test_DataReader():
918
+ import os
919
+ import timeit
920
+
921
+ import matplotlib.pyplot as plt
922
+
923
+ if not os.path.exists("test_figures"):
924
+ os.mkdir("test_figures")
925
+
926
+ def plot_sample(sample, fname, label=None):
927
+ plt.clf()
928
+ plt.subplot(211)
929
+ plt.plot(sample[:, 0, -1])
930
+ if label is not None:
931
+ plt.subplot(212)
932
+ plt.plot(label[:, 0, 0])
933
+ plt.plot(label[:, 0, 1])
934
+ plt.plot(label[:, 0, 2])
935
+ plt.savefig(f"test_figures/{fname.decode()}.png")
936
+
937
+ def read(data_reader, batch=1):
938
+ start_time = timeit.default_timer()
939
+ if batch is None:
940
+ dataset = data_reader.dataset(shuffle=False)
941
+ else:
942
+ dataset = data_reader.dataset(1, shuffle=False)
943
+ sess = tf.compat.v1.Session()
944
+
945
+ print(len(data_reader))
946
+ print("-------", tf.data.Dataset.cardinality(dataset))
947
+ num = 0
948
+ x = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
949
+ while True:
950
+ num += 1
951
+ # print(num)
952
+ try:
953
+ out = sess.run(x)
954
+ if len(out) == 2:
955
+ sample, fname = out[0], out[1]
956
+ for i in range(len(sample)):
957
+ plot_sample(sample[i], fname[i])
958
+ else:
959
+ sample, label, fname = out[0], out[1], out[2]
960
+ for i in range(len(sample)):
961
+ plot_sample(sample[i], fname[i], label[i])
962
+ except tf.errors.OutOfRangeError:
963
+ break
964
+ print("End of dataset")
965
+ print("Tensorflow Dataset:\nexecution time = ", timeit.default_timer() - start_time)
966
+
967
+ data_reader = DataReader_train(data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
968
+
969
+ read(data_reader)
970
+
971
+ data_reader = DataReader_train(format="hdf5", hdf5="test_data/data.h5", group="data")
972
+
973
+ read(data_reader)
974
+
975
+ data_reader = DataReader_test(data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
976
+
977
+ read(data_reader)
978
+
979
+ data_reader = DataReader_test(format="hdf5", hdf5="test_data/data.h5", group="data")
980
+
981
+ read(data_reader)
982
+
983
+ data_reader = DataReader_pred(format="numpy", data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
984
+
985
+ read(data_reader)
986
+
987
+ data_reader = DataReader_pred(
988
+ format="mseed", data_list="test_data/mseed_station.csv", data_dir="test_data/waveforms/"
989
+ )
990
+
991
+ read(data_reader)
992
+
993
+ data_reader = DataReader_pred(
994
+ format="mseed", amplitude=True, data_list="test_data/mseed_station.csv", data_dir="test_data/waveforms/"
995
+ )
996
+
997
+ read(data_reader)
998
+
999
+ data_reader = DataReader_mseed_array(
1000
+ data_list="test_data/mseed.csv",
1001
+ data_dir="test_data/waveforms/",
1002
+ stations="test_data/stations.csv",
1003
+ remove_resp=False,
1004
+ )
1005
+
1006
+ read(data_reader, batch=None)
1007
+
1008
+
1009
+ if __name__ == "__main__":
1010
+ test_DataReader()
phasenet/detect_peaks.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Detect peaks in data based on their amplitude and other features."""
2
+
3
+ from __future__ import division, print_function
4
+ import warnings
5
+ import numpy as np
6
+
7
+ __author__ = "Marcos Duarte, https://github.com/demotu"
8
+ __version__ = "1.0.6"
9
+ __license__ = "MIT"
10
+
11
+
12
+
13
+ def detect_peaks(x, mph=None, mpd=1, threshold=0, edge='rising',
14
+ kpsh=False, valley=False, show=False, ax=None, title=True):
15
+
16
+ """Detect peaks in data based on their amplitude and other features.
17
+
18
+ Parameters
19
+ ----------
20
+ x : 1D array_like
21
+ data.
22
+ mph : {None, number}, optional (default = None)
23
+ detect peaks that are greater than minimum peak height (if parameter
24
+ `valley` is False) or peaks that are smaller than maximum peak height
25
+ (if parameter `valley` is True).
26
+ mpd : positive integer, optional (default = 1)
27
+ detect peaks that are at least separated by minimum peak distance (in
28
+ number of data).
29
+ threshold : positive number, optional (default = 0)
30
+ detect peaks (valleys) that are greater (smaller) than `threshold`
31
+ in relation to their immediate neighbors.
32
+ edge : {None, 'rising', 'falling', 'both'}, optional (default = 'rising')
33
+ for a flat peak, keep only the rising edge ('rising'), only the
34
+ falling edge ('falling'), both edges ('both'), or don't detect a
35
+ flat peak (None).
36
+ kpsh : bool, optional (default = False)
37
+ keep peaks with same height even if they are closer than `mpd`.
38
+ valley : bool, optional (default = False)
39
+ if True (1), detect valleys (local minima) instead of peaks.
40
+ show : bool, optional (default = False)
41
+ if True (1), plot data in matplotlib figure.
42
+ ax : a matplotlib.axes.Axes instance, optional (default = None).
43
+ title : bool or string, optional (default = True)
44
+ if True, show standard title. If False or empty string, doesn't show
45
+ any title. If string, shows string as title.
46
+
47
+ Returns
48
+ -------
49
+ ind : 1D array_like
50
+ indeces of the peaks in `x`.
51
+
52
+ Notes
53
+ -----
54
+ The detection of valleys instead of peaks is performed internally by simply
55
+ negating the data: `ind_valleys = detect_peaks(-x)`
56
+
57
+ The function can handle NaN's
58
+
59
+ See this IPython Notebook [1]_.
60
+
61
+ References
62
+ ----------
63
+ .. [1] http://nbviewer.ipython.org/github/demotu/BMC/blob/master/notebooks/DetectPeaks.ipynb
64
+
65
+ Examples
66
+ --------
67
+ >>> from detect_peaks import detect_peaks
68
+ >>> x = np.random.randn(100)
69
+ >>> x[60:81] = np.nan
70
+ >>> # detect all peaks and plot data
71
+ >>> ind = detect_peaks(x, show=True)
72
+ >>> print(ind)
73
+
74
+ >>> x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5
75
+ >>> # set minimum peak height = 0 and minimum peak distance = 20
76
+ >>> detect_peaks(x, mph=0, mpd=20, show=True)
77
+
78
+ >>> x = [0, 1, 0, 2, 0, 3, 0, 2, 0, 1, 0]
79
+ >>> # set minimum peak distance = 2
80
+ >>> detect_peaks(x, mpd=2, show=True)
81
+
82
+ >>> x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5
83
+ >>> # detection of valleys instead of peaks
84
+ >>> detect_peaks(x, mph=-1.2, mpd=20, valley=True, show=True)
85
+
86
+ >>> x = [0, 1, 1, 0, 1, 1, 0]
87
+ >>> # detect both edges
88
+ >>> detect_peaks(x, edge='both', show=True)
89
+
90
+ >>> x = [-2, 1, -2, 2, 1, 1, 3, 0]
91
+ >>> # set threshold = 2
92
+ >>> detect_peaks(x, threshold = 2, show=True)
93
+
94
+ >>> x = [-2, 1, -2, 2, 1, 1, 3, 0]
95
+ >>> fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(10, 4))
96
+ >>> detect_peaks(x, show=True, ax=axs[0], threshold=0.5, title=False)
97
+ >>> detect_peaks(x, show=True, ax=axs[1], threshold=1.5, title=False)
98
+
99
+ Version history
100
+ ---------------
101
+ '1.0.6':
102
+ Fix issue of when specifying ax object only the first plot was shown
103
+ Add parameter to choose if a title is shown and input a title
104
+ '1.0.5':
105
+ The sign of `mph` is inverted if parameter `valley` is True
106
+
107
+ """
108
+
109
+ x = np.atleast_1d(x).astype('float64')
110
+ if x.size < 3:
111
+ return np.array([], dtype=int)
112
+ if valley:
113
+ x = -x
114
+ if mph is not None:
115
+ mph = -mph
116
+ # find indices of all peaks
117
+ dx = x[1:] - x[:-1]
118
+ # handle NaN's
119
+ indnan = np.where(np.isnan(x))[0]
120
+ if indnan.size:
121
+ x[indnan] = np.inf
122
+ dx[np.where(np.isnan(dx))[0]] = np.inf
123
+ ine, ire, ife = np.array([[], [], []], dtype=int)
124
+ if not edge:
125
+ ine = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) > 0))[0]
126
+ else:
127
+ if edge.lower() in ['rising', 'both']:
128
+ ire = np.where((np.hstack((dx, 0)) <= 0) & (np.hstack((0, dx)) > 0))[0]
129
+ if edge.lower() in ['falling', 'both']:
130
+ ife = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) >= 0))[0]
131
+ ind = np.unique(np.hstack((ine, ire, ife)))
132
+ # handle NaN's
133
+ if ind.size and indnan.size:
134
+ # NaN's and values close to NaN's cannot be peaks
135
+ ind = ind[np.in1d(ind, np.unique(np.hstack((indnan, indnan-1, indnan+1))), invert=True)]
136
+ # first and last values of x cannot be peaks
137
+ if ind.size and ind[0] == 0:
138
+ ind = ind[1:]
139
+ if ind.size and ind[-1] == x.size-1:
140
+ ind = ind[:-1]
141
+ # remove peaks < minimum peak height
142
+ if ind.size and mph is not None:
143
+ ind = ind[x[ind] >= mph]
144
+ # remove peaks - neighbors < threshold
145
+ if ind.size and threshold > 0:
146
+ dx = np.min(np.vstack([x[ind]-x[ind-1], x[ind]-x[ind+1]]), axis=0)
147
+ ind = np.delete(ind, np.where(dx < threshold)[0])
148
+ # detect small peaks closer than minimum peak distance
149
+ if ind.size and mpd > 1:
150
+ ind = ind[np.argsort(x[ind])][::-1] # sort ind by peak height
151
+ idel = np.zeros(ind.size, dtype=bool)
152
+ for i in range(ind.size):
153
+ if not idel[i]:
154
+ # keep peaks with the same height if kpsh is True
155
+ idel = idel | (ind >= ind[i] - mpd) & (ind <= ind[i] + mpd) \
156
+ & (x[ind[i]] > x[ind] if kpsh else True)
157
+ idel[i] = 0 # Keep current peak
158
+ # remove the small peaks and sort back the indices by their occurrence
159
+ ind = np.sort(ind[~idel])
160
+
161
+ if show:
162
+ if indnan.size:
163
+ x[indnan] = np.nan
164
+ if valley:
165
+ x = -x
166
+ if mph is not None:
167
+ mph = -mph
168
+ _plot(x, mph, mpd, threshold, edge, valley, ax, ind, title)
169
+
170
+ return ind, x[ind]
171
+
172
+
173
+ def _plot(x, mph, mpd, threshold, edge, valley, ax, ind, title):
174
+ """Plot results of the detect_peaks function, see its help."""
175
+ try:
176
+ import matplotlib.pyplot as plt
177
+ except ImportError:
178
+ print('matplotlib is not available.')
179
+ else:
180
+ if ax is None:
181
+ _, ax = plt.subplots(1, 1, figsize=(8, 4))
182
+ no_ax = True
183
+ else:
184
+ no_ax = False
185
+
186
+ ax.plot(x, 'b', lw=1)
187
+ if ind.size:
188
+ label = 'valley' if valley else 'peak'
189
+ label = label + 's' if ind.size > 1 else label
190
+ ax.plot(ind, x[ind], '+', mfc=None, mec='r', mew=2, ms=8,
191
+ label='%d %s' % (ind.size, label))
192
+ ax.legend(loc='best', framealpha=.5, numpoints=1)
193
+ ax.set_xlim(-.02*x.size, x.size*1.02-1)
194
+ ymin, ymax = x[np.isfinite(x)].min(), x[np.isfinite(x)].max()
195
+ yrange = ymax - ymin if ymax > ymin else 1
196
+ ax.set_ylim(ymin - 0.1*yrange, ymax + 0.1*yrange)
197
+ ax.set_xlabel('Data #', fontsize=14)
198
+ ax.set_ylabel('Amplitude', fontsize=14)
199
+ if title:
200
+ if not isinstance(title, str):
201
+ mode = 'Valley detection' if valley else 'Peak detection'
202
+ title = "%s (mph=%s, mpd=%d, threshold=%s, edge='%s')"% \
203
+ (mode, str(mph), mpd, str(threshold), edge)
204
+ ax.set_title(title)
205
+ # plt.grid()
206
+ if no_ax:
207
+ plt.show()
phasenet/model.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ tf.compat.v1.disable_eager_execution()
3
+ import numpy as np
4
+ import logging
5
+ import warnings
6
+ warnings.filterwarnings('ignore', category=UserWarning)
7
+
8
+ class ModelConfig:
9
+
10
+ batch_size = 20
11
+ depths = 5
12
+ filters_root = 8
13
+ kernel_size = [7, 1]
14
+ pool_size = [4, 1]
15
+ dilation_rate = [1, 1]
16
+ class_weights = [1.0, 1.0, 1.0]
17
+ loss_type = "cross_entropy"
18
+ weight_decay = 0.0
19
+ optimizer = "adam"
20
+ momentum = 0.9
21
+ learning_rate = 0.01
22
+ decay_step = 1e9
23
+ decay_rate = 0.9
24
+ drop_rate = 0.0
25
+ summary = True
26
+
27
+ X_shape = [3000, 1, 3]
28
+ n_channel = X_shape[-1]
29
+ Y_shape = [3000, 1, 3]
30
+ n_class = Y_shape[-1]
31
+
32
+ def __init__(self, **kwargs):
33
+ for k,v in kwargs.items():
34
+ setattr(self, k, v)
35
+
36
+ def update_args(self, args):
37
+ for k,v in vars(args).items():
38
+ setattr(self, k, v)
39
+
40
+
41
+ def crop_and_concat(net1, net2):
42
+ """
43
+ the size(net1) <= size(net2)
44
+ """
45
+ # net1_shape = net1.get_shape().as_list()
46
+ # net2_shape = net2.get_shape().as_list()
47
+ # # print(net1_shape)
48
+ # # print(net2_shape)
49
+ # # if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
50
+ # offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
51
+ # size = [-1, net1_shape[1], net1_shape[2], -1]
52
+ # net2_resize = tf.slice(net2, offsets, size)
53
+ # return tf.concat([net1, net2_resize], 3)
54
+
55
+ ## dynamic shape
56
+ chn1 = net1.get_shape().as_list()[-1]
57
+ chn2 = net2.get_shape().as_list()[-1]
58
+ net1_shape = tf.shape(net1)
59
+ net2_shape = tf.shape(net2)
60
+ # print(net1_shape)
61
+ # print(net2_shape)
62
+ # if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
63
+ offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
64
+ size = [-1, net1_shape[1], net1_shape[2], -1]
65
+ net2_resize = tf.slice(net2, offsets, size)
66
+
67
+ out = tf.concat([net1, net2_resize], 3)
68
+ out.set_shape([None, None, None, chn1+chn2])
69
+
70
+ return out
71
+
72
+ # else:
73
+ # offsets = [0, (net1_shape[1] - net2_shape[1]) // 2, (net1_shape[2] - net2_shape[2]) // 2, 0]
74
+ # size = [-1, net2_shape[1], net2_shape[2], -1]
75
+ # net1_resize = tf.slice(net1, offsets, size)
76
+ # return tf.concat([net1_resize, net2], 3)
77
+
78
+
79
+ def crop_only(net1, net2):
80
+ """
81
+ the size(net1) <= size(net2)
82
+ """
83
+ net1_shape = net1.get_shape().as_list()
84
+ net2_shape = net2.get_shape().as_list()
85
+ # print(net1_shape)
86
+ # print(net2_shape)
87
+ # if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
88
+ offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
89
+ size = [-1, net1_shape[1], net1_shape[2], -1]
90
+ net2_resize = tf.slice(net2, offsets, size)
91
+ #return tf.concat([net1, net2_resize], 3)
92
+ return net2_resize
93
+
94
+ class UNet:
95
+ def __init__(self, config=ModelConfig(), input_batch=None, mode='train'):
96
+ self.depths = config.depths
97
+ self.filters_root = config.filters_root
98
+ self.kernel_size = config.kernel_size
99
+ self.dilation_rate = config.dilation_rate
100
+ self.pool_size = config.pool_size
101
+ self.X_shape = config.X_shape
102
+ self.Y_shape = config.Y_shape
103
+ self.n_channel = config.n_channel
104
+ self.n_class = config.n_class
105
+ self.class_weights = config.class_weights
106
+ self.batch_size = config.batch_size
107
+ self.loss_type = config.loss_type
108
+ self.weight_decay = config.weight_decay
109
+ self.optimizer = config.optimizer
110
+ self.learning_rate = config.learning_rate
111
+ self.decay_step = config.decay_step
112
+ self.decay_rate = config.decay_rate
113
+ self.momentum = config.momentum
114
+ self.global_step = tf.compat.v1.get_variable(name="global_step", initializer=0, dtype=tf.int32)
115
+ self.summary_train = []
116
+ self.summary_valid = []
117
+
118
+ self.build(input_batch, mode=mode)
119
+
120
+ def add_placeholders(self, input_batch=None, mode="train"):
121
+ if input_batch is None:
122
+ # self.X = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, self.X_shape[-3], self.X_shape[-2], self.X_shape[-1]], name='X')
123
+ # self.Y = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, self.Y_shape[-3], self.Y_shape[-2], self.n_class], name='y')
124
+ self.X = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, None, None, self.X_shape[-1]], name='X')
125
+ self.Y = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, None, None, self.n_class], name='y')
126
+ else:
127
+ self.X = input_batch[0]
128
+ if mode in ["train", "valid", "test"]:
129
+ self.Y = input_batch[1]
130
+ self.input_batch = input_batch
131
+
132
+ self.is_training = tf.compat.v1.placeholder(dtype=tf.bool, name="is_training")
133
+ # self.keep_prob = tf.compat.v1.placeholder(dtype=tf.float32, name="keep_prob")
134
+ self.drop_rate = tf.compat.v1.placeholder(dtype=tf.float32, name="drop_rate")
135
+
136
+ def add_prediction_op(self):
137
+ logging.info("Model: depths {depths}, filters {filters}, "
138
+ "filter size {kernel_size[0]}x{kernel_size[1]}, "
139
+ "pool size: {pool_size[0]}x{pool_size[1]}, "
140
+ "dilation rate: {dilation_rate[0]}x{dilation_rate[1]}".format(
141
+ depths=self.depths,
142
+ filters=self.filters_root,
143
+ kernel_size=self.kernel_size,
144
+ dilation_rate=self.dilation_rate,
145
+ pool_size=self.pool_size))
146
+
147
+ if self.weight_decay > 0:
148
+ weight_decay = tf.constant(self.weight_decay, dtype=tf.float32, name="weight_constant")
149
+ self.regularizer = tf.keras.regularizers.l2(l=0.5 * (weight_decay))
150
+ else:
151
+ self.regularizer = None
152
+
153
+ self.initializer = tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")
154
+
155
+ # down sample layers
156
+ convs = [None] * self.depths # store output of each depth
157
+
158
+ with tf.compat.v1.variable_scope("Input"):
159
+ net = self.X
160
+ net = tf.compat.v1.layers.conv2d(net,
161
+ filters=self.filters_root,
162
+ kernel_size=self.kernel_size,
163
+ activation=None,
164
+ padding='same',
165
+ dilation_rate=self.dilation_rate,
166
+ kernel_initializer=self.initializer,
167
+ kernel_regularizer=self.regularizer,
168
+ name="input_conv")
169
+ net = tf.compat.v1.layers.batch_normalization(net,
170
+ training=self.is_training,
171
+ name="input_bn")
172
+ net = tf.nn.relu(net,
173
+ name="input_relu")
174
+ # net = tf.nn.dropout(net, self.keep_prob)
175
+ net = tf.compat.v1.layers.dropout(net,
176
+ rate=self.drop_rate,
177
+ training=self.is_training,
178
+ name="input_dropout")
179
+
180
+
181
+ for depth in range(0, self.depths):
182
+ with tf.compat.v1.variable_scope("DownConv_%d" % depth):
183
+ filters = int(2**(depth) * self.filters_root)
184
+
185
+ net = tf.compat.v1.layers.conv2d(net,
186
+ filters=filters,
187
+ kernel_size=self.kernel_size,
188
+ activation=None,
189
+ use_bias=False,
190
+ padding='same',
191
+ dilation_rate=self.dilation_rate,
192
+ kernel_initializer=self.initializer,
193
+ kernel_regularizer=self.regularizer,
194
+ name="down_conv1_{}".format(depth + 1))
195
+ net = tf.compat.v1.layers.batch_normalization(net,
196
+ training=self.is_training,
197
+ name="down_bn1_{}".format(depth + 1))
198
+ net = tf.nn.relu(net,
199
+ name="down_relu1_{}".format(depth+1))
200
+ net = tf.compat.v1.layers.dropout(net,
201
+ rate=self.drop_rate,
202
+ training=self.is_training,
203
+ name="down_dropout1_{}".format(depth + 1))
204
+
205
+ convs[depth] = net
206
+
207
+ if depth < self.depths - 1:
208
+ net = tf.compat.v1.layers.conv2d(net,
209
+ filters=filters,
210
+ kernel_size=self.kernel_size,
211
+ strides=self.pool_size,
212
+ activation=None,
213
+ use_bias=False,
214
+ padding='same',
215
+ dilation_rate=self.dilation_rate,
216
+ kernel_initializer=self.initializer,
217
+ kernel_regularizer=self.regularizer,
218
+ name="down_conv3_{}".format(depth + 1))
219
+ net = tf.compat.v1.layers.batch_normalization(net,
220
+ training=self.is_training,
221
+ name="down_bn3_{}".format(depth + 1))
222
+ net = tf.nn.relu(net,
223
+ name="down_relu3_{}".format(depth+1))
224
+ net = tf.compat.v1.layers.dropout(net,
225
+ rate=self.drop_rate,
226
+ training=self.is_training,
227
+ name="down_dropout3_{}".format(depth + 1))
228
+
229
+
230
+ # up layers
231
+ for depth in range(self.depths - 2, -1, -1):
232
+ with tf.compat.v1.variable_scope("UpConv_%d" % depth):
233
+ filters = int(2**(depth) * self.filters_root)
234
+ net = tf.compat.v1.layers.conv2d_transpose(net,
235
+ filters=filters,
236
+ kernel_size=self.kernel_size,
237
+ strides=self.pool_size,
238
+ activation=None,
239
+ use_bias=False,
240
+ padding="same",
241
+ kernel_initializer=self.initializer,
242
+ kernel_regularizer=self.regularizer,
243
+ name="up_conv0_{}".format(depth+1))
244
+ net = tf.compat.v1.layers.batch_normalization(net,
245
+ training=self.is_training,
246
+ name="up_bn0_{}".format(depth + 1))
247
+ net = tf.nn.relu(net,
248
+ name="up_relu0_{}".format(depth+1))
249
+ net = tf.compat.v1.layers.dropout(net,
250
+ rate=self.drop_rate,
251
+ training=self.is_training,
252
+ name="up_dropout0_{}".format(depth + 1))
253
+
254
+
255
+ #skip connection
256
+ net = crop_and_concat(convs[depth], net)
257
+ #net = crop_only(convs[depth], net)
258
+
259
+ net = tf.compat.v1.layers.conv2d(net,
260
+ filters=filters,
261
+ kernel_size=self.kernel_size,
262
+ activation=None,
263
+ use_bias=False,
264
+ padding='same',
265
+ dilation_rate=self.dilation_rate,
266
+ kernel_initializer=self.initializer,
267
+ kernel_regularizer=self.regularizer,
268
+ name="up_conv1_{}".format(depth + 1))
269
+ net = tf.compat.v1.layers.batch_normalization(net,
270
+ training=self.is_training,
271
+ name="up_bn1_{}".format(depth + 1))
272
+ net = tf.nn.relu(net,
273
+ name="up_relu1_{}".format(depth + 1))
274
+ net = tf.compat.v1.layers.dropout(net,
275
+ rate=self.drop_rate,
276
+ training=self.is_training,
277
+ name="up_dropout1_{}".format(depth + 1))
278
+
279
+
280
+ # Output Map
281
+ with tf.compat.v1.variable_scope("Output"):
282
+ net = tf.compat.v1.layers.conv2d(net,
283
+ filters=self.n_class,
284
+ kernel_size=(1,1),
285
+ activation=None,
286
+ padding='same',
287
+ #dilation_rate=self.dilation_rate,
288
+ kernel_initializer=self.initializer,
289
+ kernel_regularizer=self.regularizer,
290
+ name="output_conv")
291
+ # net = tf.nn.relu(net,
292
+ # name="output_relu")
293
+ # net = tf.compat.v1.layers.dropout(net,
294
+ # rate=self.drop_rate,
295
+ # training=self.is_training,
296
+ # name="output_dropout")
297
+ # net = tf.compat.v1.layers.batch_normalization(net,
298
+ # training=self.is_training,
299
+ # name="output_bn")
300
+ output = net
301
+
302
+ with tf.compat.v1.variable_scope("representation"):
303
+ self.representation = convs[-1]
304
+
305
+ with tf.compat.v1.variable_scope("logits"):
306
+ self.logits = output
307
+ tmp = tf.compat.v1.summary.histogram("logits", self.logits)
308
+ self.summary_train.append(tmp)
309
+
310
+ with tf.compat.v1.variable_scope("preds"):
311
+ self.preds = tf.nn.softmax(output)
312
+ tmp = tf.compat.v1.summary.histogram("preds", self.preds)
313
+ self.summary_train.append(tmp)
314
+
315
+ def add_loss_op(self):
316
+ if self.loss_type == "cross_entropy":
317
+ with tf.compat.v1.variable_scope("cross_entropy"):
318
+ flat_logits = tf.reshape(self.logits, [-1, self.n_class], name="logits")
319
+ flat_labels = tf.reshape(self.Y, [-1, self.n_class], name="labels")
320
+ if (np.array(self.class_weights) != 1).any():
321
+ class_weights = tf.constant(np.array(self.class_weights, dtype=np.float32), name="class_weights")
322
+ weight_map = tf.multiply(flat_labels, class_weights)
323
+ weight_map = tf.reduce_sum(input_tensor=weight_map, axis=1)
324
+ loss_map = tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits,
325
+ labels=flat_labels)
326
+
327
+ weighted_loss = tf.multiply(loss_map, weight_map)
328
+ loss = tf.reduce_mean(input_tensor=weighted_loss)
329
+ else:
330
+ loss = tf.reduce_mean(input_tensor=tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits,
331
+ labels=flat_labels))
332
+
333
+ elif self.loss_type == "IOU":
334
+ with tf.compat.v1.variable_scope("IOU"):
335
+ eps = 1e-7
336
+ loss = 0
337
+ for i in range(1, self.n_class):
338
+ intersection = eps + tf.reduce_sum(input_tensor=self.preds[:,:,:,i] * self.Y[:,:,:,i], axis=[1,2])
339
+ union = eps + tf.reduce_sum(input_tensor=self.preds[:,:,:,i], axis=[1,2]) + tf.reduce_sum(input_tensor=self.Y[:,:,:,i], axis=[1,2])
340
+ loss += 1 - tf.reduce_mean(input_tensor=intersection / union)
341
+ elif self.loss_type == "mean_squared":
342
+ with tf.compat.v1.variable_scope("mean_squared"):
343
+ flat_logits = tf.reshape(self.logits, [-1, self.n_class], name="logits")
344
+ flat_labels = tf.reshape(self.Y, [-1, self.n_class], name="labels")
345
+ with tf.compat.v1.variable_scope("mean_squared"):
346
+ loss = tf.compat.v1.losses.mean_squared_error(labels=flat_labels, predictions=flat_logits)
347
+ else:
348
+ raise ValueError("Unknown loss function: " % self.loss_type)
349
+
350
+ tmp = tf.compat.v1.summary.scalar("train_loss", loss)
351
+ self.summary_train.append(tmp)
352
+ tmp = tf.compat.v1.summary.scalar("valid_loss", loss)
353
+ self.summary_valid.append(tmp)
354
+
355
+ if self.weight_decay > 0:
356
+ with tf.compat.v1.name_scope('weight_loss'):
357
+ tmp = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
358
+ weight_loss = tf.add_n(tmp, name="weight_loss")
359
+ self.loss = loss + weight_loss
360
+ else:
361
+ self.loss = loss
362
+
363
+ def add_training_op(self):
364
+ if self.optimizer == "momentum":
365
+ self.learning_rate_node = tf.compat.v1.train.exponential_decay(learning_rate=self.learning_rate,
366
+ global_step=self.global_step,
367
+ decay_steps=self.decay_step,
368
+ decay_rate=self.decay_rate,
369
+ staircase=True)
370
+ optimizer = tf.compat.v1.train.MomentumOptimizer(learning_rate=self.learning_rate_node,
371
+ momentum=self.momentum)
372
+ elif self.optimizer == "adam":
373
+ self.learning_rate_node = tf.compat.v1.train.exponential_decay(learning_rate=self.learning_rate,
374
+ global_step=self.global_step,
375
+ decay_steps=self.decay_step,
376
+ decay_rate=self.decay_rate,
377
+ staircase=True)
378
+
379
+ optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.learning_rate_node)
380
+ update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
381
+ with tf.control_dependencies(update_ops):
382
+ self.train_op = optimizer.minimize(self.loss, global_step=self.global_step)
383
+ tmp = tf.compat.v1.summary.scalar("learning_rate", self.learning_rate_node)
384
+ self.summary_train.append(tmp)
385
+
386
+ def add_metrics_op(self):
387
+ with tf.compat.v1.variable_scope("metrics"):
388
+
389
+ Y= tf.argmax(input=self.Y, axis=-1)
390
+ confusion_matrix = tf.cast(tf.math.confusion_matrix(
391
+ labels=tf.reshape(Y, [-1]),
392
+ predictions=tf.reshape(self.preds, [-1]),
393
+ num_classes=self.n_class, name='confusion_matrix'),
394
+ dtype=tf.float32)
395
+
396
+ # with tf.variable_scope("P"):
397
+ c = tf.constant(1e-7, dtype=tf.float32)
398
+ precision_P = (confusion_matrix[1,1] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[:,1]) + c)
399
+ recall_P = (confusion_matrix[1,1] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[1,:]) + c)
400
+ f1_P = 2 * precision_P * recall_P / (precision_P + recall_P)
401
+
402
+ tmp1 = tf.compat.v1.summary.scalar("train_precision_p", precision_P)
403
+ tmp2 = tf.compat.v1.summary.scalar("train_recall_p", recall_P)
404
+ tmp3 = tf.compat.v1.summary.scalar("train_f1_p", f1_P)
405
+ self.summary_train.extend([tmp1, tmp2, tmp3])
406
+
407
+ tmp1 = tf.compat.v1.summary.scalar("valid_precision_p", precision_P)
408
+ tmp2 = tf.compat.v1.summary.scalar("valid_recall_p", recall_P)
409
+ tmp3 = tf.compat.v1.summary.scalar("valid_f1_p", f1_P)
410
+ self.summary_valid.extend([tmp1, tmp2, tmp3])
411
+
412
+ # with tf.variable_scope("S"):
413
+ precision_S = (confusion_matrix[2,2] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[:,2]) + c)
414
+ recall_S = (confusion_matrix[2,2] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[2,:]) + c)
415
+ f1_S = 2 * precision_S * recall_S / (precision_S + recall_S)
416
+
417
+ tmp1 = tf.compat.v1.summary.scalar("train_precision_s", precision_S)
418
+ tmp2 = tf.compat.v1.summary.scalar("train_recall_s", recall_S)
419
+ tmp3 = tf.compat.v1.summary.scalar("train_f1_s", f1_S)
420
+ self.summary_train.extend([tmp1, tmp2, tmp3])
421
+
422
+ tmp1 = tf.compat.v1.summary.scalar("valid_precision_s", precision_S)
423
+ tmp2 = tf.compat.v1.summary.scalar("valid_recall_s", recall_S)
424
+ tmp3 = tf.compat.v1.summary.scalar("valid_f1_s", f1_S)
425
+ self.summary_valid.extend([tmp1, tmp2, tmp3])
426
+
427
+ self.precision = [precision_P, precision_S]
428
+ self.recall = [recall_P, recall_S]
429
+ self.f1 = [f1_P, f1_S]
430
+
431
+
432
+
433
+ def train_on_batch(self, sess, inputs_batch, labels_batch, summary_writer, drop_rate=0.0):
434
+ feed = {self.X: inputs_batch,
435
+ self.Y: labels_batch,
436
+ self.drop_rate: drop_rate,
437
+ self.is_training: True}
438
+
439
+ _, step_summary, step, loss = sess.run([self.train_op,
440
+ self.summary_train,
441
+ self.global_step,
442
+ self.loss],
443
+ feed_dict=feed)
444
+ summary_writer.add_summary(step_summary, step)
445
+ return loss
446
+
447
+ def valid_on_batch(self, sess, inputs_batch, labels_batch, summary_writer):
448
+ feed = {self.X: inputs_batch,
449
+ self.Y: labels_batch,
450
+ self.drop_rate: 0,
451
+ self.is_training: False}
452
+
453
+ step_summary, step, loss, preds = sess.run([self.summary_valid,
454
+ self.global_step,
455
+ self.loss,
456
+ self.preds],
457
+ feed_dict=feed)
458
+ summary_writer.add_summary(step_summary, step)
459
+ return loss, preds
460
+
461
+ def test_on_batch(self, sess, summary_writer):
462
+ feed = {self.drop_rate: 0,
463
+ self.is_training: False}
464
+ step_summary, step, loss, preds, \
465
+ X_batch, Y_batch, fname_batch, \
466
+ itp_batch, its_batch = sess.run([self.summary_valid,
467
+ self.global_step,
468
+ self.loss,
469
+ self.preds,
470
+ self.X,
471
+ self.Y,
472
+ self.input_batch[2],
473
+ self.input_batch[3],
474
+ self.input_batch[4]],
475
+ feed_dict=feed)
476
+ summary_writer.add_summary(step_summary, step)
477
+ return loss, preds, X_batch, Y_batch, fname_batch, itp_batch, its_batch
478
+
479
+
480
+ def build(self, input_batch=None, mode='train'):
481
+ self.add_placeholders(input_batch, mode)
482
+ self.add_prediction_op()
483
+ if mode in ["train", "valid", "test"]:
484
+ self.add_loss_op()
485
+ self.add_training_op()
486
+ # self.add_metrics_op()
487
+ self.summary_train = tf.compat.v1.summary.merge(self.summary_train)
488
+ self.summary_valid = tf.compat.v1.summary.merge(self.summary_valid)
489
+ return 0
phasenet/postprocess.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from collections import namedtuple
5
+ from datetime import datetime, timedelta
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ from detect_peaks import detect_peaks
10
+
11
+ # def extract_picks(preds, fnames=None, station_ids=None, t0=None, config=None):
12
+
13
+ # if preds.shape[-1] == 4:
14
+ # record = namedtuple("phase", ["fname", "station_id", "t0", "p_idx", "p_prob", "s_idx", "s_prob", "ps_idx", "ps_prob"])
15
+ # else:
16
+ # record = namedtuple("phase", ["fname", "station_id", "t0", "p_idx", "p_prob", "s_idx", "s_prob"])
17
+
18
+ # picks = []
19
+ # for i, pred in enumerate(preds):
20
+
21
+ # if config is None:
22
+ # mph_p, mph_s, mpd = 0.3, 0.3, 50
23
+ # else:
24
+ # mph_p, mph_s, mpd = config.min_p_prob, config.min_s_prob, config.mpd
25
+
26
+ # if (fnames is None):
27
+ # fname = f"{i:04d}"
28
+ # else:
29
+ # if isinstance(fnames[i], str):
30
+ # fname = fnames[i]
31
+ # else:
32
+ # fname = fnames[i].decode()
33
+
34
+ # if (station_ids is None):
35
+ # station_id = f"{i:04d}"
36
+ # else:
37
+ # if isinstance(station_ids[i], str):
38
+ # station_id = station_ids[i]
39
+ # else:
40
+ # station_id = station_ids[i].decode()
41
+
42
+ # if (t0 is None):
43
+ # start_time = "1970-01-01T00:00:00.000"
44
+ # else:
45
+ # if isinstance(t0[i], str):
46
+ # start_time = t0[i]
47
+ # else:
48
+ # start_time = t0[i].decode()
49
+
50
+ # p_idx, p_prob, s_idx, s_prob = [], [], [], []
51
+ # for j in range(pred.shape[1]):
52
+ # p_idx_, p_prob_ = detect_peaks(pred[:,j,1], mph=mph_p, mpd=mpd, show=False)
53
+ # s_idx_, s_prob_ = detect_peaks(pred[:,j,2], mph=mph_s, mpd=mpd, show=False)
54
+ # p_idx.append(list(p_idx_))
55
+ # p_prob.append(list(p_prob_))
56
+ # s_idx.append(list(s_idx_))
57
+ # s_prob.append(list(s_prob_))
58
+
59
+ # if pred.shape[-1] == 4:
60
+ # ps_idx, ps_prob = detect_peaks(pred[:,0,3], mph=0.3, mpd=mpd, show=False)
61
+ # picks.append(record(fname, station_id, start_time, list(p_idx), list(p_prob), list(s_idx), list(s_prob), list(ps_idx), list(ps_prob)))
62
+ # else:
63
+ # picks.append(record(fname, station_id, start_time, list(p_idx), list(p_prob), list(s_idx), list(s_prob)))
64
+
65
+ # return picks
66
+
67
+
68
+ def extract_picks(
69
+ preds,
70
+ file_names=None,
71
+ begin_times=None,
72
+ station_ids=None,
73
+ dt=0.01,
74
+ phases=["P", "S"],
75
+ config=None,
76
+ waveforms=None,
77
+ use_amplitude=False,
78
+ ):
79
+ """Extract picks from prediction results.
80
+ Args:
81
+ preds ([type]): [Nb, Nt, Ns, Nc] "batch, time, station, channel"
82
+ file_names ([type], optional): [Nb]. Defaults to None.
83
+ station_ids ([type], optional): [Ns]. Defaults to None.
84
+ t0 ([type], optional): [Nb]. Defaults to None.
85
+ config ([type], optional): [description]. Defaults to None.
86
+
87
+ Returns:
88
+ picks [type]: {file_name, station_id, pick_time, pick_prob, pick_type}
89
+ """
90
+
91
+ mph = {}
92
+ if config is None:
93
+ for x in phases:
94
+ mph[x] = 0.3
95
+ mpd = 50
96
+ pre_idx = int(1 / dt)
97
+ post_idx = int(4 / dt)
98
+ else:
99
+ mph["P"] = config.min_p_prob
100
+ mph["S"] = config.min_s_prob
101
+ mph["PS"] = 0.3
102
+ mpd = config.mpd
103
+ pre_idx = int(config.pre_sec / dt)
104
+ post_idx = int(config.post_sec / dt)
105
+
106
+ Nb, Nt, Ns, Nc = preds.shape
107
+
108
+ if file_names is None:
109
+ file_names = [f"{i:04d}" for i in range(Nb)]
110
+ elif not (isinstance(file_names, np.ndarray) or isinstance(file_names, list)):
111
+ if isinstance(file_names, bytes):
112
+ file_names = file_names.decode()
113
+ file_names = [file_names] * Nb
114
+ else:
115
+ file_names = [x.decode() if isinstance(x, bytes) else x for x in file_names]
116
+
117
+ if begin_times is None:
118
+ begin_times = ["1970-01-01T00:00:00.000+00:00"] * Nb
119
+ else:
120
+ begin_times = [x.decode() if isinstance(x, bytes) else x for x in begin_times]
121
+
122
+ picks = []
123
+ for i in range(Nb):
124
+ file_name = file_names[i]
125
+ begin_time = datetime.fromisoformat(begin_times[i])
126
+
127
+ for j in range(Ns):
128
+ if (station_ids is None) or (len(station_ids[i]) == 0):
129
+ station_id = f"{j:04d}"
130
+ else:
131
+ station_id = station_ids[i][j].decode() if isinstance(station_ids[i][j], bytes) else station_ids[i][j]
132
+
133
+ if (waveforms is not None) and use_amplitude:
134
+ amp = np.max(np.abs(waveforms[i, :, j, :]), axis=-1) ## amplitude over three channelspy
135
+ for k in range(Nc - 1): # 0-th channel noise
136
+ idxs, probs = detect_peaks(preds[i, :, j, k + 1], mph=mph[phases[k]], mpd=mpd, show=False)
137
+ for l, (phase_index, phase_prob) in enumerate(zip(idxs, probs)):
138
+ pick_time = begin_time + timedelta(seconds=phase_index * dt)
139
+ pick = {
140
+ "file_name": file_name,
141
+ "station_id": station_id,
142
+ "begin_time": begin_time.isoformat(timespec="milliseconds"),
143
+ "phase_index": int(phase_index),
144
+ "phase_time": pick_time.isoformat(timespec="milliseconds"),
145
+ "phase_score": round(phase_prob, 3),
146
+ "phase_type": phases[k],
147
+ "dt": dt,
148
+ }
149
+
150
+ ## process waveform
151
+ if waveforms is not None:
152
+ tmp = np.zeros((pre_idx + post_idx, 3))
153
+ lo = phase_index - pre_idx
154
+ hi = phase_index + post_idx
155
+ insert_idx = 0
156
+ if lo < 0:
157
+ lo = 0
158
+ insert_idx = -lo
159
+ if hi > Nt:
160
+ hi = Nt
161
+ tmp[insert_idx : insert_idx + hi - lo, :] = waveforms[i, lo:hi, j, :]
162
+ if use_amplitude:
163
+ next_pick = idxs[l + 1] if l < len(idxs) - 1 else (phase_index + post_idx * 3)
164
+ pick["phase_amplitude"] = np.max(
165
+ amp[phase_index : min(phase_index + post_idx * 3, next_pick)]
166
+ ).item() ## peak amplitude
167
+
168
+ picks.append(pick)
169
+
170
+ return picks
171
+
172
+
173
+ def extract_amplitude(data, picks, window_p=10, window_s=5, config=None):
174
+ record = namedtuple("amplitude", ["p_amp", "s_amp"])
175
+ dt = 0.01 if config is None else config.dt
176
+ window_p = int(window_p / dt)
177
+ window_s = int(window_s / dt)
178
+ amps = []
179
+ for i, (da, pi) in enumerate(zip(data, picks)):
180
+ p_amp, s_amp = [], []
181
+ for j in range(da.shape[1]):
182
+ amp = np.max(np.abs(da[:, j, :]), axis=-1)
183
+ # amp = np.median(np.abs(da[:,j,:]), axis=-1)
184
+ # amp = np.linalg.norm(da[:,j,:], axis=-1)
185
+ tmp = []
186
+ for k in range(len(pi.p_idx[j]) - 1):
187
+ tmp.append(np.max(amp[pi.p_idx[j][k] : min(pi.p_idx[j][k] + window_p, pi.p_idx[j][k + 1])]))
188
+ if len(pi.p_idx[j]) >= 1:
189
+ tmp.append(np.max(amp[pi.p_idx[j][-1] : pi.p_idx[j][-1] + window_p]))
190
+ p_amp.append(tmp)
191
+ tmp = []
192
+ for k in range(len(pi.s_idx[j]) - 1):
193
+ tmp.append(np.max(amp[pi.s_idx[j][k] : min(pi.s_idx[j][k] + window_s, pi.s_idx[j][k + 1])]))
194
+ if len(pi.s_idx[j]) >= 1:
195
+ tmp.append(np.max(amp[pi.s_idx[j][-1] : pi.s_idx[j][-1] + window_s]))
196
+ s_amp.append(tmp)
197
+ amps.append(record(p_amp, s_amp))
198
+ return amps
199
+
200
+
201
+ def save_picks(picks, output_dir, amps=None, fname=None):
202
+ if fname is None:
203
+ fname = "picks.csv"
204
+
205
+ int2s = lambda x: ",".join(["[" + ",".join(map(str, i)) + "]" for i in x])
206
+ flt2s = lambda x: ",".join(["[" + ",".join(map("{:0.3f}".format, i)) + "]" for i in x])
207
+ sci2s = lambda x: ",".join(["[" + ",".join(map("{:0.3e}".format, i)) + "]" for i in x])
208
+ if amps is None:
209
+ if hasattr(picks[0], "ps_idx"):
210
+ with open(os.path.join(output_dir, fname), "w") as fp:
211
+ fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\tps_idx\tps_prob\n")
212
+ for pick in picks:
213
+ fp.write(
214
+ f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\t{int2s(pick.ps_idx)}\t{flt2s(pick.ps_prob)}\n"
215
+ )
216
+ fp.close()
217
+ else:
218
+ with open(os.path.join(output_dir, fname), "w") as fp:
219
+ fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\n")
220
+ for pick in picks:
221
+ fp.write(
222
+ f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\n"
223
+ )
224
+ fp.close()
225
+ else:
226
+ with open(os.path.join(output_dir, fname), "w") as fp:
227
+ fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\tp_amp\ts_amp\n")
228
+ for pick, amp in zip(picks, amps):
229
+ fp.write(
230
+ f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\t{sci2s(amp.p_amp)}\t{sci2s(amp.s_amp)}\n"
231
+ )
232
+ fp.close()
233
+
234
+ return 0
235
+
236
+
237
+ def calc_timestamp(timestamp, sec):
238
+ timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec)
239
+ return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
240
+
241
+
242
+ def save_picks_json(picks, output_dir, dt=0.01, amps=None, fname=None):
243
+ if fname is None:
244
+ fname = "picks.json"
245
+
246
+ picks_ = []
247
+ if amps is None:
248
+ for pick in picks:
249
+ for idxs, probs in zip(pick.p_idx, pick.p_prob):
250
+ for idx, prob in zip(idxs, probs):
251
+ picks_.append(
252
+ {
253
+ "id": pick.station_id,
254
+ "timestamp": calc_timestamp(pick.t0, float(idx) * dt),
255
+ "prob": prob.astype(float),
256
+ "type": "p",
257
+ }
258
+ )
259
+ for idxs, probs in zip(pick.s_idx, pick.s_prob):
260
+ for idx, prob in zip(idxs, probs):
261
+ picks_.append(
262
+ {
263
+ "id": pick.station_id,
264
+ "timestamp": calc_timestamp(pick.t0, float(idx) * dt),
265
+ "prob": prob.astype(float),
266
+ "type": "s",
267
+ }
268
+ )
269
+ else:
270
+ for pick, amplitude in zip(picks, amps):
271
+ for idxs, probs, amps in zip(pick.p_idx, pick.p_prob, amplitude.p_amp):
272
+ for idx, prob, amp in zip(idxs, probs, amps):
273
+ picks_.append(
274
+ {
275
+ "id": pick.station_id,
276
+ "timestamp": calc_timestamp(pick.t0, float(idx) * dt),
277
+ "prob": prob.astype(float),
278
+ "amp": amp.astype(float),
279
+ "type": "p",
280
+ }
281
+ )
282
+ for idxs, probs, amps in zip(pick.s_idx, pick.s_prob, amplitude.s_amp):
283
+ for idx, prob, amp in zip(idxs, probs, amps):
284
+ picks_.append(
285
+ {
286
+ "id": pick.station_id,
287
+ "timestamp": calc_timestamp(pick.t0, float(idx) * dt),
288
+ "prob": prob.astype(float),
289
+ "amp": amp.astype(float),
290
+ "type": "s",
291
+ }
292
+ )
293
+ with open(os.path.join(output_dir, fname), "w") as fp:
294
+ json.dump(picks_, fp)
295
+
296
+ return 0
297
+
298
+
299
+ def convert_true_picks(fname, itp, its, itps=None):
300
+ true_picks = []
301
+ if itps is None:
302
+ record = namedtuple("phase", ["fname", "p_idx", "s_idx"])
303
+ for i in range(len(fname)):
304
+ true_picks.append(record(fname[i].decode(), itp[i], its[i]))
305
+ else:
306
+ record = namedtuple("phase", ["fname", "p_idx", "s_idx", "ps_idx"])
307
+ for i in range(len(fname)):
308
+ true_picks.append(record(fname[i].decode(), itp[i], its[i], itps[i]))
309
+
310
+ return true_picks
311
+
312
+
313
+ def calc_metrics(nTP, nP, nT):
314
+ """
315
+ nTP: true positive
316
+ nP: number of positive picks
317
+ nT: number of true picks
318
+ """
319
+ precision = nTP / nP
320
+ recall = nTP / nT
321
+ f1 = 2 * precision * recall / (precision + recall)
322
+ return [precision, recall, f1]
323
+
324
+
325
+ def calc_performance(picks, true_picks, tol=3.0, dt=1.0):
326
+ assert len(picks) == len(true_picks)
327
+ logging.info("Total records: {}".format(len(picks)))
328
+
329
+ count = lambda picks: sum([len(x) for x in picks])
330
+ metrics = {}
331
+ for phase in true_picks[0]._fields:
332
+ if phase == "fname":
333
+ continue
334
+ true_positive, positive, true = 0, 0, 0
335
+ residual = []
336
+ for i in range(len(true_picks)):
337
+ true += count(getattr(true_picks[i], phase))
338
+ positive += count(getattr(picks[i], phase))
339
+ # print(i, phase, getattr(picks[i], phase), getattr(true_picks[i], phase))
340
+ diff = dt * (
341
+ np.array(getattr(picks[i], phase))[:, np.newaxis, :]
342
+ - np.array(getattr(true_picks[i], phase))[:, :, np.newaxis]
343
+ )
344
+ residual.extend(list(diff[np.abs(diff) <= tol]))
345
+ true_positive += np.sum(np.abs(diff) <= tol)
346
+ metrics[phase] = calc_metrics(true_positive, positive, true)
347
+
348
+ logging.info(f"{phase}-phase:")
349
+ logging.info(f"True={true}, Positive={positive}, True Positive={true_positive}")
350
+ logging.info(f"Precision={metrics[phase][0]:.3f}, Recall={metrics[phase][1]:.3f}, F1={metrics[phase][2]:.3f}")
351
+ logging.info(f"Residual mean={np.mean(residual):.4f}, std={np.std(residual):.4f}")
352
+
353
+ return metrics
354
+
355
+
356
+ def save_prob_h5(probs, fnames, output_h5):
357
+ if fnames is None:
358
+ fnames = [f"{i:04d}" for i in range(len(probs))]
359
+ elif type(fnames[0]) is bytes:
360
+ fnames = [f.decode().rstrip(".npz") for f in fnames]
361
+ else:
362
+ fnames = [f.rstrip(".npz") for f in fnames]
363
+ for prob, fname in zip(probs, fnames):
364
+ output_h5.create_dataset(fname, data=prob, dtype="float32")
365
+ return 0
366
+
367
+
368
+ def save_prob(probs, fnames, prob_dir):
369
+ if fnames is None:
370
+ fnames = [f"{i:04d}" for i in range(len(probs))]
371
+ elif type(fnames[0]) is bytes:
372
+ fnames = [f.decode().rstrip(".npz") for f in fnames]
373
+ else:
374
+ fnames = [f.rstrip(".npz") for f in fnames]
375
+ for prob, fname in zip(probs, fnames):
376
+ np.savez(os.path.join(prob_dir, fname + ".npz"), prob=prob)
377
+ return 0
phasenet/predict.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import multiprocessing
4
+ import os
5
+ import pickle
6
+ import time
7
+ from functools import partial
8
+
9
+ import h5py
10
+ import numpy as np
11
+ import pandas as pd
12
+ import tensorflow as tf
13
+ from data_reader import DataReader_mseed_array, DataReader_pred
14
+ from postprocess import (
15
+ extract_amplitude,
16
+ extract_picks,
17
+ save_picks,
18
+ save_picks_json,
19
+ save_prob_h5,
20
+ )
21
+ from tqdm import tqdm
22
+ from visulization import plot_waveform
23
+
24
+ from model import ModelConfig, UNet
25
+
26
+ tf.compat.v1.disable_eager_execution()
27
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
28
+
29
+
30
+ def read_args():
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument("--batch_size", default=20, type=int, help="batch size")
33
+ parser.add_argument("--model_dir", help="Checkpoint directory (default: None)")
34
+ parser.add_argument("--data_dir", default="", help="Input file directory")
35
+ parser.add_argument("--data_list", default="", help="Input csv file")
36
+ parser.add_argument("--hdf5_file", default="", help="Input hdf5 file")
37
+ parser.add_argument("--hdf5_group", default="data", help="data group name in hdf5 file")
38
+ parser.add_argument("--result_dir", default="results", help="Output directory")
39
+ parser.add_argument("--result_fname", default="picks", help="Output file")
40
+ parser.add_argument("--min_p_prob", default=0.3, type=float, help="Probability threshold for P pick")
41
+ parser.add_argument("--min_s_prob", default=0.3, type=float, help="Probability threshold for S pick")
42
+ parser.add_argument("--mpd", default=50, type=float, help="Minimum peak distance")
43
+ parser.add_argument("--amplitude", action="store_true", help="if return amplitude value")
44
+ parser.add_argument("--format", default="numpy", help="input format")
45
+ parser.add_argument("--s3_url", default="localhost:9000", help="s3 url")
46
+ parser.add_argument("--stations", default="", help="seismic station info")
47
+ parser.add_argument("--plot_figure", action="store_true", help="If plot figure for test")
48
+ parser.add_argument("--save_prob", action="store_true", help="If save result for test")
49
+ parser.add_argument("--pre_sec", default=1, type=float, help="Window length before pick")
50
+ parser.add_argument("--post_sec", default=4, type=float, help="Window length after pick")
51
+
52
+ parser.add_argument("--highpass_filter", default=0.0, type=float, help="Highpass filter")
53
+ parser.add_argument("--response_xml", default=None, type=str, help="response xml file")
54
+ parser.add_argument("--sampling_rate", default=100, type=float, help="sampling rate")
55
+ args = parser.parse_args()
56
+
57
+ return args
58
+
59
+
60
+ def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None):
61
+ current_time = time.strftime("%y%m%d-%H%M%S")
62
+ if log_dir is None:
63
+ log_dir = os.path.join(args.log_dir, "pred", current_time)
64
+ if not os.path.exists(log_dir):
65
+ os.makedirs(log_dir)
66
+ if (args.plot_figure == True) and (figure_dir is None):
67
+ figure_dir = os.path.join(log_dir, "figures")
68
+ if not os.path.exists(figure_dir):
69
+ os.makedirs(figure_dir)
70
+ if (args.save_prob == True) and (prob_dir is None):
71
+ prob_dir = os.path.join(log_dir, "probs")
72
+ if not os.path.exists(prob_dir):
73
+ os.makedirs(prob_dir)
74
+ if args.save_prob:
75
+ h5 = h5py.File(os.path.join(args.result_dir, "result.h5"), "w", libver="latest")
76
+ prob_h5 = h5.create_group("/prob")
77
+ logging.info("Pred log: %s" % log_dir)
78
+ logging.info("Dataset size: {}".format(data_reader.num_data))
79
+
80
+ with tf.compat.v1.name_scope("Input_Batch"):
81
+ if args.format == "mseed_array":
82
+ batch_size = 1
83
+ else:
84
+ batch_size = args.batch_size
85
+ dataset = data_reader.dataset(batch_size)
86
+ batch = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
87
+
88
+ config = ModelConfig(X_shape=data_reader.X_shape)
89
+ with open(os.path.join(log_dir, "config.log"), "w") as fp:
90
+ fp.write("\n".join("%s: %s" % item for item in vars(config).items()))
91
+
92
+ model = UNet(config=config, input_batch=batch, mode="pred")
93
+ # model = UNet(config=config, mode="pred")
94
+ sess_config = tf.compat.v1.ConfigProto()
95
+ sess_config.gpu_options.allow_growth = True
96
+ # sess_config.log_device_placement = False
97
+
98
+ with tf.compat.v1.Session(config=sess_config) as sess:
99
+ saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5)
100
+ init = tf.compat.v1.global_variables_initializer()
101
+ sess.run(init)
102
+
103
+ latest_check_point = tf.train.latest_checkpoint(args.model_dir)
104
+ logging.info(f"restoring model {latest_check_point}")
105
+ saver.restore(sess, latest_check_point)
106
+
107
+ picks = []
108
+ amps = [] if args.amplitude else None
109
+ if args.plot_figure:
110
+ multiprocessing.set_start_method("spawn")
111
+ pool = multiprocessing.Pool(multiprocessing.cpu_count())
112
+
113
+ for _ in tqdm(range(0, data_reader.num_data, batch_size), desc="Pred"):
114
+ if args.amplitude:
115
+ pred_batch, X_batch, amp_batch, fname_batch, t0_batch, station_batch = sess.run(
116
+ [model.preds, batch[0], batch[1], batch[2], batch[3], batch[4]],
117
+ feed_dict={model.drop_rate: 0, model.is_training: False},
118
+ )
119
+ # X_batch, amp_batch, fname_batch, t0_batch = sess.run([batch[0], batch[1], batch[2], batch[3]])
120
+ else:
121
+ pred_batch, X_batch, fname_batch, t0_batch, station_batch = sess.run(
122
+ [model.preds, batch[0], batch[1], batch[2], batch[3]],
123
+ feed_dict={model.drop_rate: 0, model.is_training: False},
124
+ )
125
+ # X_batch, fname_batch, t0_batch = sess.run([model.preds, batch[0], batch[1], batch[2]])
126
+ # pred_batch = []
127
+ # for i in range(0, len(X_batch), 1):
128
+ # pred_batch.append(sess.run(model.preds, feed_dict={model.X: X_batch[i:i+1], model.drop_rate: 0, model.is_training: False}))
129
+ # pred_batch = np.vstack(pred_batch)
130
+
131
+ waveforms = None
132
+ if args.amplitude:
133
+ waveforms = amp_batch
134
+
135
+ picks_ = extract_picks(
136
+ preds=pred_batch,
137
+ file_names=fname_batch,
138
+ station_ids=station_batch,
139
+ begin_times=t0_batch,
140
+ config=args,
141
+ waveforms=waveforms,
142
+ use_amplitude=args.amplitude,
143
+ dt=1.0 / args.sampling_rate,
144
+ )
145
+
146
+ picks.extend(picks_)
147
+
148
+ ## save pick per file
149
+ if len(fname_batch) == 1:
150
+ df = pd.DataFrame(picks_)
151
+ df = df[df["phase_index"] > 10]
152
+ if not os.path.exists(os.path.join(args.result_dir, "picks")):
153
+ os.makedirs(os.path.join(args.result_dir, "picks"))
154
+ df = df[
155
+ [
156
+ "station_id",
157
+ "begin_time",
158
+ "phase_index",
159
+ "phase_time",
160
+ "phase_score",
161
+ "phase_type",
162
+ "phase_amplitude",
163
+ "dt",
164
+ ]
165
+ ]
166
+ df.to_csv(
167
+ os.path.join(
168
+ args.result_dir, "picks", fname_batch[0].decode().split("/")[-1].rstrip(".mseed") + ".csv"
169
+ ),
170
+ index=False,
171
+ )
172
+
173
+ if args.plot_figure:
174
+ if not (isinstance(fname_batch, np.ndarray) or isinstance(fname_batch, list)):
175
+ fname_batch = [fname_batch.decode().rstrip(".mseed") + "_" + x.decode() for x in station_batch]
176
+ else:
177
+ fname_batch = [x.decode() for x in fname_batch]
178
+ pool.starmap(
179
+ partial(
180
+ plot_waveform,
181
+ figure_dir=figure_dir,
182
+ ),
183
+ # zip(X_batch, pred_batch, [x.decode() for x in fname_batch]),
184
+ zip(X_batch, pred_batch, fname_batch),
185
+ )
186
+
187
+ if args.save_prob:
188
+ # save_prob(pred_batch, fname_batch, prob_dir=prob_dir)
189
+ if not (isinstance(fname_batch, np.ndarray) or isinstance(fname_batch, list)):
190
+ fname_batch = [fname_batch.decode().rstrip(".mseed") + "_" + x.decode() for x in station_batch]
191
+ else:
192
+ fname_batch = [x.decode() for x in fname_batch]
193
+ save_prob_h5(pred_batch, fname_batch, prob_h5)
194
+
195
+ if len(picks) > 0:
196
+ # save_picks(picks, args.result_dir, amps=amps, fname=args.result_fname+".csv")
197
+ # save_picks_json(picks, args.result_dir, dt=data_reader.dt, amps=amps, fname=args.result_fname+".json")
198
+ df = pd.DataFrame(picks)
199
+ # df["fname"] = df["file_name"]
200
+ # df["id"] = df["station_id"]
201
+ # df["timestamp"] = df["phase_time"]
202
+ # df["prob"] = df["phase_prob"]
203
+ # df["type"] = df["phase_type"]
204
+
205
+ base_columns = [
206
+ "station_id",
207
+ "begin_time",
208
+ "phase_index",
209
+ "phase_time",
210
+ "phase_score",
211
+ "phase_type",
212
+ "file_name",
213
+ ]
214
+ if args.amplitude:
215
+ base_columns.append("phase_amplitude")
216
+ base_columns.append("phase_amp")
217
+ df["phase_amp"] = df["phase_amplitude"]
218
+
219
+ df = df[base_columns]
220
+ df.to_csv(os.path.join(args.result_dir, args.result_fname + ".csv"), index=False)
221
+
222
+ print(
223
+ f"Done with {len(df[df['phase_type'] == 'P'])} P-picks and {len(df[df['phase_type'] == 'S'])} S-picks"
224
+ )
225
+ else:
226
+ print(f"Done with 0 P-picks and 0 S-picks")
227
+ return 0
228
+
229
+
230
+ def main(args):
231
+ logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)
232
+
233
+ with tf.compat.v1.name_scope("create_inputs"):
234
+ if args.format == "mseed_array":
235
+ data_reader = DataReader_mseed_array(
236
+ data_dir=args.data_dir,
237
+ data_list=args.data_list,
238
+ stations=args.stations,
239
+ amplitude=args.amplitude,
240
+ highpass_filter=args.highpass_filter,
241
+ )
242
+ else:
243
+ data_reader = DataReader_pred(
244
+ format=args.format,
245
+ data_dir=args.data_dir,
246
+ data_list=args.data_list,
247
+ hdf5_file=args.hdf5_file,
248
+ hdf5_group=args.hdf5_group,
249
+ amplitude=args.amplitude,
250
+ highpass_filter=args.highpass_filter,
251
+ response_xml=args.response_xml,
252
+ sampling_rate=args.sampling_rate,
253
+ )
254
+
255
+ pred_fn(args, data_reader, log_dir=args.result_dir)
256
+
257
+ return
258
+
259
+
260
+ if __name__ == "__main__":
261
+ args = read_args()
262
+ main(args)
phasenet/slide_window.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import defaultdict, namedtuple
3
+ from datetime import datetime, timedelta
4
+ from json import dumps
5
+
6
+ import numpy as np
7
+ import tensorflow as tf
8
+
9
+ from model import ModelConfig, UNet
10
+ from postprocess import extract_amplitude, extract_picks
11
+ import pandas as pd
12
+ import obspy
13
+
14
+
15
+ tf.compat.v1.disable_eager_execution()
16
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
17
+ PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
18
+
19
+ # load model
20
+ model = UNet(mode="pred")
21
+ sess_config = tf.compat.v1.ConfigProto()
22
+ sess_config.gpu_options.allow_growth = True
23
+
24
+ sess = tf.compat.v1.Session(config=sess_config)
25
+ saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
26
+ init = tf.compat.v1.global_variables_initializer()
27
+ sess.run(init)
28
+ latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190703-214543")
29
+ print(f"restoring model {latest_check_point}")
30
+ saver.restore(sess, latest_check_point)
31
+
32
+
33
+ def calc_timestamp(timestamp, sec):
34
+ timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec)
35
+ return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
36
+
37
+ def format_picks(picks, dt):
38
+ picks_ = []
39
+ for pick in picks:
40
+ for idxs, probs in zip(pick.p_idx, pick.p_prob):
41
+ for idx, prob in zip(idxs, probs):
42
+ picks_.append(
43
+ {
44
+ "id": pick.fname,
45
+ "timestamp": calc_timestamp(pick.t0, float(idx) * dt),
46
+ "prob": prob,
47
+ "type": "p",
48
+ }
49
+ )
50
+ for idxs, probs in zip(pick.s_idx, pick.s_prob):
51
+ for idx, prob in zip(idxs, probs):
52
+ picks_.append(
53
+ {
54
+ "id": pick.fname,
55
+ "timestamp": calc_timestamp(pick.t0, float(idx) * dt),
56
+ "prob": prob,
57
+ "type": "s",
58
+ }
59
+ )
60
+ return picks_
61
+
62
+
63
+ stream = obspy.read()
64
+ stream = stream.sort() ## Assume it is NPZ sorted
65
+ assert(len(stream) == 3)
66
+ data = []
67
+ for trace in stream:
68
+ data.append(trace.data)
69
+ data = np.array(data).T
70
+ assert(data.shape[-1] == 3)
71
+
72
+ # data_id = stream[0].get_id()[:-1]
73
+ # timestamp = stream[0].stats.starttime.datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
74
+
75
+ data = np.stack([data for i in range(10)]) ## Assume 10 windows
76
+ data = data[:,:,np.newaxis,:] ## batch, nt, dummy_dim, channel
77
+ print(f"{data.shape = }")
78
+ data = (data - data.mean(axis=1, keepdims=True))/data.std(axis=1, keepdims=True)
79
+
80
+ feed = {model.X: data, model.drop_rate: 0, model.is_training: False}
81
+ preds = sess.run(model.preds, feed_dict=feed)
82
+
83
+ picks = extract_picks(preds, fnames=None, station_ids=None, t0=None)
84
+ picks = format_picks(picks, dt=0.01)
85
+
86
+
87
+ picks = pd.DataFrame(picks)
88
+ print(picks)
phasenet/test_app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import obspy
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from datetime import datetime
6
+
7
+ ### Start running the model first:
8
+ ### FLASK_ENV=development FLASK_APP=app.py flask run
9
+
10
+ def read_data(mseed):
11
+ data = []
12
+ mseed = mseed.sort()
13
+ for c in ["E", "N", "Z"]:
14
+ data.append(mseed.select(channel="*"+c)[0].data)
15
+ return np.array(data).T
16
+
17
+ timestamp = lambda x: x.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
18
+
19
+ ## prepare some test data
20
+ mseed = obspy.read()
21
+ data = []
22
+ for i in range(1):
23
+ data.append(read_data(mseed))
24
+ data = {
25
+ "id": ["test01"],
26
+ "timestamp": [timestamp(datetime.now())],
27
+ "vec": np.array(data).tolist(),
28
+ "dt": 0.01
29
+ }
30
+
31
+ ## run prediction
32
+ print(data["id"])
33
+ resp = requests.get("http://localhost:8000/predict", json=data)
34
+ # picks = resp.json()["picks"]
35
+ print(resp.json())
36
+
37
+
38
+ ## plot figure
39
+ plt.figure()
40
+ plt.plot(np.array(data["data"])[0,:,1])
41
+ ylim = plt.ylim()
42
+ plt.plot([picks[0][0][0], picks[0][0][0]], ylim, label="P-phase")
43
+ plt.text(picks[0][0][0], ylim[1]*0.9, f"{picks[0][1][0]:.2f}")
44
+ plt.plot([picks[0][2][0], picks[0][2][0]], ylim, label="S-phase")
45
+ plt.text(picks[0][2][0], ylim[1]*0.9, f"{picks[0][1][0]:.2f}")
46
+ plt.legend()
47
+ plt.savefig("test.png")
phasenet/train.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ tf.compat.v1.disable_eager_execution()
4
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
5
+ import argparse, os, time, logging
6
+ from tqdm import tqdm
7
+ import pandas as pd
8
+ import multiprocessing
9
+ from functools import partial
10
+ import pickle
11
+ from model import UNet, ModelConfig
12
+ from data_reader import DataReader_train, DataReader_test
13
+ from postprocess import extract_picks, save_picks, save_picks_json, extract_amplitude, convert_true_picks, calc_performance
14
+ from visulization import plot_waveform
15
+ from util import EMA, LMA
16
+
17
+ def read_args():
18
+
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--mode", default="train", help="train/train_valid/test/debug")
21
+ parser.add_argument("--epochs", default=100, type=int, help="number of epochs (default: 10)")
22
+ parser.add_argument("--batch_size", default=20, type=int, help="batch size")
23
+ parser.add_argument("--learning_rate", default=0.01, type=float, help="learning rate")
24
+ parser.add_argument("--drop_rate", default=0.0, type=float, help="dropout rate")
25
+ parser.add_argument("--decay_step", default=-1, type=int, help="decay step")
26
+ parser.add_argument("--decay_rate", default=0.9, type=float, help="decay rate")
27
+ parser.add_argument("--momentum", default=0.9, type=float, help="momentum")
28
+ parser.add_argument("--optimizer", default="adam", help="optimizer: adam, momentum")
29
+ parser.add_argument("--summary", default=True, type=bool, help="summary")
30
+ parser.add_argument("--class_weights", nargs="+", default=[1, 1, 1], type=float, help="class weights")
31
+ parser.add_argument("--model_dir", default=None, help="Checkpoint directory (default: None)")
32
+ parser.add_argument("--load_model", action="store_true", help="Load checkpoint")
33
+ parser.add_argument("--log_dir", default="log", help="Log directory (default: log)")
34
+ parser.add_argument("--num_plots", default=10, type=int, help="Plotting training results")
35
+ parser.add_argument("--min_p_prob", default=0.3, type=float, help="Probability threshold for P pick")
36
+ parser.add_argument("--min_s_prob", default=0.3, type=float, help="Probability threshold for S pick")
37
+ parser.add_argument("--format", default="numpy", help="Input data format")
38
+ parser.add_argument("--train_dir", default="./dataset/waveform_train/", help="Input file directory")
39
+ parser.add_argument("--train_list", default="./dataset/waveform.csv", help="Input csv file")
40
+ parser.add_argument("--valid_dir", default=None, help="Input file directory")
41
+ parser.add_argument("--valid_list", default=None, help="Input csv file")
42
+ parser.add_argument("--test_dir", default=None, help="Input file directory")
43
+ parser.add_argument("--test_list", default=None, help="Input csv file")
44
+ parser.add_argument("--result_dir", default="results", help="result directory")
45
+ parser.add_argument("--plot_figure", action="store_true", help="If plot figure for test")
46
+ parser.add_argument("--save_prob", action="store_true", help="If save result for test")
47
+ args = parser.parse_args()
48
+
49
+ return args
50
+
51
+
52
+ def train_fn(args, data_reader, data_reader_valid=None):
53
+
54
+ current_time = time.strftime("%y%m%d-%H%M%S")
55
+ log_dir = os.path.join(args.log_dir, current_time)
56
+ if not os.path.exists(log_dir):
57
+ os.makedirs(log_dir)
58
+ logging.info("Training log: {}".format(log_dir))
59
+ model_dir = os.path.join(log_dir, 'models')
60
+ os.makedirs(model_dir)
61
+
62
+ figure_dir = os.path.join(log_dir, 'figures')
63
+ if not os.path.exists(figure_dir):
64
+ os.makedirs(figure_dir)
65
+
66
+ config = ModelConfig(X_shape=data_reader.X_shape, Y_shape=data_reader.Y_shape)
67
+ if args.decay_step == -1:
68
+ args.decay_step = data_reader.num_data // args.batch_size
69
+ config.update_args(args)
70
+ with open(os.path.join(log_dir, 'config.log'), 'w') as fp:
71
+ fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
72
+
73
+ with tf.compat.v1.name_scope('Input_Batch'):
74
+ dataset = data_reader.dataset(args.batch_size, shuffle=True).repeat()
75
+ batch = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
76
+ if data_reader_valid is not None:
77
+ dataset_valid = data_reader_valid.dataset(args.batch_size, shuffle=False).repeat()
78
+ valid_batch = tf.compat.v1.data.make_one_shot_iterator(dataset_valid).get_next()
79
+
80
+ model = UNet(config, input_batch=batch)
81
+ sess_config = tf.compat.v1.ConfigProto()
82
+ sess_config.gpu_options.allow_growth = True
83
+ # sess_config.log_device_placement = False
84
+
85
+ with tf.compat.v1.Session(config=sess_config) as sess:
86
+
87
+ summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph)
88
+ saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5)
89
+ init = tf.compat.v1.global_variables_initializer()
90
+ sess.run(init)
91
+
92
+ if args.model_dir is not None:
93
+ logging.info("restoring models...")
94
+ latest_check_point = tf.train.latest_checkpoint(args.model_dir)
95
+ saver.restore(sess, latest_check_point)
96
+
97
+ if args.plot_figure:
98
+ multiprocessing.set_start_method('spawn')
99
+ pool = multiprocessing.Pool(multiprocessing.cpu_count())
100
+
101
+ flog = open(os.path.join(log_dir, 'loss.log'), 'w')
102
+ train_loss = EMA(0.9)
103
+ best_valid_loss = np.inf
104
+ for epoch in range(args.epochs):
105
+ progressbar = tqdm(range(0, data_reader.num_data, args.batch_size), desc="{}: epoch {}".format(log_dir.split("/")[-1], epoch))
106
+ for _ in progressbar:
107
+ loss_batch, _, _ = sess.run([model.loss, model.train_op, model.global_step],
108
+ feed_dict={model.drop_rate: args.drop_rate, model.is_training: True})
109
+ train_loss(loss_batch)
110
+ progressbar.set_description("{}: epoch {}, loss={:.6f}, mean={:.6f}".format(log_dir.split("/")[-1], epoch, loss_batch, train_loss.value))
111
+ flog.write("epoch: {}, mean loss: {}\n".format(epoch, train_loss.value))
112
+
113
+ if data_reader_valid is not None:
114
+ valid_loss = LMA()
115
+ progressbar = tqdm(range(0, data_reader_valid.num_data, args.batch_size), desc="Valid:")
116
+ for _ in progressbar:
117
+ loss_batch, preds_batch, X_batch, Y_batch, fname_batch = sess.run([model.loss, model.preds, valid_batch[0], valid_batch[1], valid_batch[2]],
118
+ feed_dict={model.drop_rate: 0, model.is_training: False})
119
+ valid_loss(loss_batch)
120
+ progressbar.set_description("valid, loss={:.6f}, mean={:.6f}".format(loss_batch, valid_loss.value))
121
+ if valid_loss.value < best_valid_loss:
122
+ best_valid_loss = valid_loss.value
123
+ saver.save(sess, os.path.join(model_dir, "model_{}.ckpt".format(epoch)))
124
+ flog.write("Valid: mean loss: {}\n".format(valid_loss.value))
125
+ else:
126
+ loss_batch, preds_batch, X_batch, Y_batch, fname_batch = sess.run([model.loss, model.preds, batch[0], batch[1], batch[2]],
127
+ feed_dict={model.drop_rate: 0, model.is_training: False})
128
+ saver.save(sess, os.path.join(model_dir, "model_{}.ckpt".format(epoch)))
129
+
130
+ if args.plot_figure:
131
+ pool.starmap(
132
+ partial(
133
+ plot_waveform,
134
+ figure_dir=figure_dir,
135
+ ),
136
+ zip(X_batch, preds_batch, [x.decode() for x in fname_batch], Y_batch),
137
+ )
138
+ # plot_waveform(X_batch, preds_batch, fname_batch, label=Y_batch, figure_dir=figure_dir)
139
+ flog.flush()
140
+
141
+ flog.close()
142
+
143
+ return 0
144
+
145
+ def test_fn(args, data_reader):
146
+ current_time = time.strftime("%y%m%d-%H%M%S")
147
+ logging.info("{} log: {}".format(args.mode, current_time))
148
+ if args.model_dir is None:
149
+ logging.error(f"model_dir = None!")
150
+ return -1
151
+ if not os.path.exists(args.result_dir):
152
+ os.makedirs(args.result_dir)
153
+ figure_dir=os.path.join(args.result_dir, "figures")
154
+ if not os.path.exists(figure_dir):
155
+ os.makedirs(figure_dir)
156
+
157
+ config = ModelConfig(X_shape=data_reader.X_shape, Y_shape=data_reader.Y_shape)
158
+ config.update_args(args)
159
+ with open(os.path.join(args.result_dir, 'config.log'), 'w') as fp:
160
+ fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
161
+
162
+ with tf.compat.v1.name_scope('Input_Batch'):
163
+ dataset = data_reader.dataset(args.batch_size, shuffle=False)
164
+ batch = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
165
+
166
+ model = UNet(config, input_batch=batch, mode='test')
167
+ sess_config = tf.compat.v1.ConfigProto()
168
+ sess_config.gpu_options.allow_growth = True
169
+ # sess_config.log_device_placement = False
170
+
171
+ with tf.compat.v1.Session(config=sess_config) as sess:
172
+
173
+ saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
174
+ init = tf.compat.v1.global_variables_initializer()
175
+ sess.run(init)
176
+
177
+ logging.info("restoring models...")
178
+ latest_check_point = tf.train.latest_checkpoint(args.model_dir)
179
+ if latest_check_point is None:
180
+ logging.error(f"No models found in model_dir: {args.model_dir}")
181
+ return -1
182
+ saver.restore(sess, latest_check_point)
183
+
184
+ flog = open(os.path.join(args.result_dir, 'loss.log'), 'w')
185
+ test_loss = LMA()
186
+ progressbar = tqdm(range(0, data_reader.num_data, args.batch_size), desc=args.mode)
187
+ picks = []
188
+ true_picks = []
189
+ for _ in progressbar:
190
+ loss_batch, preds_batch, X_batch, Y_batch, fname_batch, itp_batch, its_batch \
191
+ = sess.run([model.loss, model.preds, batch[0], batch[1], batch[2], batch[3], batch[4]],
192
+ feed_dict={model.drop_rate: 0, model.is_training: False})
193
+
194
+ test_loss(loss_batch)
195
+ progressbar.set_description("{}, loss={:.6f}, mean loss={:6f}".format(args.mode, loss_batch, test_loss.value))
196
+
197
+ picks_ = extract_picks(preds_batch, fname_batch)
198
+ picks.extend(picks_)
199
+ true_picks.extend(convert_true_picks(fname_batch, itp_batch, its_batch))
200
+ if args.plot_figure:
201
+ plot_waveform(data_reader.config, X_batch, preds_batch, label=Y_batch, fname=fname_batch,
202
+ itp=itp_batch, its=its_batch, figure_dir=figure_dir)
203
+
204
+ save_picks(picks, args.result_dir)
205
+ metrics = calc_performance(picks, true_picks, tol=3.0, dt=data_reader.config.dt)
206
+ flog.write("mean loss: {}\n".format(test_loss))
207
+ flog.close()
208
+
209
+ return 0
210
+
211
+ def main(args):
212
+
213
+ logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
214
+ coord = tf.train.Coordinator()
215
+
216
+ if (args.mode == "train") or (args.mode == "train_valid"):
217
+ with tf.compat.v1.name_scope('create_inputs'):
218
+ data_reader = DataReader_train(format=args.format,
219
+ data_dir=args.train_dir,
220
+ data_list=args.train_list)
221
+ if args.mode == "train_valid":
222
+ data_reader_valid = DataReader_train(format=args.format,
223
+ data_dir=args.valid_dir,
224
+ data_list=args.valid_list)
225
+ logging.info("Dataset size: train {}, valid {}".format(data_reader.num_data, data_reader_valid.num_data))
226
+ else:
227
+ data_reader_valid = None
228
+ logging.info("Dataset size: train {}".format(data_reader.num_data))
229
+ train_fn(args, data_reader, data_reader_valid)
230
+
231
+ elif args.mode == "test":
232
+ with tf.compat.v1.name_scope('create_inputs'):
233
+ data_reader = DataReader_test(format=args.format,
234
+ data_dir=args.test_dir,
235
+ data_list=args.test_list)
236
+ test_fn(args, data_reader)
237
+
238
+ else:
239
+ print("mode should be: train, train_valid, or test")
240
+
241
+ return
242
+
243
+
244
+ if __name__ == '__main__':
245
+ args = read_args()
246
+ main(args)
phasenet/util.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import matplotlib
3
+ matplotlib.use('agg')
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import os
7
+ from data_reader import DataConfig
8
+ from detect_peaks import detect_peaks
9
+ import logging
10
+
11
+ class EMA(object):
12
+ def __init__(self, alpha):
13
+ self.alpha = alpha
14
+ self.x = 0.
15
+ self.count = 0
16
+
17
+ @property
18
+ def value(self):
19
+ return self.x
20
+
21
+ def __call__(self, x):
22
+ if self.count == 0:
23
+ self.x = x
24
+ else:
25
+ self.x = self.alpha * self.x + (1 - self.alpha) * x
26
+ self.count += 1
27
+ return self.x
28
+
29
+ class LMA(object):
30
+ def __init__(self):
31
+ self.x = 0.
32
+ self.count = 0
33
+
34
+ @property
35
+ def value(self):
36
+ return self.x
37
+
38
+ def __call__(self, x):
39
+ if self.count == 0:
40
+ self.x = x
41
+ else:
42
+ self.x += (x - self.x)/(self.count+1)
43
+ self.count += 1
44
+ return self.x
45
+
46
+ def detect_peaks_thread(i, pred, fname=None, result_dir=None, args=None):
47
+ if args is None:
48
+ itp, prob_p = detect_peaks(pred[i,:,0,1], mph=0.5, mpd=0.5/DataConfig().dt, show=False)
49
+ its, prob_s = detect_peaks(pred[i,:,0,2], mph=0.5, mpd=0.5/DataConfig().dt, show=False)
50
+ else:
51
+ itp, prob_p = detect_peaks(pred[i,:,0,1], mph=args.tp_prob, mpd=0.5/DataConfig().dt, show=False)
52
+ its, prob_s = detect_peaks(pred[i,:,0,2], mph=args.ts_prob, mpd=0.5/DataConfig().dt, show=False)
53
+ if (fname is not None) and (result_dir is not None):
54
+ # np.savez(os.path.join(result_dir, fname[i].decode().split('/')[-1]), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s)
55
+ try:
56
+ np.savez(os.path.join(result_dir, fname[i].decode()), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s)
57
+ except FileNotFoundError:
58
+ #if not os.path.exists(os.path.dirname(os.path.join(result_dir, fname[i].decode()))):
59
+ os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i].decode())), exist_ok=True)
60
+ np.savez(os.path.join(result_dir, fname[i].decode()), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s)
61
+ return [(itp, prob_p), (its, prob_s)]
62
+
63
+ def plot_result_thread(i, pred, X, Y=None, itp=None, its=None,
64
+ itp_pred=None, its_pred=None, fname=None, figure_dir=None):
65
+ dt = DataConfig().dt
66
+ t = np.arange(0, pred.shape[1]) * dt
67
+ box = dict(boxstyle='round', facecolor='white', alpha=1)
68
+ text_loc = [0.05, 0.77]
69
+
70
+ plt.figure(i)
71
+ plt.clf()
72
+ # fig_size = plt.gcf().get_size_inches()
73
+ # plt.gcf().set_size_inches(fig_size*[1, 1.2])
74
+ plt.subplot(411)
75
+ plt.plot(t, X[i, :, 0, 0], 'k', label='E', linewidth=0.5)
76
+ plt.autoscale(enable=True, axis='x', tight=True)
77
+ tmp_min = np.min(X[i, :, 0, 0])
78
+ tmp_max = np.max(X[i, :, 0, 0])
79
+ if (itp is not None) and (its is not None):
80
+ for j in range(len(itp[i])):
81
+ if j == 0:
82
+ plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', label='P', linewidth=0.5)
83
+ else:
84
+ plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5)
85
+ for j in range(len(its[i])):
86
+ if j == 0:
87
+ plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', label='S', linewidth=0.5)
88
+ else:
89
+ plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5)
90
+ plt.ylabel('Amplitude')
91
+ plt.legend(loc='upper right', fontsize='small')
92
+ plt.gca().set_xticklabels([])
93
+ plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center',
94
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
95
+ plt.subplot(412)
96
+ plt.plot(t, X[i, :, 0, 1], 'k', label='N', linewidth=0.5)
97
+ plt.autoscale(enable=True, axis='x', tight=True)
98
+ tmp_min = np.min(X[i, :, 0, 1])
99
+ tmp_max = np.max(X[i, :, 0, 1])
100
+ if (itp is not None) and (its is not None):
101
+ for j in range(len(itp[i])):
102
+ plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5)
103
+ for j in range(len(its[i])):
104
+ plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5)
105
+ plt.ylabel('Amplitude')
106
+ plt.legend(loc='upper right', fontsize='small')
107
+ plt.gca().set_xticklabels([])
108
+ plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center',
109
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
110
+ plt.subplot(413)
111
+ plt.plot(t, X[i, :, 0, 2], 'k', label='Z', linewidth=0.5)
112
+ plt.autoscale(enable=True, axis='x', tight=True)
113
+ tmp_min = np.min(X[i, :, 0, 2])
114
+ tmp_max = np.max(X[i, :, 0, 2])
115
+ if (itp is not None) and (its is not None):
116
+ for j in range(len(itp[i])):
117
+ plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5)
118
+ for j in range(len(its[i])):
119
+ plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5)
120
+ plt.ylabel('Amplitude')
121
+ plt.legend(loc='upper right', fontsize='small')
122
+ plt.gca().set_xticklabels([])
123
+ plt.text(text_loc[0], text_loc[1], '(iii)', horizontalalignment='center',
124
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
125
+ plt.subplot(414)
126
+ if Y is not None:
127
+ plt.plot(t, Y[i, :, 0, 1], 'b', label='P', linewidth=0.5)
128
+ plt.plot(t, Y[i, :, 0, 2], 'r', label='S', linewidth=0.5)
129
+ plt.plot(t, pred[i, :, 0, 1], '--g', label='$\hat{P}$', linewidth=0.5)
130
+ plt.plot(t, pred[i, :, 0, 2], '-.m', label='$\hat{S}$', linewidth=0.5)
131
+ plt.autoscale(enable=True, axis='x', tight=True)
132
+ if (itp_pred is not None) and (its_pred is not None):
133
+ for j in range(len(itp_pred)):
134
+ plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--g', linewidth=0.5)
135
+ for j in range(len(its_pred)):
136
+ plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '-.m', linewidth=0.5)
137
+ plt.ylim([-0.05, 1.05])
138
+ plt.text(text_loc[0], text_loc[1], '(iv)', horizontalalignment='center',
139
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
140
+ plt.legend(loc='upper right', fontsize='small')
141
+ plt.xlabel('Time (s)')
142
+ plt.ylabel('Probability')
143
+
144
+ plt.tight_layout()
145
+ plt.gcf().align_labels()
146
+
147
+ try:
148
+ plt.savefig(os.path.join(figure_dir,
149
+ fname[i].decode().rstrip('.npz')+'.png'),
150
+ bbox_inches='tight')
151
+ except FileNotFoundError:
152
+ #if not os.path.exists(os.path.dirname(os.path.join(figure_dir, fname[i].decode()))):
153
+ os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i].decode())), exist_ok=True)
154
+ plt.savefig(os.path.join(figure_dir,
155
+ fname[i].decode().rstrip('.npz')+'.png'),
156
+ bbox_inches='tight')
157
+ #plt.savefig(os.path.join(figure_dir,
158
+ # fname[i].decode().split('/')[-1].rstrip('.npz')+'.png'),
159
+ # bbox_inches='tight')
160
+ # plt.savefig(os.path.join(figure_dir,
161
+ # fname[i].decode().split('/')[-1].rstrip('.npz')+'.pdf'),
162
+ # bbox_inches='tight')
163
+ plt.close(i)
164
+ return 0
165
+
166
+ def postprocessing_thread(i, pred, X, Y=None, itp=None, its=None, fname=None, result_dir=None, figure_dir=None, args=None):
167
+ (itp_pred, prob_p), (its_pred, prob_s) = detect_peaks_thread(i, pred, fname, result_dir, args)
168
+ if (fname is not None) and (figure_dir is not None):
169
+ plot_result_thread(i, pred, X, Y, itp, its, itp_pred, its_pred, fname, figure_dir)
170
+ return [(itp_pred, prob_p), (its_pred, prob_s)]
171
+
172
+
173
+ def clean_queue(picks):
174
+ clean = []
175
+ for i in range(len(picks)):
176
+ tmp = []
177
+ for j in picks[i]:
178
+ if j != 0:
179
+ tmp.append(j)
180
+ clean.append(tmp)
181
+ return clean
182
+
183
+ def clean_queue_thread(picks):
184
+ tmp = []
185
+ for j in picks:
186
+ if j != 0:
187
+ tmp.append(j)
188
+ return tmp
189
+
190
+
191
+ def metrics(TP, nP, nT):
192
+ '''
193
+ TP: true positive
194
+ nP: number of positive picks
195
+ nT: number of true picks
196
+ '''
197
+ precision = TP / nP
198
+ recall = TP / nT
199
+ F1 = 2* precision * recall / (precision + recall)
200
+ return [precision, recall, F1]
201
+
202
+ def correct_picks(picks, true_p, true_s, tol):
203
+ dt = DataConfig().dt
204
+ if len(true_p) != len(true_s):
205
+ print("The length of true P and S pickers are not the same")
206
+ num = len(true_p)
207
+ TP_p = 0; TP_s = 0; nP_p = 0; nP_s = 0; nT_p = 0; nT_s = 0
208
+ diff_p = []; diff_s = []
209
+ for i in range(num):
210
+ nT_p += len(true_p[i])
211
+ nT_s += len(true_s[i])
212
+ nP_p += len(picks[i][0][0])
213
+ nP_s += len(picks[i][1][0])
214
+
215
+ if len(true_p[i]) > 1 or len(true_s[i]) > 1:
216
+ print(i, picks[i], true_p[i], true_s[i])
217
+ tmp_p = np.array(picks[i][0][0]) - np.array(true_p[i])[:,np.newaxis]
218
+ tmp_s = np.array(picks[i][1][0]) - np.array(true_s[i])[:,np.newaxis]
219
+ TP_p += np.sum(np.abs(tmp_p) < tol/dt)
220
+ TP_s += np.sum(np.abs(tmp_s) < tol/dt)
221
+ diff_p.append(tmp_p[np.abs(tmp_p) < 0.5/dt])
222
+ diff_s.append(tmp_s[np.abs(tmp_s) < 0.5/dt])
223
+
224
+ return [TP_p, TP_s, nP_p, nP_s, nT_p, nT_s, diff_p, diff_s]
225
+
226
+ def calculate_metrics(picks, itp, its, tol=0.1):
227
+ TP_p, TP_s, nP_p, nP_s, nT_p, nT_s, diff_p, diff_s = correct_picks(picks, itp, its, tol)
228
+ precision_p, recall_p, f1_p = metrics(TP_p, nP_p, nT_p)
229
+ precision_s, recall_s, f1_s = metrics(TP_s, nP_s, nT_s)
230
+
231
+ logging.info("Total records: {}".format(len(picks)))
232
+ logging.info("P-phase:")
233
+ logging.info("True={}, Predict={}, TruePositive={}".format(nT_p, nP_p, TP_p))
234
+ logging.info("Precision={:.3f}, Recall={:.3f}, F1={:.3f}".format(precision_p, recall_p, f1_p))
235
+ logging.info("S-phase:")
236
+ logging.info("True={}, Predict={}, TruePositive={}".format(nT_s, nP_s, TP_s))
237
+ logging.info("Precision={:.3f}, Recall={:.3f}, F1={:.3f}".format(precision_s, recall_s, f1_s))
238
+ return [precision_p, recall_p, f1_p], [precision_s, recall_s, f1_s]
phasenet/visulization.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ matplotlib.use("agg")
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import os
6
+
7
+
8
+ def plot_residual(diff_p, diff_s, diff_ps, tol, dt):
9
+ box = dict(boxstyle='round', facecolor='white', alpha=1)
10
+ text_loc = [0.07, 0.95]
11
+ plt.figure(figsize=(8,3))
12
+ plt.subplot(1,3,1)
13
+ plt.hist(diff_p, range=(-tol, tol), bins=int(2*tol/dt)+1, facecolor='b', edgecolor='black', linewidth=1)
14
+ plt.ylabel("Number of picks")
15
+ plt.xlabel("Residual (s)")
16
+ plt.text(text_loc[0], text_loc[1], "(i)", horizontalalignment='left', verticalalignment='top',
17
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
18
+ plt.title("P-phase")
19
+ plt.subplot(1,3,2)
20
+ plt.hist(diff_s, range=(-tol, tol), bins=int(2*tol/dt)+1, facecolor='b', edgecolor='black', linewidth=1)
21
+ plt.xlabel("Residual (s)")
22
+ plt.text(text_loc[0], text_loc[1], "(ii)", horizontalalignment='left', verticalalignment='top',
23
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
24
+ plt.title("S-phase")
25
+ plt.subplot(1,3,3)
26
+ plt.hist(diff_ps, range=(-tol, tol), bins=int(2*tol/dt)+1, facecolor='b', edgecolor='black', linewidth=1)
27
+ plt.xlabel("Residual (s)")
28
+ plt.text(text_loc[0], text_loc[1], "(iii)", horizontalalignment='left', verticalalignment='top',
29
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
30
+ plt.title("PS-phase")
31
+ plt.tight_layout()
32
+ plt.savefig("residuals.png", dpi=300)
33
+ plt.savefig("residuals.pdf")
34
+
35
+
36
+ # def plot_waveform(config, data, pred, label=None,
37
+ # itp=None, its=None, itps=None,
38
+ # itp_pred=None, its_pred=None, itps_pred=None,
39
+ # fname=None, figure_dir="./", epoch=0, max_fig=10):
40
+
41
+ # dt = config.dt if hasattr(config, "dt") else 1.0
42
+ # t = np.arange(0, pred.shape[1]) * dt
43
+ # box = dict(boxstyle='round', facecolor='white', alpha=1)
44
+ # text_loc = [0.05, 0.77]
45
+ # if fname is None:
46
+ # fname = [f"{epoch:03d}_{i:02d}" for i in range(len(data))]
47
+ # else:
48
+ # fname = [fname[i].decode().rstrip(".npz") for i in range(len(fname))]
49
+
50
+ # for i in range(min(len(data), max_fig)):
51
+ # plt.figure(i)
52
+
53
+ # plt.subplot(411)
54
+ # plt.plot(t, data[i, :, 0, 0], 'k', label='E', linewidth=0.5)
55
+ # plt.autoscale(enable=True, axis='x', tight=True)
56
+ # tmp_min = np.min(data[i, :, 0, 0])
57
+ # tmp_max = np.max(data[i, :, 0, 0])
58
+ # if (itp is not None) and (its is not None):
59
+ # for j in range(len(itp[i])):
60
+ # lb = "P" if j==0 else ""
61
+ # plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
62
+ # for j in range(len(its[i])):
63
+ # lb = "S" if j==0 else ""
64
+ # plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
65
+ # if (itps is not None):
66
+ # for j in range(len(itps[i])):
67
+ # lb = "PS" if j==0 else ""
68
+ # plt.plot([itps[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
69
+ # plt.ylabel('Amplitude')
70
+ # plt.legend(loc='upper right', fontsize='small')
71
+ # plt.gca().set_xticklabels([])
72
+ # plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center',
73
+ # transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
74
+
75
+ # plt.subplot(412)
76
+ # plt.plot(t, data[i, :, 0, 1], 'k', label='N', linewidth=0.5)
77
+ # plt.autoscale(enable=True, axis='x', tight=True)
78
+ # tmp_min = np.min(data[i, :, 0, 1])
79
+ # tmp_max = np.max(data[i, :, 0, 1])
80
+ # if (itp is not None) and (its is not None):
81
+ # for j in range(len(itp[i])):
82
+ # lb = "P" if j==0 else ""
83
+ # plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
84
+ # for j in range(len(its[i])):
85
+ # lb = "S" if j==0 else ""
86
+ # plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
87
+ # if (itps is not None):
88
+ # for j in range(len(itps[i])):
89
+ # lb = "PS" if j==0 else ""
90
+ # plt.plot([itps[i][j]*dt, itps[i][j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
91
+ # plt.ylabel('Amplitude')
92
+ # plt.legend(loc='upper right', fontsize='small')
93
+ # plt.gca().set_xticklabels([])
94
+ # plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center',
95
+ # transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
96
+
97
+ # plt.subplot(413)
98
+ # plt.plot(t, data[i, :, 0, 2], 'k', label='Z', linewidth=0.5)
99
+ # plt.autoscale(enable=True, axis='x', tight=True)
100
+ # tmp_min = np.min(data[i, :, 0, 2])
101
+ # tmp_max = np.max(data[i, :, 0, 2])
102
+ # if (itp is not None) and (its is not None):
103
+ # for j in range(len(itp[i])):
104
+ # lb = "P" if j==0 else ""
105
+ # plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
106
+ # for j in range(len(its[i])):
107
+ # lb = "S" if j==0 else ""
108
+ # plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
109
+ # if (itps is not None):
110
+ # for j in range(len(itps[i])):
111
+ # lb = "PS" if j==0 else ""
112
+ # plt.plot([itps[i][j]*dt, itps[i][j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
113
+ # plt.ylabel('Amplitude')
114
+ # plt.legend(loc='upper right', fontsize='small')
115
+ # plt.gca().set_xticklabels([])
116
+ # plt.text(text_loc[0], text_loc[1], '(iii)', horizontalalignment='center',
117
+ # transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
118
+
119
+ # plt.subplot(414)
120
+ # if label is not None:
121
+ # plt.plot(t, label[i, :, 0, 1], 'C0', label='P', linewidth=1)
122
+ # plt.plot(t, label[i, :, 0, 2], 'C1', label='S', linewidth=1)
123
+ # if label.shape[-1] == 4:
124
+ # plt.plot(t, label[i, :, 0, 3], 'C2', label='PS', linewidth=1)
125
+ # plt.plot(t, pred[i, :, 0, 1], '--C0', label='$\hat{P}$', linewidth=1)
126
+ # plt.plot(t, pred[i, :, 0, 2], '--C1', label='$\hat{S}$', linewidth=1)
127
+ # if pred.shape[-1] == 4:
128
+ # plt.plot(t, pred[i, :, 0, 3], '--C2', label='$\hat{PS}$', linewidth=1)
129
+ # plt.autoscale(enable=True, axis='x', tight=True)
130
+ # if (itp_pred is not None) and (its_pred is not None) :
131
+ # for j in range(len(itp_pred)):
132
+ # plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--C0', linewidth=1)
133
+ # for j in range(len(its_pred)):
134
+ # plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '--C1', linewidth=1)
135
+ # if (itps_pred is not None):
136
+ # for j in range(len(itps_pred)):
137
+ # plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C2', linewidth=1)
138
+ # plt.ylim([-0.05, 1.05])
139
+ # plt.text(text_loc[0], text_loc[1], '(iv)', horizontalalignment='center',
140
+ # transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
141
+ # plt.legend(loc='upper right', fontsize='small', ncol=2)
142
+ # plt.xlabel('Time (s)')
143
+ # plt.ylabel('Probability')
144
+ # plt.tight_layout()
145
+ # plt.gcf().align_labels()
146
+
147
+ # try:
148
+ # plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
149
+ # except FileNotFoundError:
150
+ # os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
151
+ # plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
152
+
153
+ # plt.close(i)
154
+ # return 0
155
+
156
+
157
+ def plot_waveform(data, pred, fname, label=None,
158
+ itp=None, its=None, itps=None,
159
+ itp_pred=None, its_pred=None, itps_pred=None,
160
+ figure_dir="./", dt=0.01):
161
+
162
+ t = np.arange(0, pred.shape[0]) * dt
163
+ box = dict(boxstyle='round', facecolor='white', alpha=1)
164
+ text_loc = [0.05, 0.77]
165
+
166
+ plt.figure()
167
+
168
+ plt.subplot(411)
169
+ plt.plot(t, data[:, 0, 0], 'k', label='E', linewidth=0.5)
170
+ plt.autoscale(enable=True, axis='x', tight=True)
171
+ tmp_min = np.min(data[:, 0, 0])
172
+ tmp_max = np.max(data[:, 0, 0])
173
+ if (itp is not None) and (its is not None):
174
+ for j in range(len(itp)):
175
+ lb = "P" if j==0 else ""
176
+ plt.plot([itp[j]*dt, itp[j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
177
+ for j in range(len(its[i])):
178
+ lb = "S" if j==0 else ""
179
+ plt.plot([its[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
180
+ if (itps is not None):
181
+ for j in range(len(itps)):
182
+ lb = "PS" if j==0 else ""
183
+ plt.plot([itps[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
184
+ plt.ylabel('Amplitude')
185
+ plt.legend(loc='upper right', fontsize='small')
186
+ plt.gca().set_xticklabels([])
187
+ plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center',
188
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
189
+
190
+ plt.subplot(412)
191
+ plt.plot(t, data[:, 0, 1], 'k', label='N', linewidth=0.5)
192
+ plt.autoscale(enable=True, axis='x', tight=True)
193
+ tmp_min = np.min(data[:, 0, 1])
194
+ tmp_max = np.max(data[:, 0, 1])
195
+ if (itp is not None) and (its is not None):
196
+ for j in range(len(itp)):
197
+ lb = "P" if j==0 else ""
198
+ plt.plot([itp[j]*dt, itp[j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
199
+ for j in range(len(its)):
200
+ lb = "S" if j==0 else ""
201
+ plt.plot([its[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
202
+ if (itps is not None):
203
+ for j in range(len(itps)):
204
+ lb = "PS" if j==0 else ""
205
+ plt.plot([itps[j]*dt, itps[j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
206
+ plt.ylabel('Amplitude')
207
+ plt.legend(loc='upper right', fontsize='small')
208
+ plt.gca().set_xticklabels([])
209
+ plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center',
210
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
211
+
212
+ plt.subplot(413)
213
+ plt.plot(t, data[:, 0, 2], 'k', label='Z', linewidth=0.5)
214
+ plt.autoscale(enable=True, axis='x', tight=True)
215
+ tmp_min = np.min(data[:, 0, 2])
216
+ tmp_max = np.max(data[:, 0, 2])
217
+ if (itp is not None) and (its is not None):
218
+ for j in range(len(itp)):
219
+ lb = "P" if j==0 else ""
220
+ plt.plot([itp[j]*dt, itp[j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
221
+ for j in range(len(its)):
222
+ lb = "S" if j==0 else ""
223
+ plt.plot([its[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
224
+ if (itps is not None):
225
+ for j in range(len(itps)):
226
+ lb = "PS" if j==0 else ""
227
+ plt.plot([itps[j]*dt, itps[j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
228
+ plt.ylabel('Amplitude')
229
+ plt.legend(loc='upper right', fontsize='small')
230
+ plt.gca().set_xticklabels([])
231
+ plt.text(text_loc[0], text_loc[1], '(iii)', horizontalalignment='center',
232
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
233
+
234
+ plt.subplot(414)
235
+ if label is not None:
236
+ plt.plot(t, label[:, 0, 1], 'C0', label='P', linewidth=1)
237
+ plt.plot(t, label[:, 0, 2], 'C1', label='S', linewidth=1)
238
+ if label.shape[-1] == 4:
239
+ plt.plot(t, label[:, 0, 3], 'C2', label='PS', linewidth=1)
240
+ plt.plot(t, pred[:, 0, 1], '--C0', label='$\hat{P}$', linewidth=1)
241
+ plt.plot(t, pred[:, 0, 2], '--C1', label='$\hat{S}$', linewidth=1)
242
+ if pred.shape[-1] == 4:
243
+ plt.plot(t, pred[:, 0, 3], '--C2', label='$\hat{PS}$', linewidth=1)
244
+ plt.autoscale(enable=True, axis='x', tight=True)
245
+ if (itp_pred is not None) and (its_pred is not None) :
246
+ for j in range(len(itp_pred)):
247
+ plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--C0', linewidth=1)
248
+ for j in range(len(its_pred)):
249
+ plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '--C1', linewidth=1)
250
+ if (itps_pred is not None):
251
+ for j in range(len(itps_pred)):
252
+ plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C2', linewidth=1)
253
+ plt.ylim([-0.05, 1.05])
254
+ plt.text(text_loc[0], text_loc[1], '(iv)', horizontalalignment='center',
255
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
256
+ plt.legend(loc='upper right', fontsize='small', ncol=2)
257
+ plt.xlabel('Time (s)')
258
+ plt.ylabel('Probability')
259
+ plt.tight_layout()
260
+ plt.gcf().align_labels()
261
+
262
+ try:
263
+ plt.savefig(os.path.join(figure_dir, fname+'.png'), bbox_inches='tight')
264
+ except FileNotFoundError:
265
+ os.makedirs(os.path.dirname(os.path.join(figure_dir, fname)), exist_ok=True)
266
+ plt.savefig(os.path.join(figure_dir, fname+'.png'), bbox_inches='tight')
267
+
268
+ plt.close()
269
+ return 0
270
+
271
+
272
+ def plot_array(config, data, pred, label=None,
273
+ itp=None, its=None, itps=None,
274
+ itp_pred=None, its_pred=None, itps_pred=None,
275
+ fname=None, figure_dir="./", epoch=0):
276
+
277
+ dt = config.dt if hasattr(config, "dt") else 1.0
278
+ t = np.arange(0, pred.shape[1]) * dt
279
+ box = dict(boxstyle='round', facecolor='white', alpha=1)
280
+ text_loc = [0.05, 0.95]
281
+ if fname is None:
282
+ fname = [f"{epoch:03d}_{i:03d}" for i in range(len(data))]
283
+ else:
284
+ fname = [fname[i].decode().rstrip(".npz") for i in range(len(fname))]
285
+
286
+ for i in range(len(data)):
287
+ plt.figure(i, figsize=(10, 5))
288
+ plt.clf()
289
+
290
+ plt.subplot(121)
291
+ for j in range(data.shape[-2]):
292
+ plt.plot(t, data[i, :, j, 0]/10 + j, 'k', label='E', linewidth=0.5)
293
+ plt.autoscale(enable=True, axis='x', tight=True)
294
+ tmp_min = np.min(data[i, :, 0, 0])
295
+ tmp_max = np.max(data[i, :, 0, 0])
296
+ plt.xlabel('Time (s)')
297
+ plt.ylabel('Amplitude')
298
+ # plt.legend(loc='upper right', fontsize='small')
299
+ # plt.gca().set_xticklabels([])
300
+ plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center', verticalalignment="top",
301
+ transform=plt.gca().transAxes, fontsize="large", fontweight="normal", bbox=box)
302
+
303
+ plt.subplot(122)
304
+ for j in range(pred.shape[-2]):
305
+ if label is not None:
306
+ plt.plot(t, label[i, :, j, 1]+j, 'C2', label='P', linewidth=0.5)
307
+ plt.plot(t, label[i, :, j, 2]+j, 'C3', label='S', linewidth=0.5)
308
+ # plt.plot(t, label[i, :, j, 0]+j, 'C4', label='N', linewidth=0.5)
309
+ plt.plot(t, pred[i, :, j, 1]+j, 'C0', label='$\hat{P}$', linewidth=1)
310
+ plt.plot(t, pred[i, :, j, 2]+j, 'C1', label='$\hat{S}$', linewidth=1)
311
+ plt.autoscale(enable=True, axis='x', tight=True)
312
+ if (itp_pred is not None) and (its_pred is not None) and (itps_pred is not None):
313
+ for j in range(len(itp_pred)):
314
+ plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--C0', linewidth=1)
315
+ for j in range(len(its_pred)):
316
+ plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '--C1', linewidth=1)
317
+ for j in range(len(itps_pred)):
318
+ plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C2', linewidth=1)
319
+ # plt.ylim([-0.05, 1.05])
320
+ plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center', verticalalignment="top",
321
+ transform=plt.gca().transAxes, fontsize="large", fontweight="normal", bbox=box)
322
+ # plt.legend(loc='upper right', fontsize='small', ncol=2)
323
+ plt.xlabel('Time (s)')
324
+ plt.ylabel('Probability')
325
+ plt.tight_layout()
326
+ plt.gcf().align_labels()
327
+
328
+ try:
329
+ plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
330
+ except FileNotFoundError:
331
+ os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
332
+ plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
333
+
334
+ plt.close(i)
335
+ return 0
336
+
337
+
338
+ def plot_spectrogram(config, data, pred, label=None,
339
+ itp=None, its=None, itps=None,
340
+ itp_pred=None, its_pred=None, itps_pred=None,
341
+ time=None, freq=None,
342
+ fname=None, figure_dir="./", epoch=0):
343
+
344
+ # dt = config.dt
345
+ # df = config.df
346
+ # t = np.arange(0, data.shape[1]) * dt
347
+ # f = np.arange(0, data.shape[2]) * df
348
+ t, f = time, freq
349
+ dt = t[1] - t[0]
350
+ box = dict(boxstyle='round', facecolor='white', alpha=1)
351
+ text_loc = [0.05, 0.75]
352
+ if fname is None:
353
+ fname = [f"{i:03d}" for i in range(len(data))]
354
+ elif type(fname[0]) is bytes:
355
+ fname = [f.decode() for f in fname]
356
+
357
+ numbers = ["(i)", "(ii)", "(iii)", "(iv)"]
358
+ for i in range(len(data)):
359
+ fig = plt.figure(i)
360
+ # gs = fig.add_gridspec(4, 1)
361
+
362
+ for j in range(3):
363
+ # fig.add_subplot(gs[j, 0])
364
+ plt.subplot(4,1,j+1)
365
+ plt.pcolormesh(t, f, np.abs(data[i, :, :, j]+1j*data[i, :, :, j+3]).T, vmax=2*np.std(data[i, :, :, j]+1j*data[i, :, :, j+3]), cmap="jet", shading='auto')
366
+ plt.autoscale(enable=True, axis='x', tight=True)
367
+ plt.gca().set_xticklabels([])
368
+ if j == 1:
369
+ plt.ylabel('Frequency (Hz)')
370
+ plt.text(text_loc[0], text_loc[1], numbers[j], horizontalalignment='center',
371
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
372
+
373
+ # fig.add_subplot(gs[-1, 0])
374
+ plt.subplot(4,1,4)
375
+ if label is not None:
376
+ plt.plot(t, label[i, :, 0, 1], '--C0', linewidth=1)
377
+ plt.plot(t, label[i, :, 0, 2], '--C3', linewidth=1)
378
+ plt.plot(t, label[i, :, 0, 3], '--C1', linewidth=1)
379
+ plt.plot(t, pred[i, :, 0, 1], 'C0', label='P', linewidth=1)
380
+ plt.plot(t, pred[i, :, 0, 2], 'C3', label='S', linewidth=1)
381
+ plt.plot(t, pred[i, :, 0, 3], 'C1', label='PS', linewidth=1)
382
+ plt.plot(t, t*0, 'k', linewidth=1)
383
+ plt.autoscale(enable=True, axis='x', tight=True)
384
+ if (itp_pred is not None) and (its_pred is not None) and (itps_pred is not None):
385
+ for j in range(len(itp_pred)):
386
+ plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], ':C3', linewidth=1)
387
+ for j in range(len(its_pred)):
388
+ plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '-.C6', linewidth=1)
389
+ for j in range(len(itps_pred)):
390
+ plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C8', linewidth=1)
391
+ plt.ylim([-0.05, 1.05])
392
+ plt.text(text_loc[0], text_loc[1], numbers[-1], horizontalalignment='center',
393
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
394
+ plt.legend(loc='upper right', fontsize='small', ncol=1)
395
+ plt.xlabel('Time (s)')
396
+ plt.ylabel('Probability')
397
+ # plt.tight_layout()
398
+ plt.gcf().align_labels()
399
+
400
+ try:
401
+ plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
402
+ except FileNotFoundError:
403
+ os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
404
+ plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
405
+
406
+ plt.close(i)
407
+ return 0
408
+
409
+
410
+ def plot_spectrogram_waveform(config, spectrogram, waveform, pred, label=None,
411
+ itp=None, its=None, itps=None, picks=None,
412
+ time=None, freq=None,
413
+ fname=None, figure_dir="./", epoch=0):
414
+
415
+ # dt = config.dt
416
+ # df = config.df
417
+ # t = np.arange(0, spectrogram.shape[1]) * dt
418
+ # f = np.arange(0, spectrogram.shape[2]) * df
419
+ t, f = time, freq
420
+ dt = t[1] - t[0]
421
+ box = dict(boxstyle='round', facecolor='white', alpha=1)
422
+ text_loc = [0.02, 0.90]
423
+ if fname is None:
424
+ fname = [f"{i:03d}" for i in range(len(spectrogram))]
425
+ elif type(fname[0]) is bytes:
426
+ fname = [f.decode() for f in fname]
427
+
428
+ numbers = ["(i)", "(ii)", "(iii)", "(iv)", "(v)", "(vi)", "(vii)"]
429
+ for i in range(len(spectrogram)):
430
+ fig = plt.figure(i, figsize=(6.4, 10))
431
+ # gs = fig.add_gridspec(4, 1)
432
+
433
+ for j in range(3):
434
+ # fig.add_subplot(gs[j, 0])
435
+ plt.subplot(7,1,j*2+1)
436
+ plt.plot(waveform[i,:,j], 'k', linewidth=0.5)
437
+ plt.autoscale(enable=True, axis='x', tight=True)
438
+ plt.gca().set_xticklabels([])
439
+ plt.ylabel('')
440
+ plt.text(text_loc[0], text_loc[1], numbers[j*2], horizontalalignment='left', verticalalignment='top',
441
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
442
+
443
+ for j in range(3):
444
+ # fig.add_subplot(gs[j, 0])
445
+ plt.subplot(7,1,j*2+2)
446
+ plt.pcolormesh(t, f, np.abs(spectrogram[i, :, :, j]+1j*spectrogram[i, :, :, j+3]).T, vmax=2*np.std(spectrogram[i, :, :, j]+1j*spectrogram[i, :, :, j+3]), cmap="jet", shading='auto')
447
+ plt.autoscale(enable=True, axis='x', tight=True)
448
+ plt.gca().set_xticklabels([])
449
+ if j == 1:
450
+ plt.ylabel('Frequency (Hz) or Amplitude')
451
+ plt.text(text_loc[0], text_loc[1], numbers[j*2+1], horizontalalignment='left', verticalalignment='top',
452
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
453
+
454
+ # fig.add_subplot(gs[-1, 0])
455
+ plt.subplot(7,1,7)
456
+ if label is not None:
457
+ plt.plot(t, label[i, :, 0, 1], '--C0', linewidth=1)
458
+ plt.plot(t, label[i, :, 0, 2], '--C3', linewidth=1)
459
+ plt.plot(t, label[i, :, 0, 3], '--C1', linewidth=1)
460
+ plt.plot(t, pred[i, :, 0, 1], 'C0', label='P', linewidth=1)
461
+ plt.plot(t, pred[i, :, 0, 2], 'C3', label='S', linewidth=1)
462
+ plt.plot(t, pred[i, :, 0, 3], 'C1', label='PS', linewidth=1)
463
+ plt.plot(t, t*0, 'k', linewidth=1)
464
+ plt.autoscale(enable=True, axis='x', tight=True)
465
+ plt.ylim([-0.05, 1.05])
466
+ plt.text(text_loc[0], text_loc[1], numbers[-1], horizontalalignment='left', verticalalignment='top',
467
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
468
+ plt.legend(loc='upper right', fontsize='small', ncol=1)
469
+ plt.xlabel('Time (s)')
470
+ plt.ylabel('Probability')
471
+ # plt.tight_layout()
472
+ plt.gcf().align_labels()
473
+
474
+ try:
475
+ plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
476
+ except FileNotFoundError:
477
+ os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
478
+ plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
479
+
480
+ plt.close(i)
481
+ return 0
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ tensorflow
2
+ matplotlib
3
+ pandas
4
+ tqdm
5
+ scipy
6
+ obspy
7
+
setup.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import re
4
+ import sys
5
+ from shutil import rmtree
6
+ from typing import Tuple, List
7
+
8
+ from setuptools import Command, find_packages, setup
9
+
10
+ # Package meta-data.
11
+ name = "PhaseNet"
12
+ description = "PhaseNet"
13
+ url = ""
14
+ email = "wayne.weiqiang@gmail.com"
15
+ author = "Weiqiang Zhu"
16
+ requires_python = ">=3.6.0"
17
+ current_dir = os.path.abspath(os.path.dirname(__file__))
18
+
19
+
20
+ def get_version():
21
+ version_file = os.path.join(current_dir, "phasenet", "__init__.py")
22
+ with io.open(version_file, encoding="utf-8") as f:
23
+ return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', f.read(), re.M).group(1)
24
+
25
+
26
+ # What packages are required for this module to be executed?
27
+ try:
28
+ with open(os.path.join(current_dir, "requirements.txt"), encoding="utf-8") as f:
29
+ required = f.read().split("\n")
30
+ except FileNotFoundError:
31
+ required = []
32
+
33
+ # What packages are optional?
34
+ extras = {"test": ["pytest"]}
35
+
36
+ version = get_version()
37
+
38
+ about = {"__version__": version}
39
+
40
+
41
+ def get_test_requirements():
42
+ requirements = ["pytest"]
43
+ if sys.version_info < (3, 3):
44
+ requirements.append("mock")
45
+ return requirements
46
+
47
+
48
+ def get_long_description():
49
+ # base_dir = os.path.abspath(os.path.dirname(__file__))
50
+ # with io.open(os.path.join(base_dir, "README.md"), encoding="utf-8") as f:
51
+ # return f.read()
52
+ return ""
53
+
54
+
55
+ class UploadCommand(Command):
56
+ """Support setup.py upload."""
57
+
58
+ description = "Build and publish the package."
59
+ user_options: List[Tuple] = []
60
+
61
+ @staticmethod
62
+ def status(s):
63
+ """Print things in bold."""
64
+ print(s)
65
+
66
+ def initialize_options(self):
67
+ pass
68
+
69
+ def finalize_options(self):
70
+ pass
71
+
72
+ def run(self):
73
+ try:
74
+ self.status("Removing previous builds...")
75
+ rmtree(os.path.join(current_dir, "dist"))
76
+ except OSError:
77
+ pass
78
+
79
+ self.status("Building Source and Wheel (universal) distribution...")
80
+ os.system(f"{sys.executable} setup.py sdist bdist_wheel --universal")
81
+
82
+ self.status("Uploading the package to PyPI via Twine...")
83
+ os.system("twine upload dist/*")
84
+
85
+ self.status("Pushing git tags...")
86
+ os.system("git tag v{}".format(about["__version__"]))
87
+ os.system("git push --tags")
88
+
89
+ sys.exit()
90
+
91
+
92
+ setup(
93
+ name=name,
94
+ version=version,
95
+ description=description,
96
+ long_description=get_long_description(),
97
+ long_description_content_type="text/markdown",
98
+ author="Weiqiang Zhu",
99
+ author_email = "wayne.weiqiang@gmail.com",
100
+ license="GPL-3.0",
101
+ url=url,
102
+ packages=find_packages(exclude=["tests", "docs", "dataset", "model", "log"]),
103
+ install_requires=required,
104
+ extras_require=extras,
105
+ classifiers=[
106
+ "License :: OSI Approved :: BSD License",
107
+ "Intended Audience :: Developers",
108
+ "Intended Audience :: Science/Research",
109
+ "Operating System :: OS Independent",
110
+ "Programming Language :: Python",
111
+ "Programming Language :: Python :: 3",
112
+ "Topic :: Software Development :: Libraries",
113
+ "Topic :: Software Development :: Libraries :: Python Modules",
114
+ ],
115
+ cmdclass={"upload": UploadCommand},
116
+ )