|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
import numpy as np
|
|
from math import *
|
|
|
|
import neuromancer.psl as psl
|
|
from neuromancer.system import Node, System
|
|
from neuromancer.modules import blocks
|
|
from neuromancer.dataset import DictDataset
|
|
from neuromancer.constraint import variable
|
|
from neuromancer.loss import PenaltyLoss
|
|
from neuromancer.problem import Problem
|
|
from neuromancer.trainer import Trainer
|
|
from neuromancer.plot import pltCL
|
|
|
|
from dataloader import get_data
|
|
|
|
|
|
|
|
|
|
sys = psl.systems['LinearSimpleSingleZone']()
|
|
|
|
|
|
nx = sys.nx
|
|
nu = sys.nu
|
|
nd = sys.nD
|
|
nd_obs = sys.nD_obs
|
|
ny = sys.ny
|
|
nref = ny
|
|
|
|
|
|
umin = torch.tensor(sys.umin)
|
|
umax = torch.tensor(sys.umax)
|
|
|
|
|
|
nsteps = 100
|
|
n_samples = 2000
|
|
batch_size = 64
|
|
|
|
|
|
xmin_range = torch.distributions.Uniform(18., 20.)
|
|
|
|
train_loader, dev_loader = [
|
|
get_data(sys, nsteps, n_samples, xmin_range, batch_size, name=name)
|
|
for name in ("train", "dev")
|
|
]
|
|
|
|
|
|
|
|
A = torch.tensor(sys.A)
|
|
B = torch.tensor(sys.Beta)
|
|
C = torch.tensor(sys.C)
|
|
E = torch.tensor(sys.E)
|
|
|
|
|
|
|
|
|
|
xnext = lambda x, u, d: x @ A.T + u @ B.T + d @ E.T
|
|
state_model = Node(xnext, ['x', 'u', 'd'], ['x'], name='SSM')
|
|
|
|
|
|
ynext = lambda x: x @ C.T
|
|
output_model = Node(ynext, ['x'], ['y'], name='y=Cx')
|
|
|
|
|
|
dist_model = lambda d: d[:, sys.d_idx]
|
|
patient_cond_change = Node(dist_model, ['d'], ['patient_obs'], name='patient_cond_change')
|
|
|
|
|
|
net = blocks.MLP_bounds(
|
|
insize=ny + 2*nref + nd_obs,
|
|
outsize=nu,
|
|
hsizes=[32, 32],
|
|
nonlin=nn.GELU,
|
|
min=umin,
|
|
max=umax,
|
|
)
|
|
policy = Node(net, ['y', 'ymin', 'ymax', 'patient_obs'], ['u'], name='policy')
|
|
|
|
|
|
closed_loop_system = System([patient_cond_change, policy, state_model, output_model],
|
|
nsteps=nsteps,
|
|
name='closed_loop_system')
|
|
closed_loop_system.show()
|
|
|
|
|
|
|
|
y = variable('y')
|
|
u = variable('u')
|
|
ymin = variable('ymin')
|
|
ymax = variable('ymax')
|
|
|
|
|
|
action_loss = 0.01 * (u == 0.0)
|
|
du_loss = 0.1 * (u[:,:-1,:] - u[:,1:,:] == 0.0)
|
|
action_limit_loss = 0.02 * (abs(u[:, 1:, :] - u[:, :-1, :])==0.0)
|
|
|
|
|
|
state_lower_bound_penalty = 50.*(y > ymin)
|
|
state_upper_bound_penalty = 50.*(y < ymax)
|
|
|
|
|
|
action_loss.name = 'control_loss'
|
|
du_loss.name = 'regularization_loss'
|
|
action_limit_loss.name = 'insulin_constraint_loss'
|
|
state_lower_bound_penalty.name = 'x_min'
|
|
state_upper_bound_penalty.name = 'x_max'
|
|
|
|
|
|
objectives = [action_loss, du_loss, action_limit_loss]
|
|
constraints = [state_lower_bound_penalty, state_upper_bound_penalty]
|
|
|
|
|
|
nodes = [closed_loop_system]
|
|
|
|
loss = PenaltyLoss(objectives, constraints)
|
|
|
|
problem = Problem(nodes, loss)
|
|
|
|
problem.show()
|
|
|
|
|
|
optimizer = torch.optim.AdamW(problem.parameters(), lr=0.001)
|
|
|
|
trainer = Trainer(
|
|
problem,
|
|
train_loader,
|
|
dev_loader,
|
|
optimizer=optimizer,
|
|
epochs=200,
|
|
train_metric='train_loss',
|
|
eval_metric='dev_loss',
|
|
warmup=50,
|
|
)
|
|
|
|
|
|
|
|
best_model = trainer.train()
|
|
|
|
trainer.model.load_state_dict(best_model)
|
|
|
|
|
|
|
|
|
|
nsteps_test = 3000
|
|
|
|
|
|
np_refs = psl.signals.step(nsteps_test+1, 1, min=18, max=21, randsteps=8)
|
|
ymin_val = torch.tensor(np_refs, dtype=torch.float32).reshape(1, nsteps_test+1, 1)
|
|
ymax_val = ymin_val+6.0
|
|
|
|
torch_dist = torch.tensor(sys.get_D(nsteps_test+1)).unsqueeze(0)
|
|
|
|
x0 = torch.tensor(sys.get_x0()).reshape(1, 1, nx)
|
|
data = {'x': x0,
|
|
'y': x0[:, :, [-1]],
|
|
'ymin': ymin_val,
|
|
'ymax': ymax_val,
|
|
'd': torch_dist}
|
|
closed_loop_system.nsteps = nsteps_test
|
|
|
|
trajectories = closed_loop_system(data)
|
|
|
|
|
|
Umin = umin * np.ones([nsteps_test, nu])
|
|
Umax = umax * np.ones([nsteps_test, nu])
|
|
Ymin = trajectories['ymin'].detach().reshape(nsteps_test+1, nref)
|
|
Ymax = trajectories['ymax'].detach().reshape(nsteps_test+1, nref)
|
|
|
|
pltCL(Y=trajectories['y'].detach().reshape(nsteps_test+1, ny),
|
|
R=Ymax,
|
|
X=trajectories['x'].detach().reshape(nsteps_test+1, nx),
|
|
D=trajectories['d'].detach().reshape(nsteps_test+1, nd),
|
|
U=trajectories['u'].detach().reshape(nsteps_test, nu),
|
|
Umin=Umin, Umax=Umax, Ymin=Ymin, Ymax=Ymax) |