Spaces:
Runtime error
Runtime error
Hugo Flores Garcia
commited on
Commit
·
8ad07e8
1
Parent(s):
1f1ec83
fix temp again
Browse files- scripts/exp/fine_tune.py +1 -0
- vampnet/interface.py +1 -1
- vampnet/modules/transformer.py +2 -2
scripts/exp/fine_tune.py
CHANGED
@@ -53,6 +53,7 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
|
|
53 |
|
54 |
"Interface.coarse2fine_ckpt": f"./models/vampnet/c2f.pth",
|
55 |
"Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
|
|
|
56 |
|
57 |
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
|
58 |
"AudioLoader.sources": [audio_files_or_folders],
|
|
|
53 |
|
54 |
"Interface.coarse2fine_ckpt": f"./models/vampnet/c2f.pth",
|
55 |
"Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
|
56 |
+
"Interface.wavebeat_ckpt": "./models/wavebeat.pth",
|
57 |
|
58 |
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
|
59 |
"AudioLoader.sources": [audio_files_or_folders],
|
vampnet/interface.py
CHANGED
@@ -65,7 +65,7 @@ class Interface(torch.nn.Module):
|
|
65 |
):
|
66 |
super().__init__()
|
67 |
assert codec_ckpt is not None, "must provide a codec checkpoint"
|
68 |
-
self.codec = DAC.load(codec_ckpt)
|
69 |
self.codec.eval()
|
70 |
self.codec.to(device)
|
71 |
|
|
|
65 |
):
|
66 |
super().__init__()
|
67 |
assert codec_ckpt is not None, "must provide a codec checkpoint"
|
68 |
+
self.codec = DAC.load(Path(codec_ckpt))
|
69 |
self.codec.eval()
|
70 |
self.codec.to(device)
|
71 |
|
vampnet/modules/transformer.py
CHANGED
@@ -581,7 +581,7 @@ class VampNet(at.ml.BaseModel):
|
|
581 |
sampling_steps: int = 24,
|
582 |
start_tokens: Optional[torch.Tensor] = None,
|
583 |
mask: Optional[torch.Tensor] = None,
|
584 |
-
temperature:
|
585 |
typical_filtering=False,
|
586 |
typical_mass=0.2,
|
587 |
typical_min_tokens=1,
|
@@ -592,7 +592,7 @@ class VampNet(at.ml.BaseModel):
|
|
592 |
#####################
|
593 |
# resolve temperature #
|
594 |
#####################
|
595 |
-
|
596 |
logging.debug(f"temperature: {temperature}")
|
597 |
|
598 |
|
|
|
581 |
sampling_steps: int = 24,
|
582 |
start_tokens: Optional[torch.Tensor] = None,
|
583 |
mask: Optional[torch.Tensor] = None,
|
584 |
+
temperature: float = 2.5,
|
585 |
typical_filtering=False,
|
586 |
typical_mass=0.2,
|
587 |
typical_min_tokens=1,
|
|
|
592 |
#####################
|
593 |
# resolve temperature #
|
594 |
#####################
|
595 |
+
|
596 |
logging.debug(f"temperature: {temperature}")
|
597 |
|
598 |
|