ms180 commited on
Commit
2f2a5dd
1 Parent(s): 57b1f30

Update finetune.py

Browse files
Files changed (1) hide show
  1. finetune.py +5 -6
finetune.py CHANGED
@@ -61,7 +61,7 @@ def get_dataset(data_path, data_info, test_count=10):
61
 
62
 
63
  class CustomFinetuneModel(ESPnetS2TModel):
64
- def __init__(self, model, log_every=500):
65
  super().__init__(
66
  vocab_size=model.vocab_size,
67
  token_list=model.token_list,
@@ -94,6 +94,7 @@ class CustomFinetuneModel(ESPnetS2TModel):
94
  'loss': 0.0,
95
  'acc': 0.0
96
  }
 
97
 
98
  def forward(self, *args, **kwargs):
99
  out = super().forward(*args, **kwargs)
@@ -104,7 +105,7 @@ class CustomFinetuneModel(ESPnetS2TModel):
104
  if self.iter_count % self.log_every == 0:
105
  loss = self.log_stats['loss'] / self.log_every
106
  acc = self.log_stats['acc'] / self.log_every
107
- print(f"[{self.iter_count}] - loss: {loss:.3f} - acc: {acc:.3f}")
108
  self.log_stats['loss'] = 0.0
109
  self.log_stats['acc'] = 0.0
110
 
@@ -159,8 +160,6 @@ def finetune_model(lang, task, tempdir_path, log_every, max_epoch, scheduler, wa
159
  finetune_config['scheduler_conf']['warmup_steps'] = warmup_steps
160
  finetune_config['multiple_iterator'] = False
161
  finetune_config['num_iters_per_epoch'] = None
162
- finetune_config['multiprocessing_distributed'] = False
163
- finetune_config['distributed'] = False
164
 
165
  def build_model_fn(args):
166
  model, _ = S2TTask.build_model_from_file(
@@ -169,8 +168,8 @@ def finetune_model(lang, task, tempdir_path, log_every, max_epoch, scheduler, wa
169
  device="cuda" if torch.cuda.is_available() else "cpu",
170
  )
171
  model.train()
172
- print(f'Trainable parameters: {count_parameters(model)}')
173
- model = CustomFinetuneModel(model, log_every=log_every)
174
  return model
175
 
176
  trainer = ez.Trainer(
 
61
 
62
 
63
  class CustomFinetuneModel(ESPnetS2TModel):
64
+ def __init__(self, model, tempdir_path, log_every=500):
65
  super().__init__(
66
  vocab_size=model.vocab_size,
67
  token_list=model.token_list,
 
94
  'loss': 0.0,
95
  'acc': 0.0
96
  }
97
+ self.tempdir_path = tempdir_path
98
 
99
  def forward(self, *args, **kwargs):
100
  out = super().forward(*args, **kwargs)
 
105
  if self.iter_count % self.log_every == 0:
106
  loss = self.log_stats['loss'] / self.log_every
107
  acc = self.log_stats['acc'] / self.log_every
108
+ log(self.tempdir_path, f"[{self.iter_count}] - loss: {loss:.3f} - acc: {acc:.3f}")
109
  self.log_stats['loss'] = 0.0
110
  self.log_stats['acc'] = 0.0
111
 
 
160
  finetune_config['scheduler_conf']['warmup_steps'] = warmup_steps
161
  finetune_config['multiple_iterator'] = False
162
  finetune_config['num_iters_per_epoch'] = None
 
 
163
 
164
  def build_model_fn(args):
165
  model, _ = S2TTask.build_model_from_file(
 
168
  device="cuda" if torch.cuda.is_available() else "cpu",
169
  )
170
  model.train()
171
+ log(tempdir_path, f'Trainable parameters: {count_parameters(model)}')
172
+ model = CustomFinetuneModel(model, tempdir_path, log_every=log_every)
173
  return model
174
 
175
  trainer = ez.Trainer(