danlin1128 commited on
Commit
9922e80
·
verified ·
1 Parent(s): cc7f62d

Update GAN/diffusion.py

Browse files
Files changed (1) hide show
  1. GAN/diffusion.py +27 -3
GAN/diffusion.py CHANGED
@@ -5,6 +5,7 @@ from tqdm.auto import tqdm
5
  import tensorflow as tf
6
  from tensorflow import keras
7
  from tensorflow.keras import layers
 
8
  from GAN.utils import linear_beta_schedule, cosine_beta_schedule
9
  import matplotlib.pyplot as plt
10
  import os
@@ -152,6 +153,26 @@ def TimeMLP(units, activation_fn=keras.activations.swish):
152
 
153
  return apply
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # Kernel initializer to use
156
  def kernel_init(scale):
157
  scale = max(scale, 1e-10)
@@ -256,10 +277,13 @@ def build_model(time_len, fea_num, d_model=16, n_heads=2, encoder_type='dual'):
256
 
257
  # Input layers
258
  x_input = layers.Input(shape=(time_len, fea_num))
259
- time_input = layers.Input(shape=())
260
 
261
- # Time step embeddings
262
- time_emb = get_time_embedding(time_input, d_model)
 
 
 
 
263
 
264
  encoded_features = []
265
 
 
5
  import tensorflow as tf
6
  from tensorflow import keras
7
  from tensorflow.keras import layers
8
+ from tensorflow.keras.layers import Layer
9
  from GAN.utils import linear_beta_schedule, cosine_beta_schedule
10
  import matplotlib.pyplot as plt
11
  import os
 
153
 
154
  return apply
155
 
156
+ # 创建一个新的层来包装时间嵌入操作
157
+ class TimeEmbeddingLayer(Layer):
158
+ def __init__(self, d_model, **kwargs):
159
+ super().__init__(**kwargs)
160
+ self.d_model = d_model
161
+
162
+ def call(self, timesteps):
163
+ # 扩展维度
164
+ timesteps = tf.expand_dims(timesteps, -1)
165
+
166
+ # 计算不同频率的正弦信号
167
+ dim = self.d_model // 2
168
+ inv_freq = 1.0 / (10000 ** (tf.range(0, dim, dtype=tf.float32) / dim))
169
+
170
+ # 计算正弦波
171
+ sinusoid_inp = tf.einsum("i,j->ij", tf.squeeze(timesteps), inv_freq)
172
+ pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], axis=-1)
173
+
174
+ return pos_emb
175
+
176
  # Kernel initializer to use
177
  def kernel_init(scale):
178
  scale = max(scale, 1e-10)
 
277
 
278
  # Input layers
279
  x_input = layers.Input(shape=(time_len, fea_num))
 
280
 
281
+ #time_input = layers.Input(shape=())
282
+ #time_emb = get_time_embedding(time_input, d_model)
283
+
284
+ time_input = Input(shape=(1,), name='time_input')
285
+ feature_input = Input(shape=(time_len, fea_num), name='feature_input')
286
+ time_emb = TimeEmbeddingLayer(d_model)(time_input)
287
 
288
  encoded_features = []
289