synls / app.py
danlin1128's picture
Upload 33 files
67069a4 verified
raw
history blame
6.92 kB
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from GAN.diffusion import build_model, GaussianDiffusion, DiffusionModel
import tensorflow as tf
from tensorflow.python.types.core import TensorLike
import imageio
import tempfile
import os
EPS = 1e-18
class TSFeatureScaler:
"""Global time series scaler that scales all features to [0,1] then normalizes to [-1,1]"""
def __init__(self) -> None:
self.min_val = None
self.max_val = None
def fit(self, X: TensorLike) -> "TSFeatureScaler":
"""
Fit scaler to data
Args:
X: Input tensor of shape [N, T, D]
(N: samples, T: timesteps, D: features)
"""
# 计算整个数据集的全局最大最小值
self.min_val = np.min(X)
self.max_val = np.max(X)
return self
def transform(self, X: TensorLike) -> TensorLike:
"""
Transform data in two steps:
1. Scale to [0,1] using min-max scaling
2. Normalize to [-1,1]
"""
if self.min_val is None or self.max_val is None:
raise ValueError("Scaler must be fitted before transform")
# 1. 缩放到[0,1]
X_scaled = (X - self.min_val) / (self.max_val - self.min_val + EPS)
# 2. 归一化到[-1,1]
X_normalized = 2.0 * X_scaled - 1.0
return X_normalized
def fit_transform(self, X: TensorLike) -> TensorLike:
"""Fit to data, then transform it"""
return self.fit(X).transform(X)
def create_animation(frames, fps=1):
"""将帧列表转换为GIF动画数据"""
import tempfile
import os
temp_dir = tempfile.gettempdir()
temp_path = os.path.join(temp_dir, f"temp_{id(frames)}.gif")
# 将fps转换为duration (毫秒)
duration = int(1000 / fps) # 1000ms = 1s
# 保存为GIF文件,设置循环播放
imageio.mimsave(temp_path, frames, format='GIF', duration=duration, loop=0) # loop=0 表示无限循环
return temp_path
def generate_timeseries(input_file, num_samples=16):
try:
# 加载数据
real_data = np.load(input_file.name)
scaler = TSFeatureScaler()
real_data = scaler.fit_transform(real_data)
print(f"Loaded data shape: {real_data.shape}")
# 确保数据形状正确
expected_shape = (None, 96, 3)
if len(real_data.shape) != 3 or real_data.shape[1:] != expected_shape[1:]:
return None, None
# 创建模型和必要的组件
network = build_model(
time_len=96,
fea_num=3,
d_model=16,
n_heads=2,
encoder_type='dual'
)
ema_network = build_model(
time_len=96,
fea_num=3,
d_model=16,
n_heads=2,
encoder_type='dual'
)
ema_network.set_weights(network.get_weights())
noise_util = GaussianDiffusion(timesteps=10)
print("Creating model...")
model = DiffusionModel(
network=network,
ema_network=ema_network,
timesteps=10,
gdf_util=noise_util,
data=real_data[:num_samples]
)
# 加载预训练权重
checkpoint_path = "/Users/lindan/Dropbox/PhD/Projects/PLF/GAN/code_github/checkpoint/cp.ckpt"
print(f"Loading weights from {checkpoint_path}")
model.load_weights(checkpoint_path)
# 生成加噪过程的动画
print("Generating noising animation...")
noise_frames = model.plot_noise_process_app(num_samples)
noise_gif = create_animation(noise_frames)
# 生成去噪过程的动画
print("Generating denoising animation...")
denoise_frames = model.plot_denoise_process_app(num_samples)[1:]
denoise_gif = create_animation(denoise_frames)
return noise_gif, denoise_gif
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return None, None
def update_example_gifs(num_samples):
"""根据选择的样本数更新示例GIF"""
return f"noising_example_{num_samples}.gif", f"denoising_example_{num_samples}.gif"
# 创建Gradio界面
with gr.Blocks(title="Wearable Sensors Time-Series Generation") as demo:
with gr.Column(elem_id="container"):
# Logo
gr.Image("logo.webp", elem_id="logo", show_label=False, container=False)
# 标题和副标题
gr.Markdown(
"""
# Wearable Sensors Time-Series Generation
<h3 style='font-weight: normal; color: #666;'>-- mainly targeted at livestock wearables sensors data</h3>
""")
with gr.Row():
with gr.Column():
noise_gif = gr.Image(value="noising_example_16.gif", label="Noising Process", show_label=True)
with gr.Column():
denoise_gif = gr.Image(value="denoising_example_16.gif", label="Denoising Process", show_label=True)
with gr.Row():
with gr.Column():
num_samples = gr.Radio(
choices=[4, 9, 16, 25],
value=16,
label="Number of samples to generate"
)
generate_btn = gr.Button("Generate")
# 将File组件改为Examples组件
input_file = gr.File(label="Select example data")
gr.Examples(
examples=[
["app_examples/example1.npy"],
["app_examples/example2.npy"],
["app_examples/example3.npy"],
["app_examples/example4.npy"]
],
inputs=input_file,
label="Example Datasets"
)
# 添加按钮事件处理
generate_btn.click(
fn=generate_timeseries,
inputs=[input_file, num_samples],
outputs=[noise_gif, denoise_gif]
)
# 添加样本数选择的事件处理
num_samples.change(
fn=update_example_gifs,
inputs=[num_samples],
outputs=[noise_gif, denoise_gif]
)
# 添加CSS样式
gr.HTML(
"""
<style>
#container {
text-align: center;
padding: 2rem 0;
}
#logo {
width: 120px;
height: 120px;
margin: 0 auto;
margin-bottom: 1rem;
}
h1 {
font-size: 3.5rem;
margin-bottom: 0.5rem;
}
h3 {
font-size: 1.8rem;
margin-top: 0;
color: #666;
}
</style>
"""
)
# 启动应用
if __name__ == "__main__":
demo.launch(share=True)