reed commited on
Commit
0e24941
1 Parent(s): 49d999c

use packaging.version

Browse files
Files changed (1) hide show
  1. modeling_yi.py +4 -1
modeling_yi.py CHANGED
@@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union
4
 
5
  import torch.utils.checkpoint
6
  from einops import repeat
 
7
  from torch import nn
8
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
9
  from transformers.activations import ACT2FN
@@ -32,7 +33,9 @@ except Exception:
32
  if is_flash_attn_available:
33
  from flash_attn import __version__
34
 
35
- assert __version__ >= "2.3.0", "please update your flash_attn version (>= 2.3.0)"
 
 
36
 
37
  logger = logging.get_logger(__name__)
38
 
 
4
 
5
  import torch.utils.checkpoint
6
  from einops import repeat
7
+ from packaging import version
8
  from torch import nn
9
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10
  from transformers.activations import ACT2FN
 
33
  if is_flash_attn_available:
34
  from flash_attn import __version__
35
 
36
+ assert version.parse(__version__) >= version.parse(
37
+ "2.3.0"
38
+ ), "please update your flash_attn version (>= 2.3.0)"
39
 
40
  logger = logging.get_logger(__name__)
41