Update model_helper.py
Browse files- model_helper.py +3 -3
model_helper.py
CHANGED
@@ -22,7 +22,7 @@ from model.ymt3 import YourMT3
|
|
22 |
|
23 |
|
24 |
|
25 |
-
def load_model_checkpoint(args=None):
|
26 |
parser = argparse.ArgumentParser(description="YourMT3")
|
27 |
# General
|
28 |
parser.add_argument('exp_id', type=str, help='A unique identifier for the experiment is used to resume training. The "@" symbol can be used to load a specific checkpoint.')
|
@@ -104,7 +104,7 @@ def load_model_checkpoint(args=None):
|
|
104 |
print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}")
|
105 |
|
106 |
# Use GPU if available
|
107 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
108 |
|
109 |
# Model
|
110 |
model = YourMT3(
|
@@ -120,7 +120,7 @@ def load_model_checkpoint(args=None):
|
|
120 |
state_dict = checkpoint['state_dict']
|
121 |
new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
|
122 |
model.load_state_dict(new_state_dict, strict=False)
|
123 |
-
return model.eval()
|
124 |
|
125 |
|
126 |
def transcribe(model, audio_info):
|
|
|
22 |
|
23 |
|
24 |
|
25 |
+
def load_model_checkpoint(args=None, device='cpu'):
|
26 |
parser = argparse.ArgumentParser(description="YourMT3")
|
27 |
# General
|
28 |
parser.add_argument('exp_id', type=str, help='A unique identifier for the experiment is used to resume training. The "@" symbol can be used to load a specific checkpoint.')
|
|
|
104 |
print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}")
|
105 |
|
106 |
# Use GPU if available
|
107 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
108 |
|
109 |
# Model
|
110 |
model = YourMT3(
|
|
|
120 |
state_dict = checkpoint['state_dict']
|
121 |
new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
|
122 |
model.load_state_dict(new_state_dict, strict=False)
|
123 |
+
return model.eval() # load checkpoint on cpu first
|
124 |
|
125 |
|
126 |
def transcribe(model, audio_info):
|