KaleiNeely commited on
Commit
373ffcc
·
verified ·
1 Parent(s): b5269d8

Update modeling_rwkv5.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv5.py +17 -158
modeling_rwkv5.py CHANGED
@@ -18,6 +18,7 @@ from dataclasses import dataclass
18
  from pathlib import Path
19
  from typing import List, Optional, Tuple, Union
20
 
 
21
  import torch
22
  import torch.nn.functional as F
23
  import torch.utils.checkpoint
@@ -36,6 +37,19 @@ from transformers.utils import (
36
  logging,
37
  )
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  from .configuration_rwkv5 import Rwkv5Config
40
 
41
 
@@ -44,155 +58,6 @@ logger = logging.get_logger(__name__)
44
  _CHECKPOINT_FOR_DOC = "RWKV/rwkv-5-world-1b5"
45
  _CONFIG_FOR_DOC = "Rwkv5Config"
46
 
47
- rwkv5_cuda_kernel = None
48
-
49
-
50
- # Copied from https://github.com/huggingface/transformers/blob/18cbaf13dcaca7145f5652aefb9b19734c56c3cd/src/transformers/models/rwkv/modeling_rwkv.py#L65
51
- def load_wkv5_cuda_kernel(head_size):
52
- from torch.utils.cpp_extension import load as load_kernel
53
-
54
- global rwkv5_cuda_kernel
55
-
56
- kernel_folder = Path(__file__).parent.resolve()
57
- cuda_kernel_files = [kernel_folder / f for f in ["wkv5_op.cpp", "wkv5_cuda.cu"]]
58
-
59
- # Only load the kernel if it's not been loaded yet or if we changed the context length
60
- if rwkv5_cuda_kernel is not None and rwkv5_cuda_kernel.head_size == head_size:
61
- return
62
-
63
- logger.info(f"Loading CUDA kernel for RWKV5 at head size of {head_size}.")
64
-
65
- flags = [
66
- "-res-usage",
67
- "--maxrregcount 60",
68
- "--use_fast_math",
69
- "-O3",
70
- "-Xptxas -O3",
71
- "--extra-device-vectorization",
72
- f"-D_N_={head_size}",
73
- ]
74
- rwkv5_cuda_kernel = load_kernel(
75
- name=f"wkv_{head_size}",
76
- sources=cuda_kernel_files,
77
- verbose=(logging.get_verbosity() == logging.DEBUG),
78
- extra_cuda_cflags=flags,
79
- )
80
- rwkv5_cuda_kernel.head_size = head_size
81
-
82
-
83
- class Rwkv5LinearAttention(torch.autograd.Function):
84
- @staticmethod
85
- def forward(ctx, receptance, key, value, time_decay, time_first, state):
86
- with torch.no_grad():
87
- assert receptance.dtype == torch.bfloat16
88
- assert key.dtype == torch.bfloat16
89
- assert value.dtype == torch.bfloat16
90
- assert time_decay.dtype == torch.bfloat16
91
- assert time_first.dtype == torch.bfloat16
92
- assert state.dtype == torch.float32
93
- batch, seq_length, hidden_size = key.shape
94
- num_heads = time_decay.shape[0]
95
- ctx.batch = batch
96
- ctx.seq_length = seq_length
97
- ctx.hidden_size = hidden_size
98
- ctx.num_heads = num_heads
99
- e_time_decay = (-torch.exp(time_decay.float())).contiguous()
100
- ee_time_decay = (torch.exp(e_time_decay)).contiguous()
101
- assert ee_time_decay.dtype == torch.float32
102
- ctx.save_for_backward(receptance, key, value, ee_time_decay, e_time_decay, time_first)
103
- out = torch.empty(
104
- (batch, seq_length, hidden_size),
105
- device=receptance.device,
106
- dtype=torch.bfloat16,
107
- memory_format=torch.contiguous_format,
108
- )
109
- state = state.clone()
110
- rwkv5_cuda_kernel.forward_bf16(
111
- batch,
112
- seq_length,
113
- hidden_size,
114
- num_heads,
115
- state,
116
- receptance,
117
- key,
118
- value,
119
- ee_time_decay,
120
- time_first,
121
- out,
122
- )
123
- return out, state
124
-
125
- @staticmethod
126
- def backward(ctx, gout):
127
- with torch.no_grad():
128
- assert gout.dtype == torch.bfloat16
129
- batch = ctx.batch
130
- seq_length = ctx.seq_length
131
- hidden_size = ctx.hidden_size
132
- num_heads = ctx.num_heads
133
- receptance, key, value, ee_time_decay, e_time_decay, time_first = ctx.saved_tensors
134
-
135
- global_shape = (batch, seq_length, hidden_size)
136
-
137
- # TODO dtype should not be forced here IMO
138
- greceptance = torch.empty(
139
- global_shape,
140
- device=gout.device,
141
- requires_grad=False,
142
- dtype=torch.bfloat16,
143
- memory_format=torch.contiguous_format,
144
- )
145
- g_key = torch.empty(
146
- global_shape,
147
- device=gout.device,
148
- requires_grad=False,
149
- dtype=torch.bfloat16,
150
- memory_format=torch.contiguous_format,
151
- )
152
- g_value = torch.empty(
153
- global_shape,
154
- device=gout.device,
155
- requires_grad=False,
156
- dtype=torch.bfloat16,
157
- memory_format=torch.contiguous_format,
158
- )
159
- g_time_decay = torch.empty(
160
- (batch, hidden_size),
161
- device=gout.device,
162
- requires_grad=False,
163
- dtype=torch.bfloat16,
164
- memory_format=torch.contiguous_format,
165
- )
166
- g_time_first = torch.empty(
167
- (batch, hidden_size),
168
- device=gout.device,
169
- requires_grad=False,
170
- dtype=torch.bfloat16,
171
- memory_format=torch.contiguous_format,
172
- )
173
- rwkv5_cuda_kernel.backward_bf16(
174
- batch,
175
- seq_length,
176
- hidden_size,
177
- num_heads,
178
- receptance,
179
- key,
180
- value,
181
- ee_time_decay,
182
- e_time_decay,
183
- time_first,
184
- gout,
185
- greceptance,
186
- g_key,
187
- g_value,
188
- g_time_decay,
189
- g_time_first,
190
- )
191
- head_size = hidden_size // num_heads
192
- g_time_decay = torch.sum(g_time_decay, 0).view(num_heads, head_size)
193
- g_time_first = torch.sum(g_time_first, 0).view(num_heads, head_size)
194
- return (None, None, None, None, greceptance, g_key, g_value, g_time_decay, g_time_first)
195
-
196
 
