ZeroRVC / prelude.py
JacobLinCool's picture
feat: infer
3a010aa
raw
history blame
938 Bytes
import os
def prelude():
os.environ["PYTORCH_JIT"] = "0v"
# patch for jit script
# if we find `def expand_2d_or_3d_tensor(x,` in /usr/local/lib/python3.10/site-packages/fairseq/models/model_utils.py
# patch it with `def expand_2d_or_3d_tensor(x: Tensor,`
FAIRSEQ_CODE = (
"/usr/local/lib/python3.10/site-packages/fairseq/models/model_utils.py"
)
if os.path.exists(FAIRSEQ_CODE):
with open(FAIRSEQ_CODE, "r") as f:
lines = f.readlines()
with open(FAIRSEQ_CODE, "w") as f:
for line in lines:
if (
"def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):"
in line
):
f.write(
"def expand_2d_or_3d_tensor(x: Tensor, trg_dim: int, padding_idx: int) -> Tensor:\n"
)
else:
f.write(line)