|
import os |
|
from collections import defaultdict |
|
from datetime import datetime, timedelta |
|
from typing import Any, AnyStr, Dict, List, NamedTuple, Optional, Union |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
from fastapi import FastAPI, WebSocket |
|
from postprocess import extract_picks |
|
from pydantic import BaseModel |
|
from scipy.interpolate import interp1d |
|
|
|
from model import UNet |
|
|
|
PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")) |
|
|
|
tf.compat.v1.disable_eager_execution() |
|
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) |
|
JSONObject = Dict[AnyStr, Any] |
|
JSONArray = List[Any] |
|
JSONStructure = Union[JSONArray, JSONObject] |
|
|
|
app = FastAPI() |
|
X_SHAPE = [3000, 1, 3] |
|
SAMPLING_RATE = 100 |
|
|
|
|
|
model = UNet(mode="pred") |
|
sess_config = tf.compat.v1.ConfigProto() |
|
sess_config.gpu_options.allow_growth = True |
|
|
|
sess = tf.compat.v1.Session(config=sess_config) |
|
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) |
|
init = tf.compat.v1.global_variables_initializer() |
|
sess.run(init) |
|
latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190703-214543") |
|
print(f"restoring model {latest_check_point}") |
|
saver.restore(sess, latest_check_point) |
|
|
|
|
|
def normalize_batch(data, window=3000): |
|
""" |
|
data: nsta, nt, nch |
|
""" |
|
shift = window // 2 |
|
nsta, nt, nch = data.shape |
|
|
|
|
|
data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0)), mode="reflect") |
|
t = np.arange(0, nt, shift, dtype="int") |
|
std = np.zeros([nsta, len(t) + 1, nch]) |
|
mean = np.zeros([nsta, len(t) + 1, nch]) |
|
for i in range(1, len(t)): |
|
std[:, i, :] = np.std(data_pad[:, i * shift : i * shift + window, :], axis=1) |
|
mean[:, i, :] = np.mean(data_pad[:, i * shift : i * shift + window, :], axis=1) |
|
|
|
t = np.append(t, nt) |
|
|
|
|
|
std[:, -1, :], mean[:, -1, :] = std[:, -2, :], mean[:, -2, :] |
|
std[:, 0, :], mean[:, 0, :] = std[:, 1, :], mean[:, 1, :] |
|
std[std == 0] = 1 |
|
|
|
|
|
t_interp = np.arange(nt, dtype="int") |
|
std_interp = interp1d(t, std, axis=1, kind="slinear")(t_interp) |
|
mean_interp = interp1d(t, mean, axis=1, kind="slinear")(t_interp) |
|
data = (data - mean_interp) / std_interp |
|
|
|
return data |
|
|
|
|
|
def preprocess(data): |
|
raw = data.copy() |
|
data = normalize_batch(data) |
|
if len(data.shape) == 3: |
|
data = data[:, :, np.newaxis, :] |
|
raw = raw[:, :, np.newaxis, :] |
|
return data, raw |
|
|
|
|
|
def calc_timestamp(timestamp, sec): |
|
timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec) |
|
return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] |
|
|
|
|
|
def format_picks(picks, dt, amplitudes): |
|
picks_ = [] |
|
for pick, amplitude in zip(picks, amplitudes): |
|
for idxs, probs, amps in zip(pick.p_idx, pick.p_prob, amplitude.p_amp): |
|
for idx, prob, amp in zip(idxs, probs, amps): |
|
picks_.append( |
|
{ |
|
"id": pick.fname, |
|
"timestamp": calc_timestamp(pick.t0, float(idx) * dt), |
|
"prob": prob, |
|
"amp": amp, |
|
"type": "p", |
|
} |
|
) |
|
for idxs, probs, amps in zip(pick.s_idx, pick.s_prob, amplitude.s_amp): |
|
for idx, prob, amp in zip(idxs, probs, amps): |
|
picks_.append( |
|
{ |
|
"id": pick.fname, |
|
"timestamp": calc_timestamp(pick.t0, float(idx) * dt), |
|
"prob": prob, |
|
"amp": amp, |
|
"type": "s", |
|
} |
|
) |
|
return picks_ |
|
|
|
|
|
def format_data(data): |
|
|
|
|
|
|
|
chn2idx = {"E": 0, "N": 1, "Z": 2, "3": 0, "2": 1, "1": 2} |
|
Data = NamedTuple("data", [("id", list), ("timestamp", list), ("vec", list), ("dt", float)]) |
|
|
|
|
|
chn_ = defaultdict(list) |
|
t0_ = defaultdict(list) |
|
vv_ = defaultdict(list) |
|
for i in range(len(data.id)): |
|
key = data.id[i][:-1] |
|
chn_[key].append(data.id[i][-1]) |
|
t0_[key].append(datetime.strptime(data.timestamp[i], "%Y-%m-%dT%H:%M:%S.%f").timestamp() * SAMPLING_RATE) |
|
vv_[key].append(np.array(data.vec[i])) |
|
|
|
|
|
id_ = [] |
|
timestamp_ = [] |
|
vec_ = [] |
|
for k in chn_: |
|
id_.append(k) |
|
min_t0 = min(t0_[k]) |
|
timestamp_.append(datetime.fromtimestamp(min_t0 / SAMPLING_RATE).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]) |
|
vec = np.zeros([X_SHAPE[0], X_SHAPE[-1]]) |
|
for i in range(len(chn_[k])): |
|
|
|
shift = int(t0_[k][i] - min_t0) |
|
vec[shift : len(vv_[k][i]) + shift, chn2idx[chn_[k][i]]] = vv_[k][i][: X_SHAPE[0] - shift] - np.mean( |
|
vv_[k][i][: X_SHAPE[0] - shift] |
|
) |
|
vec_.append(vec.tolist()) |
|
|
|
return Data(id=id_, timestamp=timestamp_, vec=vec_, dt=1 / SAMPLING_RATE) |
|
|
|
|
|
|
|
def get_prediction(data, return_preds=False): |
|
vec = np.array(data.vec) |
|
vec, vec_raw = preprocess(vec) |
|
|
|
feed = {model.X: vec, model.drop_rate: 0, model.is_training: False} |
|
preds = sess.run(model.preds, feed_dict=feed) |
|
|
|
picks = extract_picks(preds, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw) |
|
|
|
picks = [ |
|
{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]} |
|
for pick in picks |
|
] |
|
|
|
if return_preds: |
|
return picks, preds |
|
|
|
return picks |
|
|
|
|
|
class Data(BaseModel): |
|
id: List[List[str]] |
|
timestamp: List[Union[str, float, datetime]] |
|
vec: Union[List[List[List[float]]], List[List[float]]] |
|
|
|
dt: Optional[float] = 0.01 |
|
|
|
stations: Optional[List[Dict[str, Union[float, str]]]] = None |
|
config: Optional[Dict[str, Union[List[float], List[int], List[str], float, int, str]]] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/predict") |
|
def predict(data: Data): |
|
picks = get_prediction(data) |
|
|
|
return picks |
|
|
|
|
|
@app.post("/predict_prob") |
|
def predict(data: Data): |
|
picks, preds = get_prediction(data, True) |
|
|
|
return picks, preds.tolist() |
|
|
|
|
|
@app.websocket("/ws") |
|
async def websocket_endpoint(websocket: WebSocket): |
|
await websocket.accept() |
|
while True: |
|
data = await websocket.receive_json() |
|
|
|
data = Data(**data) |
|
picks = get_prediction(data) |
|
await websocket.send_json(picks) |
|
print("PhaseNet Updating...") |
|
|
|
|
|
@app.get("/healthz") |
|
def healthz(): |
|
return {"status": "ok"} |
|
|
|
|
|
@app.get("/") |
|
def greet_json(): |
|
return {"Hello": "PhaseNet!"} |
|
|