File size: 2,655 Bytes
68ba412
 
 
 
7b07ad9
 
 
 
 
 
68ba412
 
 
 
 
 
 
 
 
7b07ad9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68ba412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b07ad9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from typing import  Dict, List
import numpy as np
import tensorflow as tf

from phasenet.model import ModelConfig, UNet
from phasenet.postprocess import extract_picks

tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

class PreTrainedPipeline():
    def __init__(self, path=""):
        # IMPLEMENT_THIS
        # Preload all the elements you are going to need at inference.
        # For instance your model, processors, tokenizer that might be needed.
        # This function is only called once, so do all the heavy processing I/O here"""
        # raise NotImplementedError(
        #     "Please implement PreTrainedPipeline __init__ function"
        # )

        ## load model
        model = UNet(mode="pred")
        sess_config = tf.compat.v1.ConfigProto()
        sess_config.gpu_options.allow_growth = True

        sess = tf.compat.v1.Session(config=sess_config)
        saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
        init = tf.compat.v1.global_variables_initializer()
        sess.run(init)
        latest_check_point = tf.train.latest_checkpoint(f"model/190703-214543")
        print(f"restoring model {latest_check_point}")
        saver.restore(sess, latest_check_point)

        ## 
        self.sess = sess
        self.model = model

    def __call__(self, inputs: str) -> List[List[Dict[str, float]]]:
        """
        Args:
            inputs (:obj:`str`):
                a string containing some text
        Return:
            A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
                - "label": A string representing what the label/class is. There can be multiple labels.
                - "score": A score between 0 and 1 describing how confident the model is for this label/class.
        """
        # IMPLEMENT_THIS
        # raise NotImplementedError(
        #     "Please implement PreTrainedPipeline __call__ function"
        # )

        vec = np.array(inputs)[np.newaxis, :, np.newaxis, :]

        feed = {self.model.X: vec, self.model.drop_rate: 0, self.model.is_training: False}
        preds = self.sess.run(self.model.preds, feed_dict=feed)

        picks = extract_picks(preds)#, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)

        # picks = [{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]} for pick in picks]

        return picks


if __name__ == "__main__":
    pipeline = PreTrainedPipeline()
    inputs = np.random.rand(1000, 3).tolist()
    picks = pipeline(inputs)