Livia_Zaharia commited on
Commit
64e42c0
·
1 Parent(s): 7935ff0

it works locally and with user input of preprocessed csv

Browse files
__pycache__/plot_predictions.cpython-311.pyc DELETED
Binary file (9.5 kB)
 
__pycache__/routes.cpython-311.pyc DELETED
Binary file (2.33 kB)
 
__pycache__/tools.cpython-311.pyc DELETED
Binary file (13.3 kB)
 
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  from tools import *
3
 
4
 
5
- def gradio_output():
6
- return (predict_glucose_tool())
7
 
8
- gr.Interface(fn=gradio_output,inputs=None,outputs="image").launch()
 
2
  from tools import *
3
 
4
 
5
+ def gradio_output(file):
6
+ return (predict_glucose_tool(file))
7
 
8
+ gr.Interface(fn=gradio_output,inputs=gr.File(label="Upload CSV File"),outputs="plot").launch()
data_formatter/__pycache__/base.cpython-311.pyc DELETED
Binary file (16.4 kB)
 
data_formatter/base.py CHANGED
@@ -7,6 +7,7 @@ import sklearn.preprocessing
7
  import data_formatter.types as types
8
  import data_formatter.utils as utils
9
 
 
10
  DataTypes = types.DataTypes
11
  InputTypes = types.InputTypes
12
 
@@ -44,6 +45,7 @@ class DataFormatter:
44
  print('Loading data...')
45
  self.params['index_col'] = False if self.params['index_col'] == -1 else self.params['index_col']
46
  # read data table
 
47
  self.data = pd.read_csv(self.params['data_csv_path'], index_col=self.params['index_col'])
48
 
49
  # drop columns / rows
 
7
  import data_formatter.types as types
8
  import data_formatter.utils as utils
9
 
10
+
11
  DataTypes = types.DataTypes
12
  InputTypes = types.InputTypes
13
 
 
45
  print('Loading data...')
46
  self.params['index_col'] = False if self.params['index_col'] == -1 else self.params['index_col']
47
  # read data table
48
+
49
  self.data = pd.read_csv(self.params['data_csv_path'], index_col=self.params['index_col'])
50
 
51
  # drop columns / rows
environment.yaml CHANGED
@@ -1,11 +1,10 @@
1
- name: glucose_genie
2
  channels:
3
  - conda-forge
4
  - defaults
5
  dependencies:
6
  - python=3.11
7
  - gradio
8
- - seaborn
9
  - pytorch
10
  - optuna
11
  - tensorboard
@@ -17,3 +16,10 @@ dependencies:
17
  - pmdarima==2.0.4
18
  - numpy==1.26.4
19
  - peft
 
 
 
 
 
 
 
 
1
+ name: glucose_hf
2
  channels:
3
  - conda-forge
4
  - defaults
5
  dependencies:
6
  - python=3.11
7
  - gradio
 
8
  - pytorch
9
  - optuna
10
  - tensorboard
 
16
  - pmdarima==2.0.4
17
  - numpy==1.26.4
18
  - peft
19
+ - transformers
20
+ - datasets
21
+ - python-multipart
22
+ - plotly
23
+ - kaleido
24
+
25
+
gluformer/__pycache__/model.cpython-311.pyc DELETED
Binary file (15.9 kB)
 
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  gradio
2
- seaborn
3
  torch
4
  optuna
5
  numpy==1.26.4
@@ -10,3 +9,8 @@ typer
10
  darts==0.29.0
11
  pmdarima==2.0.4
12
  peft
 
 
 
 
 
 
1
  gradio
 
2
  torch
3
  optuna
4
  numpy==1.26.4
 
9
  darts==0.29.0
10
  pmdarima==2.0.4
11
  peft
12
+ transformers
13
+ datasets
14
+ python-multipart
15
+ plotly
16
+ kaleido
tools.py CHANGED
@@ -3,25 +3,16 @@ import os
3
  import pickle
4
  import gzip
5
  from pathlib import Path
6
-
7
- import seaborn as sns
8
  import numpy as np
9
- import matplotlib.pyplot as plt
10
- import matplotlib.colors as mcolors
11
- from matplotlib.figure import Figure
12
  import torch
13
  from scipy import stats
14
-
15
  from gluformer.model import Gluformer
16
  from utils.darts_processing import *
17
  from utils.darts_dataset import *
18
-
19
-
20
  import hashlib
21
  from urllib.parse import urlparse
22
-
23
- import numpy as np
24
- import typer
25
 
26
 
