yangapku commited on
Commit
8da9235
1 Parent(s): a5c7dd6

update kernels

Browse files
kernels/cache_autogptq_cuda_256.cpp → cache_autogptq_cuda_256.cpp RENAMED
File without changes
kernels/cache_autogptq_cuda_kernel_256.cu → cache_autogptq_cuda_kernel_256.cu RENAMED
File without changes
kernels/cpp_kernels.py → cpp_kernels.py RENAMED
@@ -50,6 +50,6 @@ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
50
 
51
  extra_flags = []
52
 
53
- cache_autogptq_cuda_256_sources = ["./kernels/cache_autogptq_cuda_256.cpp",
54
- "./kernels/cache_autogptq_cuda_kernel_256.cu"]
55
  cache_autogptq_cuda_256 = _cpp_extention_load_helper("cache_autogptq_cuda_256", cache_autogptq_cuda_256_sources, extra_flags)
 
50
 
51
  extra_flags = []
52
 
53
+ cache_autogptq_cuda_256_sources = ["./cache_autogptq_cuda_256.cpp",
54
+ "./cache_autogptq_cuda_kernel_256.cu"]
55
  cache_autogptq_cuda_256 = _cpp_extention_load_helper("cache_autogptq_cuda_256", cache_autogptq_cuda_256_sources, extra_flags)
modeling_qwen.py CHANGED
@@ -32,11 +32,6 @@ except ImportError:
32
  rearrange = None
33
  from torch import nn
34
 
35
- try:
36
- from kernels.cpp_kernels import cache_autogptq_cuda_256
37
- except ImportError:
38
- cache_autogptq_cuda_256 = None
39
-
40
  SUPPORT_CUDA = torch.cuda.is_available()
41
  SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
42
  SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
@@ -294,14 +289,21 @@ class QWenAttention(nn.Module):
294
  self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
295
  self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
296
 
 
 
 
 
 
 
 
297
  def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
298
  device = query.device
299
  if self.use_cache_quantization:
300
  qk, qk_scale, qk_zero = key
301
- if self.use_cache_kernel and cache_autogptq_cuda_256 is not None:
302
  shape = query.shape[:-1] + (qk.shape[-2],)
303
  attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
304
- cache_autogptq_cuda_256.vecquant8matmul_batched_faster_old(
305
  query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(),
306
  qk.transpose(-1, -2).contiguous(),
307
  attn_weights,
@@ -353,10 +355,10 @@ class QWenAttention(nn.Module):
353
 
354
  if self.use_cache_quantization:
355
  qv, qv_scale, qv_zero = value
356
- if self.use_cache_kernel and cache_autogptq_cuda_256 is not None:
357
  shape = attn_weights.shape[:-1] + (query.shape[-1],)
358
  attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
359
- cache_autogptq_cuda_256.vecquant8matmul_batched_column_compression_faster_old(
360
  attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(),
361
  qv.contiguous(), # dtype: int32
362
  attn_output,
@@ -1022,15 +1024,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
1022
  if config.use_flash_attn:
1023
  _import_flash_attn()
1024
 
1025
-
1026
- if hasattr(config, 'use_cache_quantization') and config.use_cache_quantization:
1027
- config.use_flash_attn = False
1028
- if hasattr(config, 'use_cache_kernel') and config.use_cache_kernel:
1029
- try:
1030
- from kernels.cpp_kernels import cache_autogptq_cuda_256
1031
- except ImportError:
1032
- cache_autogptq_cuda_256 = None
1033
-
1034
  self.transformer = QWenModel(config)
1035
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1036
 
 
32
  rearrange = None
33
  from torch import nn
34
 
 
 
 
 
 
35
  SUPPORT_CUDA = torch.cuda.is_available()
36
  SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
37
  SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
 
289
  self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
290
  self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
291
 
292
+ if config.use_cache_quantization and config.use_cache_kernel:
293
+ from .cpp_kernels import cache_autogptq_cuda_256
294
+ try:
295
+ self.cache_kernels = cache_autogptq_cuda_256
296
+ except ImportError:
297
+ self.cache_kernels = None
298
+
299
  def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
300
  device = query.device
301
  if self.use_cache_quantization:
302
  qk, qk_scale, qk_zero = key
303
+ if self.use_cache_kernel and self.cache_kernels is not None:
304
  shape = query.shape[:-1] + (qk.shape[-2],)
305
  attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
306
+ self.cache_kernels.vecquant8matmul_batched_faster_old(
307
  query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(),
308
  qk.transpose(-1, -2).contiguous(),
309
  attn_weights,
 
355
 
356
  if self.use_cache_quantization:
357
  qv, qv_scale, qv_zero = value
358
+ if self.use_cache_kernel and self.cache_kernels is not None:
359
  shape = attn_weights.shape[:-1] + (query.shape[-1],)
360
  attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
361
+ self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old(
362
  attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(),
363
  qv.contiguous(), # dtype: int32
364
  attn_output,
 
1024
  if config.use_flash_attn:
1025
  _import_flash_attn()
1026
 
 
 
 
 
 
 
 
 
 
1027
  self.transformer = QWenModel(config)
1028
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1029