PhaseNet / pipeline.py
zhuwq0's picture
upload phasenet
7b07ad9
raw
history blame
No virus
2.66 kB
from typing import Dict, List
import numpy as np
import tensorflow as tf
from phasenet.model import ModelConfig, UNet
from phasenet.postprocess import extract_picks
tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
class PreTrainedPipeline():
def __init__(self, path=""):
# IMPLEMENT_THIS
# Preload all the elements you are going to need at inference.
# For instance your model, processors, tokenizer that might be needed.
# This function is only called once, so do all the heavy processing I/O here"""
# raise NotImplementedError(
# "Please implement PreTrainedPipeline __init__ function"
# )
## load model
model = UNet(mode="pred")
sess_config = tf.compat.v1.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=sess_config)
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
latest_check_point = tf.train.latest_checkpoint(f"model/190703-214543")
print(f"restoring model {latest_check_point}")
saver.restore(sess, latest_check_point)
##
self.sess = sess
self.model = model
def __call__(self, inputs: str) -> List[List[Dict[str, float]]]:
"""
Args:
inputs (:obj:`str`):
a string containing some text
Return:
A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
- "label": A string representing what the label/class is. There can be multiple labels.
- "score": A score between 0 and 1 describing how confident the model is for this label/class.
"""
# IMPLEMENT_THIS
# raise NotImplementedError(
# "Please implement PreTrainedPipeline __call__ function"
# )
vec = np.array(inputs)[np.newaxis, :, np.newaxis, :]
feed = {self.model.X: vec, self.model.drop_rate: 0, self.model.is_training: False}
preds = self.sess.run(self.model.preds, feed_dict=feed)
picks = extract_picks(preds)#, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
# picks = [{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]} for pick in picks]
return picks
if __name__ == "__main__":
pipeline = PreTrainedPipeline()
inputs = np.random.rand(1000, 3).tolist()
picks = pipeline(inputs)