27
  glucose = Path(os.path.abspath(__file__)).parent.resolve()
@@ -29,7 +20,7 @@ file_directory = glucose / "files"
29
 
30
 
31
  def plot_forecast(forecasts: np.ndarray, scalers: Any, dataset_test_glufo: Any, filename: str):
32
- filename=filename
33
  forecasts = (forecasts - scalers['target'].min_) / scalers['target'].scale_
34
 
35
  trues = [dataset_test_glufo.evalsample(i) for i in range(len(dataset_test_glufo))]
@@ -41,25 +32,18 @@ def plot_forecast(forecasts: np.ndarray, scalers: Any, dataset_test_glufo: Any,
41
  inputs = [dataset_test_glufo[i][0] for i in range(len(dataset_test_glufo))]
42
  inputs = (np.array(inputs) - scalers['target'].min_) / scalers['target'].scale_
43
 
44
- # Plot settings
45
- colors = ['#00264c', '#0a2c62', '#14437f', '#1f5a9d', '#2973bb', '#358ad9', '#4d9af4', '#7bb7ff', '#add5ff', '#e6f3ff']
46
- cmap = mcolors.LinearSegmentedColormap.from_list('my_colormap', colors)
47
- sns.set_theme(style="whitegrid")
48
-
49
- # Generate the plot
50
- fig, ax = plt.subplots(figsize=(10, 6))
51
-
52
-
53
  # Select a specific sample to plot
54
- ind = 30 # Example index
55
 
56
  samples = np.random.normal(
57
  loc=forecasts[ind, :], # Mean (center) of the distribution
58
  scale=0.1, # Standard deviation (spread) of the distribution
59
  size=(forecasts.shape[1], forecasts.shape[2])
60
  )
61
- #samples = samples.reshape(samples.shape[0], samples.shape[1], -1)
62
- #print ("samples",samples.shape)
 
 
63
 
64
  # Plot predictive distribution
65
  for point in range(samples.shape[0]):
@@ -67,38 +51,64 @@ def plot_forecast(forecasts: np.ndarray, scalers: Any, dataset_test_glufo: Any,
67
  maxi, mini = 1.2 * np.max(samples[point, :]), 0.8 * np.min(samples[point, :])
68
  y_grid = np.linspace(mini, maxi, 200)
69
  x = kde(y_grid)
70
- ax.fill_betweenx(y_grid, x1=point, x2=point - x * 15,
71
- alpha=0.7,
72
- edgecolor='black',
73
- color=cmap(point / samples.shape[0]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # Plot median
76
  forecast = samples[:, :]
77
  median = np.quantile(forecast, 0.5, axis=-1)
78
- ax.plot(np.arange(12), median, color='red', marker='o')
79
-
80
- # Plot true values
81
- ax.plot(np.arange(-12, 12), np.concatenate([inputs[ind, -12:], trues[ind, :]]), color='blue')
82
-
83
- # Add labels and title
84
- ax.set_xlabel('Time (in 5 minute intervals)')
85
- ax.set_ylabel('Glucose (mg/dL)')
86
- ax.set_title(f'Gluformer Prediction with Gradient for dateset')
87
 
88
- # Adjust font sizes
89
- ax.xaxis.label.set_fontsize(16)
90
- ax.yaxis.label.set_fontsize(16)
91
- ax.title.set_fontsize(18)
92
- for item in ax.get_xticklabels() + ax.get_yticklabels():
93
- item.set_fontsize(14)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  # Save figure
96
- plt.tight_layout()
97
- where = file_directory /filename
98
- plt.savefig(str(where), dpi=300, bbox_inches='tight')
99
-
100
- return where,ax
101
 
 
102
 
103
 
104
  def generate_filename_from_url(url: str, extension: str = "png") -> str:
@@ -120,18 +130,21 @@ def generate_filename_from_url(url: str, extension: str = "png") -> str:
120
 
121
 
122
 
123
- def predict_glucose_tool(url: str= 'https://huggingface.co/datasets/Livia-Zaharia/glucose_processed/blob/main/livia_mini.csv',
124
- model: str = 'https://huggingface.co/Livia-Zaharia/gluformer_models/blob/main/gluformer_1samples_10000epochs_10heads_32batch_geluactivation_livia_mini_weights.pth'
125
- ) -> Figure:
126
  """
127
  Function to predict future glucose of user. It receives URL with users csv. It will run an ML and will return URL with predictions that user can open on her own..
128
- :param url: of the csv file with glucose values
129
- :param model: model that is used to predict the glucose
130
  :param explain if it should give both url and explanation
131
  :param if the person is diabetic when doing prediction and explanation
132
  :return:
133
  """
134
 
 
 
 
 
 
135
  formatter, series, scalers = load_data(url=str(url), config_path=file_directory / "config.yaml", use_covs=True,
136
  cov_type='dual',
137
  use_static_covs=True)
@@ -141,7 +154,7 @@ def predict_glucose_tool(url: str= 'https://huggingface.co/datasets/Livia-Zahari
141
  formatter.params['gluformer'] = {
142
  'in_len': 96, # example input length, adjust as necessary
143
  'd_model': 512, # model dimension
144
- 'n_heads': 10, # number of attention heads##############################################################################
145
  'd_fcn': 1024, # fully connected layer dimension
146
  'num_enc_layers': 2, # number of encoder layers
147
  'num_dec_layers': 2, # number of decoder layers
@@ -166,11 +179,9 @@ def predict_glucose_tool(url: str= 'https://huggingface.co/datasets/Livia-Zahari
166
  num_dynamic_features=num_dynamic_features,
167
  num_static_features=num_static_features
168
  )
169
- weights = gr.Interface.load(model)
170
- assert f"weights for {model} should exist", weights.exists()
171
 
172
  device = "cuda" if torch.cuda.is_available() else "cpu"
173
- glufo.load_state_dict(torch.load(str(weights), map_location=torch.device(device), weights_only=False))
174
 
175
  # Define dataset for inference
176
  dataset_test_glufo = SamplingDatasetInferenceDual(
@@ -184,9 +195,9 @@ def predict_glucose_tool(url: str= 'https://huggingface.co/datasets/Livia-Zahari
184
 
185
  forecasts, _ = glufo.predict(
186
  dataset_test_glufo,
187
- batch_size=16,####################################################
188
  num_samples=10,
189
- device='cpu'
190
  )
191
  figure_path, result = plot_forecast(forecasts, scalers, dataset_test_glufo,filename)
192
 
 
3
  import pickle
4
  import gzip
5
  from pathlib import Path
 
 
6
  import numpy as np
 
 
 
7
  import torch
8
  from scipy import stats
 
9
  from gluformer.model import Gluformer
10
  from utils.darts_processing import *
11
  from utils.darts_dataset import *
 
 
12
  import hashlib
13
  from urllib.parse import urlparse
14
+ from huggingface_hub import hf_hub_download
15
+ import plotly.graph_objects as go
 
16
 
17
 
18
  glucose = Path(os.path.abspath(__file__)).parent.resolve()
 
20
 
21
 
22
  def plot_forecast(forecasts: np.ndarray, scalers: Any, dataset_test_glufo: Any, filename: str):
23
+
24
  forecasts = (forecasts - scalers['target'].min_) / scalers['target'].scale_
25
 
26
  trues = [dataset_test_glufo.evalsample(i) for i in range(len(dataset_test_glufo))]
 
32
  inputs = [dataset_test_glufo[i][0] for i in range(len(dataset_test_glufo))]
33
  inputs = (np.array(inputs) - scalers['target'].min_) / scalers['target'].scale_
34
 
 
 
 
 
 
 
 
 
 
35
  # Select a specific sample to plot
36
+ ind = 10 # Example index
37
 
38
  samples = np.random.normal(
39
  loc=forecasts[ind, :], # Mean (center) of the distribution
40
  scale=0.1, # Standard deviation (spread) of the distribution
41
  size=(forecasts.shape[1], forecasts.shape[2])
42
  )
43
+
44
+
45
+ # Create figure
46
+ fig = go.Figure()
47
 
48
  # Plot predictive distribution
49
  for point in range(samples.shape[0]):
 
51
  maxi, mini = 1.2 * np.max(samples[point, :]), 0.8 * np.min(samples[point, :])
52
  y_grid = np.linspace(mini, maxi, 200)
53
  x = kde(y_grid)
54
+
55
+ # Create gradient color
56
+ color = f'rgba(53, 138, 217, {(point + 1) / samples.shape[0]})'
57
+
58
+ # Add filled area
59
+ fig.add_trace(go.Scatter(
60
+ x=np.concatenate([np.full_like(y_grid, point), np.full_like(y_grid, point - x * 15)[::-1]]),
61
+ y=np.concatenate([y_grid, y_grid[::-1]]),
62
+ fill='tonexty',
63
+ fillcolor=color,
64
+ line=dict(color='rgba(0,0,0,0)'),
65
+ showlegend=False
66
+ ))
67
+
68
+
69
+ true_values = np.concatenate([inputs[ind, -12:], trues[ind, :]])
70
+ true_values_flat=true_values.flatten()
71
+
72
+ fig.add_trace(go.Scatter(
73
+ x=list(range(-12, 12)),
74
+ y=true_values_flat.tolist(), # Convert to list explicitly
75
+ mode='lines+markers',
76
+ line=dict(color='blue', width=2),
77
+ marker=dict(size=6),
78
+ name='True Values'
79
+ ))
80
 
81
  # Plot median
82
  forecast = samples[:, :]
83
  median = np.quantile(forecast, 0.5, axis=-1)
 
 
 
 
 
 
 
 
 
84
 
85
+ fig.add_trace(go.Scatter(
86
+ x=list(range(12)),
87
+ y=median.tolist(), # Convert to list explicitly
88
+ mode='lines+markers',
89
+ line=dict(color='red', width=2),
90
+ marker=dict(size=8),
91
+ name='Median Forecast'
92
+ ))
93
+
94
+
95
+ # Update layout
96
+ fig.update_layout(
97
+ title='Gluformer Prediction with Gradient for dataset',
98
+ xaxis_title='Time (in 5 minute intervals)',
99
+ yaxis_title='Glucose (mg/dL)',
100
+ font=dict(size=14),
101
+ showlegend=True,
102
+ width=1000,
103
+ height=600
104
+ )
105
 
106
  # Save figure
107
+ where = file_directory / filename
108
+ fig.write_html(str(where.with_suffix('.html')))
109
+ fig.write_image(str(where))
 
 
110
 
111
+ return where, fig
112
 
113
 
114
  def generate_filename_from_url(url: str, extension: str = "png") -> str:
 
130
 
131
 
132
 
133
+ def predict_glucose_tool(file) -> go.Figure:
 
 
134
  """
135
  Function to predict future glucose of user. It receives URL with users csv. It will run an ML and will return URL with predictions that user can open on her own..
136
+ :param file: it is the csv file imported as a string path to the temporary location gradio allows
137
+ :param model: model that is used to predict the glucose- was hardcoded
138
  :param explain if it should give both url and explanation
139
  :param if the person is diabetic when doing prediction and explanation
140
  :return:
141
  """
142
 
143
+ url = file
144
+ model="Livia-Zaharia/gluformer_models"
145
+ model_path = hf_hub_download(repo_id= model, filename="gluformer_1samples_10000epochs_10heads_32batch_geluactivation_livia_mini_weights.pth")
146
+
147
+
148
  formatter, series, scalers = load_data(url=str(url), config_path=file_directory / "config.yaml", use_covs=True,
149
  cov_type='dual',
150
  use_static_covs=True)
 
154
  formatter.params['gluformer'] = {
155
  'in_len': 96, # example input length, adjust as necessary
156
  'd_model': 512, # model dimension
157
+ 'n_heads': 10, # number of attention heads########################
158
  'd_fcn': 1024, # fully connected layer dimension
159
  'num_enc_layers': 2, # number of encoder layers
160
  'num_dec_layers': 2, # number of decoder layers
 
179
  num_dynamic_features=num_dynamic_features,
180
  num_static_features=num_static_features
181
  )
 
 
182
 
183
  device = "cuda" if torch.cuda.is_available() else "cpu"
184
+ glufo.load_state_dict(torch.load(str(model_path), map_location=torch.device(device), weights_only=True))
185
 
186
  # Define dataset for inference
187
  dataset_test_glufo = SamplingDatasetInferenceDual(
 
195
 
196
  forecasts, _ = glufo.predict(
197
  dataset_test_glufo,
198
+ batch_size=16,#######
199
  num_samples=10,
200
+ device=device
201
  )
202
  figure_path, result = plot_forecast(forecasts, scalers, dataset_test_glufo,filename)
203
 
utils/__pycache__/darts_processing.cpython-311.pyc DELETED
Binary file (17.2 kB)
 
utils/darts_processing.py CHANGED
@@ -165,7 +165,6 @@ def load_data(url: str,
165
  config["data_csv_path"] = url
166
 
167
  formatter = DataFormatter(config)
168
- #assert dataset is not None, 'dataset must be specified in the load_data call'
169
  assert use_covs is not None, 'use_covs must be specified in the load_data call'
170
 
171
  # convert to series
 
165
  config["data_csv_path"] = url
166
 
167
  formatter = DataFormatter(config)
 
168
  assert use_covs is not None, 'use_covs must be specified in the load_data call'
169
 
170
  # convert to series