robinzixuan
commited on
Update modeling_opt.py
Browse files- modeling_opt.py +3 -2
modeling_opt.py
CHANGED
@@ -41,6 +41,7 @@ from transformers.utils import (
|
|
41 |
is_flash_attn_greater_or_equal_2_10,
|
42 |
logging,
|
43 |
replace_return_docstrings,
|
|
|
44 |
)
|
45 |
from .configuration_opt import OPTConfig
|
46 |
|
@@ -294,8 +295,8 @@ class OPTAttention(nn.Module):
|
|
294 |
|
295 |
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
296 |
raise ValueError(
|
297 |
-
f
|
298 |
-
(bsz, self.num_heads, tgt_len, self.head_dim)}, but is
|
299 |
f" {attn_output.size()}"
|
300 |
)
|
301 |
|
|
|
41 |
is_flash_attn_greater_or_equal_2_10,
|
42 |
logging,
|
43 |
replace_return_docstrings,
|
44 |
+
|
45 |
)
|
46 |
from .configuration_opt import OPTConfig
|
47 |
|
|
|
295 |
|
296 |
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
297 |
raise ValueError(
|
298 |
+
f'''`attn_output` should be of size {
|
299 |
+
(bsz, self.num_heads, tgt_len, self.head_dim)}, but is'''
|
300 |
f" {attn_output.size()}"
|
301 |
)
|
302 |
|