Spaces:
Runtime error
Runtime error
Commit
·
7b08294
1
Parent(s):
dabc757
remove da target redsfhit
Browse files- temps/temps.py +7 -15
temps/temps.py
CHANGED
@@ -31,17 +31,15 @@ class Temps_module():
|
|
31 |
|
32 |
|
33 |
|
34 |
-
def _get_dataloaders(self, input_data, target_data, input_data_DA,
|
35 |
input_data = torch.Tensor(input_data)
|
36 |
target_data = torch.Tensor(target_data)
|
37 |
if input_data_DA is not None:
|
38 |
input_data_DA = torch.Tensor(input_data_DA)
|
39 |
-
target_data_DA = torch.Tensor(target_data_DA)
|
40 |
else:
|
41 |
input_data_DA = input_data.clone()
|
42 |
-
target_data_DA = target_data.clone()
|
43 |
|
44 |
-
dataset = TensorDataset(input_data, input_data_DA, target_data
|
45 |
trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
|
46 |
loader_train = DataLoader(trainig_dataset, batch_size=self.batch_size, shuffle = True)
|
47 |
loader_val = DataLoader(val_dataset, batch_size=64, shuffle = True)
|
@@ -76,7 +74,6 @@ class Temps_module():
|
|
76 |
def train(self,input_data,
|
77 |
input_data_DA,
|
78 |
target_data,
|
79 |
-
target_data_DA,
|
80 |
nepochs=10,
|
81 |
step_size = 100,
|
82 |
val_fraction=0.1,
|
@@ -85,7 +82,7 @@ class Temps_module():
|
|
85 |
self.modelZ = self.modelZ.train()
|
86 |
self.modelF = self.modelF.train()
|
87 |
|
88 |
-
loader_train, loader_val = self._get_dataloaders(input_data, target_data, input_data_DA,
|
89 |
optimizerZ = optim.Adam(self.modelZ.parameters(), lr=lr, weight_decay=weight_decay)
|
90 |
optimizerF = optim.Adam(self.modelF.parameters(), lr=lr, weight_decay=weight_decay)
|
91 |
|
@@ -98,7 +95,7 @@ class Temps_module():
|
|
98 |
self.loss_train, self.loss_validation = [],[]
|
99 |
|
100 |
for epoch in range(nepochs):
|
101 |
-
for input_data, input_data_da, target_data
|
102 |
_loss_train, _loss_validation = [],[]
|
103 |
|
104 |
input_data = input_data.to(self.device)
|
@@ -106,8 +103,8 @@ class Temps_module():
|
|
106 |
|
107 |
if self.da:
|
108 |
input_data_da = input_data_da.to(self.device)
|
109 |
-
target_data_DA = target_data_DA.to(self.device)
|
110 |
|
|
|
111 |
optimizerF.zero_grad()
|
112 |
optimizerZ.zero_grad()
|
113 |
|
@@ -120,12 +117,7 @@ class Temps_module():
|
|
120 |
sig = torch.exp(logsig)
|
121 |
|
122 |
lossZ = self._loss_function(mu, sig, logmix_coeff, target_data)
|
123 |
-
|
124 |
-
#mu, logsig, logmix_coeff = self.modelZ(features_DA)
|
125 |
-
#logsig = torch.clamp(logsig,-6,2)
|
126 |
-
#sig = torch.exp(logsig)
|
127 |
-
|
128 |
-
#lossZ_DA = self._loss_function(mu, sig, logmix_coeff, target_data_DA)
|
129 |
|
130 |
if self.da:
|
131 |
lossDA = maximum_mean_discrepancy(features, features_DA, kernel_type='rbf')
|
@@ -145,7 +137,7 @@ class Temps_module():
|
|
145 |
|
146 |
self.loss_train.append(np.mean(_loss_train))
|
147 |
|
148 |
-
for input_data, _, target_data
|
149 |
|
150 |
input_data = input_data.to(self.device)
|
151 |
target_data = target_data.to(self.device)
|
|
|
31 |
|
32 |
|
33 |
|
34 |
+
def _get_dataloaders(self, input_data, target_data, input_data_DA, val_fraction=0.1):
|
35 |
input_data = torch.Tensor(input_data)
|
36 |
target_data = torch.Tensor(target_data)
|
37 |
if input_data_DA is not None:
|
38 |
input_data_DA = torch.Tensor(input_data_DA)
|
|
|
39 |
else:
|
40 |
input_data_DA = input_data.clone()
|
|
|
41 |
|
42 |
+
dataset = TensorDataset(input_data, input_data_DA, target_data)
|
43 |
trainig_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-val_fraction)), int(len(dataset)*val_fraction)+1])
|
44 |
loader_train = DataLoader(trainig_dataset, batch_size=self.batch_size, shuffle = True)
|
45 |
loader_val = DataLoader(val_dataset, batch_size=64, shuffle = True)
|
|
|
74 |
def train(self,input_data,
|
75 |
input_data_DA,
|
76 |
target_data,
|
|
|
77 |
nepochs=10,
|
78 |
step_size = 100,
|
79 |
val_fraction=0.1,
|
|
|
82 |
self.modelZ = self.modelZ.train()
|
83 |
self.modelF = self.modelF.train()
|
84 |
|
85 |
+
loader_train, loader_val = self._get_dataloaders(input_data, target_data, input_data_DA, val_fraction=0.1)
|
86 |
optimizerZ = optim.Adam(self.modelZ.parameters(), lr=lr, weight_decay=weight_decay)
|
87 |
optimizerF = optim.Adam(self.modelF.parameters(), lr=lr, weight_decay=weight_decay)
|
88 |
|
|
|
95 |
self.loss_train, self.loss_validation = [],[]
|
96 |
|
97 |
for epoch in range(nepochs):
|
98 |
+
for input_data, input_data_da, target_data in loader_train:
|
99 |
_loss_train, _loss_validation = [],[]
|
100 |
|
101 |
input_data = input_data.to(self.device)
|
|
|
103 |
|
104 |
if self.da:
|
105 |
input_data_da = input_data_da.to(self.device)
|
|
|
106 |
|
107 |
+
|
108 |
optimizerF.zero_grad()
|
109 |
optimizerZ.zero_grad()
|
110 |
|
|
|
117 |
sig = torch.exp(logsig)
|
118 |
|
119 |
lossZ = self._loss_function(mu, sig, logmix_coeff, target_data)
|
120 |
+
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
if self.da:
|
123 |
lossDA = maximum_mean_discrepancy(features, features_DA, kernel_type='rbf')
|
|
|
137 |
|
138 |
self.loss_train.append(np.mean(_loss_train))
|
139 |
|
140 |
+
for input_data, _, target_data in loader_val:
|
141 |
|
142 |
input_data = input_data.to(self.device)
|
143 |
target_data = target_data.to(self.device)
|