zhuwq0 commited on
Commit
7b07ad9
1 Parent(s): 68ba412

upload phasenet

Browse files
.gitattributes CHANGED
@@ -32,3 +32,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ model/190703-214543/checkpoint filter=lfs diff=lfs merge=lfs -text
36
+ model/190703-214543/config.log filter=lfs diff=lfs merge=lfs -text
37
+ model/190703-214543/loss.log filter=lfs diff=lfs merge=lfs -text
38
+ model/190703-214543/model_95.ckpt.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
39
+ model/190703-214543/model_95.ckpt.index filter=lfs diff=lfs merge=lfs -text
40
+ model/190703-214543/model_95.ckpt.meta filter=lfs diff=lfs merge=lfs -text
model/.DS_Store ADDED
Binary file (6.15 kB). View file
 
model/190703-214543/checkpoint ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1606ccb25e1533fa0398c5dbce7f3a45ac77f90b78b99f81a044294ba38a2c0c
3
+ size 83
model/190703-214543/config.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed9dfa705053a5025facc9952c7da6abef19ec5f672d9e50386bf3f2d80294f2
3
+ size 345
model/190703-214543/loss.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccb6f19117497571e19bec5da6012ac7af91f1bd29e931ffd0b23c6b657bb401
3
+ size 8101
model/190703-214543/model_95.ckpt.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ee2c15dd78fb15de45a55ad64a446f1a0ced152ba4ac5c506d82b9194da85b4
3
+ size 3226256
model/190703-214543/model_95.ckpt.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f96b553b76be4ebae9a455eaf8d83cfa8c0e110f06cfba958de2568e5b6b2780
3
+ size 7223
model/190703-214543/model_95.ckpt.meta ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ebd154a5ba0721ba8bbb627ba61b556ee60660eb34bbcd1b1f50396b07cc4ed
3
+ size 2172055
phasenet/.DS_Store ADDED
Binary file (6.15 kB). View file
 
phasenet/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.0"
phasenet/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (178 Bytes). View file
 
phasenet/__pycache__/detect_peaks.cpython-39.pyc ADDED
Binary file (6.71 kB). View file
 
phasenet/__pycache__/model.cpython-39.pyc ADDED
Binary file (12 kB). View file
 
phasenet/__pycache__/postprocess.cpython-39.pyc ADDED
Binary file (10.6 kB). View file
 
