Spaces:
Sleeping
Sleeping
import sys | |
import os | |
import pickle | |
import gzip | |
from pathlib import Path | |
import numpy as np | |
import torch | |
from scipy import stats | |
from gluformer.model import Gluformer | |
from utils.darts_processing import * | |
from utils.darts_dataset import * | |
import hashlib | |
from urllib.parse import urlparse | |
from huggingface_hub import hf_hub_download | |
import plotly.graph_objects as go | |
glucose = Path(os.path.abspath(__file__)).parent.resolve() | |
file_directory = glucose / "files" | |
def plot_forecast(forecasts: np.ndarray, scalers: Any, dataset_test_glufo: Any, filename: str): | |
forecasts = (forecasts - scalers['target'].min_) / scalers['target'].scale_ | |
trues = [dataset_test_glufo.evalsample(i) for i in range(len(dataset_test_glufo))] | |
trues = scalers['target'].inverse_transform(trues) | |
trues = [ts.values() for ts in trues] # Convert TimeSeries to numpy arrays | |
trues = np.array(trues) | |
inputs = [dataset_test_glufo[i][0] for i in range(len(dataset_test_glufo))] | |
inputs = (np.array(inputs) - scalers['target'].min_) / scalers['target'].scale_ | |
# Select a specific sample to plot | |
ind = 10 # Example index | |
samples = np.random.normal( | |
loc=forecasts[ind, :], # Mean (center) of the distribution | |
scale=0.1, # Standard deviation (spread) of the distribution | |
size=(forecasts.shape[1], forecasts.shape[2]) | |
) | |
# Create figure | |
fig = go.Figure() | |
# Plot predictive distribution | |
for point in range(samples.shape[0]): | |
kde = stats.gaussian_kde(samples[point,:]) | |
maxi, mini = 1.2 * np.max(samples[point, :]), 0.8 * np.min(samples[point, :]) | |
y_grid = np.linspace(mini, maxi, 200) | |
x = kde(y_grid) | |
# Create gradient color | |
color = f'rgba(53, 138, 217, {(point + 1) / samples.shape[0]})' | |
# Add filled area | |
fig.add_trace(go.Scatter( | |
x=np.concatenate([np.full_like(y_grid, point), np.full_like(y_grid, point - x * 15)[::-1]]), | |
y=np.concatenate([y_grid, y_grid[::-1]]), | |
fill='tonexty', | |
fillcolor=color, | |
line=dict(color='rgba(0,0,0,0)'), | |
showlegend=False | |
)) | |
true_values = np.concatenate([inputs[ind, -12:], trues[ind, :]]) | |
true_values_flat=true_values.flatten() | |
fig.add_trace(go.Scatter( | |
x=list(range(-12, 12)), | |
y=true_values_flat.tolist(), # Convert to list explicitly | |
mode='lines+markers', | |
line=dict(color='blue', width=2), | |
marker=dict(size=6), | |
name='True Values' | |
)) | |
# Plot median | |
forecast = samples[:, :] | |
median = np.quantile(forecast, 0.5, axis=-1) | |
fig.add_trace(go.Scatter( | |
x=list(range(12)), | |
y=median.tolist(), # Convert to list explicitly | |
mode='lines+markers', | |
line=dict(color='red', width=2), | |
marker=dict(size=8), | |
name='Median Forecast' | |
)) | |
# Update layout | |
fig.update_layout( | |
title='Gluformer Prediction with Gradient for dataset', | |
xaxis_title='Time (in 5 minute intervals)', | |
yaxis_title='Glucose (mg/dL)', | |
font=dict(size=14), | |
showlegend=True, | |
width=1000, | |
height=600 | |
) | |
# Save figure | |
where = file_directory / filename | |
fig.write_html(str(where.with_suffix('.html'))) | |
fig.write_image(str(where)) | |
return where, fig | |
def generate_filename_from_url(url: str, extension: str = "png") -> str: | |
""" | |
:param url: | |
:param extension: | |
:return: | |
""" | |
# Extract the last segment of the URL | |
last_segment = urlparse(url).path.split('/')[-1] | |
# Compute the hash of the URL | |
url_hash = hashlib.md5(url.encode('utf-8')).hexdigest() | |
# Create the filename | |
filename = f"{last_segment.replace('.','_')}_{url_hash}.{extension}" | |
return filename | |
def predict_glucose_tool(file) -> go.Figure: | |
""" | |
Function to predict future glucose of user. It receives URL with users csv. It will run an ML and will return URL with predictions that user can open on her own.. | |
:param file: it is the csv file imported as a string path to the temporary location gradio allows | |
:param model: model that is used to predict the glucose- was hardcoded | |
:param explain if it should give both url and explanation | |
:param if the person is diabetic when doing prediction and explanation | |
:return: | |
""" | |
url = file | |
model="Livia-Zaharia/gluformer_models" | |
model_path = hf_hub_download(repo_id= model, filename="gluformer_1samples_10000epochs_10heads_32batch_geluactivation_livia_mini_weights.pth") | |
formatter, series, scalers = load_data(url=str(url), config_path=file_directory / "config.yaml", use_covs=True, | |
cov_type='dual', | |
use_static_covs=True) | |
filename = generate_filename_from_url(url) | |
formatter.params['gluformer'] = { | |
'in_len': 96, # example input length, adjust as necessary | |
'd_model': 512, # model dimension | |
'n_heads': 10, # number of attention heads######################## | |
'd_fcn': 1024, # fully connected layer dimension | |
'num_enc_layers': 2, # number of encoder layers | |
'num_dec_layers': 2, # number of decoder layers | |
'length_pred': 12 # prediction length, adjust as necessary | |
} | |
num_dynamic_features = series['train']['future'][-1].n_components | |
num_static_features = series['train']['static'][-1].n_components | |
glufo = Gluformer( | |
d_model=formatter.params['gluformer']['d_model'], | |
n_heads=formatter.params['gluformer']['n_heads'], | |
d_fcn=formatter.params['gluformer']['d_fcn'], | |
r_drop=0.2, | |
activ='gelu', | |
num_enc_layers=formatter.params['gluformer']['num_enc_layers'], | |
num_dec_layers=formatter.params['gluformer']['num_dec_layers'], | |
distil=True, | |
len_seq=formatter.params['gluformer']['in_len'], | |
label_len=formatter.params['gluformer']['in_len'] // 3, | |
len_pred=formatter.params['length_pred'], | |
num_dynamic_features=num_dynamic_features, | |
num_static_features=num_static_features | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
glufo.load_state_dict(torch.load(str(model_path), map_location=torch.device(device), weights_only=True)) | |
# Define dataset for inference | |
dataset_test_glufo = SamplingDatasetInferenceDual( | |
target_series=series['test']['target'], | |
covariates=series['test']['future'], | |
input_chunk_length=formatter.params['gluformer']['in_len'], | |
output_chunk_length=formatter.params['length_pred'], | |
use_static_covariates=True, | |
array_output_only=True | |
) | |
forecasts, _ = glufo.predict( | |
dataset_test_glufo, | |
batch_size=16,####### | |
num_samples=10, | |
device=device | |
) | |
figure_path, result = plot_forecast(forecasts, scalers, dataset_test_glufo,filename) | |
return result | |
if __name__ == "__main__": | |
predict_glucose_tool() | |