|
from typing import Dict, List |
|
import numpy as np |
|
import tensorflow as tf |
|
import os |
|
import json |
|
from phasenet.model import ModelConfig, UNet |
|
from phasenet.postprocess import extract_picks |
|
|
|
tf.compat.v1.disable_eager_execution() |
|
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) |
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path=""): |
|
|
|
|
|
|
|
|
|
|
|
|
|
tf.compat.v1.reset_default_graph() |
|
model = UNet(mode="pred") |
|
sess_config = tf.compat.v1.ConfigProto() |
|
sess_config.gpu_options.allow_growth = True |
|
sess = tf.compat.v1.Session(config=sess_config) |
|
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) |
|
init = tf.compat.v1.global_variables_initializer() |
|
sess.run(init) |
|
latest_check_point = tf.train.latest_checkpoint(os.path.join(path, "model/190703-214543")) |
|
print(f"restoring model {latest_check_point}") |
|
saver.restore(sess, latest_check_point) |
|
|
|
self.sess = sess |
|
self.model = model |
|
|
|
def __call__(self, inputs: str) -> List[List[Dict[str, float]]]: |
|
""" |
|
Args: |
|
inputs (:obj:`str`): |
|
a string containing some text |
|
Return: |
|
A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing : |
|
- "label": A string representing what the label/class is. There can be multiple labels. |
|
- "score": A score between 0 and 1 describing how confident the model is for this label/class. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
vec = np.asarray(json.loads(inputs)) |
|
vec = self.reshape_input(vec) |
|
vec = self.normalize(vec) |
|
|
|
feed = {self.model.X: vec, self.model.drop_rate: 0, self.model.is_training: False} |
|
preds = self.sess.run(self.model.preds, feed_dict=feed) |
|
|
|
picks = extract_picks(preds) |
|
picks = [{'phase_index': x['phase_index'], 'phase_score': x['phase_score'], 'phase_type': x['phase_type']} for x in picks] |
|
|
|
|
|
|
|
|
|
return [picks] |
|
|
|
def normalize(self, vec): |
|
mu = np.mean(vec, axis=1, keepdims=True) |
|
std = np.std(vec, axis=1, keepdims=True) |
|
std[std == 0] = 1.0 |
|
vec = (vec - mu) / std |
|
return vec |
|
|
|
def reshape_input(self, vec): |
|
if len(vec.shape) == 2: |
|
vec = vec[np.newaxis, :, np.newaxis, :] |
|
elif len(vec.shape) == 3: |
|
vec = vec[np.newaxis, :, :, :] |
|
else: |
|
pass |
|
return vec |
|
|
|
|
|
if __name__ == "__main__": |
|
import obspy |
|
waveform = obspy.read() |
|
array = np.array([x.data for x in waveform]).T |
|
|
|
pipeline = PreTrainedPipeline() |
|
inputs = array.tolist() |
|
inputs = json.dumps(inputs) |
|
picks = pipeline(inputs) |
|
print(picks) |
|
|