reed
commited on
Commit
•
0e24941
1
Parent(s):
49d999c
use packaging.version
Browse files- modeling_yi.py +4 -1
modeling_yi.py
CHANGED
@@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union
|
|
4 |
|
5 |
import torch.utils.checkpoint
|
6 |
from einops import repeat
|
|
|
7 |
from torch import nn
|
8 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
9 |
from transformers.activations import ACT2FN
|
@@ -32,7 +33,9 @@ except Exception:
|
|
32 |
if is_flash_attn_available:
|
33 |
from flash_attn import __version__
|
34 |
|
35 |
-
assert __version__ >=
|
|
|
|
|
36 |
|
37 |
logger = logging.get_logger(__name__)
|
38 |
|
|
|
4 |
|
5 |
import torch.utils.checkpoint
|
6 |
from einops import repeat
|
7 |
+
from packaging import version
|
8 |
from torch import nn
|
9 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
10 |
from transformers.activations import ACT2FN
|
|
|
33 |
if is_flash_attn_available:
|
34 |
from flash_attn import __version__
|
35 |
|
36 |
+
assert version.parse(__version__) >= version.parse(
|
37 |
+
"2.3.0"
|
38 |
+
), "please update your flash_attn version (>= 2.3.0)"
|
39 |
|
40 |
logger = logging.get_logger(__name__)
|
41 |
|