|
#pragma once |
|
|
|
#include <hip/hip_runtime.h> |
|
#include <hipblas/hipblas.h> |
|
#include <hip/hip_fp16.h> |
|
#ifdef __HIP_PLATFORM_AMD__ |
|
|
|
#include "rocblas/rocblas.h" |
|
#endif |
|
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F |
|
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F |
|
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F |
|
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT |
|
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT |
|
#define CUBLAS_OP_N HIPBLAS_OP_N |
|
#define CUBLAS_OP_T HIPBLAS_OP_T |
|
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS |
|
#define CUBLAS_TF32_TENSOR_OP_MATH 0 |
|
#define CUDA_R_16F HIPBLAS_R_16F |
|
#define CUDA_R_32F HIPBLAS_R_32F |
|
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) |
|
#define cublasComputeType_t hipblasDatatype_t |
|
#define cublasCreate hipblasCreate |
|
#define cublasDestroy hipblasDestroy |
|
#define cublasGemmEx hipblasGemmEx |
|
#define cublasGemmBatchedEx hipblasGemmBatchedEx |
|
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx |
|
#define cublasHandle_t hipblasHandle_t |
|
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS |
|
#define cublasSetStream hipblasSetStream |
|
#define cublasSgemm hipblasSgemm |
|
#define cublasStatus_t hipblasStatus_t |
|
#define cublasOperation_t hipblasOperation_t |
|
#define cudaDataType_t hipblasDatatype_t |
|
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer |
|
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess |
|
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess |
|
#define cudaDeviceProp hipDeviceProp_t |
|
#define cudaDeviceSynchronize hipDeviceSynchronize |
|
#define cudaError_t hipError_t |
|
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled |
|
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled |
|
#define cudaEventCreateWithFlags hipEventCreateWithFlags |
|
#define cudaEventDisableTiming hipEventDisableTiming |
|
#define cudaEventRecord hipEventRecord |
|
#define cudaEventSynchronize hipEventSynchronize |
|
#define cudaEvent_t hipEvent_t |
|
#define cudaEventDestroy hipEventDestroy |
|
#define cudaFree hipFree |
|
#define cudaFreeHost hipHostFree |
|
#define cudaGetDevice hipGetDevice |
|
#define cudaGetDeviceCount hipGetDeviceCount |
|
#define cudaGetDeviceProperties hipGetDeviceProperties |
|
#define cudaGetErrorString hipGetErrorString |
|
#define cudaGetLastError hipGetLastError |
|
#define cudaHostRegister hipHostRegister |
|
#define cudaHostRegisterPortable hipHostRegisterPortable |
|
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly |
|
#define cudaHostUnregister hipHostUnregister |
|
#define cudaLaunchHostFunc hipLaunchHostFunc |
|
#define cudaMalloc hipMalloc |
|
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) |
|
#define cudaMemcpy hipMemcpy |
|
#define cudaMemcpyAsync hipMemcpyAsync |
|
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync |
|
#define cudaMemcpy2DAsync hipMemcpy2DAsync |
|
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice |
|
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost |
|
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice |
|
#define cudaMemcpyKind hipMemcpyKind |
|
#define cudaMemset hipMemset |
|
#define cudaMemsetAsync hipMemsetAsync |
|
#define cudaMemGetInfo hipMemGetInfo |
|
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize |
|
#define cudaSetDevice hipSetDevice |
|
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags |
|
#define cudaStreamDestroy hipStreamDestroy |
|
#define cudaStreamFireAndForget hipStreamFireAndForget |
|
#define cudaStreamNonBlocking hipStreamNonBlocking |
|
#define cudaStreamPerThread hipStreamPerThread |
|
#define cudaStreamSynchronize hipStreamSynchronize |
|
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) |
|
#define cudaStream_t hipStream_t |
|
#define cudaSuccess hipSuccess |
|
#define __trap() do { abort(); __builtin_unreachable(); } while(0) |
|
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS |
|
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED |
|
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED |
|
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE |
|
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH |
|
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR |
|
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED |
|
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR |
|
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED |
|
|
|
#define __CUDA_ARCH__ 1300 |
|
|
|
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ |
|
defined(__gfx1150__) || defined(__gfx1151__) |
|
#define RDNA3 |
|
#endif |
|
|
|
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ |
|
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) |
|
#define RDNA2 |
|
#endif |
|
|
|
#if defined(__gfx1010__) || defined(__gfx1012__) |
|
#define RDNA1 |
|
#endif |
|
|
|
#ifndef __has_builtin |
|
#define __has_builtin(x) 0 |
|
#endif |
|
|
|
typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); |
|
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); |
|
static __device__ __forceinline__ int __vsubss4(const int a, const int b) { |
|
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a); |
|
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b); |
|
#if __has_builtin(__builtin_elementwise_sub_sat) |
|
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb); |
|
return reinterpret_cast<const int &>(c); |
|
#else |
|
int8x4_t c; |
|
int16_t tmp; |
|
#pragma unroll |
|
for (int i = 0; i < 4; i++) { |
|
tmp = va[i] - vb[i]; |
|
if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max(); |
|
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min(); |
|
c[i] = tmp; |
|
} |
|
return reinterpret_cast<int &>(c); |
|
#endif |
|
} |
|
|
|
static __device__ __forceinline__ int __vsub4(const int a, const int b) { |
|
return __vsubss4(a, b); |
|
} |
|
|
|
static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) { |
|
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a); |
|
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b); |
|
unsigned int c; |
|
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c); |
|
#pragma unroll |
|
for (int i = 0; i < 4; ++i) { |
|
vc[i] = va[i] == vb[i] ? 0xff : 0x00; |
|
} |
|
return c; |
|
} |
|
|
|
static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) { |
|
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a); |
|
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b); |
|
unsigned int c; |
|
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c); |
|
#pragma unroll |
|
for (int i = 0; i < 4; ++i) { |
|
vc[i] = va[i] == vb[i] ? 0x00 : 0xff; |
|
} |
|
return c; |
|
} |
|
|
|
#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000 |
|
|
|
static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) { |
|
typedef union half2_b32 { |
|
half2 val; |
|
int b32; |
|
} half2_b32_t; |
|
half2_b32_t tmp; |
|
tmp.val = var; |
|
tmp.b32 = __shfl_xor(tmp.b32, laneMask, width); |
|
return tmp.val; |
|
} |
|
#endif |
|
|