File size: 938 Bytes
3a010aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import os


def prelude():
    os.environ["PYTORCH_JIT"] = "0v"

    # patch for jit script
    # if we find `def expand_2d_or_3d_tensor(x,` in /usr/local/lib/python3.10/site-packages/fairseq/models/model_utils.py
    # patch it with `def expand_2d_or_3d_tensor(x: Tensor,`
    FAIRSEQ_CODE = (
        "/usr/local/lib/python3.10/site-packages/fairseq/models/model_utils.py"
    )
    if os.path.exists(FAIRSEQ_CODE):
        with open(FAIRSEQ_CODE, "r") as f:
            lines = f.readlines()
        with open(FAIRSEQ_CODE, "w") as f:
            for line in lines:
                if (
                    "def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):"
                    in line
                ):
                    f.write(
                        "def expand_2d_or_3d_tensor(x: Tensor, trg_dim: int, padding_idx: int) -> Tensor:\n"
                    )
                else:
                    f.write(line)