197
  def rwkv5_linear_attention_cpu(receptance, key, value, time_decay, time_first, state):
198
  input_dtype = receptance.dtype
@@ -224,24 +89,18 @@ def RWKV5_linear_attention(training, receptance, key, value, time_decay, time_fi
224
  # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
225
  # in this case).
226
  one_token = key.size(1) == 1
227
- if not training or rwkv5_cuda_kernel is None or no_cuda or one_token:
228
  return rwkv5_linear_attention_cpu(
229
  receptance, key, value, time_decay, time_first, state
230
  )
231
  else:
232
- return Rwkv5LinearAttention.apply(receptance, key, value, time_decay, time_first, state)
233
 
234
 
235
  class Rwkv5SelfAttention(nn.Module):
236
  def __init__(self, config, layer_id=0):
237
  super().__init__()
238
  self.config = config
239
- kernel_loaded = rwkv5_cuda_kernel is not None and rwkv5_cuda_kernel.head_size == config.head_size
240
- if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded:
241
- try:
242
- load_wkv5_cuda_kernel(config.head_size)
243
- except Exception:
244
- logger.info("Could not load the custom CUDA kernel for RWKV5 attention.")
245
  self.layer_id = layer_id
246
  hidden_size = config.hidden_size
247
  attention_hidden_size = config.attention_hidden_size
@@ -311,7 +170,7 @@ class Rwkv5SelfAttention(nn.Module):
311
  out = self.output(out)
312
  return out, state
313
 
314
- # Copied from rwkv exceot for the intermediate size
315
  class Rwkv5FeedForward(nn.Module):
316
  def __init__(self, config, layer_id=0):
317
  super().__init__()
 
18
  from pathlib import Path
19
  from typing import List, Optional, Tuple, Union
20
 
21
+ import pkg_resources
22
  import torch
23
  import torch.nn.functional as F
24
  import torch.utils.checkpoint
 
37
  logging,
38
  )
39
 
40
+ try:
41
+ from flash_rwkv import rwkv5_cuda_linear_attention
42
+ # Check version
43
+ required_version = pkg_resources.parse_version("0.2.1")
44
+ current_version = pkg_resources.get_distribution("flash-rwkv").parsed_version
45
+
46
+ if current_version < required_version:
47
+ raise Exception("Your version of flash-rwkv is below 0.2.1. Please use pip install --upgrade flash-rwkv to update or install the required version.")
48
+ except ImportError:
49
+ raise ImportError("The flash-rwkv package is not detected. Please install it using pip install flash-rwkv.")
50
+ except pkg_resources.DistributionNotFound:
51
+ raise ImportError("The flash-rwkv package is not detected. Please install it using pip install flash-rwkv.")
52
+
53
  from .configuration_rwkv5 import Rwkv5Config
54
 
55
 
 
58
  _CHECKPOINT_FOR_DOC = "RWKV/rwkv-5-world-1b5"
59
  _CONFIG_FOR_DOC = "Rwkv5Config"
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  def rwkv5_linear_attention_cpu(receptance, key, value, time_decay, time_first, state):
63
  input_dtype = receptance.dtype
 
89
  # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
90
  # in this case).
91
  one_token = key.size(1) == 1
92
+ if not training or no_cuda or one_token:
93
  return rwkv5_linear_attention_cpu(
94
  receptance, key, value, time_decay, time_first, state
95
  )
96
  else:
97
+ return rwkv5_cuda_linear_attention(receptance.float(), key.float(), value.float(), time_decay.float().flatten(), time_first.float().flatten(), state)
98
 
99
 
100
  class Rwkv5SelfAttention(nn.Module):
101
  def __init__(self, config, layer_id=0):
102
  super().__init__()
103
  self.config = config
 
 
 
 
 
 
104
  self.layer_id = layer_id
105
  hidden_size = config.hidden_size
106
  attention_hidden_size = config.attention_hidden_size
 
170
  out = self.output(out)
171
  return out, state
172
 
173
+ # Copied from rwkv except for the intermediate size
174
  class Rwkv5FeedForward(nn.Module):
175
  def __init__(self, config, layer_id=0):
176
  super().__init__()