phasenet/app.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import defaultdict, namedtuple
3
+ from datetime import datetime, timedelta
4
+ from json import dumps
5
+ from typing import Any, AnyStr, Dict, List, NamedTuple, Union, Optional
6
+
7
+ import numpy as np
8
+ import requests
9
+ import tensorflow as tf
10
+ from fastapi import FastAPI
11
+ from kafka import KafkaProducer
12
+ from pydantic import BaseModel
13
+ from scipy.interpolate import interp1d
14
+
15
+ from model import ModelConfig, UNet
16
+ from postprocess import extract_picks
17
+
18
+ tf.compat.v1.disable_eager_execution()
19
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
20
+ PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
21
+ JSONObject = Dict[AnyStr, Any]
22
+ JSONArray = List[Any]
23
+ JSONStructure = Union[JSONArray, JSONObject]
24
+
25
+ app = FastAPI()
26
+ X_SHAPE = [3000, 1, 3]
27
+ SAMPLING_RATE = 100
28
+
29
+ # load model
30
+ model = UNet(mode="pred")
31
+ sess_config = tf.compat.v1.ConfigProto()
32
+ sess_config.gpu_options.allow_growth = True
33
+
34
+ sess = tf.compat.v1.Session(config=sess_config)
35
+ saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
36
+ init = tf.compat.v1.global_variables_initializer()
37
+ sess.run(init)
38
+ latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190703-214543")
39
+ print(f"restoring model {latest_check_point}")
40
+ saver.restore(sess, latest_check_point)
41
+
42
+ # GAMMA API Endpoint
43
+ GAMMA_API_URL = "http://gamma-api:8001"
44
+ # GAMMA_API_URL = 'http://localhost:8001'
45
+ # GAMMA_API_URL = "http://gamma.quakeflow.com"
46
+ # GAMMA_API_URL = "http://127.0.0.1:8001"
47
+
48
+ # Kafak producer
49
+ use_kafka = False
50
+
51
+ try:
52
+ print("Connecting to k8s kafka")
53
+ BROKER_URL = "quakeflow-kafka-headless:9092"
54
+ # BROKER_URL = "34.83.137.139:9094"
55
+ producer = KafkaProducer(
56
+ bootstrap_servers=[BROKER_URL],
57
+ key_serializer=lambda x: dumps(x).encode("utf-8"),
58
+ value_serializer=lambda x: dumps(x).encode("utf-8"),
59
+ )
60
+ use_kafka = True
61
+ print("k8s kafka connection success!")
62
+ except BaseException:
63
+ print("k8s Kafka connection error")
64
+ try:
65
+ print("Connecting to local kafka")
66
+ producer = KafkaProducer(
67
+ bootstrap_servers=["localhost:9092"],
68
+ key_serializer=lambda x: dumps(x).encode("utf-8"),
69
+ value_serializer=lambda x: dumps(x).encode("utf-8"),
70
+ )
71
+ use_kafka = True
72
+ print("local kafka connection success!")
73
+ except BaseException:
74
+ print("local Kafka connection error")
75
+ print(f"Kafka status: {use_kafka}")
76
+
77
+
78
+ def normalize_batch(data, window=3000):
79
+ """
80
+ data: nsta, nt, nch
81
+ """
82
+ shift = window // 2
83
+ nsta, nt, nch = data.shape
84
+
85
+ # std in slide windows
86
+ data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0)), mode="reflect")
87
+ t = np.arange(0, nt, shift, dtype="int")
88
+ std = np.zeros([nsta, len(t) + 1, nch])
89
+ mean = np.zeros([nsta, len(t) + 1, nch])
90
+ for i in range(1, len(t)):
91
+ std[:, i, :] = np.std(data_pad[:, i * shift : i * shift + window, :], axis=1)
92
+ mean[:, i, :] = np.mean(data_pad[:, i * shift : i * shift + window, :], axis=1)
93
+
94
+ t = np.append(t, nt)
95
+ # std[:, -1, :] = np.std(data_pad[:, -window:, :], axis=1)
96
+ # mean[:, -1, :] = np.mean(data_pad[:, -window:, :], axis=1)
97
+ std[:, -1, :], mean[:, -1, :] = std[:, -2, :], mean[:, -2, :]
98
+ std[:, 0, :], mean[:, 0, :] = std[:, 1, :], mean[:, 1, :]
99
+ std[std == 0] = 1
100
+
101
+ # ## normalize data with interplated std
102
+ t_interp = np.arange(nt, dtype="int")
103
+ std_interp = interp1d(t, std, axis=1, kind="slinear")(t_interp)
104
+ mean_interp = interp1d(t, mean, axis=1, kind="slinear")(t_interp)
105
+ data = (data - mean_interp) / std_interp
106
+
107
+ return data
108
+
109
+
110
+ def preprocess(data):
111
+ raw = data.copy()
112
+ data = normalize_batch(data)
113
+ if len(data.shape) == 3:
114
+ data = data[:, :, np.newaxis, :]
115
+ raw = raw[:, :, np.newaxis, :]
116
+ return data, raw
117
+
118
+
119
+ def calc_timestamp(timestamp, sec):
120
+ timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec)
121
+ return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
122
+
123
+
124
+ def format_picks(picks, dt, amplitudes):
125
+ picks_ = []
126
+ for pick, amplitude in zip(picks, amplitudes):
127
+ for idxs, probs, amps in zip(pick.p_idx, pick.p_prob, amplitude.p_amp):
128
+ for idx, prob, amp in zip(idxs, probs, amps):
129
+ picks_.append(
130
+ {
131
+ "id": pick.fname,
132
+ "timestamp": calc_timestamp(pick.t0, float(idx) * dt),
133
+ "prob": prob,
134
+ "amp": amp,
135
+ "type": "p",
136
+ }
137
+ )
138
+ for idxs, probs, amps in zip(pick.s_idx, pick.s_prob, amplitude.s_amp):
139
+ for idx, prob, amp in zip(idxs, probs, amps):
140
+ picks_.append(
141
+ {
142
+ "id": pick.fname,
143
+ "timestamp": calc_timestamp(pick.t0, float(idx) * dt),
144
+ "prob": prob,
145
+ "amp": amp,
146
+ "type": "s",
147
+ }
148
+ )
149
+ return picks_
150
+
151
+
152
+ def format_data(data):
153
+
154
+ # chn2idx = {"ENZ": {"E":0, "N":1, "Z":2},
155
+ # "123": {"3":0, "2":1, "1":2},
156
+ # "12Z": {"1":0, "2":1, "Z":2}}
157
+ chn2idx = {"E": 0, "N": 1, "Z": 2, "3": 0, "2": 1, "1": 2}
158
+ Data = NamedTuple("data", [("id", list), ("timestamp", list), ("vec", list), ("dt", float)])
159
+
160
+ # Group by station
161
+ chn_ = defaultdict(list)
162
+ t0_ = defaultdict(list)
163
+ vv_ = defaultdict(list)
164
+ for i in range(len(data.id)):
165
+ key = data.id[i][:-1]
166
+ chn_[key].append(data.id[i][-1])
167
+ t0_[key].append(datetime.strptime(data.timestamp[i], "%Y-%m-%dT%H:%M:%S.%f").timestamp() * SAMPLING_RATE)
168
+ vv_[key].append(np.array(data.vec[i]))
169
+
170
+ # Merge to Data tuple
171
+ id_ = []
172
+ timestamp_ = []
173
+ vec_ = []
174
+ for k in chn_:
175
+ id_.append(k)
176
+ min_t0 = min(t0_[k])
177
+ timestamp_.append(datetime.fromtimestamp(min_t0 / SAMPLING_RATE).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3])
178
+ vec = np.zeros([X_SHAPE[0], X_SHAPE[-1]])
179
+ for i in range(len(chn_[k])):
180
+ # vec[int(t0_[k][i]-min_t0):len(vv_[k][i]), chn2idx[chn_[k][i]]] = vv_[k][i][int(t0_[k][i]-min_t0):X_SHAPE[0]] - np.mean(vv_[k][i])
181
+ shift = int(t0_[k][i] - min_t0)
182
+ vec[shift : len(vv_[k][i]) + shift, chn2idx[chn_[k][i]]] = vv_[k][i][: X_SHAPE[0] - shift] - np.mean(
183
+ vv_[k][i][: X_SHAPE[0] - shift]
184
+ )
185
+ vec_.append(vec.tolist())
186
+
187
+ return Data(id=id_, timestamp=timestamp_, vec=vec_, dt=1 / SAMPLING_RATE)
188
+ # return {"id": id_, "timestamp": timestamp_, "vec": vec_, "dt":1 / SAMPLING_RATE}
189
+
190
+
191
+ def get_prediction(data, return_preds=False):
192
+
193
+ vec = np.array(data.vec)
194
+ vec, vec_raw = preprocess(vec)
195
+
196
+ feed = {model.X: vec, model.drop_rate: 0, model.is_training: False}
197
+ preds = sess.run(model.preds, feed_dict=feed)
198
+
199
+ picks = extract_picks(preds, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
200
+
201
+ picks = [{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]} for pick in picks]
202
+
203
+ if return_preds:
204
+ return picks, preds
205
+
206
+ return picks
207
+
208
+
209
+ class Data(BaseModel):
210
+ # id: Union[List[str], str]
211
+ # timestamp: Union[List[str], str]
212
+ # vec: Union[List[List[List[float]]], List[List[float]]]
213
+ id: List[str]
214
+ timestamp: List[str]
215
+ vec: Union[List[List[List[float]]], List[List[float]]]
216
+ dt: Optional[float] = 0.01
217
+ ## gamma
218
+ stations: Optional[List[Dict[str, Union[float, str]]]] = None
219
+ config: Optional[Dict[str, Union[List[float], List[int], List[str], float, int, str]]] = None
220
+
221
+
222
+ # @app.on_event("startup")
223
+ # def set_default_executor():
224
+ # from concurrent.futures import ThreadPoolExecutor
225
+ # import asyncio
226
+ #
227
+ # loop = asyncio.get_running_loop()
228
+ # loop.set_default_executor(
229
+ # ThreadPoolExecutor(max_workers=2)
230
+ # )
231
+
232
+
233
+ @app.post("/predict")
234
+ def predict(data: Data):
235
+
236
+ picks = get_prediction(data)
237
+
238
+ return picks
239
+
240
+
241
+ @app.post("/predict_prob")
242
+ def predict(data: Data):
243
+
244
+ picks, preds = get_prediction(data, True)
245
+
246
+ return picks, preds.tolist()
247
+
248
+
249
+ @app.post("/predict_phasenet2gamma")
250
+ def predict(data: Data):
251
+
252
+ picks = get_prediction(data)
253
+
254
+ # if use_kafka:
255
+ # print("Push picks to kafka...")
256
+ # for pick in picks:
257
+ # producer.send("phasenet_picks", key=pick["id"], value=pick)
258
+ try:
259
+ catalog = requests.post(f"{GAMMA_API_URL}/predict", json={"picks": picks,
260
+ "stations": data.stations,
261
+ "config": data.config})
262
+ print(catalog.json()["catalog"])
263
+ return catalog.json()
264
+ except Exception as error:
265
+ print(error)
266
+
267
+ return {}
268
+
269
+ @app.post("/predict_phasenet2gamma2ui")
270
+ def predict(data: Data):
271
+
272
+ picks = get_prediction(data)
273
+
274
+ try:
275
+ catalog = requests.post(f"{GAMMA_API_URL}/predict", json={"picks": picks,
276
+ "stations": data.stations,
277
+ "config": data.config})
278
+ print(catalog.json()["catalog"])
279
+ return catalog.json()
280
+ except Exception as error:
281
+ print(error)
282
+
283
+ if use_kafka:
284
+ print("Push picks to kafka...")
285
+ for pick in picks:
286
+ producer.send("phasenet_picks", key=pick["id"], value=pick)
287
+ print("Push waveform to kafka...")
288
+ for id, timestamp, vec in zip(data.id, data.timestamp, data.vec):
289
+ producer.send("waveform_phasenet", key=id, value={"timestamp": timestamp, "vec": vec, "dt": data.dt})
290
+
291
+ return {}
292
+
293
+
294
+ @app.post("/predict_stream_phasenet2gamma")
295
+ def predict(data: Data):
296
+
297
+ data = format_data(data)
298
+ # for i in range(len(data.id)):
299
+ # plt.clf()
300
+ # plt.subplot(311)
301
+ # plt.plot(np.array(data.vec)[i, :, 0])
302
+ # plt.subplot(312)
303
+ # plt.plot(np.array(data.vec)[i, :, 1])
304
+ # plt.subplot(313)
305
+ # plt.plot(np.array(data.vec)[i, :, 2])
306
+ # plt.savefig(f"{data.id[i]}.png")
307
+
308
+ picks = get_prediction(data)
309
+
310
+ return_value = {}
311
+ try:
312
+ catalog = requests.post(f"{GAMMA_API_URL}/predict_stream", json={"picks": picks})
313
+ print("GMMA:", catalog.json()["catalog"])
314
+ return_value = catalog.json()
315
+ except Exception as error:
316
+ print(error)
317
+
318
+ if use_kafka:
319
+ print("Push picks to kafka...")
320
+ for pick in picks:
321
+ producer.send("phasenet_picks", key=pick["id"], value=pick)
322
+ print("Push waveform to kafka...")
323
+ for id, timestamp, vec in zip(data.id, data.timestamp, data.vec):
324
+ producer.send("waveform_phasenet", key=id, value={"timestamp": timestamp, "vec": vec, "dt": data.dt})
325
+
326
+ return return_value
327
+
328
+
329
+ @app.get("/healthz")
330
+ def healthz():
331
+ return {"status": "ok"}
phasenet/data_reader.py ADDED
@@ -0,0 +1,964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ tf.compat.v1.disable_eager_execution()
4
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
5
+ import logging
6
+ import os
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+ pd.options.mode.chained_assignment = None
12
+ import json
13
+
14
+ # import s3fs
15
+ import h5py
16
+ import obspy
17
+ from scipy.interpolate import interp1d
18
+ from tqdm import tqdm
19
+
20
+
21
+ def py_func_decorator(output_types=None, output_shapes=None, name=None):
22
+ def decorator(func):
23
+ def call(*args, **kwargs):
24
+ nonlocal output_shapes
25
+ # flat_output_types = nest.flatten(output_types)
26
+ flat_output_types = tf.nest.flatten(output_types)
27
+ # flat_values = tf.py_func(
28
+ flat_values = tf.numpy_function(func, inp=args, Tout=flat_output_types, name=name)
29
+ if output_shapes is not None:
30
+ for v, s in zip(flat_values, output_shapes):
31
+ v.set_shape(s)
32
+ # return nest.pack_sequence_as(output_types, flat_values)
33
+ return tf.nest.pack_sequence_as(output_types, flat_values)
34
+
35
+ return call
36
+
37
+ return decorator
38
+
39
+
40
+ def dataset_map(iterator, output_types, output_shapes=None, num_parallel_calls=None, name=None, shuffle=False):
41
+ dataset = tf.data.Dataset.range(len(iterator))
42
+ if shuffle:
43
+ dataset = dataset.shuffle(len(iterator), reshuffle_each_iteration=True)
44
+
45
+ @py_func_decorator(output_types, output_shapes, name=name)
46
+ def index_to_entry(idx):
47
+ return iterator[idx]
48
+
49
+ return dataset.map(index_to_entry, num_parallel_calls=num_parallel_calls)
50
+
51
+
52
+ def normalize(data, axis=(0,)):
53
+ """data shape: (nt, nsta, nch)"""
54
+ data -= np.mean(data, axis=axis, keepdims=True)
55
+ std_data = np.std(data, axis=axis, keepdims=True)
56
+ std_data[std_data == 0] = 1
57
+ data /= std_data
58
+ # data /= (std_data + 1e-12)
59
+ return data
60
+
61
+
62
+ def normalize_long(data, axis=(0,), window=3000):
63
+ """
64
+ data: nt, nch
65
+ """
66
+ nt, nar, nch = data.shape
67
+ if window is None:
68
+ window = nt
69
+ shift = window // 2
70
+
71
+ ## std in slide windows
72
+ data_pad = np.pad(data, ((window // 2, window // 2), (0, 0), (0, 0)), mode="reflect")
73
+ t = np.arange(0, nt, shift, dtype="int")
74
+ std = np.zeros([len(t) + 1, nar, nch])
75
+ mean = np.zeros([len(t) + 1, nar, nch])
76
+ for i in range(1, len(std)):
77
+ std[i, :] = np.std(data_pad[i * shift : i * shift + window, :, :], axis=axis)
78
+ mean[i, :] = np.mean(data_pad[i * shift : i * shift + window, :, :], axis=axis)
79
+
80
+ t = np.append(t, nt)
81
+ # std[-1, :] = np.std(data_pad[-window:, :], axis=0)
82
+ # mean[-1, :] = np.mean(data_pad[-window:, :], axis=0)
83
+ std[-1, ...], mean[-1, ...] = std[-2, ...], mean[-2, ...]
84
+ std[0, ...], mean[0, ...] = std[1, ...], mean[1, ...]
85
+ # std[std == 0] = 1.0
86
+
87
+ ## normalize data with interplated std
88
+ t_interp = np.arange(nt, dtype="int")
89
+ std_interp = interp1d(t, std, axis=0, kind="slinear")(t_interp)
90
+ # std_interp = np.exp(interp1d(t, np.log(std), axis=0, kind="slinear")(t_interp))
91
+ mean_interp = interp1d(t, mean, axis=0, kind="slinear")(t_interp)
92
+ tmp = np.sum(std_interp, axis=(0, 1))
93
+ std_interp[std_interp == 0] = 1.0
94
+ data = (data - mean_interp) / std_interp
95
+ # data = (data - mean_interp)/(std_interp + 1e-12)
96
+
97
+ ### dropout effect of < 3 channel
98
+ nonzero = np.count_nonzero(tmp)
99
+ if (nonzero < 3) and (nonzero > 0):
100
+ data *= 3.0 / nonzero
101
+
102
+ return data
103
+
104
+
105
+ def normalize_batch(data, window=3000):
106
+ """
107
+ data: nsta, nt, nch
108
+ """
109
+ nsta, nt, nar, nch = data.shape
110
+ if window is None:
111
+ window = nt
112
+ shift = window // 2
113
+
114
+ ## std in slide windows
115
+ data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0), (0, 0)), mode="reflect")
116
+ t = np.arange(0, nt, shift, dtype="int")
117
+ std = np.zeros([nsta, len(t) + 1, nar, nch])
118
+ mean = np.zeros([nsta, len(t) + 1, nar, nch])
119
+ for i in range(1, len(t)):
120
+ std[:, i, :, :] = np.std(data_pad[:, i * shift : i * shift + window, :, :], axis=1)
121
+ mean[:, i, :, :] = np.mean(data_pad[:, i * shift : i * shift + window, :, :], axis=1)
122
+
123
+ t = np.append(t, nt)
124
+ # std[:, -1, :] = np.std(data_pad[:, -window:, :], axis=1)
125
+ # mean[:, -1, :] = np.mean(data_pad[:, -window:, :], axis=1)
126
+ std[:, -1, :, :], mean[:, -1, :, :] = std[:, -2, :, :], mean[:, -2, :, :]
127
+ std[:, 0, :, :], mean[:, 0, :, :] = std[:, 1, :, :], mean[:, 1, :, :]
128
+ # std[std == 0] = 1
129
+
130
+ # ## normalize data with interplated std
131
+ t_interp = np.arange(nt, dtype="int")
132
+ std_interp = interp1d(t, std, axis=1, kind="slinear")(t_interp)
133
+ # std_interp = np.exp(interp1d(t, np.log(std), axis=1, kind="slinear")(t_interp))
134
+ mean_interp = interp1d(t, mean, axis=1, kind="slinear")(t_interp)
135
+ tmp = np.sum(std_interp, axis=(1, 2))
136
+ std_interp[std_interp == 0] = 1.0
137
+ data = (data - mean_interp) / std_interp
138
+ # data = (data - mean_interp)/(std_interp + 1e-12)
139
+
140
+ ### dropout effect of < 3 channel
141
+ nonzero = np.count_nonzero(tmp, axis=-1)
142
+ data[nonzero > 0, ...] *= 3.0 / nonzero[nonzero > 0][:, np.newaxis, np.newaxis, np.newaxis]
143
+
144
+ return data
145
+
146
+
147
+ class DataConfig:
148
+
149
+ seed = 123
150
+ use_seed = True
151
+ n_channel = 3
152
+ n_class = 3
153
+ sampling_rate = 100
154
+ dt = 1.0 / sampling_rate
155
+ X_shape = [3000, 1, n_channel]
156
+ Y_shape = [3000, 1, n_class]
157
+ min_event_gap = 3 * sampling_rate
158
+ label_shape = "gaussian"
159
+ label_width = 30
160
+ dtype = "float32"
161
+
162
+ def __init__(self, **kwargs):
163
+ for k, v in kwargs.items():
164
+ setattr(self, k, v)
165
+
166
+
167
+ class DataReader:
168
+ def __init__(self, format="numpy", config=DataConfig(), **kwargs):
169
+ self.buffer = {}
170
+ self.n_channel = config.n_channel
171
+ self.n_class = config.n_class
172
+ self.X_shape = config.X_shape
173
+ self.Y_shape = config.Y_shape
174
+ self.dt = config.dt
175
+ self.dtype = config.dtype
176
+ self.label_shape = config.label_shape
177
+ self.label_width = config.label_width
178
+ self.config = config
179
+ self.format = format
180
+ if "highpass_filter" in kwargs:
181
+ self.highpass_filter = kwargs["highpass_filter"]
182
+ if format in ["numpy", "mseed", "sac"]:
183
+ self.data_dir = kwargs["data_dir"]
184
+ try:
185
+ csv = pd.read_csv(kwargs["data_list"], header=0, sep="[,|\s+]", engine="python")
186
+ except:
187
+ csv = pd.read_csv(kwargs["data_list"], header=0, sep="\t")
188
+ self.data_list = csv["fname"]
189
+ self.num_data = len(self.data_list)
190
+ elif format == "hdf5":
191
+ self.h5 = h5py.File(kwargs["hdf5_file"], "r", libver="latest", swmr=True)
192
+ self.h5_data = self.h5[kwargs["hdf5_group"]]
193
+ self.data_list = list(self.h5_data.keys())
194
+ self.num_data = len(self.data_list)
195
+ elif format == "s3":
196
+ self.s3fs = s3fs.S3FileSystem(
197
+ anon=kwargs["anon"],
198
+ key=kwargs["key"],
199
+ secret=kwargs["secret"],
200
+ client_kwargs={"endpoint_url": kwargs["s3_url"]},
201
+ use_ssl=kwargs["use_ssl"],
202
+ )
203
+ self.num_data = 0
204
+ else:
205
+ raise (f"{format} not support!")
206
+
207
+ def __len__(self):
208
+ return self.num_data
209
+
210
+ def read_numpy(self, fname):
211
+ # try:
212
+ if fname not in self.buffer:
213
+ npz = np.load(fname)
214
+ meta = {}
215
+ if len(npz["data"].shape) == 2:
216
+ meta["data"] = npz["data"][:, np.newaxis, :]
217
+ else:
218
+ meta["data"] = npz["data"]
219
+ if "p_idx" in npz.files:
220
+ if len(npz["p_idx"].shape) == 0:
221
+ meta["itp"] = [[npz["p_idx"]]]
222
+ else:
223
+ meta["itp"] = npz["p_idx"]
224
+ if "s_idx" in npz.files:
225
+ if len(npz["s_idx"].shape) == 0:
226
+ meta["its"] = [[npz["s_idx"]]]
227
+ else:
228
+ meta["its"] = npz["s_idx"]
229
+ if "itp" in npz.files:
230
+ if len(npz["itp"].shape) == 0:
231
+ meta["itp"] = [[npz["itp"]]]
232
+ else:
233
+ meta["itp"] = npz["itp"]
234
+ if "its" in npz.files:
235
+ if len(npz["its"].shape) == 0:
236
+ meta["its"] = [[npz["its"]]]
237
+ else:
238
+ meta["its"] = npz["its"]
239
+ if "station_id" in npz.files:
240
+ meta["station_id"] = npz["station_id"]
241
+ if "sta_id" in npz.files:
242
+ meta["station_id"] = npz["sta_id"]
243
+ if "t0" in npz.files:
244
+ meta["t0"] = npz["t0"]
245
+ self.buffer[fname] = meta
246
+ else:
247
+ meta = self.buffer[fname]
248
+ return meta
249
+ # except:
250
+ # logging.error("Failed reading {}".format(fname))
251
+ # return None
252
+
253
+ def read_hdf5(self, fname):
254
+ data = self.h5_data[fname][()]
255
+ attrs = self.h5_data[fname].attrs
256
+ meta = {}
257
+ if len(data.shape) == 2:
258
+ meta["data"] = data[:, np.newaxis, :]
259
+ else:
260
+ meta["data"] = data
261
+ if "p_idx" in attrs:
262
+ if len(attrs["p_idx"].shape) == 0:
263
+ meta["itp"] = [[attrs["p_idx"]]]
264
+ else:
265
+ meta["itp"] = attrs["p_idx"]
266
+ if "s_idx" in attrs:
267
+ if len(attrs["s_idx"].shape) == 0:
268
+ meta["its"] = [[attrs["s_idx"]]]
269
+ else:
270
+ meta["its"] = attrs["s_idx"]
271
+ if "itp" in attrs:
272
+ if len(attrs["itp"].shape) == 0:
273
+ meta["itp"] = [[attrs["itp"]]]
274
+ else:
275
+ meta["itp"] = attrs["itp"]
276
+ if "its" in attrs:
277
+ if len(attrs["its"].shape) == 0:
278
+ meta["its"] = [[attrs["its"]]]
279
+ else:
280
+ meta["its"] = attrs["its"]
281
+ if "t0" in attrs:
282
+ meta["t0"] = attrs["t0"]
283
+ return meta
284
+
285
+ def read_s3(self, format, fname, bucket, key, secret, s3_url, use_ssl):
286
+ with self.s3fs.open(bucket + "/" + fname, "rb") as fp:
287
+ if format == "numpy":
288
+ meta = self.read_numpy(fp)
289
+ elif format == "mseed":
290
+ meta = self.read_mseed(fp)
291
+ else:
292
+ raise (f"Format {format} not supported")
293
+ return meta
294
+
295
+ def read_mseed(self, fname):
296
+
297
+ mseed = obspy.read(fname)
298
+ mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
299
+ mseed = mseed.merge(fill_value=0)
300
+ if self.highpass_filter > 0:
301
+ mseed = mseed.filter("highpass", freq=self.highpass_filter)
302
+ starttime = min([st.stats.starttime for st in mseed])
303
+ endtime = max([st.stats.endtime for st in mseed])
304
+ mseed = mseed.trim(starttime, endtime, pad=True, fill_value=0)
305
+ if abs(mseed[0].stats.sampling_rate - self.config.sampling_rate) > 1:
306
+ logging.warning(
307
+ f"Sampling rate mismatch in {fname.split('/')[-1]}: {mseed[0].stats.sampling_rate}Hz != {self.config.sampling_rate}Hz "
308
+ )
309
+
310
+ order = ["3", "2", "1", "E", "N", "Z"]
311
+ order = {key: i for i, key in enumerate(order)}
312
+ comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
313
+
314
+ t0 = starttime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
315
+ nt = len(mseed[0].data)
316
+ data = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
317
+ ids = [x.get_id() for x in mseed]
318
+
319
+ for j, id in enumerate(sorted(ids, key=lambda x: order[x[-1]])):
320
+ if len(ids) != 3:
321
+ if len(ids) > 3:
322
+ logging.warning(f"More than 3 channels {ids}!")
323
+ j = comp2idx[id[-1]]
324
+ data[:, j] = mseed.select(id=id)[0].data.astype(self.dtype)
325
+
326
+ data = data[:, np.newaxis, :]
327
+ meta = {"data": data, "t0": t0}
328
+ return meta
329
+
330
+ def read_sac(self, fname):
331
+
332
+ mseed = obspy.read(fname)
333
+ mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
334
+ mseed = mseed.merge(fill_value=0)
335
+ if self.highpass_filter > 0:
336
+ mseed = mseed.filter("highpass", freq=self.highpass_filter)
337
+ starttime = min([st.stats.starttime for st in mseed])
338
+ endtime = max([st.stats.endtime for st in mseed])
339
+ mseed = mseed.trim(starttime, endtime, pad=True, fill_value=0)
340
+ if abs(mseed[0].stats.sampling_rate - self.config.sampling_rate) > 1:
341
+ logging.warning(
342
+ f"Sampling rate mismatch in {fname.split('/')[-1]}: {mseed[0].stats.sampling_rate}Hz != {self.config.sampling_rate}Hz "
343
+ )
344
+
345
+ order = ["3", "2", "1", "E", "N", "Z"]
346
+ order = {key: i for i, key in enumerate(order)}
347
+ comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
348
+
349
+ t0 = starttime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
350
+ nt = len(mseed[0].data)
351
+ data = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
352
+ ids = [x.get_id() for x in mseed]
353
+ for j, id in enumerate(sorted(ids, key=lambda x: order[x[-1]])):
354
+ if len(ids) != 3:
355
+ if len(ids) > 3:
356
+ logging.warning(f"More than 3 channels {ids}!")
357
+ j = comp2idx[id[-1]]
358
+ data[:, j] = mseed.select(id=id)[0].data.astype(self.dtype)
359
+
360
+ data = data[:, np.newaxis, :]
361
+ meta = {"data": data, "t0": t0}
362
+ return meta
363
+
364
+ def read_mseed_array(self, fname, stations, amplitude=False, remove_resp=True):
365
+
366
+ data = []
367
+ station_id = []
368
+ t0 = []
369
+ raw_amp = []
370
+
371
+ try:
372
+ mseed = obspy.read(fname)
373
+ read_success = True
374
+ except Exception as e:
375
+ read_success = False
376
+ print(e)
377
+
378
+ if read_success:
379
+ try:
380
+ mseed = mseed.merge(fill_value=0)
381
+ except Exception as e:
382
+ print(e)
383
+
384
+ for i in range(len(mseed)):
385
+ if mseed[i].stats.sampling_rate != self.config.sampling_rate:
386
+ logging.warning(
387
+ f"Resampling {mseed[i].id} from {mseed[i].stats.sampling_rate} to {self.config.sampling_rate} Hz"
388
+ )
389
+ try:
390
+ mseed[i] = mseed[i].interpolate(self.config.sampling_rate, method="linear")
391
+ except Exception as e:
392
+ print(e)
393
+ mseed[i].data = mseed[i].data.astype(float) * 0.0 ## set to zero if resampling fails
394
+
395
+ if self.highpass_filter == 0:
396
+ try:
397
+ mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
398
+ except:
399
+ logging.error(f"Error: spline detrend failed at file {fname}")
400
+ mseed = mseed.detrend("demean")
401
+ else:
402
+ mseed = mseed.filter("highpass", freq=self.highpass_filter)
403
+
404
+ starttime = min([st.stats.starttime for st in mseed])
405
+ endtime = max([st.stats.endtime for st in mseed])
406
+ mseed = mseed.trim(starttime, endtime, pad=True, fill_value=0)
407
+
408
+ order = ["3", "2", "1", "E", "N", "Z"]
409
+ order = {key: i for i, key in enumerate(order)}
410
+ comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
411
+
412
+ nsta = len(stations)
413
+ nt = len(mseed[0].data)
414
+ # for i in range(nsta):
415
+ for sta in stations:
416
+ trace_data = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
417
+ if amplitude:
418
+ trace_amp = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
419
+ empty_station = True
420
+ # sta = stations.iloc[i]["station"]
421
+ # comp = stations.iloc[i]["component"].split(",")
422
+ comp = stations[sta]["component"]
423
+ if amplitude:
424
+ # resp = stations.iloc[i]["response"].split(",")
425
+ resp = stations[sta]["response"]
426
+
427
+ for j, c in enumerate(sorted(comp, key=lambda x: order[x[-1]])):
428
+
429
+ resp_j = resp[j]
430
+ if len(comp) != 3: ## less than 3 component
431
+ j = comp2idx[c]
432
+
433
+ if len(mseed.select(id=sta + c)) == 0:
434
+ print(f"Empty trace: {sta+c} {starttime}")
435
+ continue
436
+ else:
437
+ empty_station = False
438
+
439
+ tmp = mseed.select(id=sta + c)[0].data.astype(self.dtype)
440
+ trace_data[: len(tmp), j] = tmp[:nt]
441
+ if amplitude:
442
+ # if stations.iloc[i]["unit"] == "m/s**2":
443
+ if stations[sta]["unit"] == "m/s**2":
444
+ tmp = mseed.select(id=sta + c)[0]
445
+ tmp = tmp.integrate()
446
+ tmp = tmp.filter("highpass", freq=1.0)
447
+ tmp = tmp.data.astype(self.dtype)
448
+ trace_amp[: len(tmp), j] = tmp[:nt]
449
+ # elif stations.iloc[i]["unit"] == "m/s":
450
+ elif stations[sta]["unit"] == "m/s":
451
+ tmp = mseed.select(id=sta + c)[0].data.astype(self.dtype)
452
+ trace_amp[: len(tmp), j] = tmp[:nt]
453
+ else:
454
+ print(
455
+ f"Error in {stations.iloc[i]['station']}\n{stations.iloc[i]['unit']} should be m/s**2 or m/s!"
456
+ )
457
+ if amplitude and remove_resp:
458
+ # trace_amp[:, j] /= float(resp[j])
459
+ trace_amp[:, j] /= float(resp_j)
460
+
461
+ if not empty_station:
462
+ data.append(trace_data)
463
+ if amplitude:
464
+ raw_amp.append(trace_amp)
465
+ station_id.append(sta)
466
+ t0.append(starttime.datetime.isoformat(timespec="milliseconds"))
467
+
468
+ if len(data) > 0:
469
+ data = np.stack(data)
470
+ if len(data.shape) == 3:
471
+ data = data[:, :, np.newaxis, :]
472
+ if amplitude:
473
+ raw_amp = np.stack(raw_amp)
474
+ if len(raw_amp.shape) == 3:
475
+ raw_amp = raw_amp[:, :, np.newaxis, :]
476
+ else:
477
+ nt = 60 * 60 * self.config.sampling_rate # assume 1 hour data
478
+ data = np.zeros([1, nt, 1, self.config.n_channel], dtype=self.dtype)
479
+ if amplitude:
480
+ raw_amp = np.zeros([1, nt, 1, self.config.n_channel], dtype=self.dtype)
481
+ t0 = ["1970-01-01T00:00:00.000"]
482
+ station_id = ["None"]
483
+
484
+ if amplitude:
485
+ meta = {"data": data, "t0": t0, "station_id": station_id, "fname": fname.split("/")[-1], "raw_amp": raw_amp}
486
+ else:
487
+ meta = {"data": data, "t0": t0, "station_id": station_id, "fname": fname.split("/")[-1]}
488
+ return meta
489
+
490
+ def generate_label(self, data, phase_list, mask=None):
491
+ # target = np.zeros(self.Y_shape, dtype=self.dtype)
492
+ target = np.zeros_like(data)
493
+
494
+ if self.label_shape == "gaussian":
495
+ label_window = np.exp(
496
+ -((np.arange(-self.label_width // 2, self.label_width // 2 + 1)) ** 2)
497
+ / (2 * (self.label_width / 5) ** 2)
498
+ )
499
+ elif self.label_shape == "triangle":
500
+ label_window = 1 - np.abs(
501
+ 2 / self.label_width * (np.arange(-self.label_width // 2, self.label_width // 2 + 1))
502
+ )
503
+ else:
504
+ print(f"Label shape {self.label_shape} should be guassian or triangle")
505
+ raise
506
+
507
+ for i, phases in enumerate(phase_list):
508
+ for j, idx_list in enumerate(phases):
509
+ for idx in idx_list:
510
+ if np.isnan(idx):
511
+ continue
512
+ idx = int(idx)
513
+ if (idx - self.label_width // 2 >= 0) and (idx + self.label_width // 2 + 1 <= target.shape[0]):
514
+ target[idx - self.label_width // 2 : idx + self.label_width // 2 + 1, j, i + 1] = label_window
515
+
516
+ target[..., 0] = 1 - np.sum(target[..., 1:], axis=-1)
517
+ if mask is not None:
518
+ target[:, mask == 0, :] = 0
519
+
520
+ return target
521
+
522
+ def random_shift(self, sample, itp, its, itp_old=None, its_old=None, shift_range=None):
523
+ # anchor = np.round(1/2 * (min(itp[~np.isnan(itp.astype(float))]) + min(its[~np.isnan(its.astype(float))]))).astype(int)
524
+ flattern = lambda x: np.array([i for trace in x for i in trace], dtype=float)
525
+ shift_pick = lambda x, shift: [[i - shift for i in trace] for trace in x]
526
+ itp_flat = flattern(itp)
527
+ its_flat = flattern(its)
528
+ if (itp_old is None) and (its_old is None):
529
+ hi = np.round(np.median(itp_flat[~np.isnan(itp_flat)])).astype(int)
530
+ lo = -(sample.shape[0] - np.round(np.median(its_flat[~np.isnan(its_flat)])).astype(int))
531
+ if shift_range is None:
532
+ shift = np.random.randint(low=lo, high=hi + 1)
533
+ else:
534
+ shift = np.random.randint(low=max(lo, shift_range[0]), high=min(hi + 1, shift_range[1]))
535
+ else:
536
+ itp_old_flat = flattern(itp_old)
537
+ its_old_flat = flattern(its_old)
538
+ itp_ref = np.round(np.min(itp_flat[~np.isnan(itp_flat)])).astype(int)
539
+ its_ref = np.round(np.max(its_flat[~np.isnan(its_flat)])).astype(int)
540
+ itp_old_ref = np.round(np.min(itp_old_flat[~np.isnan(itp_old_flat)])).astype(int)
541
+ its_old_ref = np.round(np.max(its_old_flat[~np.isnan(its_old_flat)])).astype(int)
542
+ # min_event_gap = np.round(self.min_event_gap*(its_ref-itp_ref)).astype(int)
543
+ # min_event_gap_old = np.round(self.min_event_gap*(its_old_ref-itp_old_ref)).astype(int)
544
+ if shift_range is None:
545
+ hi = list(range(max(its_ref - itp_old_ref + self.min_event_gap, 0), itp_ref))
546
+ lo = list(range(-(sample.shape[0] - its_ref), -(max(its_old_ref - itp_ref + self.min_event_gap, 0))))
547
+ else:
548
+ lo_ = max(-(sample.shape[0] - its_ref), shift_range[0])
549
+ hi_ = min(itp_ref, shift_range[1])
550
+ hi = list(range(max(its_ref - itp_old_ref + self.min_event_gap, 0), hi_))
551
+ lo = list(range(lo_, -(max(its_old_ref - itp_ref + self.min_event_gap, 0))))
552
+ if len(hi + lo) > 0:
553
+ shift = np.random.choice(hi + lo)
554
+ else:
555
+ shift = 0
556
+
557
+ shifted_sample = np.zeros_like(sample)
558
+ if shift > 0:
559
+ shifted_sample[:-shift, ...] = sample[shift:, ...]
560
+ elif shift < 0:
561
+ shifted_sample[-shift:, ...] = sample[:shift, ...]
562
+ else:
563
+ shifted_sample[...] = sample[...]
564
+
565
+ return shifted_sample, shift_pick(itp, shift), shift_pick(its, shift), shift
566
+
567
+ def stack_events(self, sample_old, itp_old, its_old, shift_range=None, mask_old=None):
568
+
569
+ i = np.random.randint(self.num_data)
570
+ base_name = self.data_list[i]
571
+ if self.format == "numpy":
572
+ meta = self.read_numpy(os.path.join(self.data_dir, base_name))
573
+ elif self.format == "hdf5":
574
+ meta = self.read_hdf5(base_name)
575
+ if meta == -1:
576
+ return sample_old, itp_old, its_old
577
+
578
+ sample = np.copy(meta["data"])
579
+ itp = meta["itp"]
580
+ its = meta["its"]
581
+ if mask_old is not None:
582
+ mask = np.copy(meta["mask"])
583
+ sample = normalize(sample)
584
+ sample, itp, its, shift = self.random_shift(sample, itp, its, itp_old, its_old, shift_range)
585
+
586
+ if shift != 0:
587
+ sample_old += sample
588
+ # itp_old = [np.hstack([i, j]) for i,j in zip(itp_old, itp)]
589
+ # its_old = [np.hstack([i, j]) for i,j in zip(its_old, its)]
590
+ itp_old = [i + j for i, j in zip(itp_old, itp)]
591
+ its_old = [i + j for i, j in zip(its_old, its)]
592
+ if mask_old is not None:
593
+ mask_old = mask_old * mask
594
+
595
+ return sample_old, itp_old, its_old, mask_old
596
+
597
+ def cut_window(self, sample, target, itp, its, select_range):
598
+ shift_pick = lambda x, shift: [[i - shift for i in trace] for trace in x]
599
+ sample = sample[select_range[0] : select_range[1]]
600
+ target = target[select_range[0] : select_range[1]]
601
+ return (sample, target, shift_pick(itp, select_range[0]), shift_pick(its, select_range[0]))
602
+
603
+
604
+ class DataReader_train(DataReader):
605
+ def __init__(self, format="numpy", config=DataConfig(), **kwargs):
606
+
607
+ super().__init__(format=format, config=config, **kwargs)
608
+
609
+ self.min_event_gap = config.min_event_gap
610
+ self.buffer_channels = {}
611
+ self.shift_range = [-2000 + self.label_width * 2, 1000 - self.label_width * 2]
612
+ self.select_range = [5000, 8000]
613
+
614
+ def __getitem__(self, i):
615
+
616
+ base_name = self.data_list[i]
617
+ if self.format == "numpy":
618
+ meta = self.read_numpy(os.path.join(self.data_dir, base_name))
619
+ elif self.format == "hdf5":
620
+ meta = self.read_hdf5(base_name)
621
+ if meta == None:
622
+ return (np.zeros(self.X_shape, dtype=self.dtype), np.zeros(self.Y_shape, dtype=self.dtype), base_name)
623
+
624
+ sample = np.copy(meta["data"])
625
+ itp_list = meta["itp"]
626
+ its_list = meta["its"]
627
+
628
+ sample = normalize(sample)
629
+ if np.random.random() < 0.95:
630
+ sample, itp_list, its_list, _ = self.random_shift(sample, itp_list, its_list, shift_range=self.shift_range)
631
+ sample, itp_list, its_list, _ = self.stack_events(sample, itp_list, its_list, shift_range=self.shift_range)
632
+ target = self.generate_label(sample, [itp_list, its_list])
633
+ sample, target, itp_list, its_list = self.cut_window(sample, target, itp_list, its_list, self.select_range)
634
+ else:
635
+ ## noise
636
+ assert self.X_shape[0] <= min(min(itp_list))
637
+ sample = sample[: self.X_shape[0], ...]
638
+ target = np.zeros(self.Y_shape).astype(self.dtype)
639
+ itp_list = [[]]
640
+ its_list = [[]]
641
+
642
+ sample = normalize(sample)
643
+ return (sample.astype(self.dtype), target.astype(self.dtype), base_name)
644
+
645
+ def dataset(self, batch_size, num_parallel_calls=2, shuffle=True, drop_remainder=True):
646
+ dataset = dataset_map(
647
+ self,
648
+ output_types=(self.dtype, self.dtype, "string"),
649
+ output_shapes=(self.X_shape, self.Y_shape, None),
650
+ num_parallel_calls=num_parallel_calls,
651
+ shuffle=shuffle,
652
+ )
653
+ dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
654
+ return dataset
655
+
656
+
657
+ class DataReader_test(DataReader):
658
+ def __init__(self, format="numpy", config=DataConfig(), **kwargs):
659
+
660
+ super().__init__(format=format, config=config, **kwargs)
661
+
662
+ self.select_range = [5000, 8000]
663
+
664
+ def __getitem__(self, i):
665
+
666
+ base_name = self.data_list[i]
667
+ if self.format == "numpy":
668
+ meta = self.read_numpy(os.path.join(self.data_dir, base_name))
669
+ elif self.format == "hdf5":
670
+ meta = self.read_hdf5(base_name)
671
+ if meta == -1:
672
+ return (np.zeros(self.Y_shape, dtype=self.dtype), np.zeros(self.X_shape, dtype=self.dtype), base_name)
673
+
674
+ sample = np.copy(meta["data"])
675
+ itp_list = meta["itp"]
676
+ its_list = meta["its"]
677
+
678
+ # sample, itp_list, its_list, _ = self.random_shift(sample, itp_list, its_list, shift_range=self.shift_range)
679
+ target = self.generate_label(sample, [itp_list, its_list])
680
+ sample, target, itp_list, its_list = self.cut_window(sample, target, itp_list, its_list, self.select_range)
681
+
682
+ sample = normalize(sample)
683
+ return (sample, target, base_name, itp_list, its_list)
684
+
685
+ def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainder=False):
686
+ dataset = dataset_map(
687
+ self,
688
+ output_types=(self.dtype, self.dtype, "string", "int64", "int64"),
689
+ output_shapes=(self.X_shape, self.Y_shape, None, None, None),
690
+ num_parallel_calls=num_parallel_calls,
691
+ shuffle=shuffle,
692
+ )
693
+ dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
694
+ return dataset
695
+
696
+
697
+ class DataReader_pred(DataReader):
698
+ def __init__(self, format="numpy", amplitude=True, config=DataConfig(), **kwargs):
699
+
700
+ super().__init__(format=format, config=config, **kwargs)
701
+
702
+ self.amplitude = amplitude
703
+ self.X_shape = self.get_data_shape()
704
+
705
+ def get_data_shape(self):
706
+ base_name = self.data_list[0]
707
+ if self.format == "numpy":
708
+ meta = self.read_numpy(os.path.join(self.data_dir, base_name))
709
+ elif self.format == "mseed":
710
+ meta = self.read_mseed(os.path.join(self.data_dir, base_name))
711
+ elif self.format == "sac":
712
+ meta = self.read_sac(os.path.join(self.data_dir, base_name))
713
+ elif self.format == "hdf5":
714
+ meta = self.read_hdf5(base_name)
715
+ return meta["data"].shape
716
+
717
+ def adjust_missingchannels(self, data):
718
+ tmp = np.max(np.abs(data), axis=0, keepdims=True)
719
+ assert tmp.shape[-1] == data.shape[-1]
720
+ if np.count_nonzero(tmp) > 0:
721
+ data *= data.shape[-1] / np.count_nonzero(tmp)
722
+ return data
723
+
724
+ def __getitem__(self, i):
725
+
726
+ base_name = self.data_list[i]
727
+
728
+ if self.format == "numpy":
729
+ meta = self.read_numpy(os.path.join(self.data_dir, base_name))
730
+ elif self.format == "mseed":
731
+ meta = self.read_mseed(os.path.join(self.data_dir, base_name))
732
+ elif self.format == "sac":
733
+ meta = self.read_sac(os.path.join(self.data_dir, base_name))
734
+ elif self.format == "hdf5":
735
+ meta = self.read_hdf5(base_name)
736
+ else:
737
+ raise (f"{self.format} does not support!")
738
+ if meta == -1:
739
+ return (np.zeros(self.X_shape, dtype=self.dtype), base_name)
740
+
741
+ raw_amp = np.zeros(self.X_shape, dtype=self.dtype)
742
+ raw_amp[: meta["data"].shape[0], ...] = meta["data"][: self.X_shape[0], ...]
743
+ sample = np.zeros(self.X_shape, dtype=self.dtype)
744
+ sample[: meta["data"].shape[0], ...] = normalize_long(meta["data"])[: self.X_shape[0], ...]
745
+ if abs(meta["data"].shape[0] - self.X_shape[0]) > 1:
746
+ logging.warning(f"Data length mismatch in {base_name}: {meta['data'].shape[0]} != {self.X_shape[0]}")
747
+
748
+ if "t0" in meta:
749
+ t0 = meta["t0"]
750
+ else:
751
+ t0 = "1970-01-01T00:00:00.000"
752
+
753
+ if "station_id" in meta:
754
+ station_id = meta["station_id"].split("/")[-1].rstrip("*")
755
+ else:
756
+ # station_id = base_name.split("/")[-1].rstrip("*")
757
+ station_id = os.path.basename(base_name).rstrip("*")
758
+
759
+ if np.isnan(sample).any() or np.isinf(sample).any():
760
+ logging.warning(f"Data error: Nan or Inf found in {base_name}")
761
+ sample[np.isnan(sample)] = 0
762
+ sample[np.isinf(sample)] = 0
763
+
764
+ # sample = self.adjust_missingchannels(sample)
765
+ if self.amplitude:
766
+ return (sample[: self.X_shape[0], ...], raw_amp[: self.X_shape[0], ...], base_name, t0, station_id)
767
+ else:
768
+ return (sample[: self.X_shape[0], ...], base_name, t0, station_id)
769
+
770
+ def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainder=False):
771
+ if self.amplitude:
772
+ dataset = dataset_map(
773
+ self,
774
+ output_types=(self.dtype, self.dtype, "string", "string", "string"),
775
+ output_shapes=(self.X_shape, self.X_shape, None, None, None),
776
+ num_parallel_calls=num_parallel_calls,
777
+ shuffle=shuffle,
778
+ )
779
+ else:
780
+ dataset = dataset_map(
781
+ self,
782
+ output_types=(self.dtype, "string", "string", "string"),
783
+ output_shapes=(self.X_shape, None, None, None),
784
+ num_parallel_calls=num_parallel_calls,
785
+ shuffle=shuffle,
786
+ )
787
+ dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
788
+ return dataset
789
+
790
+
791
+ class DataReader_mseed_array(DataReader):
792
+ def __init__(self, stations, amplitude=True, remove_resp=True, config=DataConfig(), **kwargs):
793
+
794
+ super().__init__(format="mseed", config=config, **kwargs)
795
+
796
+ # self.stations = pd.read_json(stations)
797
+ with open(stations, "r") as f:
798
+ self.stations = json.load(f)
799
+ print(pd.DataFrame.from_dict(self.stations, orient="index").to_string())
800
+
801
+ self.amplitude = amplitude
802
+ self.remove_resp = remove_resp
803
+ self.X_shape = self.get_data_shape()
804
+
805
+ def get_data_shape(self):
806
+ fname = os.path.join(self.data_dir, self.data_list[0])
807
+ meta = self.read_mseed_array(fname, self.stations, self.amplitude, self.remove_resp)
808
+ return meta["data"].shape
809
+
810
+ def __getitem__(self, i):
811
+
812
+ fp = os.path.join(self.data_dir, self.data_list[i])
813
+ # try:
814
+ meta = self.read_mseed_array(fp, self.stations, self.amplitude, self.remove_resp)
815
+ # except Exception as e:
816
+ # logging.error(f"Failed reading {fp}: {e}")
817
+ # if self.amplitude:
818
+ # return (np.zeros(self.X_shape).astype(self.dtype), np.zeros(self.X_shape).astype(self.dtype),
819
+ # [self.stations.iloc[i]["station"] for i in range(len(self.stations))], ["0" for i in range(len(self.stations))])
820
+ # else:
821
+ # return (np.zeros(self.X_shape).astype(self.dtype), ["" for i in range(len(self.stations))],
822
+ # [self.stations.iloc[i]["station"] for i in range(len(self.stations))])
823
+
824
+ sample = np.zeros([len(meta["data"]), *self.X_shape[1:]], dtype=self.dtype)
825
+ sample[:, : meta["data"].shape[1], :, :] = normalize_batch(meta["data"])[:, : self.X_shape[1], :, :]
826
+ if np.isnan(sample).any() or np.isinf(sample).any():
827
+ logging.warning(f"Data error: Nan or Inf found in {fp}")
828
+ sample[np.isnan(sample)] = 0
829
+ sample[np.isinf(sample)] = 0
830
+ t0 = meta["t0"]
831
+ base_name = meta["fname"]
832
+ station_id = meta["station_id"]
833
+ # base_name = [self.stations.iloc[i]["station"]+"."+t0[i] for i in range(len(self.stations))]
834
+ # base_name = [self.stations.iloc[i]["station"] for i in range(len(self.stations))]
835
+
836
+ if self.amplitude:
837
+ raw_amp = np.zeros([len(meta["raw_amp"]), *self.X_shape[1:]], dtype=self.dtype)
838
+ raw_amp[:, : meta["raw_amp"].shape[1], :, :] = meta["raw_amp"][:, : self.X_shape[1], :, :]
839
+ if np.isnan(raw_amp).any() or np.isinf(raw_amp).any():
840
+ logging.warning(f"Data error: Nan or Inf found in {fp}")
841
+ raw_amp[np.isnan(raw_amp)] = 0
842
+ raw_amp[np.isinf(raw_amp)] = 0
843
+ return (sample, raw_amp, base_name, t0, station_id)
844
+ else:
845
+ return (sample, base_name, t0, station_id)
846
+
847
+ def dataset(self, num_parallel_calls=1, shuffle=False):
848
+ if self.amplitude:
849
+ dataset = dataset_map(
850
+ self,
851
+ output_types=(self.dtype, self.dtype, "string", "string", "string"),
852
+ output_shapes=([None, *self.X_shape[1:]], [None, *self.X_shape[1:]], None, None, None),
853
+ num_parallel_calls=num_parallel_calls,
854
+ )
855
+ else:
856
+ dataset = dataset_map(
857
+ self,
858
+ output_types=(self.dtype, "string", "string", "string"),
859
+ output_shapes=([None, *self.X_shape[1:]], None, None, None),
860
+ num_parallel_calls=num_parallel_calls,
861
+ )
862
+ dataset = dataset.prefetch(1)
863
+ # dataset = dataset.prefetch(len(self.stations)*2)
864
+ return dataset
865
+
866
+
867
+ ###### test ########
868
+
869
+
870
+ def test_DataReader():
871
+ import os
872
+ import timeit
873
+
874
+ import matplotlib.pyplot as plt
875
+
876
+ if not os.path.exists("test_figures"):
877
+ os.mkdir("test_figures")
878
+
879
+ def plot_sample(sample, fname, label=None):
880
+ plt.clf()
881
+ plt.subplot(211)
882
+ plt.plot(sample[:, 0, -1])
883
+ if label is not None:
884
+ plt.subplot(212)
885
+ plt.plot(label[:, 0, 0])
886
+ plt.plot(label[:, 0, 1])
887
+ plt.plot(label[:, 0, 2])
888
+ plt.savefig(f"test_figures/{fname.decode()}.png")
889
+
890
+ def read(data_reader, batch=1):
891
+ start_time = timeit.default_timer()
892
+ if batch is None:
893
+ dataset = data_reader.dataset(shuffle=False)
894
+ else:
895
+ dataset = data_reader.dataset(1, shuffle=False)
896
+ sess = tf.compat.v1.Session()
897
+
898
+ print(len(data_reader))
899
+ print("-------", tf.data.Dataset.cardinality(dataset))
900
+ num = 0
901
+ x = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
902
+ while True:
903
+ num += 1
904
+ # print(num)
905
+ try:
906
+ out = sess.run(x)
907
+ if len(out) == 2:
908
+ sample, fname = out[0], out[1]
909
+ for i in range(len(sample)):
910
+ plot_sample(sample[i], fname[i])
911
+ else:
912
+ sample, label, fname = out[0], out[1], out[2]
913
+ for i in range(len(sample)):
914
+ plot_sample(sample[i], fname[i], label[i])
915
+ except tf.errors.OutOfRangeError:
916
+ break
917
+ print("End of dataset")
918
+ print("Tensorflow Dataset:\nexecution time = ", timeit.default_timer() - start_time)
919
+
920
+ data_reader = DataReader_train(data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
921
+
922
+ read(data_reader)
923
+
924
+ data_reader = DataReader_train(format="hdf5", hdf5="test_data/data.h5", group="data")
925
+
926
+ read(data_reader)
927
+
928
+ data_reader = DataReader_test(data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
929
+
930
+ read(data_reader)
931
+
932
+ data_reader = DataReader_test(format="hdf5", hdf5="test_data/data.h5", group="data")
933
+
934
+ read(data_reader)
935
+
936
+ data_reader = DataReader_pred(format="numpy", data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
937
+
938
+ read(data_reader)
939
+
940
+ data_reader = DataReader_pred(
941
+ format="mseed", data_list="test_data/mseed_station.csv", data_dir="test_data/waveforms/"
942
+ )
943
+
944
+ read(data_reader)
945
+
946
+ data_reader = DataReader_pred(
947
+ format="mseed", amplitude=True, data_list="test_data/mseed_station.csv", data_dir="test_data/waveforms/"
948
+ )
949
+
950
+ read(data_reader)
951
+
952
+ data_reader = DataReader_mseed_array(
953
+ data_list="test_data/mseed.csv",
954
+ data_dir="test_data/waveforms/",
955
+ stations="test_data/stations.csv",
956
+ remove_resp=False,
957
+ )
958
+
959
+ read(data_reader, batch=None)
960
+
961
+
962
+ if __name__ == "__main__":
963
+
964
+ test_DataReader()
phasenet/detect_peaks.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Detect peaks in data based on their amplitude and other features."""
2
+
3
+ from __future__ import division, print_function
4
+ import warnings
5
+ import numpy as np
6
+
7
+ __author__ = "Marcos Duarte, https://github.com/demotu"
8
+ __version__ = "1.0.6"
9
+ __license__ = "MIT"
10
+
11
+
12
+
13
+ def detect_peaks(x, mph=None, mpd=1, threshold=0, edge='rising',
14
+ kpsh=False, valley=False, show=False, ax=None, title=True):
15
+
16
+ """Detect peaks in data based on their amplitude and other features.
17
+
18
+ Parameters
19
+ ----------
20
+ x : 1D array_like
21
+ data.
22
+ mph : {None, number}, optional (default = None)
23
+ detect peaks that are greater than minimum peak height (if parameter
24
+ `valley` is False) or peaks that are smaller than maximum peak height
25
+ (if parameter `valley` is True).
26
+ mpd : positive integer, optional (default = 1)
27
+ detect peaks that are at least separated by minimum peak distance (in
28
+ number of data).
29
+ threshold : positive number, optional (default = 0)
30
+ detect peaks (valleys) that are greater (smaller) than `threshold`
31
+ in relation to their immediate neighbors.
32
+ edge : {None, 'rising', 'falling', 'both'}, optional (default = 'rising')
33
+ for a flat peak, keep only the rising edge ('rising'), only the
34
+ falling edge ('falling'), both edges ('both'), or don't detect a
35
+ flat peak (None).
36
+ kpsh : bool, optional (default = False)
37
+ keep peaks with same height even if they are closer than `mpd`.
38
+ valley : bool, optional (default = False)
39
+ if True (1), detect valleys (local minima) instead of peaks.
40
+ show : bool, optional (default = False)
41
+ if True (1), plot data in matplotlib figure.
42
+ ax : a matplotlib.axes.Axes instance, optional (default = None).
43
+ title : bool or string, optional (default = True)
44
+ if True, show standard title. If False or empty string, doesn't show
45
+ any title. If string, shows string as title.
46
+
47
+ Returns
48
+ -------
49
+ ind : 1D array_like
50
+ indeces of the peaks in `x`.
51
+
52
+ Notes
53
+ -----
54
+ The detection of valleys instead of peaks is performed internally by simply
55
+ negating the data: `ind_valleys = detect_peaks(-x)`
56
+
57
+ The function can handle NaN's
58
+
59
+ See this IPython Notebook [1]_.
60
+
61
+ References
62
+ ----------
63
+ .. [1] http://nbviewer.ipython.org/github/demotu/BMC/blob/master/notebooks/DetectPeaks.ipynb
64
+
65
+ Examples
66
+ --------
67
+ >>> from detect_peaks import detect_peaks
68
+ >>> x = np.random.randn(100)
69
+ >>> x[60:81] = np.nan
70
+ >>> # detect all peaks and plot data
71
+ >>> ind = detect_peaks(x, show=True)
72
+ >>> print(ind)
73
+
74
+ >>> x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5
75
+ >>> # set minimum peak height = 0 and minimum peak distance = 20
76
+ >>> detect_peaks(x, mph=0, mpd=20, show=True)
77
+
78
+ >>> x = [0, 1, 0, 2, 0, 3, 0, 2, 0, 1, 0]
79
+ >>> # set minimum peak distance = 2
80
+ >>> detect_peaks(x, mpd=2, show=True)
81
+
82
+ >>> x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5
83
+ >>> # detection of valleys instead of peaks
84
+ >>> detect_peaks(x, mph=-1.2, mpd=20, valley=True, show=True)
85
+
86
+ >>> x = [0, 1, 1, 0, 1, 1, 0]
87
+ >>> # detect both edges
88
+ >>> detect_peaks(x, edge='both', show=True)
89
+
90
+ >>> x = [-2, 1, -2, 2, 1, 1, 3, 0]
91
+ >>> # set threshold = 2
92
+ >>> detect_peaks(x, threshold = 2, show=True)
93
+
94
+ >>> x = [-2, 1, -2, 2, 1, 1, 3, 0]
95
+ >>> fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(10, 4))
96
+ >>> detect_peaks(x, show=True, ax=axs[0], threshold=0.5, title=False)
97
+ >>> detect_peaks(x, show=True, ax=axs[1], threshold=1.5, title=False)
98
+
99
+ Version history
100
+ ---------------
101
+ '1.0.6':
102
+ Fix issue of when specifying ax object only the first plot was shown
103
+ Add parameter to choose if a title is shown and input a title
104
+ '1.0.5':
105
+ The sign of `mph` is inverted if parameter `valley` is True
106
+
107
+ """
108
+
109
+ x = np.atleast_1d(x).astype('float64')
110
+ if x.size < 3:
111
+ return np.array([], dtype=int)
112
+ if valley:
113
+ x = -x
114
+ if mph is not None:
115
+ mph = -mph
116
+ # find indices of all peaks
117
+ dx = x[1:] - x[:-1]
118
+ # handle NaN's
119
+ indnan = np.where(np.isnan(x))[0]
120
+ if indnan.size:
121
+ x[indnan] = np.inf
122
+ dx[np.where(np.isnan(dx))[0]] = np.inf
123
+ ine, ire, ife = np.array([[], [], []], dtype=int)
124
+ if not edge:
125
+ ine = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) > 0))[0]
126
+ else:
127
+ if edge.lower() in ['rising', 'both']:
128
+ ire = np.where((np.hstack((dx, 0)) <= 0) & (np.hstack((0, dx)) > 0))[0]
129
+ if edge.lower() in ['falling', 'both']:
130
+ ife = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) >= 0))[0]
131
+ ind = np.unique(np.hstack((ine, ire, ife)))
132
+ # handle NaN's
133
+ if ind.size and indnan.size:
134
+ # NaN's and values close to NaN's cannot be peaks
135
+ ind = ind[np.in1d(ind, np.unique(np.hstack((indnan, indnan-1, indnan+1))), invert=True)]
136
+ # first and last values of x cannot be peaks
137
+ if ind.size and ind[0] == 0:
138
+ ind = ind[1:]
139
+ if ind.size and ind[-1] == x.size-1:
140
+ ind = ind[:-1]
141
+ # remove peaks < minimum peak height
142
+ if ind.size and mph is not None:
143
+ ind = ind[x[ind] >= mph]
144
+ # remove peaks - neighbors < threshold
145
+ if ind.size and threshold > 0:
146
+ dx = np.min(np.vstack([x[ind]-x[ind-1], x[ind]-x[ind+1]]), axis=0)
147
+ ind = np.delete(ind, np.where(dx < threshold)[0])
148
+ # detect small peaks closer than minimum peak distance
149
+ if ind.size and mpd > 1:
150
+ ind = ind[np.argsort(x[ind])][::-1] # sort ind by peak height
151
+ idel = np.zeros(ind.size, dtype=bool)
152
+ for i in range(ind.size):
153
+ if not idel[i]:
154
+ # keep peaks with the same height if kpsh is True
155
+ idel = idel | (ind >= ind[i] - mpd) & (ind <= ind[i] + mpd) \
156
+ & (x[ind[i]] > x[ind] if kpsh else True)
157
+ idel[i] = 0 # Keep current peak
158
+ # remove the small peaks and sort back the indices by their occurrence
159
+ ind = np.sort(ind[~idel])
160
+
161
+ if show:
162
+ if indnan.size:
163
+ x[indnan] = np.nan
164
+ if valley:
165
+ x = -x
166
+ if mph is not None:
167
+ mph = -mph
168
+ _plot(x, mph, mpd, threshold, edge, valley, ax, ind, title)
169
+
170
+ return ind, x[ind]
171
+
172
+
173
+ def _plot(x, mph, mpd, threshold, edge, valley, ax, ind, title):
174
+ """Plot results of the detect_peaks function, see its help."""
175
+ try:
176
+ import matplotlib.pyplot as plt
177
+ except ImportError:
178
+ print('matplotlib is not available.')
179
+ else:
180
+ if ax is None:
181
+ _, ax = plt.subplots(1, 1, figsize=(8, 4))
182
+ no_ax = True
183
+ else:
184
+ no_ax = False
185
+
186
+ ax.plot(x, 'b', lw=1)
187
+ if ind.size:
188
+ label = 'valley' if valley else 'peak'
189
+ label = label + 's' if ind.size > 1 else label
190
+ ax.plot(ind, x[ind], '+', mfc=None, mec='r', mew=2, ms=8,
191
+ label='%d %s' % (ind.size, label))
192
+ ax.legend(loc='best', framealpha=.5, numpoints=1)
193
+ ax.set_xlim(-.02*x.size, x.size*1.02-1)
194
+ ymin, ymax = x[np.isfinite(x)].min(), x[np.isfinite(x)].max()
195
+ yrange = ymax - ymin if ymax > ymin else 1
196
+ ax.set_ylim(ymin - 0.1*yrange, ymax + 0.1*yrange)
197
+ ax.set_xlabel('Data #', fontsize=14)
198
+ ax.set_ylabel('Amplitude', fontsize=14)
199
+ if title:
200
+ if not isinstance(title, str):
201
+ mode = 'Valley detection' if valley else 'Peak detection'
202
+ title = "%s (mph=%s, mpd=%d, threshold=%s, edge='%s')"% \
203
+ (mode, str(mph), mpd, str(threshold), edge)
204
+ ax.set_title(title)
205
+ # plt.grid()
206
+ if no_ax:
207
+ plt.show()
phasenet/model.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ tf.compat.v1.disable_eager_execution()
3
+ import numpy as np
4
+ import logging
5
+ import warnings
6
+ warnings.filterwarnings('ignore', category=UserWarning)
7
+
8
+ class ModelConfig:
9
+
10
+ batch_size = 20
11
+ depths = 5
12
+ filters_root = 8
13
+ kernel_size = [7, 1]
14
+ pool_size = [4, 1]
15
+ dilation_rate = [1, 1]
16
+ class_weights = [1.0, 1.0, 1.0]
17
+ loss_type = "cross_entropy"
18
+ weight_decay = 0.0
19
+ optimizer = "adam"
20
+ momentum = 0.9
21
+ learning_rate = 0.01
22
+ decay_step = 1e9
23
+ decay_rate = 0.9
24
+ drop_rate = 0.0
25
+ summary = True
26
+
27
+ X_shape = [3000, 1, 3]
28
+ n_channel = X_shape[-1]
29
+ Y_shape = [3000, 1, 3]
30
+ n_class = Y_shape[-1]
31
+
32
+ def __init__(self, **kwargs):
33
+ for k,v in kwargs.items():
34
+ setattr(self, k, v)
35
+
36
+ def update_args(self, args):
37
+ for k,v in vars(args).items():
38
+ setattr(self, k, v)
39
+
40
+
41
+ def crop_and_concat(net1, net2):
42
+ """
43
+ the size(net1) <= size(net2)
44
+ """
45
+ # net1_shape = net1.get_shape().as_list()
46
+ # net2_shape = net2.get_shape().as_list()
47
+ # # print(net1_shape)
48
+ # # print(net2_shape)
49
+ # # if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
50
+ # offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
51
+ # size = [-1, net1_shape[1], net1_shape[2], -1]
52
+ # net2_resize = tf.slice(net2, offsets, size)
53
+ # return tf.concat([net1, net2_resize], 3)
54
+
55
+ ## dynamic shape
56
+ chn1 = net1.get_shape().as_list()[-1]
57
+ chn2 = net2.get_shape().as_list()[-1]
58
+ net1_shape = tf.shape(net1)
59
+ net2_shape = tf.shape(net2)
60
+ # print(net1_shape)
61
+ # print(net2_shape)
62
+ # if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
63
+ offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
64
+ size = [-1, net1_shape[1], net1_shape[2], -1]
65
+ net2_resize = tf.slice(net2, offsets, size)
66
+
67
+ out = tf.concat([net1, net2_resize], 3)
68
+ out.set_shape([None, None, None, chn1+chn2])
69
+
70
+ return out
71
+
72
+ # else:
73
+ # offsets = [0, (net1_shape[1] - net2_shape[1]) // 2, (net1_shape[2] - net2_shape[2]) // 2, 0]
74
+ # size = [-1, net2_shape[1], net2_shape[2], -1]
75
+ # net1_resize = tf.slice(net1, offsets, size)
76
+ # return tf.concat([net1_resize, net2], 3)
77
+
78
+
79
+ def crop_only(net1, net2):
80
+ """
81
+ the size(net1) <= size(net2)
82
+ """
83
+ net1_shape = net1.get_shape().as_list()
84
+ net2_shape = net2.get_shape().as_list()
85
+ # print(net1_shape)
86
+ # print(net2_shape)
87
+ # if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
88
+ offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
89
+ size = [-1, net1_shape[1], net1_shape[2], -1]
90
+ net2_resize = tf.slice(net2, offsets, size)
91
+ #return tf.concat([net1, net2_resize], 3)
92
+ return net2_resize
93
+
94
+ class UNet:
95
+ def __init__(self, config=ModelConfig(), input_batch=None, mode='train'):
96
+ self.depths = config.depths
97
+ self.filters_root = config.filters_root
98
+ self.kernel_size = config.kernel_size
99
+ self.dilation_rate = config.dilation_rate
100
+ self.pool_size = config.pool_size
101
+ self.X_shape = config.X_shape
102
+ self.Y_shape = config.Y_shape
103
+ self.n_channel = config.n_channel
104
+ self.n_class = config.n_class
105
+ self.class_weights = config.class_weights
106
+ self.batch_size = config.batch_size
107
+ self.loss_type = config.loss_type
108
+ self.weight_decay = config.weight_decay
109
+ self.optimizer = config.optimizer
110
+ self.learning_rate = config.learning_rate
111
+ self.decay_step = config.decay_step
112
+ self.decay_rate = config.decay_rate
113
+ self.momentum = config.momentum
114
+ self.global_step = tf.compat.v1.get_variable(name="global_step", initializer=0, dtype=tf.int32)
115
+ self.summary_train = []
116
+ self.summary_valid = []
117
+
118
+ self.build(input_batch, mode=mode)
119
+
120
+ def add_placeholders(self, input_batch=None, mode="train"):
121
+ if input_batch is None:
122
+ # self.X = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, self.X_shape[-3], self.X_shape[-2], self.X_shape[-1]], name='X')
123
+ # self.Y = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, self.Y_shape[-3], self.Y_shape[-2], self.n_class], name='y')
124
+ self.X = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, None, None, self.X_shape[-1]], name='X')
125
+ self.Y = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, None, None, self.n_class], name='y')
126
+ else:
127
+ self.X = input_batch[0]
128
+ if mode in ["train", "valid", "test"]:
129
+ self.Y = input_batch[1]
130
+ self.input_batch = input_batch
131
+
132
+ self.is_training = tf.compat.v1.placeholder(dtype=tf.bool, name="is_training")
133
+ # self.keep_prob = tf.compat.v1.placeholder(dtype=tf.float32, name="keep_prob")
134
+ self.drop_rate = tf.compat.v1.placeholder(dtype=tf.float32, name="drop_rate")
135
+
136
+ def add_prediction_op(self):
137
+ logging.info("Model: depths {depths}, filters {filters}, "
138
+ "filter size {kernel_size[0]}x{kernel_size[1]}, "
139
+ "pool size: {pool_size[0]}x{pool_size[1]}, "
140
+ "dilation rate: {dilation_rate[0]}x{dilation_rate[1]}".format(
141
+ depths=self.depths,
142
+ filters=self.filters_root,
143
+ kernel_size=self.kernel_size,
144
+ dilation_rate=self.dilation_rate,
145
+ pool_size=self.pool_size))
146
+
147
+ if self.weight_decay > 0:
148
+ weight_decay = tf.constant(self.weight_decay, dtype=tf.float32, name="weight_constant")
149
+ self.regularizer = tf.keras.regularizers.l2(l=0.5 * (weight_decay))
150
+ else:
151
+ self.regularizer = None
152
+
153
+ self.initializer = tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")
154
+
155
+ # down sample layers
156
+ convs = [None] * self.depths # store output of each depth
157
+
158
+ with tf.compat.v1.variable_scope("Input"):
159
+ net = self.X
160
+ net = tf.compat.v1.layers.conv2d(net,
161
+ filters=self.filters_root,
162
+ kernel_size=self.kernel_size,
163
+ activation=None,
164
+ padding='same',
165
+ dilation_rate=self.dilation_rate,
166
+ kernel_initializer=self.initializer,
167
+ kernel_regularizer=self.regularizer,
168
+ name="input_conv")
169
+ net = tf.compat.v1.layers.batch_normalization(net,
170
+ training=self.is_training,
171
+ name="input_bn")
172
+ net = tf.nn.relu(net,
173
+ name="input_relu")
174
+ # net = tf.nn.dropout(net, self.keep_prob)
175
+ net = tf.compat.v1.layers.dropout(net,
176
+ rate=self.drop_rate,
177
+ training=self.is_training,
178
+ name="input_dropout")
179
+
180
+
181
+ for depth in range(0, self.depths):
182
+ with tf.compat.v1.variable_scope("DownConv_%d" % depth):
183
+ filters = int(2**(depth) * self.filters_root)
184
+
185
+ net = tf.compat.v1.layers.conv2d(net,
186
+ filters=filters,
187
+ kernel_size=self.kernel_size,
188
+ activation=None,
189
+ use_bias=False,
190
+ padding='same',
191
+ dilation_rate=self.dilation_rate,
192
+ kernel_initializer=self.initializer,
193
+ kernel_regularizer=self.regularizer,
194
+ name="down_conv1_{}".format(depth + 1))
195
+ net = tf.compat.v1.layers.batch_normalization(net,
196
+ training=self.is_training,
197
+ name="down_bn1_{}".format(depth + 1))
198
+ net = tf.nn.relu(net,
199
+ name="down_relu1_{}".format(depth+1))
200
+ net = tf.compat.v1.layers.dropout(net,
201
+ rate=self.drop_rate,
202
+ training=self.is_training,
203
+ name="down_dropout1_{}".format(depth + 1))
204
+
205
+ convs[depth] = net
206
+
207
+ if depth < self.depths - 1:
208
+ net = tf.compat.v1.layers.conv2d(net,
209
+ filters=filters,
210
+ kernel_size=self.kernel_size,
211
+ strides=self.pool_size,
212
+ activation=None,
213
+ use_bias=False,
214
+ padding='same',
215
+ dilation_rate=self.dilation_rate,
216
+ kernel_initializer=self.initializer,
217
+ kernel_regularizer=self.regularizer,
218
+ name="down_conv3_{}".format(depth + 1))
219
+ net = tf.compat.v1.layers.batch_normalization(net,
220
+ training=self.is_training,
221
+ name="down_bn3_{}".format(depth + 1))
222
+ net = tf.nn.relu(net,
223
+ name="down_relu3_{}".format(depth+1))
224
+ net = tf.compat.v1.layers.dropout(net,
225
+ rate=self.drop_rate,
226
+ training=self.is_training,
227
+ name="down_dropout3_{}".format(depth + 1))
228
+
229
+
230
+ # up layers
231
+ for depth in range(self.depths - 2, -1, -1):
232
+ with tf.compat.v1.variable_scope("UpConv_%d" % depth):
233
+ filters = int(2**(depth) * self.filters_root)
234
+ net = tf.compat.v1.layers.conv2d_transpose(net,
235
+ filters=filters,
236
+ kernel_size=self.kernel_size,
237
+ strides=self.pool_size,
238
+ activation=None,
239
+ use_bias=False,
240
+ padding="same",
241
+ kernel_initializer=self.initializer,
242
+ kernel_regularizer=self.regularizer,
243
+ name="up_conv0_{}".format(depth+1))
244
+ net = tf.compat.v1.layers.batch_normalization(net,
245
+ training=self.is_training,
246
+ name="up_bn0_{}".format(depth + 1))
247
+ net = tf.nn.relu(net,
248
+ name="up_relu0_{}".format(depth+1))
249
+ net = tf.compat.v1.layers.dropout(net,
250
+ rate=self.drop_rate,
251
+ training=self.is_training,
252
+ name="up_dropout0_{}".format(depth + 1))
253
+
254
+
255
+ #skip connection
256
+ net = crop_and_concat(convs[depth], net)
257
+ #net = crop_only(convs[depth], net)
258
+
259
+ net = tf.compat.v1.layers.conv2d(net,
260
+ filters=filters,
261
+ kernel_size=self.kernel_size,
262
+ activation=None,
263
+ use_bias=False,
264
+ padding='same',
265
+ dilation_rate=self.dilation_rate,
266
+ kernel_initializer=self.initializer,
267
+ kernel_regularizer=self.regularizer,
268
+ name="up_conv1_{}".format(depth + 1))
269
+ net = tf.compat.v1.layers.batch_normalization(net,
270
+ training=self.is_training,
271
+ name="up_bn1_{}".format(depth + 1))
272
+ net = tf.nn.relu(net,
273
+ name="up_relu1_{}".format(depth + 1))
274
+ net = tf.compat.v1.layers.dropout(net,
275
+ rate=self.drop_rate,
276
+ training=self.is_training,
277
+ name="up_dropout1_{}".format(depth + 1))
278
+
279
+
280
+ # Output Map
281
+ with tf.compat.v1.variable_scope("Output"):
282
+ net = tf.compat.v1.layers.conv2d(net,
283
+ filters=self.n_class,
284
+ kernel_size=(1,1),
285
+ activation=None,
286
+ padding='same',
287
+ #dilation_rate=self.dilation_rate,
288
+ kernel_initializer=self.initializer,
289
+ kernel_regularizer=self.regularizer,
290
+ name="output_conv")
291
+ # net = tf.nn.relu(net,
292
+ # name="output_relu")
293
+ # net = tf.compat.v1.layers.dropout(net,
294
+ # rate=self.drop_rate,
295
+ # training=self.is_training,
296
+ # name="output_dropout")
297
+ # net = tf.compat.v1.layers.batch_normalization(net,
298
+ # training=self.is_training,
299
+ # name="output_bn")
300
+ output = net
301
+
302
+ with tf.compat.v1.variable_scope("representation"):
303
+ self.representation = convs[-1]
304
+
305
+ with tf.compat.v1.variable_scope("logits"):
306
+ self.logits = output
307
+ tmp = tf.compat.v1.summary.histogram("logits", self.logits)
308
+ self.summary_train.append(tmp)
309
+
310
+ with tf.compat.v1.variable_scope("preds"):
311
+ self.preds = tf.nn.softmax(output)
312
+ tmp = tf.compat.v1.summary.histogram("preds", self.preds)
313
+ self.summary_train.append(tmp)
314
+
315
+ def add_loss_op(self):
316
+ if self.loss_type == "cross_entropy":
317
+ with tf.compat.v1.variable_scope("cross_entropy"):
318
+ flat_logits = tf.reshape(self.logits, [-1, self.n_class], name="logits")
319
+ flat_labels = tf.reshape(self.Y, [-1, self.n_class], name="labels")
320
+ if (np.array(self.class_weights) != 1).any():
321
+ class_weights = tf.constant(np.array(self.class_weights, dtype=np.float32), name="class_weights")
322
+ weight_map = tf.multiply(flat_labels, class_weights)
323
+ weight_map = tf.reduce_sum(input_tensor=weight_map, axis=1)
324
+ loss_map = tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits,
325
+ labels=flat_labels)
326
+
327
+ weighted_loss = tf.multiply(loss_map, weight_map)
328
+ loss = tf.reduce_mean(input_tensor=weighted_loss)
329
+ else:
330
+ loss = tf.reduce_mean(input_tensor=tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits,
331
+ labels=flat_labels))
332
+
333
+ elif self.loss_type == "IOU":
334
+ with tf.compat.v1.variable_scope("IOU"):
335
+ eps = 1e-7
336
+ loss = 0
337
+ for i in range(1, self.n_class):
338
+ intersection = eps + tf.reduce_sum(input_tensor=self.preds[:,:,:,i] * self.Y[:,:,:,i], axis=[1,2])
339
+ union = eps + tf.reduce_sum(input_tensor=self.preds[:,:,:,i], axis=[1,2]) + tf.reduce_sum(input_tensor=self.Y[:,:,:,i], axis=[1,2])
340
+ loss += 1 - tf.reduce_mean(input_tensor=intersection / union)
341
+ elif self.loss_type == "mean_squared":
342
+ with tf.compat.v1.variable_scope("mean_squared"):
343
+ flat_logits = tf.reshape(self.logits, [-1, self.n_class], name="logits")
344
+ flat_labels = tf.reshape(self.Y, [-1, self.n_class], name="labels")
345
+ with tf.compat.v1.variable_scope("mean_squared"):
346
+ loss = tf.compat.v1.losses.mean_squared_error(labels=flat_labels, predictions=flat_logits)
347
+ else:
348
+ raise ValueError("Unknown loss function: " % self.loss_type)
349
+
350
+ tmp = tf.compat.v1.summary.scalar("train_loss", loss)
351
+ self.summary_train.append(tmp)
352
+ tmp = tf.compat.v1.summary.scalar("valid_loss", loss)
353
+ self.summary_valid.append(tmp)
354
+
355
+ if self.weight_decay > 0:
356
+ with tf.compat.v1.name_scope('weight_loss'):
357
+ tmp = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
358
+ weight_loss = tf.add_n(tmp, name="weight_loss")
359
+ self.loss = loss + weight_loss
360
+ else:
361
+ self.loss = loss
362
+
363
+ def add_training_op(self):
364
+ if self.optimizer == "momentum":
365
+ self.learning_rate_node = tf.compat.v1.train.exponential_decay(learning_rate=self.learning_rate,
366
+ global_step=self.global_step,
367
+ decay_steps=self.decay_step,
368
+ decay_rate=self.decay_rate,
369
+ staircase=True)
370
+ optimizer = tf.compat.v1.train.MomentumOptimizer(learning_rate=self.learning_rate_node,
371
+ momentum=self.momentum)
372
+ elif self.optimizer == "adam":
373
+ self.learning_rate_node = tf.compat.v1.train.exponential_decay(learning_rate=self.learning_rate,
374
+ global_step=self.global_step,
375
+ decay_steps=self.decay_step,
376
+ decay_rate=self.decay_rate,
377
+ staircase=True)
378
+
379
+ optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.learning_rate_node)
380
+ update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
381
+ with tf.control_dependencies(update_ops):
382
+ self.train_op = optimizer.minimize(self.loss, global_step=self.global_step)
383
+ tmp = tf.compat.v1.summary.scalar("learning_rate", self.learning_rate_node)
384
+ self.summary_train.append(tmp)
385
+
386
+ def add_metrics_op(self):
387
+ with tf.compat.v1.variable_scope("metrics"):
388
+
389
+ Y= tf.argmax(input=self.Y, axis=-1)
390
+ confusion_matrix = tf.cast(tf.math.confusion_matrix(
391
+ labels=tf.reshape(Y, [-1]),
392
+ predictions=tf.reshape(self.preds, [-1]),
393
+ num_classes=self.n_class, name='confusion_matrix'),
394
+ dtype=tf.float32)
395
+
396
+ # with tf.variable_scope("P"):
397
+ c = tf.constant(1e-7, dtype=tf.float32)
398
+ precision_P = (confusion_matrix[1,1] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[:,1]) + c)
399
+ recall_P = (confusion_matrix[1,1] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[1,:]) + c)
400
+ f1_P = 2 * precision_P * recall_P / (precision_P + recall_P)
401
+
402
+ tmp1 = tf.compat.v1.summary.scalar("train_precision_p", precision_P)
403
+ tmp2 = tf.compat.v1.summary.scalar("train_recall_p", recall_P)
404
+ tmp3 = tf.compat.v1.summary.scalar("train_f1_p", f1_P)
405
+ self.summary_train.extend([tmp1, tmp2, tmp3])
406
+
407
+ tmp1 = tf.compat.v1.summary.scalar("valid_precision_p", precision_P)
408
+ tmp2 = tf.compat.v1.summary.scalar("valid_recall_p", recall_P)
409
+ tmp3 = tf.compat.v1.summary.scalar("valid_f1_p", f1_P)
410
+ self.summary_valid.extend([tmp1, tmp2, tmp3])
411
+
412
+ # with tf.variable_scope("S"):
413
+ precision_S = (confusion_matrix[2,2] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[:,2]) + c)
414
+ recall_S = (confusion_matrix[2,2] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[2,:]) + c)
415
+ f1_S = 2 * precision_S * recall_S / (precision_S + recall_S)
416
+
417
+ tmp1 = tf.compat.v1.summary.scalar("train_precision_s", precision_S)
418
+ tmp2 = tf.compat.v1.summary.scalar("train_recall_s", recall_S)
419
+ tmp3 = tf.compat.v1.summary.scalar("train_f1_s", f1_S)
420
+ self.summary_train.extend([tmp1, tmp2, tmp3])
421
+
422
+ tmp1 = tf.compat.v1.summary.scalar("valid_precision_s", precision_S)
423
+ tmp2 = tf.compat.v1.summary.scalar("valid_recall_s", recall_S)
424
+ tmp3 = tf.compat.v1.summary.scalar("valid_f1_s", f1_S)
425
+ self.summary_valid.extend([tmp1, tmp2, tmp3])
426
+
427
+ self.precision = [precision_P, precision_S]
428
+ self.recall = [recall_P, recall_S]
429
+ self.f1 = [f1_P, f1_S]
430
+
431
+
432
+
433
+ def train_on_batch(self, sess, inputs_batch, labels_batch, summary_writer, drop_rate=0.0):
434
+ feed = {self.X: inputs_batch,
435
+ self.Y: labels_batch,
436
+ self.drop_rate: drop_rate,
437
+ self.is_training: True}
438
+
439
+ _, step_summary, step, loss = sess.run([self.train_op,
440
+ self.summary_train,
441
+ self.global_step,
442
+ self.loss],
443
+ feed_dict=feed)
444
+ summary_writer.add_summary(step_summary, step)
445
+ return loss
446
+
447
+ def valid_on_batch(self, sess, inputs_batch, labels_batch, summary_writer):
448
+ feed = {self.X: inputs_batch,
449
+ self.Y: labels_batch,
450
+ self.drop_rate: 0,
451
+ self.is_training: False}
452
+
453
+ step_summary, step, loss, preds = sess.run([self.summary_valid,
454
+ self.global_step,
455
+ self.loss,
456
+ self.preds],
457
+ feed_dict=feed)
458
+ summary_writer.add_summary(step_summary, step)
459
+ return loss, preds
460
+
461
+ def test_on_batch(self, sess, summary_writer):
462
+ feed = {self.drop_rate: 0,
463
+ self.is_training: False}
464
+ step_summary, step, loss, preds, \
465
+ X_batch, Y_batch, fname_batch, \
466
+ itp_batch, its_batch = sess.run([self.summary_valid,
467
+ self.global_step,
468
+ self.loss,
469
+ self.preds,
470
+ self.X,
471
+ self.Y,
472
+ self.input_batch[2],
473
+ self.input_batch[3],
474
+ self.input_batch[4]],
475
+ feed_dict=feed)
476
+ summary_writer.add_summary(step_summary, step)
477
+ return loss, preds, X_batch, Y_batch, fname_batch, itp_batch, its_batch
478
+
479
+
480
+ def build(self, input_batch=None, mode='train'):
481
+ self.add_placeholders(input_batch, mode)
482
+ self.add_prediction_op()
483
+ if mode in ["train", "valid", "test"]:
484
+ self.add_loss_op()
485
+ self.add_training_op()
486
+ # self.add_metrics_op()
487
+ self.summary_train = tf.compat.v1.summary.merge(self.summary_train)
488
+ self.summary_valid = tf.compat.v1.summary.merge(self.summary_valid)
489
+ return 0
phasenet/postprocess.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from collections import namedtuple
5
+ from datetime import datetime, timedelta
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ from .detect_peaks import detect_peaks
10
+
11
+ # def extract_picks(preds, fnames=None, station_ids=None, t0=None, config=None):
12
+
13
+ # if preds.shape[-1] == 4:
14
+ # record = namedtuple("phase", ["fname", "station_id", "t0", "p_idx", "p_prob", "s_idx", "s_prob", "ps_idx", "ps_prob"])
15
+ # else:
16
+ # record = namedtuple("phase", ["fname", "station_id", "t0", "p_idx", "p_prob", "s_idx", "s_prob"])
17
+
18
+ # picks = []
19
+ # for i, pred in enumerate(preds):
20
+
21
+ # if config is None:
22
+ # mph_p, mph_s, mpd = 0.3, 0.3, 50
23
+ # else:
24
+ # mph_p, mph_s, mpd = config.min_p_prob, config.min_s_prob, config.mpd
25
+
26
+ # if (fnames is None):
27
+ # fname = f"{i:04d}"
28
+ # else:
29
+ # if isinstance(fnames[i], str):
30
+ # fname = fnames[i]
31
+ # else:
32
+ # fname = fnames[i].decode()
33
+
34
+ # if (station_ids is None):
35
+ # station_id = f"{i:04d}"
36
+ # else:
37
+ # if isinstance(station_ids[i], str):
38
+ # station_id = station_ids[i]
39
+ # else:
40
+ # station_id = station_ids[i].decode()
41
+
42
+ # if (t0 is None):
43
+ # start_time = "1970-01-01T00:00:00.000"
44
+ # else:
45
+ # if isinstance(t0[i], str):
46
+ # start_time = t0[i]
47
+ # else:
48
+ # start_time = t0[i].decode()
49
+
50
+ # p_idx, p_prob, s_idx, s_prob = [], [], [], []
51
+ # for j in range(pred.shape[1]):
52
+ # p_idx_, p_prob_ = detect_peaks(pred[:,j,1], mph=mph_p, mpd=mpd, show=False)
53
+ # s_idx_, s_prob_ = detect_peaks(pred[:,j,2], mph=mph_s, mpd=mpd, show=False)
54
+ # p_idx.append(list(p_idx_))
55
+ # p_prob.append(list(p_prob_))
56
+ # s_idx.append(list(s_idx_))
57
+ # s_prob.append(list(s_prob_))
58
+
59
+ # if pred.shape[-1] == 4:
60
+ # ps_idx, ps_prob = detect_peaks(pred[:,0,3], mph=0.3, mpd=mpd, show=False)
61
+ # picks.append(record(fname, station_id, start_time, list(p_idx), list(p_prob), list(s_idx), list(s_prob), list(ps_idx), list(ps_prob)))
62
+ # else:
63
+ # picks.append(record(fname, station_id, start_time, list(p_idx), list(p_prob), list(s_idx), list(s_prob)))
64
+
65
+ # return picks
66
+
67
+
68
+ def extract_picks(
69
+ preds,
70
+ file_names=None,
71
+ begin_times=None,
72
+ station_ids=None,
73
+ dt=0.01,
74
+ phases=["P", "S"],
75
+ config=None,
76
+ waveforms=None,
77
+ use_amplitude=False,
78
+ upload_waveform=False,
79
+ ):
80
+ """Extract picks from prediction results.
81
+ Args:
82
+ preds ([type]): [Nb, Nt, Ns, Nc] "batch, time, station, channel"
83
+ file_names ([type], optional): [Nb]. Defaults to None.
84
+ station_ids ([type], optional): [Ns]. Defaults to None.
85
+ t0 ([type], optional): [Nb]. Defaults to None.
86
+ config ([type], optional): [description]. Defaults to None.
87
+
88
+ Returns:
89
+ picks [type]: {file_name, station_id, pick_time, pick_prob, pick_type}
90
+ """
91
+
92
+ mph = {}
93
+ if config is None:
94
+ for x in phases:
95
+ mph[x] = 0.3
96
+ mpd = 50
97
+ ## upload waveform
98
+ pre_idx = int(1 / dt)
99
+ post_idx = int(4 / dt)
100
+ else:
101
+ mph["P"] = config.min_p_prob
102
+ mph["S"] = config.min_s_prob
103
+ mph["PS"] = 0.3
104
+ mpd = config.mpd
105
+ pre_idx = int(config.pre_sec / dt)
106
+ post_idx = int(config.post_sec / dt)
107
+
108
+ Nb, Nt, Ns, Nc = preds.shape
109
+
110
+ if file_names is None:
111
+ file_names = [f"{i:04d}" for i in range(Nb)]
112
+ elif not (isinstance(file_names, np.ndarray) or isinstance(file_names, list)):
113
+ if isinstance(file_names, bytes):
114
+ file_names = file_names.decode()
115
+ file_names = [file_names] * Nb
116
+ else:
117
+ file_names = [x.decode() if isinstance(x, bytes) else x for x in file_names]
118
+
119
+ if begin_times is None:
120
+ begin_times = ["1970-01-01T00:00:00.000+00:00"] * Nb
121
+ else:
122
+ begin_times = [x.decode() if isinstance(x, bytes) else x for x in begin_times]
123
+
124
+ picks = []
125
+ for i in range(Nb):
126
+
127
+ file_name = file_names[i]
128
+ begin_time = datetime.fromisoformat(begin_times[i])
129
+
130
+ for j in range(Ns):
131
+ if (station_ids is None) or (len(station_ids[i]) == 0):
132
+ station_id = f"{j:04d}"
133
+ else:
134
+ station_id = station_ids[i].decode() if isinstance(station_ids[i], bytes) else station_ids[i]
135
+
136
+ if (waveforms is not None) and use_amplitude:
137
+ amp = np.max(np.abs(waveforms[i, :, j, :]), axis=-1) ## amplitude over three channelspy
138
+ for k in range(Nc - 1): # 0-th channel noise
139
+ idxs, probs = detect_peaks(preds[i, :, j, k + 1], mph=mph[phases[k]], mpd=mpd, show=False)
140
+ for l, (phase_index, phase_prob) in enumerate(zip(idxs, probs)):
141
+ pick_time = begin_time + timedelta(seconds=phase_index * dt)
142
+ pick = {
143
+ "file_name": file_name,
144
+ "station_id": station_id,
145
+ "begin_time": begin_time.isoformat(timespec="milliseconds"),
146
+ "phase_index": int(phase_index),
147
+ "phase_time": pick_time.isoformat(timespec="milliseconds"),
148
+ "phase_score": round(phase_prob, 3),
149
+ "phase_type": phases[k],
150
+ "dt": dt,
151
+ }
152
+
153
+ ## process waveform
154
+ if waveforms is not None:
155
+ tmp = np.zeros((pre_idx + post_idx, 3))
156
+ lo = phase_index - pre_idx
157
+ hi = phase_index + post_idx
158
+ insert_idx = 0
159
+ if lo < 0:
160
+ lo = 0
161
+ insert_idx = -lo
162
+ if hi > Nt:
163
+ hi = Nt
164
+ tmp[insert_idx : insert_idx + hi - lo, :] = waveforms[i, lo:hi, j, :]
165
+ if upload_waveform:
166
+ pick["waveform"] = tmp.tolist()
167
+ pick["_id"] = f"{pick['station_id']}_{pick['timestamp']}_{pick['type']}"
168
+ if use_amplitude:
169
+ next_pick = idxs[l + 1] if l < len(idxs) - 1 else (phase_index + post_idx * 3)
170
+ pick["phase_amp"] = np.max(
171
+ amp[phase_index : min(phase_index + post_idx * 3, next_pick)]
172
+ ).item() ## peak amplitude
173
+
174
+ picks.append(pick)
175
+
176
+ return picks
177
+
178
+
179
+ def extract_amplitude(data, picks, window_p=10, window_s=5, config=None):
180
+ record = namedtuple("amplitude", ["p_amp", "s_amp"])
181
+ dt = 0.01 if config is None else config.dt
182
+ window_p = int(window_p / dt)
183
+ window_s = int(window_s / dt)
184
+ amps = []
185
+ for i, (da, pi) in enumerate(zip(data, picks)):
186
+ p_amp, s_amp = [], []
187
+ for j in range(da.shape[1]):
188
+ amp = np.max(np.abs(da[:, j, :]), axis=-1)
189
+ # amp = np.median(np.abs(da[:,j,:]), axis=-1)
190
+ # amp = np.linalg.norm(da[:,j,:], axis=-1)
191
+ tmp = []
192
+ for k in range(len(pi.p_idx[j]) - 1):
193
+ tmp.append(np.max(amp[pi.p_idx[j][k] : min(pi.p_idx[j][k] + window_p, pi.p_idx[j][k + 1])]))
194
+ if len(pi.p_idx[j]) >= 1:
195
+ tmp.append(np.max(amp[pi.p_idx[j][-1] : pi.p_idx[j][-1] + window_p]))
196
+ p_amp.append(tmp)
197
+ tmp = []
198
+ for k in range(len(pi.s_idx[j]) - 1):
199
+ tmp.append(np.max(amp[pi.s_idx[j][k] : min(pi.s_idx[j][k] + window_s, pi.s_idx[j][k + 1])]))
200
+ if len(pi.s_idx[j]) >= 1:
201
+ tmp.append(np.max(amp[pi.s_idx[j][-1] : pi.s_idx[j][-1] + window_s]))
202
+ s_amp.append(tmp)
203
+ amps.append(record(p_amp, s_amp))
204
+ return amps
205
+
206
+
207
+ def save_picks(picks, output_dir, amps=None, fname=None):
208
+ if fname is None:
209
+ fname = "picks.csv"
210
+
211
+ int2s = lambda x: ",".join(["[" + ",".join(map(str, i)) + "]" for i in x])
212
+ flt2s = lambda x: ",".join(["[" + ",".join(map("{:0.3f}".format, i)) + "]" for i in x])
213
+ sci2s = lambda x: ",".join(["[" + ",".join(map("{:0.3e}".format, i)) + "]" for i in x])
214
+ if amps is None:
215
+ if hasattr(picks[0], "ps_idx"):
216
+ with open(os.path.join(output_dir, fname), "w") as fp:
217
+ fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\tps_idx\tps_prob\n")
218
+ for pick in picks:
219
+ fp.write(
220
+ f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\t{int2s(pick.ps_idx)}\t{flt2s(pick.ps_prob)}\n"
221
+ )
222
+ fp.close()
223
+ else:
224
+ with open(os.path.join(output_dir, fname), "w") as fp:
225
+ fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\n")
226
+ for pick in picks:
227
+ fp.write(
228
+ f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\n"
229
+ )
230
+ fp.close()
231
+ else:
232
+ with open(os.path.join(output_dir, fname), "w") as fp:
233
+ fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\tp_amp\ts_amp\n")
234
+ for pick, amp in zip(picks, amps):
235
+ fp.write(
236
+ f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\t{sci2s(amp.p_amp)}\t{sci2s(amp.s_amp)}\n"
237
+ )
238
+ fp.close()
239
+
240
+ return 0
241
+
242
+
243
+ def calc_timestamp(timestamp, sec):
244
+ timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec)
245
+ return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
246
+
247
+
248
+ def save_picks_json(picks, output_dir, dt=0.01, amps=None, fname=None):
249
+ if fname is None:
250
+ fname = "picks.json"
251
+
252
+ picks_ = []
253
+ if amps is None:
254
+ for pick in picks:
255
+ for idxs, probs in zip(pick.p_idx, pick.p_prob):
256
+ for idx, prob in zip(idxs, probs):
257
+ picks_.append(
258
+ {
259
+ "id": pick.station_id,
260
+ "timestamp": calc_timestamp(pick.t0, float(idx) * dt),
261
+ "prob": prob.astype(float),
262
+ "type": "p",
263
+ }
264
+ )
265
+ for idxs, probs in zip(pick.s_idx, pick.s_prob):
266
+ for idx, prob in zip(idxs, probs):
267
+ picks_.append(
268
+ {
269
+ "id": pick.station_id,
270
+ "timestamp": calc_timestamp(pick.t0, float(idx) * dt),
271
+ "prob": prob.astype(float),
272
+ "type": "s",
273
+ }
274
+ )
275
+ else:
276
+ for pick, amplitude in zip(picks, amps):
277
+ for idxs, probs, amps in zip(pick.p_idx, pick.p_prob, amplitude.p_amp):
278
+ for idx, prob, amp in zip(idxs, probs, amps):
279
+ picks_.append(
280
+ {
281
+ "id": pick.station_id,
282
+ "timestamp": calc_timestamp(pick.t0, float(idx) * dt),
283
+ "prob": prob.astype(float),
284
+ "amp": amp.astype(float),
285
+ "type": "p",
286
+ }
287
+ )
288
+ for idxs, probs, amps in zip(pick.s_idx, pick.s_prob, amplitude.s_amp):
289
+ for idx, prob, amp in zip(idxs, probs, amps):
290
+ picks_.append(
291
+ {
292
+ "id": pick.station_id,
293
+ "timestamp": calc_timestamp(pick.t0, float(idx) * dt),
294
+ "prob": prob.astype(float),
295
+ "amp": amp.astype(float),
296
+ "type": "s",
297
+ }
298
+ )
299
+ with open(os.path.join(output_dir, fname), "w") as fp:
300
+ json.dump(picks_, fp)
301
+
302
+ return 0
303
+
304
+
305
+ def convert_true_picks(fname, itp, its, itps=None):
306
+ true_picks = []
307
+ if itps is None:
308
+ record = namedtuple("phase", ["fname", "p_idx", "s_idx"])
309
+ for i in range(len(fname)):
310
+ true_picks.append(record(fname[i].decode(), itp[i], its[i]))
311
+ else:
312
+ record = namedtuple("phase", ["fname", "p_idx", "s_idx", "ps_idx"])
313
+ for i in range(len(fname)):
314
+ true_picks.append(record(fname[i].decode(), itp[i], its[i], itps[i]))
315
+
316
+ return true_picks
317
+
318
+
319
+ def calc_metrics(nTP, nP, nT):
320
+ """
321
+ nTP: true positive
322
+ nP: number of positive picks
323
+ nT: number of true picks
324
+ """
325
+ precision = nTP / nP
326
+ recall = nTP / nT
327
+ f1 = 2 * precision * recall / (precision + recall)
328
+ return [precision, recall, f1]
329
+
330
+
331
+ def calc_performance(picks, true_picks, tol=3.0, dt=1.0):
332
+ assert len(picks) == len(true_picks)
333
+ logging.info("Total records: {}".format(len(picks)))
334
+
335
+ count = lambda picks: sum([len(x) for x in picks])
336
+ metrics = {}
337
+ for phase in true_picks[0]._fields:
338
+ if phase == "fname":
339
+ continue
340
+ true_positive, positive, true = 0, 0, 0
341
+ residual = []
342
+ for i in range(len(true_picks)):
343
+ true += count(getattr(true_picks[i], phase))
344
+ positive += count(getattr(picks[i], phase))
345
+ # print(i, phase, getattr(picks[i], phase), getattr(true_picks[i], phase))
346
+ diff = dt * (
347
+ np.array(getattr(picks[i], phase))[:, np.newaxis, :]
348
+ - np.array(getattr(true_picks[i], phase))[:, :, np.newaxis]
349
+ )
350
+ residual.extend(list(diff[np.abs(diff) <= tol]))
351
+ true_positive += np.sum(np.abs(diff) <= tol)
352
+ metrics[phase] = calc_metrics(true_positive, positive, true)
353
+
354
+ logging.info(f"{phase}-phase:")
355
+ logging.info(f"True={true}, Positive={positive}, True Positive={true_positive}")
356
+ logging.info(f"Precision={metrics[phase][0]:.3f}, Recall={metrics[phase][1]:.3f}, F1={metrics[phase][2]:.3f}")
357
+ logging.info(f"Residual mean={np.mean(residual):.4f}, std={np.std(residual):.4f}")
358
+
359
+ return metrics
360
+
361
+
362
+ def save_prob_h5(probs, fnames, output_h5):
363
+ if fnames is None:
364
+ fnames = [f"{i:04d}" for i in range(len(probs))]
365
+ elif type(fnames[0]) is bytes:
366
+ fnames = [f.decode().rstrip(".npz") for f in fnames]
367
+ else:
368
+ fnames = [f.rstrip(".npz") for f in fnames]
369
+ for prob, fname in zip(probs, fnames):
370
+ output_h5.create_dataset(fname, data=prob, dtype="float32")
371
+ return 0
372
+
373
+
374
+ def save_prob(probs, fnames, prob_dir):
375
+ if fnames is None:
376
+ fnames = [f"{i:04d}" for i in range(len(probs))]
377
+ elif type(fnames[0]) is bytes:
378
+ fnames = [f.decode().rstrip(".npz") for f in fnames]
379
+ else:
380
+ fnames = [f.rstrip(".npz") for f in fnames]
381
+ for prob, fname in zip(probs, fnames):
382
+ np.savez(os.path.join(prob_dir, fname + ".npz"), prob=prob)
383
+ return 0
phasenet/predict.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import multiprocessing
4
+ import os
5
+ import pickle
6
+ import time
7
+ from functools import partial
8
+
9
+ import h5py
10
+ import numpy as np
11
+ import pandas as pd
12
+ import tensorflow as tf
13
+ from data_reader import DataReader_mseed_array, DataReader_pred
14
+ from model import ModelConfig, UNet
15
+ from postprocess import (
16
+ extract_amplitude,
17
+ extract_picks,
18
+ save_picks,
19
+ save_picks_json,
20
+ save_prob_h5,
21
+ )
22
+ from pymongo import MongoClient
23
+ from tqdm import tqdm
24
+ from visulization import plot_waveform
25
+
26
+ tf.compat.v1.disable_eager_execution()
27
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
28
+
29
+ username = "root"
30
+ password = "quakeflow123"
31
+ # client = MongoClient(f"mongodb://{username}:{password}@127.0.0.1:27017")
32
+ client = MongoClient(f"mongodb://{username}:{password}@quakeflow-mongodb-headless.default.svc.cluster.local:27017")
33
+
34
+ # db = client["quakeflow"]
35
+ # collection = db["waveform"]
36
+
37
+
38
+ def upload_mongodb(picks):
39
+ db = client["quakeflow"]
40
+ collection = db["waveform"]
41
+ try:
42
+ collection.insert_many(picks)
43
+ except Exception as e:
44
+ print("Warning:", e)
45
+ collection.delete_many({"_id": {"$in": [p["_id"] for p in picks]}})
46
+ collection.insert_many(picks)
47
+
48
+
49
+ def read_args():
50
+
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument("--batch_size", default=20, type=int, help="batch size")
53
+ parser.add_argument("--model_dir", help="Checkpoint directory (default: None)")
54
+ parser.add_argument("--data_dir", default="", help="Input file directory")
55
+ parser.add_argument("--data_list", default="", help="Input csv file")
56
+ parser.add_argument("--hdf5_file", default="", help="Input hdf5 file")
57
+ parser.add_argument("--hdf5_group", default="data", help="data group name in hdf5 file")
58
+ parser.add_argument("--result_dir", default="results", help="Output directory")
59
+ parser.add_argument("--result_fname", default="picks", help="Output file")
60
+ parser.add_argument("--highpass_filter", default=0.0, type=float, help="Highpass filter")
61
+ parser.add_argument("--min_p_prob", default=0.3, type=float, help="Probability threshold for P pick")
62
+ parser.add_argument("--min_s_prob", default=0.3, type=float, help="Probability threshold for S pick")
63
+ parser.add_argument("--mpd", default=50, type=float, help="Minimum peak distance")
64
+ parser.add_argument("--amplitude", action="store_true", help="if return amplitude value")
65
+ parser.add_argument("--format", default="numpy", help="input format")
66
+ parser.add_argument("--s3_url", default="localhost:9000", help="s3 url")
67
+ parser.add_argument("--stations", default="", help="seismic station info")
68
+ parser.add_argument("--plot_figure", action="store_true", help="If plot figure for test")
69
+ parser.add_argument("--save_prob", action="store_true", help="If save result for test")
70
+ parser.add_argument("--upload_waveform", action="store_true", help="If upload waveform to mongodb")
71
+ parser.add_argument("--pre_sec", default=1, type=float, help="Window length before pick")
72
+ parser.add_argument("--post_sec", default=4, type=float, help="Window length after pick")
73
+ args = parser.parse_args()
74
+
75
+ return args
76
+
77
+
78
+ def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None):
79
+ current_time = time.strftime("%y%m%d-%H%M%S")
80
+ if log_dir is None:
81
+ log_dir = os.path.join(args.log_dir, "pred", current_time)
82
+ if not os.path.exists(log_dir):
83
+ os.makedirs(log_dir)
84
+ if (args.plot_figure == True) and (figure_dir is None):
85
+ figure_dir = os.path.join(log_dir, "figures")
86
+ if not os.path.exists(figure_dir):
87
+ os.makedirs(figure_dir)
88
+ if (args.save_prob == True) and (prob_dir is None):
89
+ prob_dir = os.path.join(log_dir, "probs")
90
+ if not os.path.exists(prob_dir):
91
+ os.makedirs(prob_dir)
92
+ if args.save_prob:
93
+ h5 = h5py.File(os.path.join(args.result_dir, "result.h5"), "w", libver="latest")
94
+ prob_h5 = h5.create_group("/prob")
95
+ logging.info("Pred log: %s" % log_dir)
96
+ logging.info("Dataset size: {}".format(data_reader.num_data))
97
+
98
+ with tf.compat.v1.name_scope("Input_Batch"):
99
+ if args.format == "mseed_array":
100
+ batch_size = 1
101
+ else:
102
+ batch_size = args.batch_size
103
+ dataset = data_reader.dataset(batch_size)
104
+ batch = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
105
+
106
+ config = ModelConfig(X_shape=data_reader.X_shape)
107
+ with open(os.path.join(log_dir, "config.log"), "w") as fp:
108
+ fp.write("\n".join("%s: %s" % item for item in vars(config).items()))
109
+
110
+ model = UNet(config=config, input_batch=batch, mode="pred")
111
+ # model = UNet(config=config, mode="pred")
112
+ sess_config = tf.compat.v1.ConfigProto()
113
+ sess_config.gpu_options.allow_growth = True
114
+ # sess_config.log_device_placement = False
115
+
116
+ with tf.compat.v1.Session(config=sess_config) as sess:
117
+
118
+ saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5)
119
+ init = tf.compat.v1.global_variables_initializer()
120
+ sess.run(init)
121
+
122
+ latest_check_point = tf.train.latest_checkpoint(args.model_dir)
123
+ logging.info(f"restoring model {latest_check_point}")
124
+ saver.restore(sess, latest_check_point)
125
+
126
+ picks = []
127
+ amps = [] if args.amplitude else None
128
+ if args.plot_figure:
129
+ multiprocessing.set_start_method("spawn")
130
+ pool = multiprocessing.Pool(multiprocessing.cpu_count())
131
+
132
+ for _ in tqdm(range(0, data_reader.num_data, batch_size), desc="Pred"):
133
+ if args.amplitude:
134
+ pred_batch, X_batch, amp_batch, fname_batch, t0_batch, station_batch = sess.run(
135
+ [model.preds, batch[0], batch[1], batch[2], batch[3], batch[4]],
136
+ feed_dict={model.drop_rate: 0, model.is_training: False},
137
+ )
138
+ # X_batch, amp_batch, fname_batch, t0_batch = sess.run([batch[0], batch[1], batch[2], batch[3]])
139
+ else:
140
+ pred_batch, X_batch, fname_batch, t0_batch, station_batch = sess.run(
141
+ [model.preds, batch[0], batch[1], batch[2], batch[3]],
142
+ feed_dict={model.drop_rate: 0, model.is_training: False},
143
+ )
144
+ # X_batch, fname_batch, t0_batch = sess.run([model.preds, batch[0], batch[1], batch[2]])
145
+ # pred_batch = []
146
+ # for i in range(0, len(X_batch), 1):
147
+ # pred_batch.append(sess.run(model.preds, feed_dict={model.X: X_batch[i:i+1], model.drop_rate: 0, model.is_training: False}))
148
+ # pred_batch = np.vstack(pred_batch)
149
+
150
+ waveforms = None
151
+ if args.upload_waveform:
152
+ waveforms = X_batch
153
+ if args.amplitude:
154
+ waveforms = amp_batch
155
+
156
+ picks_ = extract_picks(
157
+ preds=pred_batch,
158
+ file_names=fname_batch,
159
+ station_ids=station_batch,
160
+ begin_times=t0_batch,
161
+ config=args,
162
+ waveforms=waveforms,
163
+ use_amplitude=args.amplitude,
164
+ upload_waveform=args.upload_waveform,
165
+ )
166
+
167
+ if args.upload_waveform:
168
+ upload_mongodb(picks_)
169
+ picks.extend(picks_)
170
+
171
+ if args.plot_figure:
172
+ if not (isinstance(fname_batch, np.ndarray) or isinstance(fname_batch, list)):
173
+ fname_batch = [fname_batch.decode().rstrip(".mseed") + "_" + x.decode() for x in station_batch]
174
+ else:
175
+ fname_batch = [x.decode() for x in fname_batch]
176
+ pool.starmap(
177
+ partial(
178
+ plot_waveform,
179
+ figure_dir=figure_dir,
180
+ ),
181
+ # zip(X_batch, pred_batch, [x.decode() for x in fname_batch]),
182
+ zip(X_batch, pred_batch, fname_batch),
183
+ )
184
+
185
+ if args.save_prob:
186
+ # save_prob(pred_batch, fname_batch, prob_dir=prob_dir)
187
+ if not (isinstance(fname_batch, np.ndarray) or isinstance(fname_batch, list)):
188
+ fname_batch = [fname_batch.decode().rstrip(".mseed") + "_" + x.decode() for x in station_batch]
189
+ else:
190
+ fname_batch = [x.decode() for x in fname_batch]
191
+ save_prob_h5(pred_batch, fname_batch, prob_h5)
192
+
193
+ if len(picks) > 0:
194
+ # save_picks(picks, args.result_dir, amps=amps, fname=args.result_fname+".csv")
195
+ # save_picks_json(picks, args.result_dir, dt=data_reader.dt, amps=amps, fname=args.result_fname+".json")
196
+ df = pd.DataFrame(picks)
197
+ # df["fname"] = df["file_name"]
198
+ # df["id"] = df["station_id"]
199
+ # df["timestamp"] = df["phase_time"]
200
+ # df["prob"] = df["phase_prob"]
201
+ # df["type"] = df["phase_type"]
202
+ if args.amplitude:
203
+ # df["amp"] = df["phase_amp"]
204
+ df = df[
205
+ [
206
+ "file_name",
207
+ "begin_time",
208
+ "station_id",
209
+ "phase_index",
210
+ "phase_time",
211
+ "phase_score",
212
+ "phase_amp",
213
+ "phase_type",
214
+ ]
215
+ ]
216
+ else:
217
+ df = df[
218
+ ["file_name", "begin_time", "station_id", "phase_index", "phase_time", "phase_score", "phase_type"]
219
+ ]
220
+ # if args.amplitude:
221
+ # df = df[["file_name","station_id","phase_index","phase_time","phase_prob","phase_amplitude", "phase_type","dt",]]
222
+ # else:
223
+ # df = df[["file_name","station_id","phase_index","phase_time","phase_prob","phase_type","dt"]]
224
+ df.to_csv(os.path.join(args.result_dir, args.result_fname + ".csv"), index=False)
225
+
226
+ print(
227
+ f"Done with {len(df[df['phase_type'] == 'P'])} P-picks and {len(df[df['phase_type'] == 'S'])} S-picks"
228
+ )
229
+ else:
230
+ print(f"Done with 0 P-picks and 0 S-picks")
231
+ return 0
232
+
233
+
234
+ def main(args):
235
+
236
+ logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)
237
+
238
+ with tf.compat.v1.name_scope("create_inputs"):
239
+
240
+ if args.format == "mseed_array":
241
+ data_reader = DataReader_mseed_array(
242
+ data_dir=args.data_dir,
243
+ data_list=args.data_list,
244
+ stations=args.stations,
245
+ amplitude=args.amplitude,
246
+ highpass_filter=args.highpass_filter,
247
+ )
248
+ else:
249
+ data_reader = DataReader_pred(
250
+ format=args.format,
251
+ data_dir=args.data_dir,
252
+ data_list=args.data_list,
253
+ hdf5_file=args.hdf5_file,
254
+ hdf5_group=args.hdf5_group,
255
+ amplitude=args.amplitude,
256
+ highpass_filter=args.highpass_filter,
257
+ )
258
+
259
+ pred_fn(args, data_reader, log_dir=args.result_dir)
260
+
261
+ return
262
+
263
+
264
+ if __name__ == "__main__":
265
+ args = read_args()
266
+ main(args)
phasenet/slide_window.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import defaultdict, namedtuple
3
+ from datetime import datetime, timedelta
4
+ from json import dumps
5
+
6
+ import numpy as np
7
+ import tensorflow as tf
8
+
9
+ from model import ModelConfig, UNet
10
+ from postprocess import extract_amplitude, extract_picks
11
+ import pandas as pd
12
+ import obspy
13
+
14
+
15
+ tf.compat.v1.disable_eager_execution()
16
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
17
+ PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
18
+
19
+ # load model
20
+ model = UNet(mode="pred")
21
+ sess_config = tf.compat.v1.ConfigProto()
22
+ sess_config.gpu_options.allow_growth = True
23
+
24
+ sess = tf.compat.v1.Session(config=sess_config)
25
+ saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
26
+ init = tf.compat.v1.global_variables_initializer()
27
+ sess.run(init)
28
+ latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190703-214543")
29
+ print(f"restoring model {latest_check_point}")
30
+ saver.restore(sess, latest_check_point)
31
+
32
+
33
+ def calc_timestamp(timestamp, sec):
34
+ timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec)
35
+ return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
36
+
37
+ def format_picks(picks, dt):
38
+ picks_ = []
39
+ for pick in picks:
40
+ for idxs, probs in zip(pick.p_idx, pick.p_prob):
41
+ for idx, prob in zip(idxs, probs):
42
+ picks_.append(
43
+ {
44
+ "id": pick.fname,
45
+ "timestamp": calc_timestamp(pick.t0, float(idx) * dt),
46
+ "prob": prob,
47
+ "type": "p",
48
+ }
49
+ )
50
+ for idxs, probs in zip(pick.s_idx, pick.s_prob):
51
+ for idx, prob in zip(idxs, probs):
52
+ picks_.append(
53
+ {
54
+ "id": pick.fname,
55
+ "timestamp": calc_timestamp(pick.t0, float(idx) * dt),
56
+ "prob": prob,
57
+ "type": "s",
58
+ }
59
+ )
60
+ return picks_
61
+
62
+
63
+ stream = obspy.read()
64
+ stream = stream.sort() ## Assume it is NPZ sorted
65
+ assert(len(stream) == 3)
66
+ data = []
67
+ for trace in stream:
68
+ data.append(trace.data)
69
+ data = np.array(data).T
70
+ assert(data.shape[-1] == 3)
71
+
72
+ # data_id = stream[0].get_id()[:-1]
73
+ # timestamp = stream[0].stats.starttime.datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
74
+
75
+ data = np.stack([data for i in range(10)]) ## Assume 10 windows
76
+ data = data[:,:,np.newaxis,:] ## batch, nt, dummy_dim, channel
77
+ print(f"{data.shape = }")
78
+ data = (data - data.mean(axis=1, keepdims=True))/data.std(axis=1, keepdims=True)
79
+
80
+ feed = {model.X: data, model.drop_rate: 0, model.is_training: False}
81
+ preds = sess.run(model.preds, feed_dict=feed)
82
+
83
+ picks = extract_picks(preds, fnames=None, station_ids=None, t0=None)
84
+ picks = format_picks(picks, dt=0.01)
85
+
86
+
87
+ picks = pd.DataFrame(picks)
88
+ print(picks)
phasenet/test_app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import obspy
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from datetime import datetime
6
+
7
+ ### Start running the model first:
8
+ ### FLASK_ENV=development FLASK_APP=app.py flask run
9
+
10
+ def read_data(mseed):
11
+ data = []
12
+ mseed = mseed.sort()
13
+ for c in ["E", "N", "Z"]:
14
+ data.append(mseed.select(channel="*"+c)[0].data)
15
+ return np.array(data).T
16
+
17
+ timestamp = lambda x: x.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
18
+
19
+ ## prepare some test data
20
+ mseed = obspy.read()
21
+ data = []
22
+ for i in range(1):
23
+ data.append(read_data(mseed))
24
+ data = {
25
+ "id": ["test01"],
26
+ "timestamp": [timestamp(datetime.now())],
27
+ "vec": np.array(data).tolist(),
28
+ "dt": 0.01
29
+ }
30
+
31
+ ## run prediction
32
+ print(data["id"])
33
+ resp = requests.get("http://localhost:8000/predict", json=data)
34
+ # picks = resp.json()["picks"]
35
+ print(resp.json())
36
+
37
+
38
+ ## plot figure
39
+ plt.figure()
40
+ plt.plot(np.array(data["data"])[0,:,1])
41
+ ylim = plt.ylim()
42
+ plt.plot([picks[0][0][0], picks[0][0][0]], ylim, label="P-phase")
43
+ plt.text(picks[0][0][0], ylim[1]*0.9, f"{picks[0][1][0]:.2f}")
44
+ plt.plot([picks[0][2][0], picks[0][2][0]], ylim, label="S-phase")
45
+ plt.text(picks[0][2][0], ylim[1]*0.9, f"{picks[0][1][0]:.2f}")
46
+ plt.legend()
47
+ plt.savefig("test.png")
phasenet/train.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ tf.compat.v1.disable_eager_execution()
4
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
5
+ import argparse, os, time, logging
6
+ from tqdm import tqdm
7
+ import pandas as pd
8
+ import multiprocessing
9
+ from functools import partial
10
+ import pickle
11
+ from model import UNet, ModelConfig
12
+ from data_reader import DataReader_train, DataReader_test
13
+ from postprocess import extract_picks, save_picks, save_picks_json, extract_amplitude, convert_true_picks, calc_performance
14
+ from visulization import plot_waveform
15
+ from util import EMA, LMA
16
+
17
+ def read_args():
18
+
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--mode", default="train", help="train/train_valid/test/debug")
21
+ parser.add_argument("--epochs", default=100, type=int, help="number of epochs (default: 10)")
22
+ parser.add_argument("--batch_size", default=20, type=int, help="batch size")
23
+ parser.add_argument("--learning_rate", default=0.01, type=float, help="learning rate")
24
+ parser.add_argument("--drop_rate", default=0.0, type=float, help="dropout rate")
25
+ parser.add_argument("--decay_step", default=-1, type=int, help="decay step")
26
+ parser.add_argument("--decay_rate", default=0.9, type=float, help="decay rate")
27
+ parser.add_argument("--momentum", default=0.9, type=float, help="momentum")
28
+ parser.add_argument("--optimizer", default="adam", help="optimizer: adam, momentum")
29
+ parser.add_argument("--summary", default=True, type=bool, help="summary")
30
+ parser.add_argument("--class_weights", nargs="+", default=[1, 1, 1], type=float, help="class weights")
31
+ parser.add_argument("--model_dir", default=None, help="Checkpoint directory (default: None)")
32
+ parser.add_argument("--load_model", action="store_true", help="Load checkpoint")
33
+ parser.add_argument("--log_dir", default="log", help="Log directory (default: log)")
34
+ parser.add_argument("--num_plots", default=10, type=int, help="Plotting training results")
35
+ parser.add_argument("--min_p_prob", default=0.3, type=float, help="Probability threshold for P pick")
36
+ parser.add_argument("--min_s_prob", default=0.3, type=float, help="Probability threshold for S pick")
37
+ parser.add_argument("--format", default="numpy", help="Input data format")
38
+ parser.add_argument("--train_dir", default="./dataset/waveform_train/", help="Input file directory")
39
+ parser.add_argument("--train_list", default="./dataset/waveform.csv", help="Input csv file")
40
+ parser.add_argument("--valid_dir", default=None, help="Input file directory")
41
+ parser.add_argument("--valid_list", default=None, help="Input csv file")
42
+ parser.add_argument("--test_dir", default=None, help="Input file directory")
43
+ parser.add_argument("--test_list", default=None, help="Input csv file")
44
+ parser.add_argument("--result_dir", default="results", help="result directory")
45
+ parser.add_argument("--plot_figure", action="store_true", help="If plot figure for test")
46
+ parser.add_argument("--save_prob", action="store_true", help="If save result for test")
47
+ args = parser.parse_args()
48
+
49
+ return args
50
+
51
+
52
+ def train_fn(args, data_reader, data_reader_valid=None):
53
+
54
+ current_time = time.strftime("%y%m%d-%H%M%S")
55
+ log_dir = os.path.join(args.log_dir, current_time)
56
+ if not os.path.exists(log_dir):
57
+ os.makedirs(log_dir)
58
+ logging.info("Training log: {}".format(log_dir))
59
+ model_dir = os.path.join(log_dir, 'models')
60
+ os.makedirs(model_dir)
61
+
62
+ figure_dir = os.path.join(log_dir, 'figures')
63
+ if not os.path.exists(figure_dir):
64
+ os.makedirs(figure_dir)
65
+
66
+ config = ModelConfig(X_shape=data_reader.X_shape, Y_shape=data_reader.Y_shape)
67
+ if args.decay_step == -1:
68
+ args.decay_step = data_reader.num_data // args.batch_size
69
+ config.update_args(args)
70
+ with open(os.path.join(log_dir, 'config.log'), 'w') as fp:
71
+ fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
72
+
73
+ with tf.compat.v1.name_scope('Input_Batch'):
74
+ dataset = data_reader.dataset(args.batch_size, shuffle=True).repeat()
75
+ batch = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
76
+ if data_reader_valid is not None:
77
+ dataset_valid = data_reader_valid.dataset(args.batch_size, shuffle=False).repeat()
78
+ valid_batch = tf.compat.v1.data.make_one_shot_iterator(dataset_valid).get_next()
79
+
80
+ model = UNet(config, input_batch=batch)
81
+ sess_config = tf.compat.v1.ConfigProto()
82
+ sess_config.gpu_options.allow_growth = True
83
+ # sess_config.log_device_placement = False
84
+
85
+ with tf.compat.v1.Session(config=sess_config) as sess:
86
+
87
+ summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph)
88
+ saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5)
89
+ init = tf.compat.v1.global_variables_initializer()
90
+ sess.run(init)
91
+
92
+ if args.model_dir is not None:
93
+ logging.info("restoring models...")
94
+ latest_check_point = tf.train.latest_checkpoint(args.model_dir)
95
+ saver.restore(sess, latest_check_point)
96
+
97
+ if args.plot_figure:
98
+ multiprocessing.set_start_method('spawn')
99
+ pool = multiprocessing.Pool(multiprocessing.cpu_count())
100
+
101
+ flog = open(os.path.join(log_dir, 'loss.log'), 'w')
102
+ train_loss = EMA(0.9)
103
+ best_valid_loss = np.inf
104
+ for epoch in range(args.epochs):
105
+ progressbar = tqdm(range(0, data_reader.num_data, args.batch_size), desc="{}: epoch {}".format(log_dir.split("/")[-1], epoch))
106
+ for _ in progressbar:
107
+ loss_batch, _, _ = sess.run([model.loss, model.train_op, model.global_step],
108
+ feed_dict={model.drop_rate: args.drop_rate, model.is_training: True})
109
+ train_loss(loss_batch)
110
+ progressbar.set_description("{}: epoch {}, loss={:.6f}, mean={:.6f}".format(log_dir.split("/")[-1], epoch, loss_batch, train_loss.value))
111
+ flog.write("epoch: {}, mean loss: {}\n".format(epoch, train_loss.value))
112
+
113
+ if data_reader_valid is not None:
114
+ valid_loss = LMA()
115
+ progressbar = tqdm(range(0, data_reader_valid.num_data, args.batch_size), desc="Valid:")
116
+ for _ in progressbar:
117
+ loss_batch, preds_batch, X_batch, Y_batch, fname_batch = sess.run([model.loss, model.preds, valid_batch[0], valid_batch[1], valid_batch[2]],
118
+ feed_dict={model.drop_rate: 0, model.is_training: False})
119
+ valid_loss(loss_batch)
120
+ progressbar.set_description("valid, loss={:.6f}, mean={:.6f}".format(loss_batch, valid_loss.value))
121
+ if valid_loss.value < best_valid_loss:
122
+ best_valid_loss = valid_loss.value
123
+ saver.save(sess, os.path.join(model_dir, "model_{}.ckpt".format(epoch)))
124
+ flog.write("Valid: mean loss: {}\n".format(valid_loss.value))
125
+ else:
126
+ loss_batch, preds_batch, X_batch, Y_batch, fname_batch = sess.run([model.loss, model.preds, batch[0], batch[1], batch[2]],
127
+ feed_dict={model.drop_rate: 0, model.is_training: False})
128
+ saver.save(sess, os.path.join(model_dir, "model_{}.ckpt".format(epoch)))
129
+
130
+ if args.plot_figure:
131
+ pool.starmap(
132
+ partial(
133
+ plot_waveform,
134
+ figure_dir=figure_dir,
135
+ ),
136
+ zip(X_batch, preds_batch, [x.decode() for x in fname_batch], Y_batch),
137
+ )
138
+ # plot_waveform(X_batch, preds_batch, fname_batch, label=Y_batch, figure_dir=figure_dir)
139
+ flog.flush()
140
+
141
+ flog.close()
142
+
143
+ return 0
144
+
145
+ def test_fn(args, data_reader):
146
+ current_time = time.strftime("%y%m%d-%H%M%S")
147
+ logging.info("{} log: {}".format(args.mode, current_time))
148
+ if args.model_dir is None:
149
+ logging.error(f"model_dir = None!")
150
+ return -1
151
+ if not os.path.exists(args.result_dir):
152
+ os.makedirs(args.result_dir)
153
+ figure_dir=os.path.join(args.result_dir, "figures")
154
+ if not os.path.exists(figure_dir):
155
+ os.makedirs(figure_dir)
156
+
157
+ config = ModelConfig(X_shape=data_reader.X_shape, Y_shape=data_reader.Y_shape)
158
+ config.update_args(args)
159
+ with open(os.path.join(args.result_dir, 'config.log'), 'w') as fp:
160
+ fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
161
+
162
+ with tf.compat.v1.name_scope('Input_Batch'):
163
+ dataset = data_reader.dataset(args.batch_size, shuffle=False)
164
+ batch = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
165
+
166
+ model = UNet(config, input_batch=batch, mode='test')
167
+ sess_config = tf.compat.v1.ConfigProto()
168
+ sess_config.gpu_options.allow_growth = True
169
+ # sess_config.log_device_placement = False
170
+
171
+ with tf.compat.v1.Session(config=sess_config) as sess:
172
+
173
+ saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
174
+ init = tf.compat.v1.global_variables_initializer()
175
+ sess.run(init)
176
+
177
+ logging.info("restoring models...")
178
+ latest_check_point = tf.train.latest_checkpoint(args.model_dir)
179
+ if latest_check_point is None:
180
+ logging.error(f"No models found in model_dir: {args.model_dir}")
181
+ return -1
182
+ saver.restore(sess, latest_check_point)
183
+
184
+ flog = open(os.path.join(args.result_dir, 'loss.log'), 'w')
185
+ test_loss = LMA()
186
+ progressbar = tqdm(range(0, data_reader.num_data, args.batch_size), desc=args.mode)
187
+ picks = []
188
+ true_picks = []
189
+ for _ in progressbar:
190
+ loss_batch, preds_batch, X_batch, Y_batch, fname_batch, itp_batch, its_batch \
191
+ = sess.run([model.loss, model.preds, batch[0], batch[1], batch[2], batch[3], batch[4]],
192
+ feed_dict={model.drop_rate: 0, model.is_training: False})
193
+
194
+ test_loss(loss_batch)
195
+ progressbar.set_description("{}, loss={:.6f}, mean loss={:6f}".format(args.mode, loss_batch, test_loss.value))
196
+
197
+ picks_ = extract_picks(preds_batch, fname_batch)
198
+ picks.extend(picks_)
199
+ true_picks.extend(convert_true_picks(fname_batch, itp_batch, its_batch))
200
+ if args.plot_figure:
201
+ plot_waveform(data_reader.config, X_batch, preds_batch, label=Y_batch, fname=fname_batch,
202
+ itp=itp_batch, its=its_batch, figure_dir=figure_dir)
203
+
204
+ save_picks(picks, args.result_dir)
205
+ metrics = calc_performance(picks, true_picks, tol=3.0, dt=data_reader.config.dt)
206
+ flog.write("mean loss: {}\n".format(test_loss))
207
+ flog.close()
208
+
209
+ return 0
210
+
211
+ def main(args):
212
+
213
+ logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
214
+ coord = tf.train.Coordinator()
215
+
216
+ if (args.mode == "train") or (args.mode == "train_valid"):
217
+ with tf.compat.v1.name_scope('create_inputs'):
218
+ data_reader = DataReader_train(format=args.format,
219
+ data_dir=args.train_dir,
220
+ data_list=args.train_list)
221
+ if args.mode == "train_valid":
222
+ data_reader_valid = DataReader_train(format=args.format,
223
+ data_dir=args.valid_dir,
224
+ data_list=args.valid_list)
225
+ logging.info("Dataset size: train {}, valid {}".format(data_reader.num_data, data_reader_valid.num_data))
226
+ else:
227
+ data_reader_valid = None
228
+ logging.info("Dataset size: train {}".format(data_reader.num_data))
229
+ train_fn(args, data_reader, data_reader_valid)
230
+
231
+ elif args.mode == "test":
232
+ with tf.compat.v1.name_scope('create_inputs'):
233
+ data_reader = DataReader_test(format=args.format,
234
+ data_dir=args.test_dir,
235
+ data_list=args.test_list)
236
+ test_fn(args, data_reader)
237
+
238
+ else:
239
+ print("mode should be: train, train_valid, or test")
240
+
241
+ return
242
+
243
+
244
+ if __name__ == '__main__':
245
+ args = read_args()
246
+ main(args)
phasenet/util.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import matplotlib
3
+ matplotlib.use('agg')
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import os
7
+ from data_reader import DataConfig
8
+ from detect_peaks import detect_peaks
9
+ import logging
10
+
11
+ class EMA(object):
12
+ def __init__(self, alpha):
13
+ self.alpha = alpha
14
+ self.x = 0.
15
+ self.count = 0
16
+
17
+ @property
18
+ def value(self):
19
+ return self.x
20
+
21
+ def __call__(self, x):
22
+ if self.count == 0:
23
+ self.x = x
24
+ else:
25
+ self.x = self.alpha * self.x + (1 - self.alpha) * x
26
+ self.count += 1
27
+ return self.x
28
+
29
+ class LMA(object):
30
+ def __init__(self):
31
+ self.x = 0.
32
+ self.count = 0
33
+
34
+ @property
35
+ def value(self):
36
+ return self.x
37
+
38
+ def __call__(self, x):
39
+ if self.count == 0:
40
+ self.x = x
41
+ else:
42
+ self.x += (x - self.x)/(self.count+1)
43
+ self.count += 1
44
+ return self.x
45
+
46
+ def detect_peaks_thread(i, pred, fname=None, result_dir=None, args=None):
47
+ if args is None:
48
+ itp, prob_p = detect_peaks(pred[i,:,0,1], mph=0.5, mpd=0.5/DataConfig().dt, show=False)
49
+ its, prob_s = detect_peaks(pred[i,:,0,2], mph=0.5, mpd=0.5/DataConfig().dt, show=False)
50
+ else:
51
+ itp, prob_p = detect_peaks(pred[i,:,0,1], mph=args.tp_prob, mpd=0.5/DataConfig().dt, show=False)
52
+ its, prob_s = detect_peaks(pred[i,:,0,2], mph=args.ts_prob, mpd=0.5/DataConfig().dt, show=False)
53
+ if (fname is not None) and (result_dir is not None):
54
+ # np.savez(os.path.join(result_dir, fname[i].decode().split('/')[-1]), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s)
55
+ try:
56
+ np.savez(os.path.join(result_dir, fname[i].decode()), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s)
57
+ except FileNotFoundError:
58
+ #if not os.path.exists(os.path.dirname(os.path.join(result_dir, fname[i].decode()))):
59
+ os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i].decode())), exist_ok=True)
60
+ np.savez(os.path.join(result_dir, fname[i].decode()), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s)
61
+ return [(itp, prob_p), (its, prob_s)]
62
+
63
+ def plot_result_thread(i, pred, X, Y=None, itp=None, its=None,
64
+ itp_pred=None, its_pred=None, fname=None, figure_dir=None):
65
+ dt = DataConfig().dt
66
+ t = np.arange(0, pred.shape[1]) * dt
67
+ box = dict(boxstyle='round', facecolor='white', alpha=1)
68
+ text_loc = [0.05, 0.77]
69
+
70
+ plt.figure(i)
71
+ plt.clf()
72
+ # fig_size = plt.gcf().get_size_inches()
73
+ # plt.gcf().set_size_inches(fig_size*[1, 1.2])
74
+ plt.subplot(411)
75
+ plt.plot(t, X[i, :, 0, 0], 'k', label='E', linewidth=0.5)
76
+ plt.autoscale(enable=True, axis='x', tight=True)
77
+ tmp_min = np.min(X[i, :, 0, 0])
78
+ tmp_max = np.max(X[i, :, 0, 0])
79
+ if (itp is not None) and (its is not None):
80
+ for j in range(len(itp[i])):
81
+ if j == 0:
82
+ plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', label='P', linewidth=0.5)
83
+ else:
84
+ plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5)
85
+ for j in range(len(its[i])):
86
+ if j == 0:
87
+ plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', label='S', linewidth=0.5)
88
+ else:
89
+ plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5)
90
+ plt.ylabel('Amplitude')
91
+ plt.legend(loc='upper right', fontsize='small')
92
+ plt.gca().set_xticklabels([])
93
+ plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center',
94
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
95
+ plt.subplot(412)
96
+ plt.plot(t, X[i, :, 0, 1], 'k', label='N', linewidth=0.5)
97
+ plt.autoscale(enable=True, axis='x', tight=True)
98
+ tmp_min = np.min(X[i, :, 0, 1])
99
+ tmp_max = np.max(X[i, :, 0, 1])
100
+ if (itp is not None) and (its is not None):
101
+ for j in range(len(itp[i])):
102
+ plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5)
103
+ for j in range(len(its[i])):
104
+ plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5)
105
+ plt.ylabel('Amplitude')
106
+ plt.legend(loc='upper right', fontsize='small')
107
+ plt.gca().set_xticklabels([])
108
+ plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center',
109
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
110
+ plt.subplot(413)
111
+ plt.plot(t, X[i, :, 0, 2], 'k', label='Z', linewidth=0.5)
112
+ plt.autoscale(enable=True, axis='x', tight=True)
113
+ tmp_min = np.min(X[i, :, 0, 2])
114
+ tmp_max = np.max(X[i, :, 0, 2])
115
+ if (itp is not None) and (its is not None):
116
+ for j in range(len(itp[i])):
117
+ plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5)
118
+ for j in range(len(its[i])):
119
+ plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5)
120
+ plt.ylabel('Amplitude')
121
+ plt.legend(loc='upper right', fontsize='small')
122
+ plt.gca().set_xticklabels([])
123
+ plt.text(text_loc[0], text_loc[1], '(iii)', horizontalalignment='center',
124
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
125
+ plt.subplot(414)
126
+ if Y is not None:
127
+ plt.plot(t, Y[i, :, 0, 1], 'b', label='P', linewidth=0.5)
128
+ plt.plot(t, Y[i, :, 0, 2], 'r', label='S', linewidth=0.5)
129
+ plt.plot(t, pred[i, :, 0, 1], '--g', label='$\hat{P}$', linewidth=0.5)
130
+ plt.plot(t, pred[i, :, 0, 2], '-.m', label='$\hat{S}$', linewidth=0.5)
131
+ plt.autoscale(enable=True, axis='x', tight=True)
132
+ if (itp_pred is not None) and (its_pred is not None):
133
+ for j in range(len(itp_pred)):
134
+ plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--g', linewidth=0.5)
135
+ for j in range(len(its_pred)):
136
+ plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '-.m', linewidth=0.5)
137
+ plt.ylim([-0.05, 1.05])
138
+ plt.text(text_loc[0], text_loc[1], '(iv)', horizontalalignment='center',
139
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
140
+ plt.legend(loc='upper right', fontsize='small')
141
+ plt.xlabel('Time (s)')
142
+ plt.ylabel('Probability')
143
+
144
+ plt.tight_layout()
145
+ plt.gcf().align_labels()
146
+
147
+ try:
148
+ plt.savefig(os.path.join(figure_dir,
149
+ fname[i].decode().rstrip('.npz')+'.png'),
150
+ bbox_inches='tight')
151
+ except FileNotFoundError:
152
+ #if not os.path.exists(os.path.dirname(os.path.join(figure_dir, fname[i].decode()))):
153
+ os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i].decode())), exist_ok=True)
154
+ plt.savefig(os.path.join(figure_dir,
155
+ fname[i].decode().rstrip('.npz')+'.png'),
156
+ bbox_inches='tight')
157
+ #plt.savefig(os.path.join(figure_dir,
158
+ # fname[i].decode().split('/')[-1].rstrip('.npz')+'.png'),
159
+ # bbox_inches='tight')
160
+ # plt.savefig(os.path.join(figure_dir,
161
+ # fname[i].decode().split('/')[-1].rstrip('.npz')+'.pdf'),
162
+ # bbox_inches='tight')
163
+ plt.close(i)
164
+ return 0
165
+
166
+ def postprocessing_thread(i, pred, X, Y=None, itp=None, its=None, fname=None, result_dir=None, figure_dir=None, args=None):
167
+ (itp_pred, prob_p), (its_pred, prob_s) = detect_peaks_thread(i, pred, fname, result_dir, args)
168
+ if (fname is not None) and (figure_dir is not None):
169
+ plot_result_thread(i, pred, X, Y, itp, its, itp_pred, its_pred, fname, figure_dir)
170
+ return [(itp_pred, prob_p), (its_pred, prob_s)]
171
+
172
+
173
+ def clean_queue(picks):
174
+ clean = []
175
+ for i in range(len(picks)):
176
+ tmp = []
177
+ for j in picks[i]:
178
+ if j != 0:
179
+ tmp.append(j)
180
+ clean.append(tmp)
181
+ return clean
182
+
183
+ def clean_queue_thread(picks):
184
+ tmp = []
185
+ for j in picks:
186
+ if j != 0:
187
+ tmp.append(j)
188
+ return tmp
189
+
190
+
191
+ def metrics(TP, nP, nT):
192
+ '''
193
+ TP: true positive
194
+ nP: number of positive picks
195
+ nT: number of true picks
196
+ '''
197
+ precision = TP / nP
198
+ recall = TP / nT
199
+ F1 = 2* precision * recall / (precision + recall)
200
+ return [precision, recall, F1]
201
+
202
+ def correct_picks(picks, true_p, true_s, tol):
203
+ dt = DataConfig().dt
204
+ if len(true_p) != len(true_s):
205
+ print("The length of true P and S pickers are not the same")
206
+ num = len(true_p)
207
+ TP_p = 0; TP_s = 0; nP_p = 0; nP_s = 0; nT_p = 0; nT_s = 0
208
+ diff_p = []; diff_s = []
209
+ for i in range(num):
210
+ nT_p += len(true_p[i])
211
+ nT_s += len(true_s[i])
212
+ nP_p += len(picks[i][0][0])
213
+ nP_s += len(picks[i][1][0])
214
+
215
+ if len(true_p[i]) > 1 or len(true_s[i]) > 1:
216
+ print(i, picks[i], true_p[i], true_s[i])
217
+ tmp_p = np.array(picks[i][0][0]) - np.array(true_p[i])[:,np.newaxis]
218
+ tmp_s = np.array(picks[i][1][0]) - np.array(true_s[i])[:,np.newaxis]
219
+ TP_p += np.sum(np.abs(tmp_p) < tol/dt)
220
+ TP_s += np.sum(np.abs(tmp_s) < tol/dt)
221
+ diff_p.append(tmp_p[np.abs(tmp_p) < 0.5/dt])
222
+ diff_s.append(tmp_s[np.abs(tmp_s) < 0.5/dt])
223
+
224
+ return [TP_p, TP_s, nP_p, nP_s, nT_p, nT_s, diff_p, diff_s]
225
+
226
+ def calculate_metrics(picks, itp, its, tol=0.1):
227
+ TP_p, TP_s, nP_p, nP_s, nT_p, nT_s, diff_p, diff_s = correct_picks(picks, itp, its, tol)
228
+ precision_p, recall_p, f1_p = metrics(TP_p, nP_p, nT_p)
229
+ precision_s, recall_s, f1_s = metrics(TP_s, nP_s, nT_s)
230
+
231
+ logging.info("Total records: {}".format(len(picks)))
232
+ logging.info("P-phase:")
233
+ logging.info("True={}, Predict={}, TruePositive={}".format(nT_p, nP_p, TP_p))
234
+ logging.info("Precision={:.3f}, Recall={:.3f}, F1={:.3f}".format(precision_p, recall_p, f1_p))
235
+ logging.info("S-phase:")
236
+ logging.info("True={}, Predict={}, TruePositive={}".format(nT_s, nP_s, TP_s))
237
+ logging.info("Precision={:.3f}, Recall={:.3f}, F1={:.3f}".format(precision_s, recall_s, f1_s))
238
+ return [precision_p, recall_p, f1_p], [precision_s, recall_s, f1_s]
phasenet/visulization.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ matplotlib.use("agg")
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import os
6
+
7
+
8
+ def plot_residual(diff_p, diff_s, diff_ps, tol, dt):
9
+ box = dict(boxstyle='round', facecolor='white', alpha=1)
10
+ text_loc = [0.07, 0.95]
11
+ plt.figure(figsize=(8,3))
12
+ plt.subplot(1,3,1)
13
+ plt.hist(diff_p, range=(-tol, tol), bins=int(2*tol/dt)+1, facecolor='b', edgecolor='black', linewidth=1)
14
+ plt.ylabel("Number of picks")
15
+ plt.xlabel("Residual (s)")
16
+ plt.text(text_loc[0], text_loc[1], "(i)", horizontalalignment='left', verticalalignment='top',
17
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
18
+ plt.title("P-phase")
19
+ plt.subplot(1,3,2)
20
+ plt.hist(diff_s, range=(-tol, tol), bins=int(2*tol/dt)+1, facecolor='b', edgecolor='black', linewidth=1)
21
+ plt.xlabel("Residual (s)")
22
+ plt.text(text_loc[0], text_loc[1], "(ii)", horizontalalignment='left', verticalalignment='top',
23
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
24
+ plt.title("S-phase")
25
+ plt.subplot(1,3,3)
26
+ plt.hist(diff_ps, range=(-tol, tol), bins=int(2*tol/dt)+1, facecolor='b', edgecolor='black', linewidth=1)
27
+ plt.xlabel("Residual (s)")
28
+ plt.text(text_loc[0], text_loc[1], "(iii)", horizontalalignment='left', verticalalignment='top',
29
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
30
+ plt.title("PS-phase")
31
+ plt.tight_layout()
32
+ plt.savefig("residuals.png", dpi=300)
33
+ plt.savefig("residuals.pdf")
34
+
35
+
36
+ # def plot_waveform(config, data, pred, label=None,
37
+ # itp=None, its=None, itps=None,
38
+ # itp_pred=None, its_pred=None, itps_pred=None,
39
+ # fname=None, figure_dir="./", epoch=0, max_fig=10):
40
+
41
+ # dt = config.dt if hasattr(config, "dt") else 1.0
42
+ # t = np.arange(0, pred.shape[1]) * dt
43
+ # box = dict(boxstyle='round', facecolor='white', alpha=1)
44
+ # text_loc = [0.05, 0.77]
45
+ # if fname is None:
46
+ # fname = [f"{epoch:03d}_{i:02d}" for i in range(len(data))]
47
+ # else:
48
+ # fname = [fname[i].decode().rstrip(".npz") for i in range(len(fname))]
49
+
50
+ # for i in range(min(len(data), max_fig)):
51
+ # plt.figure(i)
52
+
53
+ # plt.subplot(411)
54
+ # plt.plot(t, data[i, :, 0, 0], 'k', label='E', linewidth=0.5)
55
+ # plt.autoscale(enable=True, axis='x', tight=True)
56
+ # tmp_min = np.min(data[i, :, 0, 0])
57
+ # tmp_max = np.max(data[i, :, 0, 0])
58
+ # if (itp is not None) and (its is not None):
59
+ # for j in range(len(itp[i])):
60
+ # lb = "P" if j==0 else ""
61
+ # plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
62
+ # for j in range(len(its[i])):
63
+ # lb = "S" if j==0 else ""
64
+ # plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
65
+ # if (itps is not None):
66
+ # for j in range(len(itps[i])):
67
+ # lb = "PS" if j==0 else ""
68
+ # plt.plot([itps[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
69
+ # plt.ylabel('Amplitude')
70
+ # plt.legend(loc='upper right', fontsize='small')
71
+ # plt.gca().set_xticklabels([])
72
+ # plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center',
73
+ # transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
74
+
75
+ # plt.subplot(412)
76
+ # plt.plot(t, data[i, :, 0, 1], 'k', label='N', linewidth=0.5)
77
+ # plt.autoscale(enable=True, axis='x', tight=True)
78
+ # tmp_min = np.min(data[i, :, 0, 1])
79
+ # tmp_max = np.max(data[i, :, 0, 1])
80
+ # if (itp is not None) and (its is not None):
81
+ # for j in range(len(itp[i])):
82
+ # lb = "P" if j==0 else ""
83
+ # plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
84
+ # for j in range(len(its[i])):
85
+ # lb = "S" if j==0 else ""
86
+ # plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
87
+ # if (itps is not None):
88
+ # for j in range(len(itps[i])):
89
+ # lb = "PS" if j==0 else ""
90
+ # plt.plot([itps[i][j]*dt, itps[i][j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
91
+ # plt.ylabel('Amplitude')
92
+ # plt.legend(loc='upper right', fontsize='small')
93
+ # plt.gca().set_xticklabels([])
94
+ # plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center',
95
+ # transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
96
+
97
+ # plt.subplot(413)
98
+ # plt.plot(t, data[i, :, 0, 2], 'k', label='Z', linewidth=0.5)
99
+ # plt.autoscale(enable=True, axis='x', tight=True)
100
+ # tmp_min = np.min(data[i, :, 0, 2])
101
+ # tmp_max = np.max(data[i, :, 0, 2])
102
+ # if (itp is not None) and (its is not None):
103
+ # for j in range(len(itp[i])):
104
+ # lb = "P" if j==0 else ""
105
+ # plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
106
+ # for j in range(len(its[i])):
107
+ # lb = "S" if j==0 else ""
108
+ # plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
109
+ # if (itps is not None):
110
+ # for j in range(len(itps[i])):
111
+ # lb = "PS" if j==0 else ""
112
+ # plt.plot([itps[i][j]*dt, itps[i][j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
113
+ # plt.ylabel('Amplitude')
114
+ # plt.legend(loc='upper right', fontsize='small')
115
+ # plt.gca().set_xticklabels([])
116
+ # plt.text(text_loc[0], text_loc[1], '(iii)', horizontalalignment='center',
117
+ # transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
118
+
119
+ # plt.subplot(414)
120
+ # if label is not None:
121
+ # plt.plot(t, label[i, :, 0, 1], 'C0', label='P', linewidth=1)
122
+ # plt.plot(t, label[i, :, 0, 2], 'C1', label='S', linewidth=1)
123
+ # if label.shape[-1] == 4:
124
+ # plt.plot(t, label[i, :, 0, 3], 'C2', label='PS', linewidth=1)
125
+ # plt.plot(t, pred[i, :, 0, 1], '--C0', label='$\hat{P}$', linewidth=1)
126
+ # plt.plot(t, pred[i, :, 0, 2], '--C1', label='$\hat{S}$', linewidth=1)
127
+ # if pred.shape[-1] == 4:
128
+ # plt.plot(t, pred[i, :, 0, 3], '--C2', label='$\hat{PS}$', linewidth=1)
129
+ # plt.autoscale(enable=True, axis='x', tight=True)
130
+ # if (itp_pred is not None) and (its_pred is not None) :
131
+ # for j in range(len(itp_pred)):
132
+ # plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--C0', linewidth=1)
133
+ # for j in range(len(its_pred)):
134
+ # plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '--C1', linewidth=1)
135
+ # if (itps_pred is not None):
136
+ # for j in range(len(itps_pred)):
137
+ # plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C2', linewidth=1)
138
+ # plt.ylim([-0.05, 1.05])
139
+ # plt.text(text_loc[0], text_loc[1], '(iv)', horizontalalignment='center',
140
+ # transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
141
+ # plt.legend(loc='upper right', fontsize='small', ncol=2)
142
+ # plt.xlabel('Time (s)')
143
+ # plt.ylabel('Probability')
144
+ # plt.tight_layout()
145
+ # plt.gcf().align_labels()
146
+
147
+ # try:
148
+ # plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
149
+ # except FileNotFoundError:
150
+ # os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
151
+ # plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
152
+
153
+ # plt.close(i)
154
+ # return 0
155
+
156
+
157
+ def plot_waveform(data, pred, fname, label=None,
158
+ itp=None, its=None, itps=None,
159
+ itp_pred=None, its_pred=None, itps_pred=None,
160
+ figure_dir="./", dt=0.01):
161
+
162
+ t = np.arange(0, pred.shape[0]) * dt
163
+ box = dict(boxstyle='round', facecolor='white', alpha=1)
164
+ text_loc = [0.05, 0.77]
165
+
166
+ plt.figure()
167
+
168
+ plt.subplot(411)
169
+ plt.plot(t, data[:, 0, 0], 'k', label='E', linewidth=0.5)
170
+ plt.autoscale(enable=True, axis='x', tight=True)
171
+ tmp_min = np.min(data[:, 0, 0])
172
+ tmp_max = np.max(data[:, 0, 0])
173
+ if (itp is not None) and (its is not None):
174
+ for j in range(len(itp)):
175
+ lb = "P" if j==0 else ""
176
+ plt.plot([itp[j]*dt, itp[j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
177
+ for j in range(len(its[i])):
178
+ lb = "S" if j==0 else ""
179
+ plt.plot([its[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
180
+ if (itps is not None):
181
+ for j in range(len(itps)):
182
+ lb = "PS" if j==0 else ""
183
+ plt.plot([itps[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
184
+ plt.ylabel('Amplitude')
185
+ plt.legend(loc='upper right', fontsize='small')
186
+ plt.gca().set_xticklabels([])
187
+ plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center',
188
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
189
+
190
+ plt.subplot(412)
191
+ plt.plot(t, data[:, 0, 1], 'k', label='N', linewidth=0.5)
192
+ plt.autoscale(enable=True, axis='x', tight=True)
193
+ tmp_min = np.min(data[:, 0, 1])
194
+ tmp_max = np.max(data[:, 0, 1])
195
+ if (itp is not None) and (its is not None):
196
+ for j in range(len(itp)):
197
+ lb = "P" if j==0 else ""
198
+ plt.plot([itp[j]*dt, itp[j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
199
+ for j in range(len(its)):
200
+ lb = "S" if j==0 else ""
201
+ plt.plot([its[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
202
+ if (itps is not None):
203
+ for j in range(len(itps)):
204
+ lb = "PS" if j==0 else ""
205
+ plt.plot([itps[j]*dt, itps[j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
206
+ plt.ylabel('Amplitude')
207
+ plt.legend(loc='upper right', fontsize='small')
208
+ plt.gca().set_xticklabels([])
209
+ plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center',
210
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
211
+
212
+ plt.subplot(413)
213
+ plt.plot(t, data[:, 0, 2], 'k', label='Z', linewidth=0.5)
214
+ plt.autoscale(enable=True, axis='x', tight=True)
215
+ tmp_min = np.min(data[:, 0, 2])
216
+ tmp_max = np.max(data[:, 0, 2])
217
+ if (itp is not None) and (its is not None):
218
+ for j in range(len(itp)):
219
+ lb = "P" if j==0 else ""
220
+ plt.plot([itp[j]*dt, itp[j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
221
+ for j in range(len(its)):
222
+ lb = "S" if j==0 else ""
223
+ plt.plot([its[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
224
+ if (itps is not None):
225
+ for j in range(len(itps)):
226
+ lb = "PS" if j==0 else ""
227
+ plt.plot([itps[j]*dt, itps[j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
228
+ plt.ylabel('Amplitude')
229
+ plt.legend(loc='upper right', fontsize='small')
230
+ plt.gca().set_xticklabels([])
231
+ plt.text(text_loc[0], text_loc[1], '(iii)', horizontalalignment='center',
232
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
233
+
234
+ plt.subplot(414)
235
+ if label is not None:
236
+ plt.plot(t, label[:, 0, 1], 'C0', label='P', linewidth=1)
237
+ plt.plot(t, label[:, 0, 2], 'C1', label='S', linewidth=1)
238
+ if label.shape[-1] == 4:
239
+ plt.plot(t, label[:, 0, 3], 'C2', label='PS', linewidth=1)
240
+ plt.plot(t, pred[:, 0, 1], '--C0', label='$\hat{P}$', linewidth=1)
241
+ plt.plot(t, pred[:, 0, 2], '--C1', label='$\hat{S}$', linewidth=1)
242
+ if pred.shape[-1] == 4:
243
+ plt.plot(t, pred[:, 0, 3], '--C2', label='$\hat{PS}$', linewidth=1)
244
+ plt.autoscale(enable=True, axis='x', tight=True)
245
+ if (itp_pred is not None) and (its_pred is not None) :
246
+ for j in range(len(itp_pred)):
247
+ plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--C0', linewidth=1)
248
+ for j in range(len(its_pred)):
249
+ plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '--C1', linewidth=1)
250
+ if (itps_pred is not None):
251
+ for j in range(len(itps_pred)):
252
+ plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C2', linewidth=1)
253
+ plt.ylim([-0.05, 1.05])
254
+ plt.text(text_loc[0], text_loc[1], '(iv)', horizontalalignment='center',
255
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
256
+ plt.legend(loc='upper right', fontsize='small', ncol=2)
257
+ plt.xlabel('Time (s)')
258
+ plt.ylabel('Probability')
259
+ plt.tight_layout()
260
+ plt.gcf().align_labels()
261
+
262
+ try:
263
+ plt.savefig(os.path.join(figure_dir, fname+'.png'), bbox_inches='tight')
264
+ except FileNotFoundError:
265
+ os.makedirs(os.path.dirname(os.path.join(figure_dir, fname)), exist_ok=True)
266
+ plt.savefig(os.path.join(figure_dir, fname+'.png'), bbox_inches='tight')
267
+
268
+ plt.close()
269
+ return 0
270
+
271
+
272
+ def plot_array(config, data, pred, label=None,
273
+ itp=None, its=None, itps=None,
274
+ itp_pred=None, its_pred=None, itps_pred=None,
275
+ fname=None, figure_dir="./", epoch=0):
276
+
277
+ dt = config.dt if hasattr(config, "dt") else 1.0
278
+ t = np.arange(0, pred.shape[1]) * dt
279
+ box = dict(boxstyle='round', facecolor='white', alpha=1)
280
+ text_loc = [0.05, 0.95]
281
+ if fname is None:
282
+ fname = [f"{epoch:03d}_{i:03d}" for i in range(len(data))]
283
+ else:
284
+ fname = [fname[i].decode().rstrip(".npz") for i in range(len(fname))]
285
+
286
+ for i in range(len(data)):
287
+ plt.figure(i, figsize=(10, 5))
288
+ plt.clf()
289
+
290
+ plt.subplot(121)
291
+ for j in range(data.shape[-2]):
292
+ plt.plot(t, data[i, :, j, 0]/10 + j, 'k', label='E', linewidth=0.5)
293
+ plt.autoscale(enable=True, axis='x', tight=True)
294
+ tmp_min = np.min(data[i, :, 0, 0])
295
+ tmp_max = np.max(data[i, :, 0, 0])
296
+ plt.xlabel('Time (s)')
297
+ plt.ylabel('Amplitude')
298
+ # plt.legend(loc='upper right', fontsize='small')
299
+ # plt.gca().set_xticklabels([])
300
+ plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center', verticalalignment="top",
301
+ transform=plt.gca().transAxes, fontsize="large", fontweight="normal", bbox=box)
302
+
303
+ plt.subplot(122)
304
+ for j in range(pred.shape[-2]):
305
+ if label is not None:
306
+ plt.plot(t, label[i, :, j, 1]+j, 'C2', label='P', linewidth=0.5)
307
+ plt.plot(t, label[i, :, j, 2]+j, 'C3', label='S', linewidth=0.5)
308
+ # plt.plot(t, label[i, :, j, 0]+j, 'C4', label='N', linewidth=0.5)
309
+ plt.plot(t, pred[i, :, j, 1]+j, 'C0', label='$\hat{P}$', linewidth=1)
310
+ plt.plot(t, pred[i, :, j, 2]+j, 'C1', label='$\hat{S}$', linewidth=1)
311
+ plt.autoscale(enable=True, axis='x', tight=True)
312
+ if (itp_pred is not None) and (its_pred is not None) and (itps_pred is not None):
313
+ for j in range(len(itp_pred)):
314
+ plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--C0', linewidth=1)
315
+ for j in range(len(its_pred)):
316
+ plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '--C1', linewidth=1)
317
+ for j in range(len(itps_pred)):
318
+ plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C2', linewidth=1)
319
+ # plt.ylim([-0.05, 1.05])
320
+ plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center', verticalalignment="top",
321
+ transform=plt.gca().transAxes, fontsize="large", fontweight="normal", bbox=box)
322
+ # plt.legend(loc='upper right', fontsize='small', ncol=2)
323
+ plt.xlabel('Time (s)')
324
+ plt.ylabel('Probability')
325
+ plt.tight_layout()
326
+ plt.gcf().align_labels()
327
+
328
+ try:
329
+ plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
330
+ except FileNotFoundError:
331
+ os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
332
+ plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
333
+
334
+ plt.close(i)
335
+ return 0
336
+
337
+
338
+ def plot_spectrogram(config, data, pred, label=None,
339
+ itp=None, its=None, itps=None,
340
+ itp_pred=None, its_pred=None, itps_pred=None,
341
+ time=None, freq=None,
342
+ fname=None, figure_dir="./", epoch=0):
343
+
344
+ # dt = config.dt
345
+ # df = config.df
346
+ # t = np.arange(0, data.shape[1]) * dt
347
+ # f = np.arange(0, data.shape[2]) * df
348
+ t, f = time, freq
349
+ dt = t[1] - t[0]
350
+ box = dict(boxstyle='round', facecolor='white', alpha=1)
351
+ text_loc = [0.05, 0.75]
352
+ if fname is None:
353
+ fname = [f"{i:03d}" for i in range(len(data))]
354
+ elif type(fname[0]) is bytes:
355
+ fname = [f.decode() for f in fname]
356
+
357
+ numbers = ["(i)", "(ii)", "(iii)", "(iv)"]
358
+ for i in range(len(data)):
359
+ fig = plt.figure(i)
360
+ # gs = fig.add_gridspec(4, 1)
361
+
362
+ for j in range(3):
363
+ # fig.add_subplot(gs[j, 0])
364
+ plt.subplot(4,1,j+1)
365
+ plt.pcolormesh(t, f, np.abs(data[i, :, :, j]+1j*data[i, :, :, j+3]).T, vmax=2*np.std(data[i, :, :, j]+1j*data[i, :, :, j+3]), cmap="jet", shading='auto')
366
+ plt.autoscale(enable=True, axis='x', tight=True)
367
+ plt.gca().set_xticklabels([])
368
+ if j == 1:
369
+ plt.ylabel('Frequency (Hz)')
370
+ plt.text(text_loc[0], text_loc[1], numbers[j], horizontalalignment='center',
371
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
372
+
373
+ # fig.add_subplot(gs[-1, 0])
374
+ plt.subplot(4,1,4)
375
+ if label is not None:
376
+ plt.plot(t, label[i, :, 0, 1], '--C0', linewidth=1)
377
+ plt.plot(t, label[i, :, 0, 2], '--C3', linewidth=1)
378
+ plt.plot(t, label[i, :, 0, 3], '--C1', linewidth=1)
379
+ plt.plot(t, pred[i, :, 0, 1], 'C0', label='P', linewidth=1)
380
+ plt.plot(t, pred[i, :, 0, 2], 'C3', label='S', linewidth=1)
381
+ plt.plot(t, pred[i, :, 0, 3], 'C1', label='PS', linewidth=1)
382
+ plt.plot(t, t*0, 'k', linewidth=1)
383
+ plt.autoscale(enable=True, axis='x', tight=True)
384
+ if (itp_pred is not None) and (its_pred is not None) and (itps_pred is not None):
385
+ for j in range(len(itp_pred)):
386
+ plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], ':C3', linewidth=1)
387
+ for j in range(len(its_pred)):
388
+ plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '-.C6', linewidth=1)
389
+ for j in range(len(itps_pred)):
390
+ plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C8', linewidth=1)
391
+ plt.ylim([-0.05, 1.05])
392
+ plt.text(text_loc[0], text_loc[1], numbers[-1], horizontalalignment='center',
393
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
394
+ plt.legend(loc='upper right', fontsize='small', ncol=1)
395
+ plt.xlabel('Time (s)')
396
+ plt.ylabel('Probability')
397
+ # plt.tight_layout()
398
+ plt.gcf().align_labels()
399
+
400
+ try:
401
+ plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
402
+ except FileNotFoundError:
403
+ os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
404
+ plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
405
+
406
+ plt.close(i)
407
+ return 0
408
+
409
+
410
+ def plot_spectrogram_waveform(config, spectrogram, waveform, pred, label=None,
411
+ itp=None, its=None, itps=None, picks=None,
412
+ time=None, freq=None,
413
+ fname=None, figure_dir="./", epoch=0):
414
+
415
+ # dt = config.dt
416
+ # df = config.df
417
+ # t = np.arange(0, spectrogram.shape[1]) * dt
418
+ # f = np.arange(0, spectrogram.shape[2]) * df
419
+ t, f = time, freq
420
+ dt = t[1] - t[0]
421
+ box = dict(boxstyle='round', facecolor='white', alpha=1)
422
+ text_loc = [0.02, 0.90]
423
+ if fname is None:
424
+ fname = [f"{i:03d}" for i in range(len(spectrogram))]
425
+ elif type(fname[0]) is bytes:
426
+ fname = [f.decode() for f in fname]
427
+
428
+ numbers = ["(i)", "(ii)", "(iii)", "(iv)", "(v)", "(vi)", "(vii)"]
429
+ for i in range(len(spectrogram)):
430
+ fig = plt.figure(i, figsize=(6.4, 10))
431
+ # gs = fig.add_gridspec(4, 1)
432
+
433
+ for j in range(3):
434
+ # fig.add_subplot(gs[j, 0])
435
+ plt.subplot(7,1,j*2+1)
436
+ plt.plot(waveform[i,:,j], 'k', linewidth=0.5)
437
+ plt.autoscale(enable=True, axis='x', tight=True)
438
+ plt.gca().set_xticklabels([])
439
+ plt.ylabel('')
440
+ plt.text(text_loc[0], text_loc[1], numbers[j*2], horizontalalignment='left', verticalalignment='top',
441
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
442
+
443
+ for j in range(3):
444
+ # fig.add_subplot(gs[j, 0])
445
+ plt.subplot(7,1,j*2+2)
446
+ plt.pcolormesh(t, f, np.abs(spectrogram[i, :, :, j]+1j*spectrogram[i, :, :, j+3]).T, vmax=2*np.std(spectrogram[i, :, :, j]+1j*spectrogram[i, :, :, j+3]), cmap="jet", shading='auto')
447
+ plt.autoscale(enable=True, axis='x', tight=True)
448
+ plt.gca().set_xticklabels([])
449
+ if j == 1:
450
+ plt.ylabel('Frequency (Hz) or Amplitude')
451
+ plt.text(text_loc[0], text_loc[1], numbers[j*2+1], horizontalalignment='left', verticalalignment='top',
452
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
453
+
454
+ # fig.add_subplot(gs[-1, 0])
455
+ plt.subplot(7,1,7)
456
+ if label is not None:
457
+ plt.plot(t, label[i, :, 0, 1], '--C0', linewidth=1)
458
+ plt.plot(t, label[i, :, 0, 2], '--C3', linewidth=1)
459
+ plt.plot(t, label[i, :, 0, 3], '--C1', linewidth=1)
460
+ plt.plot(t, pred[i, :, 0, 1], 'C0', label='P', linewidth=1)
461
+ plt.plot(t, pred[i, :, 0, 2], 'C3', label='S', linewidth=1)
462
+ plt.plot(t, pred[i, :, 0, 3], 'C1', label='PS', linewidth=1)
463
+ plt.plot(t, t*0, 'k', linewidth=1)
464
+ plt.autoscale(enable=True, axis='x', tight=True)
465
+ plt.ylim([-0.05, 1.05])
466
+ plt.text(text_loc[0], text_loc[1], numbers[-1], horizontalalignment='left', verticalalignment='top',
467
+ transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
468
+ plt.legend(loc='upper right', fontsize='small', ncol=1)
469
+ plt.xlabel('Time (s)')
470
+ plt.ylabel('Probability')
471
+ # plt.tight_layout()
472
+ plt.gcf().align_labels()
473
+
474
+ try:
475
+ plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
476
+ except FileNotFoundError:
477
+ os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
478
+ plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
479
+
480
+ plt.close(i)
481
+ return 0
pipeline.py CHANGED
@@ -2,6 +2,12 @@ from typing import Dict, List
2
  import numpy as np
3
  import tensorflow as tf
4
 
 
 
 
 
 
 
5
  class PreTrainedPipeline():
6
  def __init__(self, path=""):
7
  # IMPLEMENT_THIS
@@ -11,7 +17,23 @@ class PreTrainedPipeline():
11
  # raise NotImplementedError(
12
  # "Please implement PreTrainedPipeline __init__ function"
13
  # )
14
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def __call__(self, inputs: str) -> List[List[Dict[str, float]]]:
17
  """
@@ -27,4 +49,20 @@ class PreTrainedPipeline():
27
  # raise NotImplementedError(
28
  # "Please implement PreTrainedPipeline __call__ function"
29
  # )
30
- return [[{"label": inputs, "score":0.2}]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import tensorflow as tf
4
 
5
+ from phasenet.model import ModelConfig, UNet
6
+ from phasenet.postprocess import extract_picks
7
+
8
+ tf.compat.v1.disable_eager_execution()
9
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
10
+
11
  class PreTrainedPipeline():
12
  def __init__(self, path=""):
13
  # IMPLEMENT_THIS
 
17
  # raise NotImplementedError(
18
  # "Please implement PreTrainedPipeline __init__ function"
19
  # )
20
+
21
+ ## load model
22
+ model = UNet(mode="pred")
23
+ sess_config = tf.compat.v1.ConfigProto()
24
+ sess_config.gpu_options.allow_growth = True
25
+
26
+ sess = tf.compat.v1.Session(config=sess_config)
27
+ saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
28
+ init = tf.compat.v1.global_variables_initializer()
29
+ sess.run(init)
30
+ latest_check_point = tf.train.latest_checkpoint(f"model/190703-214543")
31
+ print(f"restoring model {latest_check_point}")
32
+ saver.restore(sess, latest_check_point)
33
+
34
+ ##
35
+ self.sess = sess
36
+ self.model = model
37
 
38
  def __call__(self, inputs: str) -> List[List[Dict[str, float]]]:
39
  """
 
49
  # raise NotImplementedError(
50
  # "Please implement PreTrainedPipeline __call__ function"
51
  # )
52
+
53
+ vec = np.array(inputs)[np.newaxis, :, np.newaxis, :]
54
+
55
+ feed = {self.model.X: vec, self.model.drop_rate: 0, self.model.is_training: False}
56
+ preds = self.sess.run(self.model.preds, feed_dict=feed)
57
+
58
+ picks = extract_picks(preds)#, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
59
+
60
+ # picks = [{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]} for pick in picks]
61
+
62
+ return picks
63
+
64
+
65
+ if __name__ == "__main__":
66
+ pipeline = PreTrainedPipeline()
67
+ inputs = np.random.rand(1000, 3).tolist()
68
+ picks = pipeline(inputs)
requirements.txt CHANGED
@@ -1 +1 @@
1
- tensorflow
 
1
+ tensorflow