EQNet / phasenet /data_reader.py
zhuwq0's picture
init
0eb79a8
raw
history blame
40.5 kB
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
import logging
import os
import numpy as np
import pandas as pd
pd.options.mode.chained_assignment = None
import json
import random
from collections import defaultdict
# import s3fs
import h5py
import obspy
from scipy.interpolate import interp1d
from tqdm import tqdm
def py_func_decorator(output_types=None, output_shapes=None, name=None):
def decorator(func):
def call(*args, **kwargs):
nonlocal output_shapes
# flat_output_types = nest.flatten(output_types)
flat_output_types = tf.nest.flatten(output_types)
# flat_values = tf.py_func(
flat_values = tf.numpy_function(func, inp=args, Tout=flat_output_types, name=name)
if output_shapes is not None:
for v, s in zip(flat_values, output_shapes):
v.set_shape(s)
# return nest.pack_sequence_as(output_types, flat_values)
return tf.nest.pack_sequence_as(output_types, flat_values)
return call
return decorator
def dataset_map(iterator, output_types, output_shapes=None, num_parallel_calls=None, name=None, shuffle=False):
dataset = tf.data.Dataset.range(len(iterator))
if shuffle:
dataset = dataset.shuffle(len(iterator), reshuffle_each_iteration=True)
@py_func_decorator(output_types, output_shapes, name=name)
def index_to_entry(idx):
return iterator[idx]
return dataset.map(index_to_entry, num_parallel_calls=num_parallel_calls)
def normalize(data, axis=(0,)):
"""data shape: (nt, nsta, nch)"""
data -= np.mean(data, axis=axis, keepdims=True)
std_data = np.std(data, axis=axis, keepdims=True)
std_data[std_data == 0] = 1
data /= std_data
# data /= (std_data + 1e-12)
return data
def normalize_long(data, axis=(0,), window=3000):
"""
data: nt, nch
"""
nt, nar, nch = data.shape
if window is None:
window = nt
shift = window // 2
dtype = data.dtype
## std in slide windows
data_pad = np.pad(data, ((window // 2, window // 2), (0, 0), (0, 0)), mode="reflect")
t = np.arange(0, nt, shift, dtype="int")
std = np.zeros([len(t) + 1, nar, nch])
mean = np.zeros([len(t) + 1, nar, nch])
for i in range(1, len(std)):
std[i, :] = np.std(data_pad[i * shift : i * shift + window, :, :], axis=axis)
mean[i, :] = np.mean(data_pad[i * shift : i * shift + window, :, :], axis=axis)
t = np.append(t, nt)
# std[-1, :] = np.std(data_pad[-window:, :], axis=0)
# mean[-1, :] = np.mean(data_pad[-window:, :], axis=0)
std[-1, ...], mean[-1, ...] = std[-2, ...], mean[-2, ...]
std[0, ...], mean[0, ...] = std[1, ...], mean[1, ...]
# std[std == 0] = 1.0
## normalize data with interplated std
t_interp = np.arange(nt, dtype="int")
std_interp = interp1d(t, std, axis=0, kind="slinear")(t_interp)
# std_interp = np.exp(interp1d(t, np.log(std), axis=0, kind="slinear")(t_interp))
mean_interp = interp1d(t, mean, axis=0, kind="slinear")(t_interp)
tmp = np.sum(std_interp, axis=(0, 1))
std_interp[std_interp == 0] = 1.0
data = (data - mean_interp) / std_interp
# data = (data - mean_interp)/(std_interp + 1e-12)
### dropout effect of < 3 channel
nonzero = np.count_nonzero(tmp)
if (nonzero < 3) and (nonzero > 0):
data *= 3.0 / nonzero
return data.astype(dtype)
def normalize_batch(data, window=3000):
"""
data: nsta, nt, nch
"""
nsta, nt, nar, nch = data.shape
if window is None:
window = nt
shift = window // 2
## std in slide windows
data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0), (0, 0)), mode="reflect")
t = np.arange(0, nt, shift, dtype="int")
std = np.zeros([nsta, len(t) + 1, nar, nch])
mean = np.zeros([nsta, len(t) + 1, nar, nch])
for i in range(1, len(t)):
std[:, i, :, :] = np.std(data_pad[:, i * shift : i * shift + window, :, :], axis=1)
mean[:, i, :, :] = np.mean(data_pad[:, i * shift : i * shift + window, :, :], axis=1)
t = np.append(t, nt)
# std[:, -1, :] = np.std(data_pad[:, -window:, :], axis=1)
# mean[:, -1, :] = np.mean(data_pad[:, -window:, :], axis=1)
std[:, -1, :, :], mean[:, -1, :, :] = std[:, -2, :, :], mean[:, -2, :, :]
std[:, 0, :, :], mean[:, 0, :, :] = std[:, 1, :, :], mean[:, 1, :, :]
# std[std == 0] = 1
# ## normalize data with interplated std
t_interp = np.arange(nt, dtype="int")
std_interp = interp1d(t, std, axis=1, kind="slinear")(t_interp)
# std_interp = np.exp(interp1d(t, np.log(std), axis=1, kind="slinear")(t_interp))
mean_interp = interp1d(t, mean, axis=1, kind="slinear")(t_interp)
tmp = np.sum(std_interp, axis=(1, 2))
std_interp[std_interp == 0] = 1.0
data = (data - mean_interp) / std_interp
# data = (data - mean_interp)/(std_interp + 1e-12)
### dropout effect of < 3 channel
nonzero = np.count_nonzero(tmp, axis=-1)
data[nonzero > 0, ...] *= 3.0 / nonzero[nonzero > 0][:, np.newaxis, np.newaxis, np.newaxis]
return data
class DataConfig:
seed = 123
use_seed = True
n_channel = 3
n_class = 3
sampling_rate = 100
dt = 1.0 / sampling_rate
X_shape = [3000, 1, n_channel]
Y_shape = [3000, 1, n_class]
min_event_gap = 3 * sampling_rate
label_shape = "gaussian"
label_width = 30
dtype = "float32"
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
class DataReader:
def __init__(
self, format="numpy", config=DataConfig(), response_xml=None, sampling_rate=100, highpass_filter=0, **kwargs
):
self.buffer = {}
self.n_channel = config.n_channel
self.n_class = config.n_class
self.X_shape = config.X_shape
self.Y_shape = config.Y_shape
self.dt = config.dt
self.dtype = config.dtype
self.label_shape = config.label_shape
self.label_width = config.label_width
self.config = config
self.format = format
# if "highpass_filter" in kwargs:
# self.highpass_filter = kwargs["highpass_filter"]
self.highpass_filter = highpass_filter
# self.response_xml = response_xml
if response_xml is not None:
self.response = obspy.read_inventory(response_xml)
else:
self.response = None
self.sampling_rate = sampling_rate
if format in ["numpy", "mseed", "sac"]:
self.data_dir = kwargs["data_dir"]
try:
csv = pd.read_csv(kwargs["data_list"], header=0, sep="[,|\s+]", engine="python")
except:
csv = pd.read_csv(kwargs["data_list"], header=0, sep="\t")
self.data_list = csv["fname"]
self.num_data = len(self.data_list)
elif format == "hdf5":
self.h5 = h5py.File(kwargs["hdf5_file"], "r", libver="latest", swmr=True)
self.h5_data = self.h5[kwargs["hdf5_group"]]
self.data_list = list(self.h5_data.keys())
self.num_data = len(self.data_list)
elif format == "s3":
self.s3fs = s3fs.S3FileSystem(
anon=kwargs["anon"],
key=kwargs["key"],
secret=kwargs["secret"],
client_kwargs={"endpoint_url": kwargs["s3_url"]},
use_ssl=kwargs["use_ssl"],
)
self.num_data = 0
else:
raise (f"{format} not support!")
def __len__(self):
return self.num_data
def read_numpy(self, fname):
# try:
if fname not in self.buffer:
npz = np.load(fname)
meta = {}
if len(npz["data"].shape) == 2:
meta["data"] = npz["data"][:, np.newaxis, :]
else:
meta["data"] = npz["data"]
if "p_idx" in npz.files:
if len(npz["p_idx"].shape) == 0:
meta["itp"] = [[npz["p_idx"]]]
else:
meta["itp"] = npz["p_idx"]
if "s_idx" in npz.files:
if len(npz["s_idx"].shape) == 0:
meta["its"] = [[npz["s_idx"]]]
else:
meta["its"] = npz["s_idx"]
if "itp" in npz.files:
if len(npz["itp"].shape) == 0:
meta["itp"] = [[npz["itp"]]]
else:
meta["itp"] = npz["itp"]
if "its" in npz.files:
if len(npz["its"].shape) == 0:
meta["its"] = [[npz["its"]]]
else:
meta["its"] = npz["its"]
if "station_id" in npz.files:
meta["station_id"] = npz["station_id"]
if "sta_id" in npz.files:
meta["station_id"] = npz["sta_id"]
if "t0" in npz.files:
meta["t0"] = npz["t0"]
self.buffer[fname] = meta
else:
meta = self.buffer[fname]
return meta
# except:
# logging.error("Failed reading {}".format(fname))
# return None
def read_hdf5(self, fname):
data = self.h5_data[fname][()]
attrs = self.h5_data[fname].attrs
meta = {}
if len(data.shape) == 2:
meta["data"] = data[:, np.newaxis, :]
else:
meta["data"] = data
if "p_idx" in attrs:
if len(attrs["p_idx"].shape) == 0:
meta["itp"] = [[attrs["p_idx"]]]
else:
meta["itp"] = attrs["p_idx"]
if "s_idx" in attrs:
if len(attrs["s_idx"].shape) == 0:
meta["its"] = [[attrs["s_idx"]]]
else:
meta["its"] = attrs["s_idx"]
if "itp" in attrs:
if len(attrs["itp"].shape) == 0:
meta["itp"] = [[attrs["itp"]]]
else:
meta["itp"] = attrs["itp"]
if "its" in attrs:
if len(attrs["its"].shape) == 0:
meta["its"] = [[attrs["its"]]]
else:
meta["its"] = attrs["its"]
if "t0" in attrs:
meta["t0"] = attrs["t0"]
return meta
def read_s3(self, format, fname, bucket, key, secret, s3_url, use_ssl):
with self.s3fs.open(bucket + "/" + fname, "rb") as fp:
if format == "numpy":
meta = self.read_numpy(fp)
elif format == "mseed":
meta = self.read_mseed(fp)
else:
raise (f"Format {format} not supported")
return meta
def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=100, return_single_station=True):
try:
stream = obspy.read(fname)
stream = stream.merge(fill_value="latest")
if response is not None:
# response = obspy.read_inventory(response_xml)
stream = stream.remove_sensitivity(response)
except Exception as e:
print(f"Error reading {fname}:\n{e}")
return {}
tmp_stream = obspy.Stream()
for trace in stream:
if len(trace.data) < 10:
continue
## interpolate to 100 Hz
if abs(trace.stats.sampling_rate - sampling_rate) > 0.1:
logging.warning(f"Resampling {trace.id} from {trace.stats.sampling_rate} to {sampling_rate} Hz")
try:
trace = trace.interpolate(sampling_rate, method="linear")
except Exception as e:
print(f"Error resampling {trace.id}:\n{e}")
trace = trace.detrend("demean")
## highpass filtering > 1Hz
if highpass_filter > 0.0:
trace = trace.filter("highpass", freq=highpass_filter)
tmp_stream.append(trace)
if len(tmp_stream) == 0:
return {}
stream = tmp_stream
begin_time = min([st.stats.starttime for st in stream])
end_time = max([st.stats.endtime for st in stream])
stream = stream.trim(begin_time, end_time, pad=True, fill_value=0)
comp = ["3", "2", "1", "E", "N", "U", "V", "Z"]
order = {key: i for i, key in enumerate(comp)}
comp2idx = {
"3": 0,
"2": 1,
"1": 2,
"E": 0,
"N": 1,
"Z": 2,
"U": 0,
"V": 1,
} ## only for cases less than 3 components
station_ids = defaultdict(list)
for tr in stream:
station_ids[tr.id[:-1]].append(tr.id[-1])
if tr.id[-1] not in comp:
print(f"Unknown component {tr.id[-1]}")
station_keys = sorted(list(station_ids.keys()))
nx = len(station_ids)
nt = len(stream[0].data)
data = np.zeros([3, nt, nx], dtype=np.float32)
for i, sta in enumerate(station_keys):
for j, c in enumerate(sorted(station_ids[sta], key=lambda x: order[x])):
if len(station_ids[sta]) != 3: ## less than 3 component
j = comp2idx[c]
if len(stream.select(id=sta + c)) == 0:
print(f"Empty trace: {sta+c} {begin_time}")
continue
trace = stream.select(id=sta + c)[0]
## accerleration to velocity
if sta[-1] == "N":
trace = trace.integrate().filter("highpass", freq=1.0)
tmp = trace.data.astype("float32")
data[j, : len(tmp), i] = tmp[:nt]
# if return_single_station and (len(station_keys) > 1):
# print(f"Warning: {fname} has multiple stations, returning only the first one {station_keys[0]}")
# data = data[:, :, 0:1]
# station_keys = station_keys[0:1]
meta = {
"data": data.transpose([1, 2, 0]),
"t0": begin_time.datetime.isoformat(timespec="milliseconds"),
"station_id": station_keys,
}
return meta
def read_sac(self, fname):
mseed = obspy.read(fname)
mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
mseed = mseed.merge(fill_value=0)
if self.highpass_filter > 0:
mseed = mseed.filter("highpass", freq=self.highpass_filter)
starttime = min([st.stats.starttime for st in mseed])
endtime = max([st.stats.endtime for st in mseed])
mseed = mseed.trim(starttime, endtime, pad=True, fill_value=0)
if abs(mseed[0].stats.sampling_rate - self.config.sampling_rate) > 1:
logging.warning(
f"Sampling rate mismatch in {fname.split('/')[-1]}: {mseed[0].stats.sampling_rate}Hz != {self.config.sampling_rate}Hz "
)
order = ["3", "2", "1", "E", "N", "Z"]
order = {key: i for i, key in enumerate(order)}
comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
t0 = starttime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
nt = len(mseed[0].data)
data = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
ids = [x.get_id() for x in mseed]
for j, id in enumerate(sorted(ids, key=lambda x: order[x[-1]])):
if len(ids) != 3:
if len(ids) > 3:
logging.warning(f"More than 3 channels {ids}!")
j = comp2idx[id[-1]]
data[:, j] = mseed.select(id=id)[0].data.astype(self.dtype)
data = data[:, np.newaxis, :]
meta = {"data": data, "t0": t0}
return meta
def read_mseed_array(self, fname, stations, amplitude=False, remove_resp=True):
data = []
station_id = []
t0 = []
raw_amp = []
try:
mseed = obspy.read(fname)
read_success = True
except Exception as e:
read_success = False
print(e)
if read_success:
try:
mseed = mseed.merge(fill_value=0)
except Exception as e:
print(e)
for i in range(len(mseed)):
if mseed[i].stats.sampling_rate != self.config.sampling_rate:
logging.warning(
f"Resampling {mseed[i].id} from {mseed[i].stats.sampling_rate} to {self.config.sampling_rate} Hz"
)
try:
mseed[i] = mseed[i].interpolate(self.config.sampling_rate, method="linear")
except Exception as e:
print(e)
mseed[i].data = mseed[i].data.astype(float) * 0.0 ## set to zero if resampling fails
if self.highpass_filter == 0:
try:
mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
except:
logging.error(f"Error: spline detrend failed at file {fname}")
mseed = mseed.detrend("demean")
else:
mseed = mseed.filter("highpass", freq=self.highpass_filter)
starttime = min([st.stats.starttime for st in mseed])
endtime = max([st.stats.endtime for st in mseed])
mseed = mseed.trim(starttime, endtime, pad=True, fill_value=0)
order = ["3", "2", "1", "E", "N", "Z"]
order = {key: i for i, key in enumerate(order)}
comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
nsta = len(stations)
nt = len(mseed[0].data)
# for i in range(nsta):
for sta in stations:
trace_data = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
if amplitude:
trace_amp = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
empty_station = True
# sta = stations.iloc[i]["station"]
# comp = stations.iloc[i]["component"].split(",")
comp = stations[sta]["component"]
if amplitude:
# resp = stations.iloc[i]["response"].split(",")
resp = stations[sta]["response"]
for j, c in enumerate(sorted(comp, key=lambda x: order[x[-1]])):
resp_j = resp[j]
if len(comp) != 3: ## less than 3 component
j = comp2idx[c]
if len(mseed.select(id=sta + c)) == 0:
print(f"Empty trace: {sta+c} {starttime}")
continue
else:
empty_station = False
tmp = mseed.select(id=sta + c)[0].data.astype(self.dtype)
trace_data[: len(tmp), j] = tmp[:nt]
if amplitude:
# if stations.iloc[i]["unit"] == "m/s**2":
if stations[sta]["unit"] == "m/s**2":
tmp = mseed.select(id=sta + c)[0]
tmp = tmp.integrate()
tmp = tmp.filter("highpass", freq=1.0)
tmp = tmp.data.astype(self.dtype)
trace_amp[: len(tmp), j] = tmp[:nt]
# elif stations.iloc[i]["unit"] == "m/s":
elif stations[sta]["unit"] == "m/s":
tmp = mseed.select(id=sta + c)[0].data.astype(self.dtype)
trace_amp[: len(tmp), j] = tmp[:nt]
else:
print(
f"Error in {stations.iloc[i]['station']}\n{stations.iloc[i]['unit']} should be m/s**2 or m/s!"
)
if amplitude and remove_resp:
# trace_amp[:, j] /= float(resp[j])
trace_amp[:, j] /= float(resp_j)
if not empty_station:
data.append(trace_data)
if amplitude:
raw_amp.append(trace_amp)
station_id.append([sta])
t0.append(starttime.datetime.isoformat(timespec="milliseconds"))
if len(data) > 0:
data = np.stack(data)
if len(data.shape) == 3:
data = data[:, :, np.newaxis, :]
if amplitude:
raw_amp = np.stack(raw_amp)
if len(raw_amp.shape) == 3:
raw_amp = raw_amp[:, :, np.newaxis, :]
else:
nt = 60 * 60 * self.config.sampling_rate # assume 1 hour data
data = np.zeros([1, nt, 1, self.config.n_channel], dtype=self.dtype)
if amplitude:
raw_amp = np.zeros([1, nt, 1, self.config.n_channel], dtype=self.dtype)
t0 = ["1970-01-01T00:00:00.000"]
station_id = ["None"]
if amplitude:
meta = {"data": data, "t0": t0, "station_id": station_id, "fname": fname.split("/")[-1], "raw_amp": raw_amp}
else:
meta = {"data": data, "t0": t0, "station_id": station_id, "fname": fname.split("/")[-1]}
return meta
def generate_label(self, data, phase_list, mask=None):
# target = np.zeros(self.Y_shape, dtype=self.dtype)
target = np.zeros_like(data)
if self.label_shape == "gaussian":
label_window = np.exp(
-((np.arange(-self.label_width // 2, self.label_width // 2 + 1)) ** 2)
/ (2 * (self.label_width / 5) ** 2)
)
elif self.label_shape == "triangle":
label_window = 1 - np.abs(
2 / self.label_width * (np.arange(-self.label_width // 2, self.label_width // 2 + 1))
)
else:
print(f"Label shape {self.label_shape} should be guassian or triangle")
raise
for i, phases in enumerate(phase_list):
for j, idx_list in enumerate(phases):
for idx in idx_list:
if np.isnan(idx):
continue
idx = int(idx)
if (idx - self.label_width // 2 >= 0) and (idx + self.label_width // 2 + 1 <= target.shape[0]):
target[idx - self.label_width // 2 : idx + self.label_width // 2 + 1, j, i + 1] = label_window
target[..., 0] = 1 - np.sum(target[..., 1:], axis=-1)
if mask is not None:
target[:, mask == 0, :] = 0
return target
def random_shift(self, sample, itp, its, itp_old=None, its_old=None, shift_range=None):
# anchor = np.round(1/2 * (min(itp[~np.isnan(itp.astype(float))]) + min(its[~np.isnan(its.astype(float))]))).astype(int)
flattern = lambda x: np.array([i for trace in x for i in trace], dtype=float)
shift_pick = lambda x, shift: [[i - shift for i in trace] for trace in x]
itp_flat = flattern(itp)
its_flat = flattern(its)
if (itp_old is None) and (its_old is None):
hi = np.round(np.median(itp_flat[~np.isnan(itp_flat)])).astype(int)
lo = -(sample.shape[0] - np.round(np.median(its_flat[~np.isnan(its_flat)])).astype(int))
if shift_range is None:
shift = np.random.randint(low=lo, high=hi + 1)
else:
shift = np.random.randint(low=max(lo, shift_range[0]), high=min(hi + 1, shift_range[1]))
else:
itp_old_flat = flattern(itp_old)
its_old_flat = flattern(its_old)
itp_ref = np.round(np.min(itp_flat[~np.isnan(itp_flat)])).astype(int)
its_ref = np.round(np.max(its_flat[~np.isnan(its_flat)])).astype(int)
itp_old_ref = np.round(np.min(itp_old_flat[~np.isnan(itp_old_flat)])).astype(int)
its_old_ref = np.round(np.max(its_old_flat[~np.isnan(its_old_flat)])).astype(int)
# min_event_gap = np.round(self.min_event_gap*(its_ref-itp_ref)).astype(int)
# min_event_gap_old = np.round(self.min_event_gap*(its_old_ref-itp_old_ref)).astype(int)
if shift_range is None:
hi = list(range(max(its_ref - itp_old_ref + self.min_event_gap, 0), itp_ref))
lo = list(range(-(sample.shape[0] - its_ref), -(max(its_old_ref - itp_ref + self.min_event_gap, 0))))
else:
lo_ = max(-(sample.shape[0] - its_ref), shift_range[0])
hi_ = min(itp_ref, shift_range[1])
hi = list(range(max(its_ref - itp_old_ref + self.min_event_gap, 0), hi_))
lo = list(range(lo_, -(max(its_old_ref - itp_ref + self.min_event_gap, 0))))
if len(hi + lo) > 0:
shift = np.random.choice(hi + lo)
else:
shift = 0
shifted_sample = np.zeros_like(sample)
if shift > 0:
shifted_sample[:-shift, ...] = sample[shift:, ...]
elif shift < 0:
shifted_sample[-shift:, ...] = sample[:shift, ...]
else:
shifted_sample[...] = sample[...]
return shifted_sample, shift_pick(itp, shift), shift_pick(its, shift), shift
def stack_events(self, sample_old, itp_old, its_old, shift_range=None, mask_old=None):
i = np.random.randint(self.num_data)
base_name = self.data_list[i]
if self.format == "numpy":
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
elif self.format == "hdf5":
meta = self.read_hdf5(base_name)
if meta == -1:
return sample_old, itp_old, its_old
sample = np.copy(meta["data"])
itp = meta["itp"]
its = meta["its"]
if mask_old is not None:
mask = np.copy(meta["mask"])
sample = normalize(sample)
sample, itp, its, shift = self.random_shift(sample, itp, its, itp_old, its_old, shift_range)
if shift != 0:
sample_old += sample
# itp_old = [np.hstack([i, j]) for i,j in zip(itp_old, itp)]
# its_old = [np.hstack([i, j]) for i,j in zip(its_old, its)]
itp_old = [i + j for i, j in zip(itp_old, itp)]
its_old = [i + j for i, j in zip(its_old, its)]
if mask_old is not None:
mask_old = mask_old * mask
return sample_old, itp_old, its_old, mask_old
def cut_window(self, sample, target, itp, its, select_range):
shift_pick = lambda x, shift: [[i - shift for i in trace] for trace in x]
sample = sample[select_range[0] : select_range[1]]
target = target[select_range[0] : select_range[1]]
return (sample, target, shift_pick(itp, select_range[0]), shift_pick(its, select_range[0]))
class DataReader_train(DataReader):
def __init__(self, format="numpy", config=DataConfig(), **kwargs):
super().__init__(format=format, config=config, **kwargs)
self.min_event_gap = config.min_event_gap
self.buffer_channels = {}
self.shift_range = [-2000 + self.label_width * 2, 1000 - self.label_width * 2]
self.select_range = [5000, 8000]
def __getitem__(self, i):
base_name = self.data_list[i]
if self.format == "numpy":
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
elif self.format == "hdf5":
meta = self.read_hdf5(base_name)
if meta == None:
return (np.zeros(self.X_shape, dtype=self.dtype), np.zeros(self.Y_shape, dtype=self.dtype), base_name)
sample = np.copy(meta["data"])
itp_list = meta["itp"]
its_list = meta["its"]
sample = normalize(sample)
if np.random.random() < 0.95:
sample, itp_list, its_list, _ = self.random_shift(sample, itp_list, its_list, shift_range=self.shift_range)
sample, itp_list, its_list, _ = self.stack_events(sample, itp_list, its_list, shift_range=self.shift_range)
target = self.generate_label(sample, [itp_list, its_list])
sample, target, itp_list, its_list = self.cut_window(sample, target, itp_list, its_list, self.select_range)
else:
## noise
assert self.X_shape[0] <= min(min(itp_list))
sample = sample[: self.X_shape[0], ...]
target = np.zeros(self.Y_shape).astype(self.dtype)
itp_list = [[]]
its_list = [[]]
sample = normalize(sample)
return (sample.astype(self.dtype), target.astype(self.dtype), base_name)
def dataset(self, batch_size, num_parallel_calls=2, shuffle=True, drop_remainder=True):
dataset = dataset_map(
self,
output_types=(self.dtype, self.dtype, "string"),
output_shapes=(self.X_shape, self.Y_shape, None),
num_parallel_calls=num_parallel_calls,
shuffle=shuffle,
)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
return dataset
class DataReader_test(DataReader):
def __init__(self, format="numpy", config=DataConfig(), **kwargs):
super().__init__(format=format, config=config, **kwargs)
self.select_range = [5000, 8000]
def __getitem__(self, i):
base_name = self.data_list[i]
if self.format == "numpy":
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
elif self.format == "hdf5":
meta = self.read_hdf5(base_name)
if meta == -1:
return (np.zeros(self.Y_shape, dtype=self.dtype), np.zeros(self.X_shape, dtype=self.dtype), base_name)
sample = np.copy(meta["data"])
itp_list = meta["itp"]
its_list = meta["its"]
# sample, itp_list, its_list, _ = self.random_shift(sample, itp_list, its_list, shift_range=self.shift_range)
target = self.generate_label(sample, [itp_list, its_list])
sample, target, itp_list, its_list = self.cut_window(sample, target, itp_list, its_list, self.select_range)
sample = normalize(sample)
return (sample, target, base_name, itp_list, its_list)
def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainder=False):
dataset = dataset_map(
self,
output_types=(self.dtype, self.dtype, "string", "int64", "int64"),
output_shapes=(self.X_shape, self.Y_shape, None, None, None),
num_parallel_calls=num_parallel_calls,
shuffle=shuffle,
)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
return dataset
class DataReader_pred(DataReader):
def __init__(self, format="numpy", amplitude=True, config=DataConfig(), **kwargs):
super().__init__(format=format, config=config, **kwargs)
self.amplitude = amplitude
def adjust_missingchannels(self, data):
tmp = np.max(np.abs(data), axis=0, keepdims=True)
assert tmp.shape[-1] == data.shape[-1]
if np.count_nonzero(tmp) > 0:
data *= data.shape[-1] / np.count_nonzero(tmp)
return data
def __getitem__(self, i):
base_name = self.data_list[i]
if self.format == "numpy":
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
elif (self.format == "mseed") or (self.format == "sac"):
meta = self.read_mseed(
os.path.join(self.data_dir, base_name),
response=self.response,
sampling_rate=self.sampling_rate,
highpass_filter=self.highpass_filter,
return_single_station=True,
)
elif self.format == "hdf5":
meta = self.read_hdf5(base_name)
else:
raise (f"{self.format} does not support!")
if "data" in meta:
raw_amp = meta["data"].copy()
sample = normalize_long(meta["data"])
else:
raw_amp = np.zeros([3000, 1, 3], dtype=np.float32)
sample = np.zeros([3000, 1, 3], dtype=np.float32)
if "t0" in meta:
t0 = meta["t0"]
else:
t0 = "1970-01-01T00:00:00.000"
if "station_id" in meta:
station_id = meta["station_id"]
else:
# station_id = base_name.split("/")[-1].rstrip("*")
station_id = os.path.basename(base_name).rstrip("*")
if np.isnan(sample).any() or np.isinf(sample).any():
logging.warning(f"Data error: Nan or Inf found in {base_name}")
sample[np.isnan(sample)] = 0
sample[np.isinf(sample)] = 0
# sample = self.adjust_missingchannels(sample)
if self.amplitude:
return (sample, raw_amp, base_name, t0, station_id)
else:
return (sample, base_name, t0, station_id)
def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainder=False):
if self.amplitude:
dataset = dataset_map(
self,
output_types=(self.dtype, self.dtype, "string", "string", "string"),
output_shapes=([None, None, 3], [None, None, 3], None, None, None),
num_parallel_calls=num_parallel_calls,
shuffle=shuffle,
)
else:
dataset = dataset_map(
self,
output_types=(self.dtype, "string", "string", "string"),
output_shapes=([None, None, 3], None, None, None),
num_parallel_calls=num_parallel_calls,
shuffle=shuffle,
)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
return dataset
class DataReader_mseed_array(DataReader):
def __init__(self, stations, amplitude=True, remove_resp=True, config=DataConfig(), **kwargs):
super().__init__(format="mseed", config=config, **kwargs)
# self.stations = pd.read_json(stations)
with open(stations, "r") as f:
self.stations = json.load(f)
print(pd.DataFrame.from_dict(self.stations, orient="index").to_string())
self.amplitude = amplitude
self.remove_resp = remove_resp
self.X_shape = self.get_data_shape()
def get_data_shape(self):
fname = os.path.join(self.data_dir, self.data_list[0])
meta = self.read_mseed_array(fname, self.stations, self.amplitude, self.remove_resp)
return meta["data"].shape
def __getitem__(self, i):
fp = os.path.join(self.data_dir, self.data_list[i])
# try:
meta = self.read_mseed_array(fp, self.stations, self.amplitude, self.remove_resp)
# except Exception as e:
# logging.error(f"Failed reading {fp}: {e}")
# if self.amplitude:
# return (np.zeros(self.X_shape).astype(self.dtype), np.zeros(self.X_shape).astype(self.dtype),
# [self.stations.iloc[i]["station"] for i in range(len(self.stations))], ["0" for i in range(len(self.stations))])
# else:
# return (np.zeros(self.X_shape).astype(self.dtype), ["" for i in range(len(self.stations))],
# [self.stations.iloc[i]["station"] for i in range(len(self.stations))])
sample = np.zeros([len(meta["data"]), *self.X_shape[1:]], dtype=self.dtype)
sample[:, : meta["data"].shape[1], :, :] = normalize_batch(meta["data"])[:, : self.X_shape[1], :, :]
if np.isnan(sample).any() or np.isinf(sample).any():
logging.warning(f"Data error: Nan or Inf found in {fp}")
sample[np.isnan(sample)] = 0
sample[np.isinf(sample)] = 0
t0 = meta["t0"]
base_name = meta["fname"]
station_id = meta["station_id"]
# base_name = [self.stations.iloc[i]["station"]+"."+t0[i] for i in range(len(self.stations))]
# base_name = [self.stations.iloc[i]["station"] for i in range(len(self.stations))]
if self.amplitude:
raw_amp = np.zeros([len(meta["raw_amp"]), *self.X_shape[1:]], dtype=self.dtype)
raw_amp[:, : meta["raw_amp"].shape[1], :, :] = meta["raw_amp"][:, : self.X_shape[1], :, :]
if np.isnan(raw_amp).any() or np.isinf(raw_amp).any():
logging.warning(f"Data error: Nan or Inf found in {fp}")
raw_amp[np.isnan(raw_amp)] = 0
raw_amp[np.isinf(raw_amp)] = 0
return (sample, raw_amp, base_name, t0, station_id)
else:
return (sample, base_name, t0, station_id)
def dataset(self, num_parallel_calls=1, shuffle=False):
if self.amplitude:
dataset = dataset_map(
self,
output_types=(self.dtype, self.dtype, "string", "string", "string"),
output_shapes=([None, *self.X_shape[1:]], [None, *self.X_shape[1:]], None, None, None),
num_parallel_calls=num_parallel_calls,
)
else:
dataset = dataset_map(
self,
output_types=(self.dtype, "string", "string", "string"),
output_shapes=([None, *self.X_shape[1:]], None, None, None),
num_parallel_calls=num_parallel_calls,
)
dataset = dataset.prefetch(1)
# dataset = dataset.prefetch(len(self.stations)*2)
return dataset
###### test ########
def test_DataReader():
import os
import timeit
import matplotlib.pyplot as plt
if not os.path.exists("test_figures"):
os.mkdir("test_figures")
def plot_sample(sample, fname, label=None):
plt.clf()
plt.subplot(211)
plt.plot(sample[:, 0, -1])
if label is not None:
plt.subplot(212)
plt.plot(label[:, 0, 0])
plt.plot(label[:, 0, 1])
plt.plot(label[:, 0, 2])
plt.savefig(f"test_figures/{fname.decode()}.png")
def read(data_reader, batch=1):
start_time = timeit.default_timer()
if batch is None:
dataset = data_reader.dataset(shuffle=False)
else:
dataset = data_reader.dataset(1, shuffle=False)
sess = tf.compat.v1.Session()
print(len(data_reader))
print("-------", tf.data.Dataset.cardinality(dataset))
num = 0
x = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
while True:
num += 1
# print(num)
try:
out = sess.run(x)
if len(out) == 2:
sample, fname = out[0], out[1]
for i in range(len(sample)):
plot_sample(sample[i], fname[i])
else:
sample, label, fname = out[0], out[1], out[2]
for i in range(len(sample)):
plot_sample(sample[i], fname[i], label[i])
except tf.errors.OutOfRangeError:
break
print("End of dataset")
print("Tensorflow Dataset:\nexecution time = ", timeit.default_timer() - start_time)
data_reader = DataReader_train(data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
read(data_reader)
data_reader = DataReader_train(format="hdf5", hdf5="test_data/data.h5", group="data")
read(data_reader)
data_reader = DataReader_test(data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
read(data_reader)
data_reader = DataReader_test(format="hdf5", hdf5="test_data/data.h5", group="data")
read(data_reader)
data_reader = DataReader_pred(format="numpy", data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
read(data_reader)
data_reader = DataReader_pred(
format="mseed", data_list="test_data/mseed_station.csv", data_dir="test_data/waveforms/"
)
read(data_reader)
data_reader = DataReader_pred(
format="mseed", amplitude=True, data_list="test_data/mseed_station.csv", data_dir="test_data/waveforms/"
)
read(data_reader)
data_reader = DataReader_mseed_array(
data_list="test_data/mseed.csv",
data_dir="test_data/waveforms/",
stations="test_data/stations.csv",
remove_resp=False,
)
read(data_reader, batch=None)
if __name__ == "__main__":
test_DataReader()