Spaces:
Sleeping
Sleeping
Update finetune.py
Browse files- finetune.py +5 -6
finetune.py
CHANGED
@@ -61,7 +61,7 @@ def get_dataset(data_path, data_info, test_count=10):
|
|
61 |
|
62 |
|
63 |
class CustomFinetuneModel(ESPnetS2TModel):
|
64 |
-
def __init__(self, model, log_every=500):
|
65 |
super().__init__(
|
66 |
vocab_size=model.vocab_size,
|
67 |
token_list=model.token_list,
|
@@ -94,6 +94,7 @@ class CustomFinetuneModel(ESPnetS2TModel):
|
|
94 |
'loss': 0.0,
|
95 |
'acc': 0.0
|
96 |
}
|
|
|
97 |
|
98 |
def forward(self, *args, **kwargs):
|
99 |
out = super().forward(*args, **kwargs)
|
@@ -104,7 +105,7 @@ class CustomFinetuneModel(ESPnetS2TModel):
|
|
104 |
if self.iter_count % self.log_every == 0:
|
105 |
loss = self.log_stats['loss'] / self.log_every
|
106 |
acc = self.log_stats['acc'] / self.log_every
|
107 |
-
|
108 |
self.log_stats['loss'] = 0.0
|
109 |
self.log_stats['acc'] = 0.0
|
110 |
|
@@ -159,8 +160,6 @@ def finetune_model(lang, task, tempdir_path, log_every, max_epoch, scheduler, wa
|
|
159 |
finetune_config['scheduler_conf']['warmup_steps'] = warmup_steps
|
160 |
finetune_config['multiple_iterator'] = False
|
161 |
finetune_config['num_iters_per_epoch'] = None
|
162 |
-
finetune_config['multiprocessing_distributed'] = False
|
163 |
-
finetune_config['distributed'] = False
|
164 |
|
165 |
def build_model_fn(args):
|
166 |
model, _ = S2TTask.build_model_from_file(
|
@@ -169,8 +168,8 @@ def finetune_model(lang, task, tempdir_path, log_every, max_epoch, scheduler, wa
|
|
169 |
device="cuda" if torch.cuda.is_available() else "cpu",
|
170 |
)
|
171 |
model.train()
|
172 |
-
|
173 |
-
model = CustomFinetuneModel(model, log_every=log_every)
|
174 |
return model
|
175 |
|
176 |
trainer = ez.Trainer(
|
|
|
61 |
|
62 |
|
63 |
class CustomFinetuneModel(ESPnetS2TModel):
|
64 |
+
def __init__(self, model, tempdir_path, log_every=500):
|
65 |
super().__init__(
|
66 |
vocab_size=model.vocab_size,
|
67 |
token_list=model.token_list,
|
|
|
94 |
'loss': 0.0,
|
95 |
'acc': 0.0
|
96 |
}
|
97 |
+
self.tempdir_path = tempdir_path
|
98 |
|
99 |
def forward(self, *args, **kwargs):
|
100 |
out = super().forward(*args, **kwargs)
|
|
|
105 |
if self.iter_count % self.log_every == 0:
|
106 |
loss = self.log_stats['loss'] / self.log_every
|
107 |
acc = self.log_stats['acc'] / self.log_every
|
108 |
+
log(self.tempdir_path, f"[{self.iter_count}] - loss: {loss:.3f} - acc: {acc:.3f}")
|
109 |
self.log_stats['loss'] = 0.0
|
110 |
self.log_stats['acc'] = 0.0
|
111 |
|
|
|
160 |
finetune_config['scheduler_conf']['warmup_steps'] = warmup_steps
|
161 |
finetune_config['multiple_iterator'] = False
|
162 |
finetune_config['num_iters_per_epoch'] = None
|
|
|
|
|
163 |
|
164 |
def build_model_fn(args):
|
165 |
model, _ = S2TTask.build_model_from_file(
|
|
|
168 |
device="cuda" if torch.cuda.is_available() else "cpu",
|
169 |
)
|
170 |
model.train()
|
171 |
+
log(tempdir_path, f'Trainable parameters: {count_parameters(model)}')
|
172 |
+
model = CustomFinetuneModel(model, tempdir_path, log_every=log_every)
|
173 |
return model
|
174 |
|
175 |
trainer = ez.Trainer(
|