PhaseNet / pipeline.py
zhuwq0's picture
update
1b6e4e8
raw
history blame
No virus
3.49 kB
from typing import Dict, List
import numpy as np
import tensorflow as tf
import os
import json
from phasenet.model import ModelConfig, UNet
from phasenet.postprocess import extract_picks
tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
class PreTrainedPipeline():
def __init__(self, path=""):
# IMPLEMENT_THIS
# Preload all the elements you are going to need at inference.
# For instance your model, processors, tokenizer that might be needed.
# This function is only called once, so do all the heavy processing I/O here"""
## load model
tf.compat.v1.reset_default_graph()
model = UNet(mode="pred")
sess_config = tf.compat.v1.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=sess_config)
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
latest_check_point = tf.train.latest_checkpoint(os.path.join(path, "model/190703-214543"))
print(f"restoring model {latest_check_point}")
saver.restore(sess, latest_check_point)
self.sess = sess
self.model = model
def __call__(self, inputs: str) -> List[List[Dict[str, float]]]:
"""
Args:
inputs (:obj:`str`):
a string containing some text
Return:
A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
- "label": A string representing what the label/class is. There can be multiple labels.
- "score": A score between 0 and 1 describing how confident the model is for this label/class.
"""
# IMPLEMENT_THIS
# raise NotImplementedError(
# "Please implement PreTrainedPipeline __call__ function"
# )
vec = np.asarray(json.loads(inputs))
vec = self.reshape_input(vec) # (nb, nt, nsta, nch)
vec = self.normalize(vec)
feed = {self.model.X: vec, self.model.drop_rate: 0, self.model.is_training: False}
preds = self.sess.run(self.model.preds, feed_dict=feed)
picks = extract_picks(preds)#, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
picks = [{'phase_index': x['phase_index'], 'phase_score': x['phase_score'], 'phase_type': x['phase_type']} for x in picks]
# picks = [{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]} for pick in picks]
# return picks
# return [[picks, {"label": "debug", "score": 0.1}]]
return [picks]
def normalize(self, vec):
mu = np.mean(vec, axis=1, keepdims=True)
std = np.std(vec, axis=1, keepdims=True)
std[std == 0] = 1.0
vec = (vec - mu) / std
return vec
def reshape_input(self, vec):
if len(vec.shape) == 2:
vec = vec[np.newaxis, :, np.newaxis, :]
elif len(vec.shape) == 3:
vec = vec[np.newaxis, :, :, :]
else:
pass
return vec
if __name__ == "__main__":
import obspy
waveform = obspy.read()
array = np.array([x.data for x in waveform]).T
pipeline = PreTrainedPipeline()
inputs = array.tolist()
inputs = json.dumps(inputs)
picks = pipeline(inputs)
print(picks)