Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from GAN.diffusion import build_model, GaussianDiffusion, DiffusionModel | |
import tensorflow as tf | |
from tensorflow.python.types.core import TensorLike | |
import imageio | |
import tempfile | |
import os | |
EPS = 1e-18 | |
class TSFeatureScaler: | |
"""Global time series scaler that scales all features to [0,1] then normalizes to [-1,1]""" | |
def __init__(self) -> None: | |
self.min_val = None | |
self.max_val = None | |
def fit(self, X: TensorLike) -> "TSFeatureScaler": | |
""" | |
Fit scaler to data | |
Args: | |
X: Input tensor of shape [N, T, D] | |
(N: samples, T: timesteps, D: features) | |
""" | |
# 计算整个数据集的全局最大最小值 | |
self.min_val = np.min(X) | |
self.max_val = np.max(X) | |
return self | |
def transform(self, X: TensorLike) -> TensorLike: | |
""" | |
Transform data in two steps: | |
1. Scale to [0,1] using min-max scaling | |
2. Normalize to [-1,1] | |
""" | |
if self.min_val is None or self.max_val is None: | |
raise ValueError("Scaler must be fitted before transform") | |
# 1. 缩放到[0,1] | |
X_scaled = (X - self.min_val) / (self.max_val - self.min_val + EPS) | |
# 2. 归一化到[-1,1] | |
X_normalized = 2.0 * X_scaled - 1.0 | |
return X_normalized | |
def fit_transform(self, X: TensorLike) -> TensorLike: | |
"""Fit to data, then transform it""" | |
return self.fit(X).transform(X) | |
def create_animation(frames, fps=1): | |
"""将帧列表转换为GIF动画数据""" | |
import tempfile | |
import os | |
temp_dir = tempfile.gettempdir() | |
temp_path = os.path.join(temp_dir, f"temp_{id(frames)}.gif") | |
# 将fps转换为duration (毫秒) | |
duration = int(1000 / fps) # 1000ms = 1s | |
# 保存为GIF文件,设置循环播放 | |
imageio.mimsave(temp_path, frames, format='GIF', duration=duration, loop=0) # loop=0 表示无限循环 | |
return temp_path | |
def generate_timeseries(input_file, num_samples=16): | |
try: | |
# 加载数据 | |
real_data = np.load(input_file.name) | |
scaler = TSFeatureScaler() | |
real_data = scaler.fit_transform(real_data) | |
print(f"Loaded data shape: {real_data.shape}") | |
# 确保数据形状正确 | |
expected_shape = (None, 96, 3) | |
if len(real_data.shape) != 3 or real_data.shape[1:] != expected_shape[1:]: | |
return None, None | |
# 创建模型和必要的组件 | |
network = build_model( | |
time_len=96, | |
fea_num=3, | |
d_model=16, | |
n_heads=2, | |
encoder_type='dual' | |
) | |
ema_network = build_model( | |
time_len=96, | |
fea_num=3, | |
d_model=16, | |
n_heads=2, | |
encoder_type='dual' | |
) | |
ema_network.set_weights(network.get_weights()) | |
noise_util = GaussianDiffusion(timesteps=10) | |
print("Creating model...") | |
model = DiffusionModel( | |
network=network, | |
ema_network=ema_network, | |
timesteps=10, | |
gdf_util=noise_util, | |
data=real_data[:num_samples] | |
) | |
# 加载预训练权重 | |
checkpoint_path = "/Users/lindan/Dropbox/PhD/Projects/PLF/GAN/code_github/checkpoint/cp.ckpt" | |
print(f"Loading weights from {checkpoint_path}") | |
model.load_weights(checkpoint_path) | |
# 生成加噪过程的动画 | |
print("Generating noising animation...") | |
noise_frames = model.plot_noise_process_app(num_samples) | |
noise_gif = create_animation(noise_frames) | |
# 生成去噪过程的动画 | |
print("Generating denoising animation...") | |
denoise_frames = model.plot_denoise_process_app(num_samples)[1:] | |
denoise_gif = create_animation(denoise_frames) | |
return noise_gif, denoise_gif | |
except Exception as e: | |
import traceback | |
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}" | |
print(error_msg) | |
return None, None | |
def update_example_gifs(num_samples): | |
"""根据选择的样本数更新示例GIF""" | |
return f"noising_example_{num_samples}.gif", f"denoising_example_{num_samples}.gif" | |
# 创建Gradio界面 | |
with gr.Blocks(title="Wearable Sensors Time-Series Generation") as demo: | |
with gr.Column(elem_id="container"): | |
# Logo | |
gr.Image("logo.webp", elem_id="logo", show_label=False, container=False) | |
# 标题和副标题 | |
gr.Markdown( | |
""" | |
# Wearable Sensors Time-Series Generation | |
<h3 style='font-weight: normal; color: #666;'>-- mainly targeted at livestock wearables sensors data</h3> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
noise_gif = gr.Image(value="noising_example_16.gif", label="Noising Process", show_label=True) | |
with gr.Column(): | |
denoise_gif = gr.Image(value="denoising_example_16.gif", label="Denoising Process", show_label=True) | |
with gr.Row(): | |
with gr.Column(): | |
num_samples = gr.Radio( | |
choices=[4, 9, 16, 25], | |
value=16, | |
label="Number of samples to generate" | |
) | |
generate_btn = gr.Button("Generate") | |
# 将File组件改为Examples组件 | |
input_file = gr.File(label="Select example data") | |
gr.Examples( | |
examples=[ | |
["app_examples/example1.npy"], | |
["app_examples/example2.npy"], | |
["app_examples/example3.npy"], | |
["app_examples/example4.npy"] | |
], | |
inputs=input_file, | |
label="Example Datasets" | |
) | |
# 添加按钮事件处理 | |
generate_btn.click( | |
fn=generate_timeseries, | |
inputs=[input_file, num_samples], | |
outputs=[noise_gif, denoise_gif] | |
) | |
# 添加样本数选择的事件处理 | |
num_samples.change( | |
fn=update_example_gifs, | |
inputs=[num_samples], | |
outputs=[noise_gif, denoise_gif] | |
) | |
# 添加CSS样式 | |
gr.HTML( | |
""" | |
<style> | |
#container { | |
text-align: center; | |
padding: 2rem 0; | |
} | |
#logo { | |
width: 120px; | |
height: 120px; | |
margin: 0 auto; | |
margin-bottom: 1rem; | |
} | |
h1 { | |
font-size: 3.5rem; | |
margin-bottom: 0.5rem; | |
} | |
h3 { | |
font-size: 1.8rem; | |
margin-top: 0; | |
color: #666; | |
} | |
</style> | |
""" | |
) | |
# 启动应用 | |
if __name__ == "__main__": | |
demo.launch(share=True) |