EQNet / phasenet /postprocess.py
zhuwq0's picture
init
0eb79a8
import json
import logging
import os
from collections import namedtuple
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import numpy as np
from detect_peaks import detect_peaks
# def extract_picks(preds, fnames=None, station_ids=None, t0=None, config=None):
# if preds.shape[-1] == 4:
# record = namedtuple("phase", ["fname", "station_id", "t0", "p_idx", "p_prob", "s_idx", "s_prob", "ps_idx", "ps_prob"])
# else:
# record = namedtuple("phase", ["fname", "station_id", "t0", "p_idx", "p_prob", "s_idx", "s_prob"])
# picks = []
# for i, pred in enumerate(preds):
# if config is None:
# mph_p, mph_s, mpd = 0.3, 0.3, 50
# else:
# mph_p, mph_s, mpd = config.min_p_prob, config.min_s_prob, config.mpd
# if (fnames is None):
# fname = f"{i:04d}"
# else:
# if isinstance(fnames[i], str):
# fname = fnames[i]
# else:
# fname = fnames[i].decode()
# if (station_ids is None):
# station_id = f"{i:04d}"
# else:
# if isinstance(station_ids[i], str):
# station_id = station_ids[i]
# else:
# station_id = station_ids[i].decode()
# if (t0 is None):
# start_time = "1970-01-01T00:00:00.000"
# else:
# if isinstance(t0[i], str):
# start_time = t0[i]
# else:
# start_time = t0[i].decode()
# p_idx, p_prob, s_idx, s_prob = [], [], [], []
# for j in range(pred.shape[1]):
# p_idx_, p_prob_ = detect_peaks(pred[:,j,1], mph=mph_p, mpd=mpd, show=False)
# s_idx_, s_prob_ = detect_peaks(pred[:,j,2], mph=mph_s, mpd=mpd, show=False)
# p_idx.append(list(p_idx_))
# p_prob.append(list(p_prob_))
# s_idx.append(list(s_idx_))
# s_prob.append(list(s_prob_))
# if pred.shape[-1] == 4:
# ps_idx, ps_prob = detect_peaks(pred[:,0,3], mph=0.3, mpd=mpd, show=False)
# picks.append(record(fname, station_id, start_time, list(p_idx), list(p_prob), list(s_idx), list(s_prob), list(ps_idx), list(ps_prob)))
# else:
# picks.append(record(fname, station_id, start_time, list(p_idx), list(p_prob), list(s_idx), list(s_prob)))
# return picks
def extract_picks(
preds,
file_names=None,
begin_times=None,
station_ids=None,
dt=0.01,
phases=["P", "S"],
config=None,
waveforms=None,
use_amplitude=False,
):
"""Extract picks from prediction results.
Args:
preds ([type]): [Nb, Nt, Ns, Nc] "batch, time, station, channel"
file_names ([type], optional): [Nb]. Defaults to None.
station_ids ([type], optional): [Ns]. Defaults to None.
t0 ([type], optional): [Nb]. Defaults to None.
config ([type], optional): [description]. Defaults to None.
Returns:
picks [type]: {file_name, station_id, pick_time, pick_prob, pick_type}
"""
mph = {}
if config is None:
for x in phases:
mph[x] = 0.3
mpd = 50
pre_idx = int(1 / dt)
post_idx = int(4 / dt)
else:
mph["P"] = config.min_p_prob
mph["S"] = config.min_s_prob
mph["PS"] = 0.3
mpd = config.mpd
pre_idx = int(config.pre_sec / dt)
post_idx = int(config.post_sec / dt)
Nb, Nt, Ns, Nc = preds.shape
if file_names is None:
file_names = [f"{i:04d}" for i in range(Nb)]
elif not (isinstance(file_names, np.ndarray) or isinstance(file_names, list)):
if isinstance(file_names, bytes):
file_names = file_names.decode()
file_names = [file_names] * Nb
else:
file_names = [x.decode() if isinstance(x, bytes) else x for x in file_names]
if begin_times is None:
begin_times = ["1970-01-01T00:00:00.000+00:00"] * Nb
else:
begin_times = [x.decode() if isinstance(x, bytes) else x for x in begin_times]
picks = []
for i in range(Nb):
file_name = file_names[i]
begin_time = datetime.fromisoformat(begin_times[i])
for j in range(Ns):
if (station_ids is None) or (len(station_ids[i]) == 0):
station_id = f"{j:04d}"
else:
station_id = station_ids[i][j].decode() if isinstance(station_ids[i][j], bytes) else station_ids[i][j]
if (waveforms is not None) and use_amplitude:
amp = np.max(np.abs(waveforms[i, :, j, :]), axis=-1) ## amplitude over three channelspy
for k in range(Nc - 1): # 0-th channel noise
idxs, probs = detect_peaks(preds[i, :, j, k + 1], mph=mph[phases[k]], mpd=mpd, show=False)
for l, (phase_index, phase_prob) in enumerate(zip(idxs, probs)):
pick_time = begin_time + timedelta(seconds=phase_index * dt)
pick = {
"file_name": file_name,
"station_id": station_id,
"begin_time": begin_time.isoformat(timespec="milliseconds"),
"phase_index": int(phase_index),
"phase_time": pick_time.isoformat(timespec="milliseconds"),
"phase_score": round(phase_prob, 3),
"phase_type": phases[k],
"dt": dt,
}
## process waveform
if waveforms is not None:
tmp = np.zeros((pre_idx + post_idx, 3))
lo = phase_index - pre_idx
hi = phase_index + post_idx
insert_idx = 0
if lo < 0:
lo = 0
insert_idx = -lo
if hi > Nt:
hi = Nt
tmp[insert_idx : insert_idx + hi - lo, :] = waveforms[i, lo:hi, j, :]
if use_amplitude:
next_pick = idxs[l + 1] if l < len(idxs) - 1 else (phase_index + post_idx * 3)
pick["phase_amplitude"] = np.max(
amp[phase_index : min(phase_index + post_idx * 3, next_pick)]
).item() ## peak amplitude
picks.append(pick)
return picks
def extract_amplitude(data, picks, window_p=10, window_s=5, config=None):
record = namedtuple("amplitude", ["p_amp", "s_amp"])
dt = 0.01 if config is None else config.dt
window_p = int(window_p / dt)
window_s = int(window_s / dt)
amps = []
for i, (da, pi) in enumerate(zip(data, picks)):
p_amp, s_amp = [], []
for j in range(da.shape[1]):
amp = np.max(np.abs(da[:, j, :]), axis=-1)
# amp = np.median(np.abs(da[:,j,:]), axis=-1)
# amp = np.linalg.norm(da[:,j,:], axis=-1)
tmp = []
for k in range(len(pi.p_idx[j]) - 1):
tmp.append(np.max(amp[pi.p_idx[j][k] : min(pi.p_idx[j][k] + window_p, pi.p_idx[j][k + 1])]))
if len(pi.p_idx[j]) >= 1:
tmp.append(np.max(amp[pi.p_idx[j][-1] : pi.p_idx[j][-1] + window_p]))
p_amp.append(tmp)
tmp = []
for k in range(len(pi.s_idx[j]) - 1):
tmp.append(np.max(amp[pi.s_idx[j][k] : min(pi.s_idx[j][k] + window_s, pi.s_idx[j][k + 1])]))
if len(pi.s_idx[j]) >= 1:
tmp.append(np.max(amp[pi.s_idx[j][-1] : pi.s_idx[j][-1] + window_s]))
s_amp.append(tmp)
amps.append(record(p_amp, s_amp))
return amps
def save_picks(picks, output_dir, amps=None, fname=None):
if fname is None:
fname = "picks.csv"
int2s = lambda x: ",".join(["[" + ",".join(map(str, i)) + "]" for i in x])
flt2s = lambda x: ",".join(["[" + ",".join(map("{:0.3f}".format, i)) + "]" for i in x])
sci2s = lambda x: ",".join(["[" + ",".join(map("{:0.3e}".format, i)) + "]" for i in x])
if amps is None:
if hasattr(picks[0], "ps_idx"):
with open(os.path.join(output_dir, fname), "w") as fp:
fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\tps_idx\tps_prob\n")
for pick in picks:
fp.write(
f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\t{int2s(pick.ps_idx)}\t{flt2s(pick.ps_prob)}\n"
)
fp.close()
else:
with open(os.path.join(output_dir, fname), "w") as fp:
fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\n")
for pick in picks:
fp.write(
f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\n"
)
fp.close()
else:
with open(os.path.join(output_dir, fname), "w") as fp:
fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\tp_amp\ts_amp\n")
for pick, amp in zip(picks, amps):
fp.write(
f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\t{sci2s(amp.p_amp)}\t{sci2s(amp.s_amp)}\n"
)
fp.close()
return 0
def calc_timestamp(timestamp, sec):
timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec)
return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
def save_picks_json(picks, output_dir, dt=0.01, amps=None, fname=None):
if fname is None:
fname = "picks.json"
picks_ = []
if amps is None:
for pick in picks:
for idxs, probs in zip(pick.p_idx, pick.p_prob):
for idx, prob in zip(idxs, probs):
picks_.append(
{
"id": pick.station_id,
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
"prob": prob.astype(float),
"type": "p",
}
)
for idxs, probs in zip(pick.s_idx, pick.s_prob):
for idx, prob in zip(idxs, probs):
picks_.append(
{
"id": pick.station_id,
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
"prob": prob.astype(float),
"type": "s",
}
)
else:
for pick, amplitude in zip(picks, amps):
for idxs, probs, amps in zip(pick.p_idx, pick.p_prob, amplitude.p_amp):
for idx, prob, amp in zip(idxs, probs, amps):
picks_.append(
{
"id": pick.station_id,
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
"prob": prob.astype(float),
"amp": amp.astype(float),
"type": "p",
}
)
for idxs, probs, amps in zip(pick.s_idx, pick.s_prob, amplitude.s_amp):
for idx, prob, amp in zip(idxs, probs, amps):
picks_.append(
{
"id": pick.station_id,
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
"prob": prob.astype(float),
"amp": amp.astype(float),
"type": "s",
}
)
with open(os.path.join(output_dir, fname), "w") as fp:
json.dump(picks_, fp)
return 0
def convert_true_picks(fname, itp, its, itps=None):
true_picks = []
if itps is None:
record = namedtuple("phase", ["fname", "p_idx", "s_idx"])
for i in range(len(fname)):
true_picks.append(record(fname[i].decode(), itp[i], its[i]))
else:
record = namedtuple("phase", ["fname", "p_idx", "s_idx", "ps_idx"])
for i in range(len(fname)):
true_picks.append(record(fname[i].decode(), itp[i], its[i], itps[i]))
return true_picks
def calc_metrics(nTP, nP, nT):
"""
nTP: true positive
nP: number of positive picks
nT: number of true picks
"""
precision = nTP / nP
recall = nTP / nT
f1 = 2 * precision * recall / (precision + recall)
return [precision, recall, f1]
def calc_performance(picks, true_picks, tol=3.0, dt=1.0):
assert len(picks) == len(true_picks)
logging.info("Total records: {}".format(len(picks)))
count = lambda picks: sum([len(x) for x in picks])
metrics = {}
for phase in true_picks[0]._fields:
if phase == "fname":
continue
true_positive, positive, true = 0, 0, 0
residual = []
for i in range(len(true_picks)):
true += count(getattr(true_picks[i], phase))
positive += count(getattr(picks[i], phase))
# print(i, phase, getattr(picks[i], phase), getattr(true_picks[i], phase))
diff = dt * (
np.array(getattr(picks[i], phase))[:, np.newaxis, :]
- np.array(getattr(true_picks[i], phase))[:, :, np.newaxis]
)
residual.extend(list(diff[np.abs(diff) <= tol]))
true_positive += np.sum(np.abs(diff) <= tol)
metrics[phase] = calc_metrics(true_positive, positive, true)
logging.info(f"{phase}-phase:")
logging.info(f"True={true}, Positive={positive}, True Positive={true_positive}")
logging.info(f"Precision={metrics[phase][0]:.3f}, Recall={metrics[phase][1]:.3f}, F1={metrics[phase][2]:.3f}")
logging.info(f"Residual mean={np.mean(residual):.4f}, std={np.std(residual):.4f}")
return metrics
def save_prob_h5(probs, fnames, output_h5):
if fnames is None:
fnames = [f"{i:04d}" for i in range(len(probs))]
elif type(fnames[0]) is bytes:
fnames = [f.decode().rstrip(".npz") for f in fnames]
else:
fnames = [f.rstrip(".npz") for f in fnames]
for prob, fname in zip(probs, fnames):
output_h5.create_dataset(fname, data=prob, dtype="float32")
return 0
def save_prob(probs, fnames, prob_dir):
if fnames is None:
fnames = [f"{i:04d}" for i in range(len(probs))]
elif type(fnames[0]) is bytes:
fnames = [f.decode().rstrip(".npz") for f in fnames]
else:
fnames = [f.rstrip(".npz") for f in fnames]
for prob, fname in zip(probs, fnames):
np.savez(os.path.join(prob_dir, fname + ".npz"), prob=prob)
return 0