I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,9 @@
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <c10/macros/Export.h>
// Use TORCH_CUDA_CPP_API or TORCH_CUDA_CU_API for exports from this folder

View File

@ -0,0 +1,47 @@
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
namespace at::cuda {
/**
Computes ceil(a / b)
*/
template <typename T>
__host__ __device__ __forceinline__ T ATenCeilDiv(T a, T b) {
return (a + b - 1) / b;
}
namespace {
// Threads per block for our apply kernel
// FIXME: use occupancy calculator instead
constexpr uint32_t AT_APPLY_THREADS_PER_BLOCK = 512;
constexpr uint32_t AT_APPLY_BLOCKS_PER_SM = 4;
template <int step = 1>
inline bool getApplyGrid(uint64_t totalElements, dim3& grid, c10::DeviceIndex curDevice, int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) {
if (curDevice == -1) return false;
uint64_t numel_per_thread = static_cast<uint64_t>(max_threads_per_block) * static_cast<uint64_t>(step);
uint64_t numBlocks = ATenCeilDiv(totalElements, numel_per_thread);
uint64_t maxGridX = at::cuda::getDeviceProperties(curDevice)->maxGridSize[0];
if (numBlocks > maxGridX)
numBlocks = maxGridX;
grid = dim3(numBlocks);
return true;
}
constexpr int getApplyBlocksPerSM() {
return AT_APPLY_BLOCKS_PER_SM;
}
constexpr int getApplyBlockSize() {
return AT_APPLY_THREADS_PER_BLOCK;
}
inline dim3 getApplyBlock(int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) {
return dim3(max_threads_per_block);
}
} // anonymous namespace
} // namespace at::cuda

View File

@ -0,0 +1,149 @@
#pragma once
#include <cstdint>
// Collection of direct PTX functions
namespace at::cuda {
template <typename T>
struct Bitfield {};
template <>
struct Bitfield<unsigned int> {
static __device__ __host__ __forceinline__
unsigned int getBitfield(unsigned int val, int pos, int len) {
#if !defined(__CUDA_ARCH__)
pos &= 0xff;
len &= 0xff;
unsigned int m = (1u << len) - 1u;
return (val >> pos) & m;
#else
unsigned int ret;
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
return ret;
#endif
}
static __device__ __host__ __forceinline__
unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
#if !defined(__CUDA_ARCH__)
pos &= 0xff;
len &= 0xff;
unsigned int m = (1u << len) - 1u;
toInsert &= m;
toInsert <<= pos;
m <<= pos;
return (val & ~m) | toInsert;
#else
unsigned int ret;
asm("bfi.b32 %0, %1, %2, %3, %4;" :
"=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
return ret;
#endif
}
};
template <>
struct Bitfield<uint64_t> {
static __device__ __host__ __forceinline__
uint64_t getBitfield(uint64_t val, int pos, int len) {
#if !defined(__CUDA_ARCH__)
pos &= 0xff;
len &= 0xff;
uint64_t m = (1u << len) - 1u;
return (val >> pos) & m;
#else
uint64_t ret;
asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
return ret;
#endif
}
static __device__ __host__ __forceinline__
uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) {
#if !defined(__CUDA_ARCH__)
pos &= 0xff;
len &= 0xff;
uint64_t m = (1u << len) - 1u;
toInsert &= m;
toInsert <<= pos;
m <<= pos;
return (val & ~m) | toInsert;
#else
uint64_t ret;
asm("bfi.b64 %0, %1, %2, %3, %4;" :
"=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len));
return ret;
#endif
}
};
__device__ __forceinline__ int getLaneId() {
#if defined(USE_ROCM)
return __lane_id();
#else
int laneId;
asm("mov.s32 %0, %%laneid;" : "=r"(laneId) );
return laneId;
#endif
}
#if defined(USE_ROCM)
__device__ __forceinline__ unsigned long long int getLaneMaskLt() {
const std::uint64_t m = (1ull << getLaneId()) - 1ull;
return m;
}
#else
__device__ __forceinline__ unsigned getLaneMaskLt() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask));
return mask;
}
#endif
#if defined (USE_ROCM)
__device__ __forceinline__ unsigned long long int getLaneMaskLe() {
std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1));
return m;
}
#else
__device__ __forceinline__ unsigned getLaneMaskLe() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask));
return mask;
}
#endif
#if defined(USE_ROCM)
__device__ __forceinline__ unsigned long long int getLaneMaskGt() {
const std::uint64_t m = getLaneMaskLe();
return m ? ~m : m;
}
#else
__device__ __forceinline__ unsigned getLaneMaskGt() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask));
return mask;
}
#endif
#if defined(USE_ROCM)
__device__ __forceinline__ unsigned long long int getLaneMaskGe() {
const std::uint64_t m = getLaneMaskLt();
return ~m;
}
#else
__device__ __forceinline__ unsigned getLaneMaskGe() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask));
return mask;
}
#endif
} // namespace at::cuda

View File

@ -0,0 +1,514 @@
#pragma once
#include <cuda.h>
#include <c10/util/Half.h>
#include <c10/util/BFloat16.h>
#include <ATen/NumericUtils.h>
#if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
#include <cuda_bf16.h>
#endif
template <typename T>
struct AtomicFPOp;
template <>
struct AtomicFPOp<at::Half> {
template <typename func_t>
inline __device__ at::Half operator() (at::Half *address, at::Half val, const func_t& func) {
unsigned int * address_as_ui =
(unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
at::Half hsum;
do {
assumed = old;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
hsum = func(hsum, val);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
return hsum;
}
};
template <>
struct AtomicFPOp<at::BFloat16> {
template <typename func_t>
inline __device__ at::BFloat16 operator() (at::BFloat16 *address, at::BFloat16 val, const func_t& func) {
unsigned int * address_as_ui =
(unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
at::BFloat16 bsum;
do {
assumed = old;
bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
bsum = func(bsum, val);
old = (size_t)address & 2 ? (old & 0xffff) | (bsum.x << 16) : (old & 0xffff0000) | bsum.x;
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);
bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
return bsum.x;
}
};
template <>
struct AtomicFPOp<double> {
template <typename func_t>
inline __device__ double operator() (double * address, double val, const func_t& func) {
unsigned long long int* address_as_ull = (unsigned long long int*)address;
unsigned long long int old = *address_as_ull;
unsigned long long int assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed, func(val, assumed));
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
} while (assumed != old);
return __longlong_as_double(old);
}
};
#define ATOMIC_INTEGER_IMPL(NAME) \
template <typename T, size_t n> \
struct Atomic##NAME##IntegerImpl; \
\
template<typename T> \
struct Atomic##NAME##IntegerImpl<T, 1> { \
template <typename func_t> \
inline __device__ void operator()(T *address, T val, const func_t& func) { \
size_t offset = (size_t)address & 3; \
uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \
uint32_t old = *address_as_ui; \
uint32_t shift = offset * 8; \
uint32_t old_byte; \
uint32_t newval; \
uint32_t assumed; \
\
do { \
assumed = old; \
old_byte = (old >> shift) & 0xff; \
newval = static_cast<uint8_t>(func(val, static_cast<T>(old_byte))); \
newval = (old & ~(0x000000ff << shift)) | (newval << shift); \
old = atomicCAS(address_as_ui, assumed, newval); \
} while (assumed != old); \
} \
}; \
\
template<typename T> \
struct Atomic##NAME##IntegerImpl<T, 2> { \
template <typename func_t> \
inline __device__ void operator()(T *address, T val, const func_t& func) { \
size_t offset = (size_t)address & 2; \
uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \
bool is_32_align = offset; \
uint32_t old = *address_as_ui; \
uint32_t old_bytes; \
uint32_t newval; \
uint32_t assumed; \
\
do { \
assumed = old; \
old_bytes = is_32_align ? old >> 16 : old & 0xffff; \
newval = static_cast<uint16_t>(func(val, static_cast<T>(old_bytes))); \
newval = is_32_align ? (old & 0xffff) | (newval << 16) : (old & 0xffff0000) | newval; \
old = atomicCAS(address_as_ui, assumed, newval); \
} while (assumed != old); \
} \
}; \
\
template<typename T> \
struct Atomic##NAME##IntegerImpl<T, 4> { \
template <typename func_t> \
inline __device__ void operator()(T *address, T val, const func_t& func) { \
uint32_t * address_as_ui = (uint32_t *) (address); \
uint32_t old = *address_as_ui; \
uint32_t newval; \
uint32_t assumed; \
\
do { \
assumed = old; \
newval = static_cast<uint32_t>(func(val, static_cast<T>(old))); \
old = atomicCAS(address_as_ui, assumed, newval); \
} while (assumed != old); \
} \
}; \
\
template<typename T> \
struct Atomic##NAME##IntegerImpl<T, 8> { \
template <typename func_t> \
inline __device__ void operator()(T *address, T val, const func_t& func) { \
unsigned long long * address_as_ui = (unsigned long long *) (address); \
unsigned long long old = *address_as_ui; \
unsigned long long newval; \
unsigned long long assumed; \
\
do { \
assumed = old; \
newval = static_cast<uint64_t>(func(val, static_cast<T>(old))); \
old = atomicCAS(address_as_ui, assumed, newval); \
} while (assumed != old); \
} \
};
# define GPU_ATOMIC_INTEGER(NAME, OP, DTYPE) \
inline __device__ void gpuAtomic##NAME(DTYPE *address, DTYPE val) { \
Atomic##NAME##IntegerImpl<DTYPE, sizeof(DTYPE)>()(address, \
val, \
[](DTYPE a, DTYPE b) { \
return OP; \
}); \
} \
ATOMIC_INTEGER_IMPL(Add)
GPU_ATOMIC_INTEGER(Add, a || b, bool)
// Don't instantiate gpuAtomicAdd with the macro as it seems non-standard (see int32, int64)
inline __device__ void gpuAtomicAdd(uint8_t *address, uint8_t val) {
AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address,
val,
[](uint8_t a, uint8_t b) {
return a + b;
});
}
inline __device__ void gpuAtomicAdd(int8_t *address, int8_t val) {
AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address,
val,
[](int8_t a, int8_t b) {
return a + b;
});
}
inline __device__ void gpuAtomicAdd(int16_t *address, int16_t val) {
AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address,
val,
[](int16_t a, int16_t b) {
return a + b;
});
}
inline __device__ int32_t gpuAtomicAdd(int32_t *address, int32_t val) {
return atomicAdd(address, val);
}
inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) {
#if defined(USE_ROCM)
__atomic_fetch_add(address, val, __ATOMIC_RELAXED);
#else
static_assert(sizeof(unsigned long long int) == sizeof(int64_t), "bitwidth change is not allowed");
atomicAdd(reinterpret_cast<unsigned long long int *>(address), static_cast<unsigned long long int>(val));
#endif
}
inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) {
#if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
return AtomicFPOp<at::Half>()(address, val,
[](at::Half hsum, at::Half val) {
return hsum + val;
});
#else
return atomicAdd(reinterpret_cast<__half*>(address), val);
#endif
}
inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BFloat16 val) {
#if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
return AtomicFPOp<at::BFloat16>()(address, val,
[](at::BFloat16 bsum, at::BFloat16 val) {
return bsum + val;
});
#else
__nv_bfloat16 r = atomicAdd(reinterpret_cast<__nv_bfloat16*>(address), *reinterpret_cast<__nv_bfloat16*>(&val));
return *reinterpret_cast<c10::BFloat16*>(&r);
#endif
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600)
// from CUDA C Programmic Guide
inline __device__ double atomicAdd(double* address, double val)
#if defined(__clang__) && defined(__CUDA__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wgcc-compat"
__attribute__((enable_if(true, "")))
#pragma GCC diagnostic pop
#endif
{
return AtomicFPOp<double>()(address, val,
[](double val, unsigned long long int assumed) {
return __double_as_longlong(val + __longlong_as_double(assumed));
});
}
#elif defined(USE_ROCM) || !(defined(__CUDA_ARCH__))
/* Note [hip-clang differences to hcc]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* The upcoming hip-clang compiler for ROCm differs from hcc in a few details.
* It exports the __HIP__ macro, we can hence differentiate between hcc and
* hip-clang. In the below, hcc only received support for atomicAdd with double
* typing after work week 18312. hip-clang had support from the first version.
* In general, the code-visible differences between hip-clang and hcc will be
* minimal.
*/
#if defined(USE_ROCM) && __hcc_workweek__ < 18312 && !__HIP__
// This needs to be defined for the host side pass
inline __device__ double atomicAdd(double *address, double val) { }
#endif
#endif
inline __device__ double gpuAtomicAdd(double *address, double val) {
return atomicAdd(address, val);
}
inline __device__ float gpuAtomicAdd(float *address, float val) {
return atomicAdd(address, val);
}
template<typename T>
inline __device__ void gpuAtomicAdd(c10::complex<T> *address, c10::complex<T> val) {
gpuAtomicAdd(&address->real_, val.real_);
gpuAtomicAdd(&address->imag_, val.imag_);
}
/* Note [gpuAtomicAdd vs atomicAdd]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* Some extensions such as torchvision call atomicAdd()
* directly and require non-library provided data type support. Only for these, we
* continue to provide atomicAdd overloads.
*/
inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
return gpuAtomicAdd(address, val);
}
inline __device__ at::BFloat16 atomicAdd(at::BFloat16 *address, at::BFloat16 val) {
return gpuAtomicAdd(address, val);
}
inline __device__ void atomicAdd(uint8_t *address, uint8_t val) {
gpuAtomicAdd(address, val);
}
inline __device__ void atomicAdd(int8_t *address, int8_t val) {
gpuAtomicAdd(address, val);
}
inline __device__ void atomicAdd(int16_t *address, int16_t val) {
gpuAtomicAdd(address, val);
}
inline __device__ void atomicAdd(int64_t *address, int64_t val) {
gpuAtomicAdd(address, val);
}
inline __device__ void atomicAdd(bool *address, bool val) {
gpuAtomicAdd(address, val);
}
/* Note [explicitly non-returning atomics]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* AMD's MI100 (gfx908) provides an optimized fp32 atomicAdd, exposed via atomicAddNoRet().
* Due to compiler limitations, callers must opt-in to guarantee the optimized instruction.
* This non-returning atomicAddNoRet cannot be used to implement the returning atomicAdd,
* therefore we need a new API 'gpuAtomicAddNoReturn'.
*/
template<typename T>
inline __device__ void gpuAtomicAddNoReturn(c10::complex<T> *address, c10::complex<T> val) { gpuAtomicAdd(address, val); }
inline __device__ void gpuAtomicAddNoReturn(uint8_t *address, uint8_t val) { gpuAtomicAdd(address, val); }
inline __device__ void gpuAtomicAddNoReturn(int8_t *address, int8_t val) { gpuAtomicAdd(address, val); }
inline __device__ void gpuAtomicAddNoReturn(int16_t *address, int16_t val) { gpuAtomicAdd(address, val); }
inline __device__ void gpuAtomicAddNoReturn(int32_t *address, int32_t val) { gpuAtomicAdd(address, val); }
inline __device__ void gpuAtomicAddNoReturn(int64_t *address, int64_t val) { gpuAtomicAdd(address, val); }
inline __device__ void gpuAtomicAddNoReturn(bool *address, bool val) { gpuAtomicAdd(address, val); }
inline __device__ void gpuAtomicAddNoReturn(at::Half *address, at::Half val) { gpuAtomicAdd(address, val); }
inline __device__ void gpuAtomicAddNoReturn(at::BFloat16 *address, at::BFloat16 val) { gpuAtomicAdd(address, val); }
inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { gpuAtomicAdd(address, val); }
/* Special case fp32 atomic. */
#if defined(USE_ROCM)
inline __device__ void gpuAtomicAddNoReturn(float *address, float val) {
#if defined(__gfx908__)
atomicAddNoRet(address, val);
#else
(void)unsafeAtomicAdd(address, val);
#endif
}
#else
inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); }
#endif
// Atomic multiplication implementation.
ATOMIC_INTEGER_IMPL(Mul)
GPU_ATOMIC_INTEGER(Mul, a * b, uint8_t)
GPU_ATOMIC_INTEGER(Mul, a * b, int8_t)
GPU_ATOMIC_INTEGER(Mul, a * b, int16_t)
GPU_ATOMIC_INTEGER(Mul, a * b, int32_t)
GPU_ATOMIC_INTEGER(Mul, a * b, int64_t)
inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) {
return AtomicFPOp<at::Half>()(address, val,
[](at::Half bsum, at::Half val) {
return bsum * val;
});
}
inline __device__ at::BFloat16 gpuAtomicMul(at::BFloat16 * address, at::BFloat16 val) {
return AtomicFPOp<at::BFloat16>()(address, val,
[](at::BFloat16 bsum, at::BFloat16 val) {
return bsum * val;
});
}
inline __device__ double gpuAtomicMul(double * address, double val) {
return AtomicFPOp<double>()(address, val,
[](double val, unsigned long long int assumed) {
return __double_as_longlong(val * __longlong_as_double(assumed));
});
}
// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
inline __device__ float gpuAtomicMul (float * address, float val) {
unsigned int* address_as_ull = (unsigned int*)address;
unsigned int old = *address_as_ull;
unsigned int assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__float_as_int(val *
__int_as_float(assumed)));
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
} while (assumed != old);
return __int_as_float(old);
}
// Atomic maximum implementation.
template <typename T>
__host__ __device__ T safe_max(T a, T b) {
#if defined(__HIPCC__)
// TODO: remove this special case for HIP when issue is fixed:
// https://github.com/ROCm-Developer-Tools/HIP/issues/2209
T max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max<T>(a, b));
#else
T max = at::_isnan(b) ? b : std::max<T>(a, b);
#endif
return max;
}
ATOMIC_INTEGER_IMPL(Max)
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t)
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t)
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int16_t)
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int32_t)
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int64_t)
inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) {
return AtomicFPOp<at::Half>()(address, val,
[](at::Half bsum, at::Half val) {
return safe_max(bsum, val);
});
}
inline __device__ at::BFloat16 gpuAtomicMax(at::BFloat16 * address, at::BFloat16 val) {
return AtomicFPOp<at::BFloat16>()(address, val,
[](at::BFloat16 bsum, at::BFloat16 val) {
return safe_max(bsum, val);
});
}
inline __device__ double gpuAtomicMax(double * address, double val) {
return AtomicFPOp<double>()(address, val,
[](double val, unsigned long long int assumed) {
return __double_as_longlong(safe_max(val, __longlong_as_double(assumed)));
});
}
// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
inline __device__ float gpuAtomicMax(float * address, float val) {
unsigned int* address_as_ull = (unsigned int*)address;
unsigned int old = *address_as_ull;
unsigned int assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__float_as_int(safe_max(val, __int_as_float(assumed))));
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
} while (assumed != old);
return __int_as_float(old);
}
// Atomic minimum implementation.
template <typename T>
__host__ __device__ T safe_min(T a, T b) {
#if defined(__HIPCC__)
// TODO: remove this special case for HIP when issue is fixed:
// https://github.com/ROCm-Developer-Tools/HIP/issues/2209
T min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min<T>(a, b));
#else
T min = at::_isnan(b) ? b : std::min<T>(a, b);
#endif
return min;
}
ATOMIC_INTEGER_IMPL(Min)
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t)
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t)
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int16_t)
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int32_t)
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int64_t)
inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) {
return AtomicFPOp<at::Half>()(address, val,
[](at::Half bsum, at::Half val) {
return safe_min(bsum, val);
});
}
inline __device__ at::BFloat16 gpuAtomicMin(at::BFloat16 * address, at::BFloat16 val) {
return AtomicFPOp<at::BFloat16>()(address, val,
[](at::BFloat16 bsum, at::BFloat16 val) {
return safe_min(bsum, val);
});
}
inline __device__ double gpuAtomicMin(double * address, double val) {
return AtomicFPOp<double>()(address, val,
[](double val, unsigned long long int assumed) {
return __double_as_longlong(safe_min(val, __longlong_as_double(assumed)));
});
}
// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
inline __device__ float gpuAtomicMin(float * address, float val) {
unsigned int* address_as_ull = (unsigned int*)address;
unsigned int old = *address_as_ull;
unsigned int assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__float_as_int(safe_min(val, __int_as_float(assumed))));
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
} while (assumed != old);
return __int_as_float(old);
}

View File

@ -0,0 +1,537 @@
#pragma once
#include <ATen/cuda/ApplyGridUtils.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/core/TensorBase.h>
#include <ATen/ceil_div.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <c10/macros/Macros.h>
#include <ATen/native/Copy.h>
#include <math.h>
//
// This file contains pointwise operation functions and kernels that
// work on both contiguous and non-contiguous tensor arguments of
// arbitrary (up to MAX_CUTORCH_DIMS) dimensioned arguments without
// copying or temporary storage.
//
/*
NOTE [ CUDA_tensor_applyN helpers ]
The following CUDA_tensor_applyN (where N currently can be 1, 2, 3, or 4)
functions apply a pointwise operator to N tensor(s).
The calling convention is
1. The template arguments should be, sequentially,
- First N typename args specify the scalar types of each of the N tensors.
- (Optional) `int step` arg specifies the number of elements processed
together at the same time.
Default is 1.
- A usually omitted (i.e., inferred) typename arg specifies the type of the
function/functor applied on `N * step` values in each iteration of each
CUDA thread.
2. The arguments should be, sequentially,
- N tensors
- op: a function/functor that processes `N * step` values at the same time.
- If `step == 1`, it must have signature
`void(*)(scalar1_t&, scalar2_t&, ..., scalarN_t&)`, where
`scalar*_t`s are the first N typename template args, and the inputs
are the `N` values from the `N` tensors retrieved at a common index.
- Otherwise, it must must have signature
void(*)(int n, scalar1_t&, scalar1_t&, ..., scalar1_t&, // repeat `step` times
scalar2_t&, scalar2_t&, ..., scalar2_t&, // repeat `step` times
...,
scalarN_t&, scalarN_t&, ..., scalarN_t&) // repeat `step` times
Different from `step == 1` case, it processes `N * step` values taken
from `step` common indices. Moreover, the first input `n` represents the
number of valid indices (it will always have `0 < n <= step`). It will
almost always be `step`, but at the boundary we may not have full `step`
elements and `n` can be a lesser value.
E.g., if `step == 4` and `N == 2`, `op` could be
[](int n, scalar1_t &u1, scalar1_t &u2, scalar1_t &u3, scalar1_t &u4,
scalar2_t &v1, scalar2_t &v2, scalar2_t &v3, scalar2_t &v4) {
// Only process u1, ..., un and v1, ..., vn.
// So if `n == 3`, `u4` and `v4` need not to be considered.
}
In both cases, the references can actually be const, but at least one of
them should be non-const in order to write the output.
- (Optional, but recommended) N TensorArgType args that specify for each
tensor whether `op` reads AND writes ] (i.e., TensorArgType::ReadWrite),
or only reads (i.e., TensorArgType::ReadOnly).
Default is TensorArgType::ReadWrite for first Tensor, and
TensorArgType::ReadOnly for the rest.
E.g.,
to compute a = b^2 for a and b of same dtype, we can call
CUDA_tensor_apply2<scalar, scalar>(
a, b,
[] __device__ (scalar &a_val, const scalar &b_val) { a_val = b_val * b_val; }
);
to work on 2 values at the same time, we can call
CUDA_tensor_apply2<scalar1, scalar2, 2>(
a, b,
[] __device__ (int n, scalar1 &a_val1, scalar1 &a_val2,
const scalar2 &b_val1, const scalar2 &b_val2) {
// call special vectorized op here, or just do elementwise and enjoy unrolling...
// if n == 1, only process a_val1 and b_val1
}
);
*/
namespace at::cuda {
// TODO: combine with TensorArg? So far that's been for debugging, and this is functional...
enum class TensorArgType { ReadWrite, ReadOnly };
namespace {
// Rearrange dimensions for pointwise operations so that strides are in
// decreasing order as much as possible, so that kernels have better memory
// access patterns.
//
// For example, consider a binary operation on two "transposed" 2-dim tensors:
// sizes: 256 512
// aInfo->strides: 1 256
// bInfo->strides: 1 256
//
// Given this, each concurrent memory access inside kernelPointwiseApply2() is
// exactly 256 elements apart, resulting in poor performance.
//
// This function exchanges dimensions so that memory access is contiguous:
// sizes: 512 256
// aInfo->strides: 256 1
// bInfo->strides: 256 1
//
// (Actually, it becomes even better because now collapseDims() can turn each
// input into one contiguous array.)
//
// In general, given M (<=4) TensorInfo's with N dimensions, we can view each
// strides[i] (0 <= i < N) as an M-tuple. Given each pair i < j, we exchange
// strides[i] and [j] if
// (1) strides[i][k] < strides[j][k] for some k (0 <= k < M)
// (exchanging them will benefit input #k), and
// (2) strides[i][k] <= strieds[j][k] for all k
// (exchanging them will not make any input worse).
template <typename T1, typename IndexType,
typename T2 = void, typename T3 = void, typename T4 = void>
inline void rearrangeDims(detail::TensorInfo<T1, IndexType>* aInfo,
detail::TensorInfo<T2, IndexType>* bInfo = nullptr,
detail::TensorInfo<T3, IndexType>* cInfo = nullptr,
detail::TensorInfo<T4, IndexType>* dInfo = nullptr) {
int numInfos = 1;
int dims = aInfo->dims;
IndexType *sizes[4] = { aInfo->sizes, };
IndexType *strides[4] = { aInfo->strides, };
if (bInfo != nullptr) {
++numInfos;
if (bInfo->dims != dims) return;
sizes[1] = bInfo->sizes;
strides[1] = bInfo->strides;
}
if (cInfo != nullptr) {
++numInfos;
if (cInfo->dims != dims) return;
sizes[2] = cInfo->sizes;
strides[2] = cInfo->strides;
}
if (dInfo != nullptr) {
++numInfos;
if (dInfo->dims != dims) return;
sizes[3] = dInfo->sizes;
strides[3] = dInfo->strides;
}
// Bail out if sizes do not match: we are using "deprecated pointwise
// behavior" among tensors of different shapes but same number of elements.
for (int i = 1; i < numInfos; ++i) {
for (int j = 0; j < dims; ++j) {
if (sizes[i][j] != sizes[0][j]) return;
}
}
for (int i = 0; i < dims - 1; ++i) {
// No need to consider dimensions of size 1.
if (sizes[0][i] == 1) continue;
for (int j = i + 1; j < dims; ++j) {
if (sizes[0][j] == 1) continue;
// Compare the relative sizes of strides between dim #i and dim #j.
bool hasIncreasingStrides = false;
bool hasDecreasingStrides = false;
for (int k = 0; k < numInfos; k++) {
IndexType stride_i = strides[k][i];
IndexType stride_j = strides[k][j];
if (stride_i < stride_j) {
hasIncreasingStrides = true;
} else if (stride_i > stride_j) {
hasDecreasingStrides = true;
}
}
if (hasIncreasingStrides && !hasDecreasingStrides) {
for (int k = 0; k < numInfos; k++) {
IndexType size = sizes[k][i];
sizes[k][i] = sizes[k][j];
sizes[k][j] = size;
IndexType stride = strides[k][i];
strides[k][i] = strides[k][j];
strides[k][j] = stride;
}
}
}
}
}
// The `remaining_steps` argument is used to support Op that operates on
// multiple elements at the same time. Generally, the strategy of ApplyOpN is to
// 1. Initialize `remaining_steps = step`, where `step` is the template arg of
// CUDA_tensor_applyN helpers. The input arg `n` to `apply()` represents the
// number of elements in bound for this call. It will almost always equal to
// `step` except at boundaries.
// 2. If `remaining_steps > 0` convert the current linearIndex to offset (if in
// bound), and recursively call `ApplyOpN` with `remaining_steps - 1`.
// 3. At `remaining_steps = 0`,
// if `step = 1`, call `op(tensor1_val, tensor2_val, ...)`;
// if `step > 1`, call `op(n, tensor1_val1, tensor1_val2, ..., tesor1_valstep,
// tensor2_val1, tensor2_val2, ..., tesor2_valstep,
// ...
// tensorN_val1, tensorN_val2, ..., tesorN_valstep);`
//
// See NOTE [ CUDA_tensor_applyN helpers ] above for how Op may look like.
template <typename Op,
typename scalar,
typename IndexType,
int ADims,
int remaining_steps,
typename... Offsets>
struct ApplyOp1 {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op, int n,
IndexType linearIndex, Offsets... aOffsets) {
// Convert `linearIndex` into an offset of `a`
const IndexType aOffset = sizeof...(Offsets) < n ?
detail::IndexToOffset<scalar, IndexType, ADims>::get(linearIndex, a) : 0;
ApplyOp1<Op, scalar, IndexType, ADims, remaining_steps - 1, const IndexType, Offsets...>::apply(
a, op, n, linearIndex + 1, aOffsets..., aOffset
);
}
};
// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
// We don't need to pass in how many elements need to processed in this case.
template <typename Op,
typename scalar,
typename IndexType,
int ADims,
typename Offset>
struct ApplyOp1<Op, scalar, IndexType, ADims, 0, Offset> {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op,
int n, IndexType linearIndex, Offset offset) {
op(a.data[offset]);
}
};
template <typename Op,
typename scalar,
typename IndexType,
int ADims,
typename... Offsets>
struct ApplyOp1<Op, scalar, IndexType, ADims, 0, Offsets...> {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op, int n,
IndexType linearIndex, Offsets... offsets) {
op(n, a.data[offsets]...);
}
};
template <typename Op,
typename scalar,
typename IndexType,
int ADims,
int step>
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
#endif
__global__ void kernelPointwiseApply1(detail::TensorInfo<scalar, IndexType> a,
IndexType totalElements, const Op op) {
for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
linearIndex < totalElements;
linearIndex += gridDim.x * blockDim.x * step) {
ApplyOp1<Op, scalar, IndexType, ADims, step>::apply(
a, op, ::min(step, static_cast<int>(totalElements - linearIndex)), linearIndex);
}
}
template <typename Op,
typename scalar1,
typename scalar2,
typename IndexType,
int ADims,
int BDims,
int remaining_steps,
typename... Offsets>
struct ApplyOp2 {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
detail::TensorInfo<scalar2, IndexType> &b,
const Op &op, int64_t n, IndexType linearIndex,
Offsets... aOffsets, Offsets... bOffsets) {
// Convert `linearIndex` into an offset of `a`
const IndexType aOffset = static_cast<int64_t>(sizeof...(Offsets)) < n ?
detail::IndexToOffset<scalar1, IndexType, ADims>::get(linearIndex, a) : 0;
// Convert `linearIndex` into an offset of `b`
const IndexType bOffset = static_cast<int64_t>(sizeof...(Offsets)) < n ?
detail::IndexToOffset<scalar2, IndexType, BDims>::get(linearIndex, b) : 0;
ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, remaining_steps - 1, const IndexType, Offsets...>::apply(
a, b, op, n, linearIndex + 1, aOffsets..., aOffset, bOffsets..., bOffset
);
}
};
// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
// We don't need to pass in how many elements need to processed in this case.
template <typename Op,
typename scalar1,
typename scalar2,
typename IndexType,
int ADims,
int BDims,
typename Offset>
struct ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, 0, Offset> {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
detail::TensorInfo<scalar2, IndexType> &b,
const Op &op, int /*n*/, IndexType /*linearIndex*/,
Offset aOffset, Offset bOffset) {
op(a.data[aOffset], b.data[bOffset]);
}
};
template <typename Op,
typename scalar1,
typename scalar2,
typename IndexType,
int ADims,
int BDims,
typename... Offsets>
struct ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, 0, Offsets...> {
__device__ __forceinline__
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
detail::TensorInfo<scalar2, IndexType> &b,
const Op &op, int n, IndexType linearIndex,
Offsets... aOffsets, Offsets... bOffsets) {
op(n, a.data[aOffsets]..., b.data[bOffsets]...);
}
};
template <typename Op,
typename scalar1,
typename scalar2,
typename IndexType,
int ADims, int BDims,
int step,
int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm)
#endif
__global__ void
kernelPointwiseApply2(detail::TensorInfo<scalar1, IndexType> a,
detail::TensorInfo<scalar2, IndexType> b,
IndexType totalElements,
const Op op) {
for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
linearIndex < totalElements;
linearIndex += gridDim.x * blockDim.x * step) {
ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, step>::apply(
a, b, op, ::min(step, static_cast<int>(totalElements - linearIndex)),
linearIndex);
}
}
} // anonymous namespace
template <typename scalar1, typename scalar2, int step, typename Op,
int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
inline bool CUDA_tensor_apply2(at::TensorBase a,
at::TensorBase b,
const Op op,
TensorArgType aType = TensorArgType::ReadWrite,
TensorArgType bType = TensorArgType::ReadOnly) {
TORCH_CHECK(a.device().is_cuda() && b.device().is_cuda(),
"CUDA_tensor_apply2: Expected tensors to have CUDA DeviceType, but got "
"tensors with type ", a.device().type(), " and ", b.device().type());
int64_t totalElements = a.numel();
if (totalElements != b.numel()) {
return false;
}
if (a.dim() > MAX_TENSORINFO_DIMS ||
b.dim() > MAX_TENSORINFO_DIMS) {
return false;
}
if (a.numel() == 0) {
// Empty tensor; do nothing
return true;
}
const dim3 block = getApplyBlock(max_threads_per_block);
dim3 grid;
auto curDevice = current_device();
if (curDevice == -1) return false;
if (!getApplyGrid<step>(totalElements, grid, curDevice, max_threads_per_block)) {
return false;
}
/*
Expands readable/writable tensors whose indices may be "overlapped."
This ensures that each element of the tensor is operated on once and only
once.
*/
TensorBase oldA;
TensorBase oldB;
if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) {
// Must perform in contiguous space
oldA = std::exchange(a, a.contiguous());
}
if (bType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(b)) {
// Must perform in contiguous space
oldB = std::exchange(b, b.contiguous());
}
// It is possible that the tensor dimensions are able to be collapsed,
// and thus we can reduce the actual code complexity of the copy by
// exploiting this knowledge statically, since the div/mod is the
// most expensive part of the operation, more so than memory accesses.
// For instance, when copying a non-contiguous to a contiguous tensor
// (or vice versa), the contiguous tensor can be collapsed to one
// dimension, and the loop to translate the linear index to the array
// index can be similarly collapsed. That is what this unrolling is for.
#define HANDLE_CASE(TYPE, A, B) \
kernelPointwiseApply2<Op, \
scalar1, \
scalar2, \
TYPE, A, B, step, \
max_threads_per_block, \
min_blocks_per_sm> \
<<<grid, block, 0, at::cuda::getCurrentCUDAStream(curDevice)>>>( \
aInfo, bInfo, static_cast<TYPE>(totalElements), op); \
C10_CUDA_KERNEL_LAUNCH_CHECK();
#define HANDLE_B_CASE(TYPE, A, B) { \
switch (B) { \
case 1: \
HANDLE_CASE(TYPE, A, 1); \
break; \
case 2: \
HANDLE_CASE(TYPE, A, 2); \
break; \
default: \
HANDLE_CASE(TYPE, A, -1); \
break; \
} \
}
#define HANDLE_A_CASE(TYPE, A, B) { \
switch (A) { \
case 1: \
HANDLE_B_CASE(TYPE, 1, B); \
break; \
case 2: \
HANDLE_B_CASE(TYPE, 2, B); \
break; \
default: \
HANDLE_B_CASE(TYPE, -1, B); \
break; \
} \
}
if (detail::canUse32BitIndexMath(a) &&
detail::canUse32BitIndexMath(b)) {
detail::TensorInfo<scalar1, unsigned int> aInfo =
detail::getTensorInfo<scalar1, unsigned int>(a);
detail::TensorInfo<scalar2, unsigned int> bInfo =
detail::getTensorInfo<scalar2, unsigned int>(b);
rearrangeDims(&aInfo, &bInfo);
aInfo.collapseDims();
bInfo.collapseDims();
HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims);
} else {
detail::TensorInfo<scalar1, uint64_t> aInfo =
detail::getTensorInfo<scalar1, uint64_t>(a);
detail::TensorInfo<scalar2, uint64_t> bInfo =
detail::getTensorInfo<scalar2, uint64_t>(b);
rearrangeDims(&aInfo, &bInfo);
aInfo.collapseDims();
bInfo.collapseDims();
/*
Only instantiates the all 1D special case and the fallback all nD case for
large (64-bit indexed) tensors to reduce compilation time.
*/
if (aInfo.dims == 1 && bInfo.dims == 1) {
HANDLE_CASE(uint64_t, 1, 1);
} else {
HANDLE_CASE(uint64_t, -1, -1);
}
}
#undef HANDLE_CASE
#undef HANDLE_B_CASE
#undef HANDLE_A_CASE
if (oldA.defined()) {
at::native::copy_ignoring_overlaps(oldA, a);
}
if (oldB.defined()) {
at::native::copy_ignoring_overlaps(oldB, b);
}
return true;
}
/* Provides default step = 1 to CUDA_tensor_apply2. */
template <typename scalar1, typename scalar2, typename Op,
int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
inline bool CUDA_tensor_apply2(const at::TensorBase &a,
const at::TensorBase &b,
const Op op,
TensorArgType aType = TensorArgType::ReadWrite,
TensorArgType bType = TensorArgType::ReadOnly) {
return CUDA_tensor_apply2<scalar1, scalar2, 1, Op,
max_threads_per_block, min_blocks_per_sm>(a, b, op, aType, bType);
}
} // namespace at::cuda

View File

@ -0,0 +1,358 @@
#pragma once
/*
Provides a subset of CUDA BLAS functions as templates:
gemm<Dtype>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
ldc)
gemv<Dtype>(transa, m, n, alpha, a, lda, x, incx, beta, y, incy)
dot<Dtype>(n, x, incx, y, incy, result)
where Dtype is double, float, at::Half or at::BFloat16 (ROCm, NOT for dot).
The functions are available in at::cuda::blas namespace.
*/
#include <ATen/cuda/CUDAContext.h>
#include <ATen/OpMathType.h>
namespace at::cuda::blas {
// RAII guard that sets the CuBLAS pointer mode and restores it to
// its previous value when the guard is destroyed
class PointerModeGuard {
public:
PointerModeGuard(cublasHandle_t handle, cublasPointerMode_t mode) :
handle(handle) {
TORCH_CUDABLAS_CHECK(cublasGetPointerMode(handle, &previous_mode));
TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, mode));
}
~PointerModeGuard() {
cublasSetPointerMode(handle, previous_mode);
}
private:
cublasHandle_t handle;
cublasPointerMode_t previous_mode;
};
/* LEVEL 3 BLAS FUNCTIONS */
#define CUDABLAS_GEMM_ARGTYPES(Dtype) \
char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, at::opmath_type<Dtype> beta,\
Dtype *c, int64_t ldc
#define CUDABLAS_GEMM_ARGS(Dtype) transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc
template <typename Dtype>
inline void gemm(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm: not implemented");
}
template <>
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double));
template <>
void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float));
template <>
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
template <>
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
template <>
void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
template <typename Dtype>
inline void gemm_internal(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm_internal: not implemented");
}
template <>
void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double));
template <>
void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float));
template <>
void gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
template <>
void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
template <>
void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
template <>
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
enum GEMMAndBiasActivationEpilogue {
None,
RELU,
GELU,
};
// NOTE: GELU activation is not supported prior to CUDA 11.4 and will
// do nothing if passed in that case.
template <typename Dtype>
void gemm_and_bias(
bool transpose_mat1,
bool transpose_mat2,
int64_t m,
int64_t n,
int64_t k,
at::opmath_type<Dtype> alpha_val,
const Dtype* mat1_ptr,
int64_t mat1_ld,
const Dtype* mat2_ptr,
int64_t mat2_ld,
const Dtype* bias,
Dtype* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation = GEMMAndBiasActivationEpilogue::None);
void int8_gemm(
bool transpose_mat1,
bool transpose_mat2,
int64_t m,
int64_t n,
int64_t k,
const int8_t* mat1_ptr,
int64_t mat1_ld,
const int8_t* mat2_ptr,
int64_t mat2_ld,
int32_t* result_ptr,
int64_t result_ld);
void scaled_gemm(
char transa,
char transb,
int64_t m,
int64_t n,
int64_t k,
const void* mat1_ptr,
const void* mat1_scale_ptr,
int64_t mat1_ld,
ScalarType mat1_dtype,
const void* mat2_ptr,
const void* mat2_scale_ptr,
int64_t mat2_ld,
ScalarType mat2_dtype,
const void* bias_ptr,
ScalarType bias_dtype,
void* result_ptr,
const void* result_scale_ptr,
int64_t result_ld,
ScalarType result_dtype,
void* amax_ptr,
bool use_fast_accum);
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
const Dtype *a, int64_t lda, int64_t stridea, \
const Dtype *b, int64_t ldb, int64_t strideb, \
at::opmath_type<Dtype> beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches
#define CUDABLAS_BGEMM_ARGS(Dtype) \
transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, beta, c, ldc, stridec, num_batches
template <typename Dtype>
inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm: not implemented");
}
template <>
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double));
template <>
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float));
template <>
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
template <>
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
template <>
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
template <>
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
template <typename Dtype>
inline void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm_internal: not implemented");
}
template <>
void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double));
template <>
void bgemm_internal<float>(CUDABLAS_BGEMM_ARGTYPES(float));
template <>
void bgemm_internal<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
template <>
void bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
template <>
void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
template <>
void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
#define CUDABLAS_TRSM_ARGTYPES(Dtype) \
cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \
const Dtype *alpha, const Dtype *A, int lda, Dtype *B, int ldb
template <typename Dtype>
inline void trsm(CUDABLAS_TRSM_ARGTYPES(Dtype)) {
static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsm: not implemented");
}
template <>
TORCH_CUDA_CU_API void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float));
template <>
TORCH_CUDA_CU_API void trsm<double>(CUDABLAS_TRSM_ARGTYPES(double));
template <>
TORCH_CUDA_CU_API void trsm<c10::complex<float>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<float>));
template <>
TORCH_CUDA_CU_API void trsm<c10::complex<double>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<double>));
#define CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype) \
cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \
const Dtype *alpha, Dtype *A[], int lda, Dtype *B[], int ldb, \
int batchCount
template <typename Dtype>
inline void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype)) {
static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsmBatched: not implemented");
}
template <>
TORCH_CUDA_CU_API void trsmBatched<float>(CUDABLAS_TRSM_BATCHED_ARGTYPES(float));
template <>
TORCH_CUDA_CU_API void trsmBatched<double>(CUDABLAS_TRSM_BATCHED_ARGTYPES(double));
template <>
TORCH_CUDA_CU_API void trsmBatched<c10::complex<float>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<float>));
template <>
TORCH_CUDA_CU_API void trsmBatched<c10::complex<double>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<double>));
/* LEVEL 2 BLAS FUNCTIONS */
#define CUDABLAS_GEMV_ARGTYPES(Dtype) \
char trans, int64_t m, int64_t n, Dtype alpha, const Dtype *a, int64_t lda, \
const Dtype *x, int64_t incx, Dtype beta, Dtype *y, int64_t incy
template <typename Dtype>
inline void gemv(CUDABLAS_GEMV_ARGTYPES(Dtype)) {
static_assert(false&&sizeof(Dtype), "at::cuda::blas::gemv: not implemented");
}
template <>
void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double));
template <>
void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float));
template <>
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>));
template <>
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>));
template <>
void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half));
template <>
void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16));
/* LEVEL 1 BLAS FUNCTIONS */
#define CUDABLAS_DOT_ARGTYPES(Dtype) \
cublasHandle_t handle, int n, const Dtype *x, int incx, const Dtype *y, \
int incy, Dtype *result
template <typename Dtype>
inline void dot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
static_assert(false&&sizeof(Dtype),"at::cuda::blas::dot: not implemented");
}
template <>
void dot<double>(CUDABLAS_DOT_ARGTYPES(double));
template <>
void dot<float>(CUDABLAS_DOT_ARGTYPES(float));
template <>
void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half));
template <>
void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16));
template <>
void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
template <>
void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
template <typename Dtype>
inline void vdot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
static_assert(false&&sizeof(Dtype),"at::cuda::blas::vdot: not implemented");
}
template <>
void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
template <>
void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
#define CUDABLAS_GETRS_ARGTYPES(Dtype) \
cublasHandle_t handle, cublasOperation_t trans, \
int n, int nrhs, Dtype** dA_array, int lda, int* ipiv_array, \
Dtype** dB_array, int ldb, int* info_array, int batchsize
template<class Dtype>
void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) {
static_assert(false&&sizeof(Dtype),"at::cuda::blas::getrsBatched: not implemented");
}
template<>
TORCH_CUDA_CU_API void getrsBatched<float>(CUDABLAS_GETRS_ARGTYPES(float));
template<>
TORCH_CUDA_CU_API void getrsBatched<double>(CUDABLAS_GETRS_ARGTYPES(double));
template<>
TORCH_CUDA_CU_API void getrsBatched<c10::complex<float>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<float>));
template<>
TORCH_CUDA_CU_API void getrsBatched<c10::complex<double>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<double>));
#define CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype) \
cublasHandle_t handle, int m, int n, Dtype **A_array, int lda, \
Dtype **tau_array, int *info, int batchsize
template <class Dtype>
void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) {
static_assert(false&&sizeof(Dtype), "at::cuda::blas::geqrfBatched: not implemented");
}
template <>
TORCH_CUDA_CU_API void geqrfBatched<float>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float));
template <>
TORCH_CUDA_CU_API void geqrfBatched<double>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double));
template <>
TORCH_CUDA_CU_API void geqrfBatched<c10::complex<double>>(
CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<double>));
template <>
TORCH_CUDA_CU_API void geqrfBatched<c10::complex<float>>(
CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<float>));
#define CUDABLAS_GETRF_ARGTYPES(Dtype) \
int n, Dtype** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize
template<class Dtype>
void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) {
TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not implemented");
}
template<>
TORCH_CUDA_CU_API void getrfBatched<float>(CUDABLAS_GETRF_ARGTYPES(float));
template<>
TORCH_CUDA_CU_API void getrfBatched<double>(CUDABLAS_GETRF_ARGTYPES(double));
template<>
TORCH_CUDA_CU_API void getrfBatched<c10::complex<double>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<double>));
template<>
TORCH_CUDA_CU_API void getrfBatched<c10::complex<float>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<float>));
#define CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype) \
cublasHandle_t handle, cublasOperation_t trans, int m, int n, int nrhs, Dtype** dA_array, int ldda, Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize
template <class Dtype>
void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)) {
static_assert(false&&sizeof(Dtype),"at::cuda::blas::gelsBatched: not implemented");
}
template<>
TORCH_CUDA_CU_API void gelsBatched<double>(CUDABLAS_GELS_BATCHED_ARGTYPES(double));
template<>
TORCH_CUDA_CU_API void gelsBatched<float>(CUDABLAS_GELS_BATCHED_ARGTYPES(float));
template<>
TORCH_CUDA_CU_API void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<double>));
template<>
TORCH_CUDA_CU_API void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>));
} // namespace at::cuda::blas

View File

@ -0,0 +1,9 @@
#pragma once
#include <ATen/cuda/CUDAContextLight.h>
// Preserved for BC, as many files depend on these includes
#include <ATen/Context.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/Logging.h>
#include <ATen/cuda/Exceptions.h>

View File

@ -0,0 +1,99 @@
#pragma once
// Light-weight version of CUDAContext.h with fewer transitive includes
#include <cstdint>
#include <cuda_runtime_api.h>
#include <cusparse.h>
#include <cublas_v2.h>
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
// added bf16 support
#include <cublasLt.h>
#ifdef CUDART_VERSION
#include <cusolverDn.h>
#endif
#if defined(USE_CUDSS)
#include <cudss.h>
#endif
#if defined(USE_ROCM)
#include <hipsolver/hipsolver.h>
#endif
#include <c10/core/Allocator.h>
#include <c10/cuda/CUDAFunctions.h>
namespace c10 {
struct Allocator;
}
namespace at::cuda {
/*
A common CUDA interface for ATen.
This interface is distinct from CUDAHooks, which defines an interface that links
to both CPU-only and CUDA builds. That interface is intended for runtime
dispatch and should be used from files that are included in both CPU-only and
CUDA builds.
CUDAContext, on the other hand, should be preferred by files only included in
CUDA builds. It is intended to expose CUDA functionality in a consistent
manner.
This means there is some overlap between the CUDAContext and CUDAHooks, but
the choice of which to use is simple: use CUDAContext when in a CUDA-only file,
use CUDAHooks otherwise.
Note that CUDAContext simply defines an interface with no associated class.
It is expected that the modules whose functions compose this interface will
manage their own state. There is only a single CUDA context/state.
*/
/**
* DEPRECATED: use device_count() instead
*/
inline int64_t getNumGPUs() {
return c10::cuda::device_count();
}
/**
* CUDA is available if we compiled with CUDA, and there are one or more
* devices. If we compiled with CUDA but there is a driver problem, etc.,
* this function will report CUDA is not available (rather than raise an error.)
*/
inline bool is_available() {
return c10::cuda::device_count() > 0;
}
TORCH_CUDA_CPP_API cudaDeviceProp* getCurrentDeviceProperties();
TORCH_CUDA_CPP_API int warp_size();
TORCH_CUDA_CPP_API cudaDeviceProp* getDeviceProperties(c10::DeviceIndex device);
TORCH_CUDA_CPP_API bool canDeviceAccessPeer(
c10::DeviceIndex device,
c10::DeviceIndex peer_device);
TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
/* Handles */
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
#if defined(CUDART_VERSION) || defined(USE_ROCM)
TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();
#endif
#if defined(USE_CUDSS)
TORCH_CUDA_CPP_API cudssHandle_t getCurrentCudssHandle();
#endif
} // namespace at::cuda

View File

@ -0,0 +1,105 @@
#pragma once
#include <c10/core/ScalarType.h>
#include <cuda.h>
#include <library_types.h>
namespace at::cuda {
template <typename scalar_t>
cudaDataType getCudaDataType() {
static_assert(false && sizeof(scalar_t), "Cannot convert type to cudaDataType.");
return {};
}
template<> inline cudaDataType getCudaDataType<at::Half>() {
return CUDA_R_16F;
}
template<> inline cudaDataType getCudaDataType<float>() {
return CUDA_R_32F;
}
template<> inline cudaDataType getCudaDataType<double>() {
return CUDA_R_64F;
}
template<> inline cudaDataType getCudaDataType<c10::complex<c10::Half>>() {
return CUDA_C_16F;
}
template<> inline cudaDataType getCudaDataType<c10::complex<float>>() {
return CUDA_C_32F;
}
template<> inline cudaDataType getCudaDataType<c10::complex<double>>() {
return CUDA_C_64F;
}
template<> inline cudaDataType getCudaDataType<uint8_t>() {
return CUDA_R_8U;
}
template<> inline cudaDataType getCudaDataType<int8_t>() {
return CUDA_R_8I;
}
template<> inline cudaDataType getCudaDataType<int>() {
return CUDA_R_32I;
}
template<> inline cudaDataType getCudaDataType<int16_t>() {
return CUDA_R_16I;
}
template<> inline cudaDataType getCudaDataType<int64_t>() {
return CUDA_R_64I;
}
template<> inline cudaDataType getCudaDataType<at::BFloat16>() {
return CUDA_R_16BF;
}
inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) {
switch (scalar_type) {
case c10::ScalarType::Byte:
return CUDA_R_8U;
case c10::ScalarType::Char:
return CUDA_R_8I;
case c10::ScalarType::Int:
return CUDA_R_32I;
case c10::ScalarType::Half:
return CUDA_R_16F;
case c10::ScalarType::Float:
return CUDA_R_32F;
case c10::ScalarType::Double:
return CUDA_R_64F;
case c10::ScalarType::ComplexHalf:
return CUDA_C_16F;
case c10::ScalarType::ComplexFloat:
return CUDA_C_32F;
case c10::ScalarType::ComplexDouble:
return CUDA_C_64F;
case c10::ScalarType::Short:
return CUDA_R_16I;
case c10::ScalarType::Long:
return CUDA_R_64I;
case c10::ScalarType::BFloat16:
return CUDA_R_16BF;
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
case c10::ScalarType::Float8_e4m3fn:
return CUDA_R_8F_E4M3;
case c10::ScalarType::Float8_e5m2:
return CUDA_R_8F_E5M2;
#endif
#if defined(USE_ROCM)
#if defined(HIP_NEW_TYPE_ENUMS)
case c10::ScalarType::Float8_e4m3fnuz:
return HIP_R_8F_E4M3_FNUZ;
case c10::ScalarType::Float8_e5m2fnuz:
return HIP_R_8F_E5M2_FNUZ;
#else
case c10::ScalarType::Float8_e4m3fnuz:
return static_cast<hipDataType>(1000);
case c10::ScalarType::Float8_e5m2fnuz:
return static_cast<hipDataType>(1001);
#endif
#endif
default:
TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.")
}
}
} // namespace at::cuda

View File

@ -0,0 +1,23 @@
#pragma once
#include <ATen/cuda/Exceptions.h>
#include <cuda.h>
#include <cuda_runtime.h>
namespace at::cuda {
inline Device getDeviceFromPtr(void* ptr) {
cudaPointerAttributes attr{};
AT_CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr));
#if !defined(USE_ROCM)
TORCH_CHECK(attr.type != cudaMemoryTypeUnregistered,
"The specified pointer resides on host memory and is not registered with any CUDA device.");
#endif
return {c10::DeviceType::CUDA, static_cast<DeviceIndex>(attr.device)};
}
} // namespace at::cuda

View File

@ -0,0 +1,211 @@
#pragma once
#include <ATen/cuda/ATenCUDAGeneral.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/util/Exception.h>
#include <cuda_runtime_api.h>
#include <cstdint>
#include <utility>
namespace at::cuda {
/*
* CUDAEvents are movable not copyable wrappers around CUDA's events.
*
* CUDAEvents are constructed lazily when first recorded unless it is
* reconstructed from a cudaIpcEventHandle_t. The event has a device, and this
* device is acquired from the first recording stream. However, if reconstructed
* from a handle, the device should be explicitly specified; or if ipc_handle() is
* called before the event is ever recorded, it will use the current device.
* Later streams that record the event must match this device.
*/
struct TORCH_CUDA_CPP_API CUDAEvent {
// Constructors
// Default value for `flags` is specified below - it's cudaEventDisableTiming
CUDAEvent() noexcept = default;
CUDAEvent(unsigned int flags) noexcept : flags_{flags} {}
CUDAEvent(
DeviceIndex device_index, const cudaIpcEventHandle_t* handle) : device_index_(device_index) {
CUDAGuard guard(device_index_);
AT_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle));
is_created_ = true;
}
// Note: event destruction done on creating device to avoid creating a
// CUDA context on other devices.
~CUDAEvent() {
try {
if (is_created_) {
CUDAGuard guard(device_index_);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
}
AT_CUDA_CHECK(cudaEventDestroy(event_));
}
} catch (...) { /* No throw */ }
}
CUDAEvent(const CUDAEvent&) = delete;
CUDAEvent& operator=(const CUDAEvent&) = delete;
CUDAEvent(CUDAEvent&& other) noexcept { moveHelper(std::move(other)); }
CUDAEvent& operator=(CUDAEvent&& other) noexcept {
if (this != &other) {
moveHelper(std::move(other));
}
return *this;
}
operator cudaEvent_t() const { return event(); }
// Less than operator (to allow use in sets)
friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) {
return left.event_ < right.event_;
}
std::optional<at::Device> device() const {
if (is_created_) {
return at::Device(at::kCUDA, device_index_);
} else {
return {};
}
}
bool isCreated() const { return is_created_; }
DeviceIndex device_index() const {return device_index_;}
cudaEvent_t event() const { return event_; }
// Note: cudaEventQuery can be safely called from any device
bool query() const {
if (!is_created_) {
return true;
}
cudaError_t err = cudaEventQuery(event_);
if (err == cudaSuccess) {
return true;
} else if (err != cudaErrorNotReady) {
C10_CUDA_CHECK(err);
} else {
// ignore and clear the error if not ready
(void)cudaGetLastError();
}
return false;
}
void record() { record(getCurrentCUDAStream()); }
void recordOnce(const CUDAStream& stream) {
if (!was_recorded_) record(stream);
}
// Note: cudaEventRecord must be called on the same device as the event.
void record(const CUDAStream& stream) {
if (!is_created_) {
createEvent(stream.device_index());
}
TORCH_CHECK(device_index_ == stream.device_index(), "Event device ", device_index_,
" does not match recording stream's device ", stream.device_index(), ".");
CUDAGuard guard(device_index_);
AT_CUDA_CHECK(cudaEventRecord(event_, stream));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_record(at::kCUDA,
reinterpret_cast<uintptr_t>(event_),
reinterpret_cast<uintptr_t>(stream.stream())
);
}
was_recorded_ = true;
}
// Note: cudaStreamWaitEvent must be called on the same device as the stream.
// The event has no actual GPU resources associated with it.
void block(const CUDAStream& stream) {
if (is_created_) {
CUDAGuard guard(stream.device_index());
AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, 0));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_wait(at::kCUDA,
reinterpret_cast<uintptr_t>(event_),
reinterpret_cast<uintptr_t>(stream.stream())
);
}
}
}
// Note: cudaEventElapsedTime can be safely called from any device
float elapsed_time(const CUDAEvent& other) const {
TORCH_CHECK(is_created_ && other.isCreated(),
"Both events must be recorded before calculating elapsed time.");
float time_ms = 0;
// We do not strictly have to set the device index to the same as our event,
// but if we don't and the current device is not initialized, it will
// create a new cuda context, which will consume a lot of memory.
CUDAGuard guard(device_index_);
// raise cudaErrorNotReady if either event is recorded but not yet completed
AT_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_));
return time_ms;
}
// Note: cudaEventSynchronize can be safely called from any device
void synchronize() const {
if (is_created_) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_synchronization(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
}
AT_CUDA_CHECK(cudaEventSynchronize(event_));
}
}
// Note: cudaIpcGetEventHandle must be called on the same device as the event
void ipc_handle(cudaIpcEventHandle_t * handle) {
if (!is_created_) {
// this CUDAEvent object was initially constructed from flags but event_
// is not created yet.
createEvent(getCurrentCUDAStream().device_index());
}
CUDAGuard guard(device_index_);
AT_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_));
}
private:
unsigned int flags_ = cudaEventDisableTiming;
bool is_created_ = false;
bool was_recorded_ = false;
DeviceIndex device_index_ = -1;
cudaEvent_t event_{};
void createEvent(DeviceIndex device_index) {
device_index_ = device_index;
CUDAGuard guard(device_index_);
AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_creation(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
}
is_created_ = true;
}
void moveHelper(CUDAEvent&& other) {
std::swap(flags_, other.flags_);
std::swap(is_created_, other.is_created_);
std::swap(was_recorded_, other.was_recorded_);
std::swap(device_index_, other.device_index_);
std::swap(event_, other.event_);
}
};
} // namespace at::cuda

View File

@ -0,0 +1,181 @@
#pragma once
#include <ATen/Context.h>
#include <ATen/core/Generator.h>
#include <ATen/core/TensorBase.h>
#include <ATen/cuda/PhiloxCudaState.h>
#include <atomic>
#include <limits>
#include <memory>
#include <unordered_set>
namespace at {
namespace cuda {
struct CUDAGraph;
}
/**
* Note [CUDA Graph-safe RNG states]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
*
* Strategy:
* ~~~~~~~~~
* (It helps to look at
* cuda/detail/PhiloxCudaStateRaw.cuh and
* cuda/detail/UnpackRaw.cuh
* while you read this.)
*
* A CUDA graph containing multiple RNG ops behaves like a
* single giant kernel from the perspective of ops external
* to the graph. During graph capture, logic in CUDAGeneratorImpl
* records the total of all offset increments that occur in the
* graphed region, and records the final total as the offset for
* the entire graph.
*
* When the graph reruns, the logic that reruns it
* increments this device's CUDA generator's offset
* by that total.
*
* Meanwhile, within the graph, at capture time, instead of
* populating PhiloxCudaStates with the uint64_t offset pulled
* directly from the global state, PhiloxCudaState uses a pointer
* to a one-element stream-local int64_t device tensor
* holding an initial offset value, and a uint64_t holding an
* intra-graph offset. (The intra-graph offset starts from zero
* when capture begins.) In each consumer kernel,
* at::cuda::philox::unpack computes the offset to use for this kernel
* as intra-graph offset + *initial offset.
*
* When the graph reruns, the logic that reruns it first
* fill_s the initial offset tensor with this device's
* CUDA generator's current offset.
*
* The control flow above ensures graphed execution is bitwise
* identical to eager execution as long as RNG ops are enqueued
* from a single thread, even if RNG ops and graphs containing
* RNG ops are enqueued and run simultaneously on multiple streams.
*
* Usage:
* ~~~~~~
* PhiloxCudaState in this file, and unpack() in
* cuda/CUDAGraphsUtils.cuh allow non-divergent use of
* CUDAGeneratorImpl whether graph capture is underway or not.
*
* Each PhiloxCudaState instance should be used for one and only one
* consumer kernel.
*
* Example (see e.g. native/cuda/Dropout.cu):
*
* #include <ATen/cuda/CUDAGeneratorImpl.h>
* #include <ATen/cuda/CUDAGraphsUtils.cuh>
*
* __global__ void kernel(..., PhiloxCudaState philox_args) {
* auto seeds = at::cuda::philox::unpack(philox_args);
* IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
* curandStatePhilox4_32_10_t state;
* curand_init(std::get<0>(seeds), // seed
* idx, // per-thread subsequence
* std::get<1>(seeds), // offset in subsequence
* &state);
* ...
* }
*
* host_caller(...) {
* PhiloxCudaState rng_engine_inputs;
* {
* // See Note [Acquire lock when using random generators]
* std::lock_guard<std::mutex> lock(gen->mutex_);
*
* // gen could be HostState or DevState here! No divergent code needed!
* rng_engine_inputs = gen->philox_cuda_state(offset_increment);
* }
* kernel<<<...>>>(..., rng_engine_inputs);
* }
*
*/
struct CUDAGeneratorState : public c10::intrusive_ptr_target {
uint64_t seed_;
uint64_t philox_offset_per_thread_;
uint32_t offset_intragraph_;
bool capturing_{};
std::unordered_set<cuda::CUDAGraph*> registered_graphs_;
at::TensorBase seed_extragraph_{};
at::TensorBase offset_extragraph_{};
CUDAGeneratorState(
uint64_t seed = default_rng_seed_val,
uint64_t philox_offset_per_thread = 0,
uint32_t offset_intragraph = 0)
: seed_(seed),
philox_offset_per_thread_(philox_offset_per_thread),
offset_intragraph_(offset_intragraph) {}
void increase(uint64_t increment);
void register_graph(cuda::CUDAGraph* graph);
void unregister_graph(cuda::CUDAGraph* graph);
void capture_prologue();
// capture_epilogue returns the wholegraph_increment
uint64_t capture_epilogue();
void replay_prologue(uint64_t wholegraph_increment);
c10::intrusive_ptr<CUDAGeneratorState> clone();
};
struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl {
// Constructors
CUDAGeneratorImpl(DeviceIndex device_index = -1);
CUDAGeneratorImpl(
DeviceIndex device_index,
c10::intrusive_ptr<CUDAGeneratorState> state_);
~CUDAGeneratorImpl() override = default;
// CUDAGeneratorImpl methods
std::shared_ptr<CUDAGeneratorImpl> clone() const;
void set_current_seed(uint64_t seed) override;
void set_offset(uint64_t offset) override;
uint64_t get_offset() const override;
uint64_t current_seed() const override;
uint64_t seed() override;
void set_state(const c10::TensorImpl& new_state) override;
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
void graphsafe_set_state(
const c10::intrusive_ptr<GeneratorImpl>& state) override;
c10::intrusive_ptr<c10::GeneratorImpl> graphsafe_get_state() const override;
void set_philox_offset_per_thread(uint64_t offset);
uint64_t philox_offset_per_thread() const;
void register_graph(cuda::CUDAGraph* graph);
void unregister_graph(cuda::CUDAGraph* graph);
// Generates a PhiloxCudaState with a specified increment, and increment
// current state
PhiloxCudaState philox_cuda_state(uint64_t increment);
bool reset_rnn_state() {
return !no_reset_rnn_state_.test_and_set();
}
// Temporarily accommodates call sites that use philox_engine_inputs.
// Allows incremental refactor of call sites to use philox_cuda_state.
std::pair<uint64_t, uint64_t> philox_engine_inputs(uint64_t increment);
static c10::DeviceType device_type();
private:
CUDAGeneratorImpl* clone_impl() const override;
c10::intrusive_ptr<CUDAGeneratorState> state_;
std::atomic_flag no_reset_rnn_state_;
};
namespace cuda::detail {
TORCH_CUDA_CPP_API const Generator& getDefaultCUDAGenerator(
DeviceIndex device_index = -1);
TORCH_CUDA_CPP_API Generator createCUDAGenerator(DeviceIndex device_index = -1);
} // namespace cuda::detail
} // namespace at

View File

@ -0,0 +1,89 @@
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/Device.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/flat_hash_map.h>
namespace at {
struct Generator;
struct CUDAGeneratorImpl;
struct CUDAGeneratorState;
namespace cuda {
// Standalone way to get a unique mempool id usable as a pool=... argument
// to CUDAGraph::capture_begin
TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle();
struct TORCH_CUDA_CPP_API CUDAGraph {
CUDAGraph();
~CUDAGraph();
static void inc_pending_event_queries();
static void dec_pending_event_queries();
static int num_pending_event_queries();
// See Note [Explicit Registration of Generators to the CUDA Graph]
void register_generator_state(c10::intrusive_ptr<at::CUDAGeneratorState> state);
void register_generator_state(const at::Generator& generator);
void capture_begin(
MempoolId_t pool = {0, 0},
cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal);
void capture_end();
void replay();
void reset();
MempoolId_t pool();
void enable_debug_mode();
void debug_dump(const std::string& debug_path);
protected:
cudaGraph_t graph_ = nullptr;
cudaGraphExec_t graph_exec_ = nullptr;
static std::atomic<int> pending_event_queries;
// internal states so reset() can do its best cleaning up
// Set to true in capture_end if cudaStreamEndCapture succeeded
// Set back to false soon after, when graph_ is consumed by cudaGraphInstantiate
// to create graph_exec_, then graph_ is deleted
bool has_graph_ = false;
// Set to true in capture_end if cudaGraphInstantiate succeeded
bool has_graph_exec_ = false;
// the ID assigned by cuda during graph capture,
// used to identify when a stream is participating in capture
CaptureId_t capture_id_ = -1;
// uuid used to request a particular private mempool from CUDACachingAllocator.
// By default, this will be set to {id_, 0}.
//
// If capture_begin is called with "pool=other_graph.pool()", this graph's mempool_id_
// will be set to the other graph's mempool_id_, and therefore share a mempool with the
// other graph.
//
// If capture_begin is called with "pool=handle" where "handle" came from graph_pool_handle(),
// it will share a mempool with any other captures that used "pool=handle".
//
// Sharing a mempool across graphs saves memory, and it's safe if you
// know you'll replay those graphs in the same order you captured them.
MempoolId_t mempool_id_;
// Stream on which capture began
at::cuda::CUDAStream capture_stream_;
// multiple generator states and their wholegraph_increments in this graph
// that are managed by the CUDA Graph
ska::flat_hash_map<c10::intrusive_ptr<at::CUDAGeneratorState>, uint64_t>
captured_generator_states_;
// Device where capture occurred. Right now, for simplicity, we require all ops
// in a capture to run on the same device, but this is a limitation of CUDAGraph,
// not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device
// captures if needed.
int capture_dev_;
};
} // namespace cuda
} // namespace at

View File

@ -0,0 +1,53 @@
#pragma once
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAEvent.h>
#include <ATen/cuda/PhiloxUtils.cuh>
#include <ATen/cuda/detail/CUDAHooks.h>
#include <ATen/detail/CUDAHooksInterface.h>
#include <c10/core/StreamGuard.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAGuard.h>
// c10/cuda/CUDAGraphsC10Utils.h has utils used by both c10 and aten.
// This file adds utils used by aten only.
namespace at::cuda {
using CaptureId_t = c10::cuda::CaptureId_t;
using CaptureStatus = c10::cuda::CaptureStatus;
// Use this version where you don't want to create a CUDA context if none exists.
inline CaptureStatus currentStreamCaptureStatus() {
// don't create a context if we don't have to
if (c10::cuda::hasPrimaryContext(c10::cuda::current_device())) {
return c10::cuda::currentStreamCaptureStatusMayInitCtx();
} else {
return CaptureStatus::None;
}
}
inline void assertNotCapturing(const std::string& attempt) {
auto status = currentStreamCaptureStatus();
TORCH_CHECK(status == CaptureStatus::None,
attempt,
" during CUDA graph capture. If you need this call to be captured, "
"please file an issue. "
"Current cudaStreamCaptureStatus: ",
status);
}
inline void errorIfCapturingCudnnBenchmark(const std::string& version_specific) {
auto status = currentStreamCaptureStatus();
TORCH_CHECK(status == CaptureStatus::None,
"Current cudaStreamCaptureStatus: ",
status,
"\nCapturing ",
version_specific,
"is prohibited. Possible causes of this error:\n"
"1. No warmup iterations occurred before capture.\n"
"2. The convolutions you're trying to capture use dynamic shapes, "
"in which case capturing them is generally prohibited.");
}
} // namespace at::cuda

View File

@ -0,0 +1,75 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#if defined(USE_ROCM)
#include <hipsparse/hipsparse-version.h>
#define HIPSPARSE_VERSION ((hipsparseVersionMajor*100000) + (hipsparseVersionMinor*100) + hipsparseVersionPatch)
#endif
// cuSparse Generic API added in CUDA 10.1
// Windows support added in CUDA 11.0
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && ((CUSPARSE_VERSION >= 10300) || (CUSPARSE_VERSION >= 11000 && defined(_WIN32)))
#define AT_USE_CUSPARSE_GENERIC_API() 1
#else
#define AT_USE_CUSPARSE_GENERIC_API() 0
#endif
// cuSparse Generic API descriptor pointers were changed to const in CUDA 12.0
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \
(CUSPARSE_VERSION < 12000)
#define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 1
#else
#define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 0
#endif
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \
(CUSPARSE_VERSION >= 12000)
#define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 1
#else
#define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 0
#endif
#if defined(USE_ROCM)
// hipSparse const API added in v2.4.0
#if HIPSPARSE_VERSION >= 200400
#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 1
#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0
#define AT_USE_HIPSPARSE_GENERIC_API() 1
#else
#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0
#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 1
#define AT_USE_HIPSPARSE_GENERIC_API() 1
#endif
#else // USE_ROCM
#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0
#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0
#define AT_USE_HIPSPARSE_GENERIC_API() 0
#endif // USE_ROCM
// cuSparse Generic API spsv function was added in CUDA 11.3.0
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500)
#define AT_USE_CUSPARSE_GENERIC_SPSV() 1
#else
#define AT_USE_CUSPARSE_GENERIC_SPSV() 0
#endif
// cuSparse Generic API spsm function was added in CUDA 11.3.1
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11600)
#define AT_USE_CUSPARSE_GENERIC_SPSM() 1
#else
#define AT_USE_CUSPARSE_GENERIC_SPSM() 0
#endif
// cuSparse Generic API sddmm function was added in CUDA 11.2.1 (cuSparse version 11400)
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11400)
#define AT_USE_CUSPARSE_GENERIC_SDDMM() 1
#else
#define AT_USE_CUSPARSE_GENERIC_SDDMM() 0
#endif
// BSR triangular solve functions were added in hipSPARSE 1.11.2 (ROCm 4.5.0)
#if defined(CUDART_VERSION) || defined(USE_ROCM)
#define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 1
#else
#define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 0
#endif

View File

@ -0,0 +1,318 @@
#pragma once
/*
Provides a subset of cuSPARSE functions as templates:
csrgeam2<scalar_t>(...)
where scalar_t is double, float, c10::complex<double> or c10::complex<float>.
The functions are available in at::cuda::sparse namespace.
*/
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDASparse.h>
namespace at::cuda::sparse {
#define CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t) \
cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \
const cusparseMatDescr_t descrA, int nnzA, \
const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \
const int *csrSortedColIndA, const scalar_t *beta, \
const cusparseMatDescr_t descrB, int nnzB, \
const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \
const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
const scalar_t *csrSortedValC, const int *csrSortedRowPtrC, \
const int *csrSortedColIndC, size_t *pBufferSizeInBytes
template <typename scalar_t>
inline void csrgeam2_bufferSizeExt(
CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::cuda::sparse::csrgeam2_bufferSizeExt: not implemented for ",
typeid(scalar_t).name());
}
template <>
void csrgeam2_bufferSizeExt<float>(
CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(float));
template <>
void csrgeam2_bufferSizeExt<double>(
CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(double));
template <>
void csrgeam2_bufferSizeExt<c10::complex<float>>(
CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<float>));
template <>
void csrgeam2_bufferSizeExt<c10::complex<double>>(
CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<double>));
#define CUSPARSE_CSRGEAM2_NNZ_ARGTYPES() \
cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, \
int nnzA, const int *csrSortedRowPtrA, const int *csrSortedColIndA, \
const cusparseMatDescr_t descrB, int nnzB, const int *csrSortedRowPtrB, \
const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *workspace
template <typename scalar_t>
inline void csrgeam2Nnz(CUSPARSE_CSRGEAM2_NNZ_ARGTYPES()) {
TORCH_CUDASPARSE_CHECK(cusparseXcsrgeam2Nnz(
handle,
m,
n,
descrA,
nnzA,
csrSortedRowPtrA,
csrSortedColIndA,
descrB,
nnzB,
csrSortedRowPtrB,
csrSortedColIndB,
descrC,
csrSortedRowPtrC,
nnzTotalDevHostPtr,
workspace));
}
#define CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t) \
cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \
const cusparseMatDescr_t descrA, int nnzA, \
const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \
const int *csrSortedColIndA, const scalar_t *beta, \
const cusparseMatDescr_t descrB, int nnzB, \
const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \
const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
scalar_t *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC, \
void *pBuffer
template <typename scalar_t>
inline void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::cuda::sparse::csrgeam2: not implemented for ",
typeid(scalar_t).name());
}
template <>
void csrgeam2<float>(CUSPARSE_CSRGEAM2_ARGTYPES(float));
template <>
void csrgeam2<double>(CUSPARSE_CSRGEAM2_ARGTYPES(double));
template <>
void csrgeam2<c10::complex<float>>(
CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<float>));
template <>
void csrgeam2<c10::complex<double>>(
CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<double>));
#define CUSPARSE_BSRMM_ARGTYPES(scalar_t) \
cusparseHandle_t handle, cusparseDirection_t dirA, \
cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, \
int kb, int nnzb, const scalar_t *alpha, \
const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
const scalar_t *B, int ldb, const scalar_t *beta, scalar_t *C, int ldc
template <typename scalar_t>
inline void bsrmm(CUSPARSE_BSRMM_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::cuda::sparse::bsrmm: not implemented for ",
typeid(scalar_t).name());
}
template <>
void bsrmm<float>(CUSPARSE_BSRMM_ARGTYPES(float));
template <>
void bsrmm<double>(CUSPARSE_BSRMM_ARGTYPES(double));
template <>
void bsrmm<c10::complex<float>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<float>));
template <>
void bsrmm<c10::complex<double>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<double>));
#define CUSPARSE_BSRMV_ARGTYPES(scalar_t) \
cusparseHandle_t handle, cusparseDirection_t dirA, \
cusparseOperation_t transA, int mb, int nb, int nnzb, \
const scalar_t *alpha, const cusparseMatDescr_t descrA, \
const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
int blockDim, const scalar_t *x, const scalar_t *beta, scalar_t *y
template <typename scalar_t>
inline void bsrmv(CUSPARSE_BSRMV_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::cuda::sparse::bsrmv: not implemented for ",
typeid(scalar_t).name());
}
template <>
void bsrmv<float>(CUSPARSE_BSRMV_ARGTYPES(float));
template <>
void bsrmv<double>(CUSPARSE_BSRMV_ARGTYPES(double));
template <>
void bsrmv<c10::complex<float>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<float>));
template <>
void bsrmv<c10::complex<double>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<double>));
#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
#define CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t) \
cusparseHandle_t handle, cusparseDirection_t dirA, \
cusparseOperation_t transA, int mb, int nnzb, \
const cusparseMatDescr_t descrA, scalar_t *bsrValA, \
const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
bsrsv2Info_t info, int *pBufferSizeInBytes
template <typename scalar_t>
inline void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::cuda::sparse::bsrsv2_bufferSize: not implemented for ",
typeid(scalar_t).name());
}
template <>
void bsrsv2_bufferSize<float>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(float));
template <>
void bsrsv2_bufferSize<double>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(double));
template <>
void bsrsv2_bufferSize<c10::complex<float>>(
CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<float>));
template <>
void bsrsv2_bufferSize<c10::complex<double>>(
CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<double>));
#define CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t) \
cusparseHandle_t handle, cusparseDirection_t dirA, \
cusparseOperation_t transA, int mb, int nnzb, \
const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
template <typename scalar_t>
inline void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::cuda::sparse::bsrsv2_analysis: not implemented for ",
typeid(scalar_t).name());
}
template <>
void bsrsv2_analysis<float>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(float));
template <>
void bsrsv2_analysis<double>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(double));
template <>
void bsrsv2_analysis<c10::complex<float>>(
CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<float>));
template <>
void bsrsv2_analysis<c10::complex<double>>(
CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<double>));
#define CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t) \
cusparseHandle_t handle, cusparseDirection_t dirA, \
cusparseOperation_t transA, int mb, int nnzb, const scalar_t *alpha, \
const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
bsrsv2Info_t info, const scalar_t *x, scalar_t *y, \
cusparseSolvePolicy_t policy, void *pBuffer
template <typename scalar_t>
inline void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::cuda::sparse::bsrsv2_solve: not implemented for ",
typeid(scalar_t).name());
}
template <>
void bsrsv2_solve<float>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(float));
template <>
void bsrsv2_solve<double>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(double));
template <>
void bsrsv2_solve<c10::complex<float>>(
CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<float>));
template <>
void bsrsv2_solve<c10::complex<double>>(
CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<double>));
#define CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t) \
cusparseHandle_t handle, cusparseDirection_t dirA, \
cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
int nnzb, const cusparseMatDescr_t descrA, scalar_t *bsrValA, \
const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
bsrsm2Info_t info, int *pBufferSizeInBytes
template <typename scalar_t>
inline void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::cuda::sparse::bsrsm2_bufferSize: not implemented for ",
typeid(scalar_t).name());
}
template <>
void bsrsm2_bufferSize<float>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(float));
template <>
void bsrsm2_bufferSize<double>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(double));
template <>
void bsrsm2_bufferSize<c10::complex<float>>(
CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<float>));
template <>
void bsrsm2_bufferSize<c10::complex<double>>(
CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<double>));
#define CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t) \
cusparseHandle_t handle, cusparseDirection_t dirA, \
cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
int nnzb, const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
template <typename scalar_t>
inline void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::cuda::sparse::bsrsm2_analysis: not implemented for ",
typeid(scalar_t).name());
}
template <>
void bsrsm2_analysis<float>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(float));
template <>
void bsrsm2_analysis<double>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(double));
template <>
void bsrsm2_analysis<c10::complex<float>>(
CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<float>));
template <>
void bsrsm2_analysis<c10::complex<double>>(
CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<double>));
#define CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t) \
cusparseHandle_t handle, cusparseDirection_t dirA, \
cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
int nnzb, const scalar_t *alpha, const cusparseMatDescr_t descrA, \
const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
int blockDim, bsrsm2Info_t info, const scalar_t *B, int ldb, \
scalar_t *X, int ldx, cusparseSolvePolicy_t policy, void *pBuffer
template <typename scalar_t>
inline void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::cuda::sparse::bsrsm2_solve: not implemented for ",
typeid(scalar_t).name());
}
template <>
void bsrsm2_solve<float>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(float));
template <>
void bsrsm2_solve<double>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(double));
template <>
void bsrsm2_solve<c10::complex<float>>(
CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<float>));
template <>
void bsrsm2_solve<c10::complex<double>>(
CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<double>));
#endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
} // namespace at::cuda::sparse

View File

@ -0,0 +1,288 @@
#pragma once
#include <ATen/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDASparse.h>
#include <c10/core/ScalarType.h>
#if defined(USE_ROCM)
#include <type_traits>
#endif
namespace at::cuda::sparse {
template <typename T, cusparseStatus_t (*destructor)(T*)>
struct CuSparseDescriptorDeleter {
void operator()(T* x) {
if (x != nullptr) {
TORCH_CUDASPARSE_CHECK(destructor(x));
}
}
};
template <typename T, cusparseStatus_t (*destructor)(T*)>
class CuSparseDescriptor {
public:
T* descriptor() const {
return descriptor_.get();
}
T* descriptor() {
return descriptor_.get();
}
protected:
std::unique_ptr<T, CuSparseDescriptorDeleter<T, destructor>> descriptor_;
};
#if AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
template <typename T, cusparseStatus_t (*destructor)(const T*)>
struct ConstCuSparseDescriptorDeleter {
void operator()(T* x) {
if (x != nullptr) {
TORCH_CUDASPARSE_CHECK(destructor(x));
}
}
};
template <typename T, cusparseStatus_t (*destructor)(const T*)>
class ConstCuSparseDescriptor {
public:
T* descriptor() const {
return descriptor_.get();
}
T* descriptor() {
return descriptor_.get();
}
protected:
std::unique_ptr<T, ConstCuSparseDescriptorDeleter<T, destructor>> descriptor_;
};
#endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS || AT_USE_HIPSPARSE_CONST_DESCRIPTORS
#if defined(USE_ROCM)
using cusparseMatDescr = std::remove_pointer<hipsparseMatDescr_t>::type;
using cusparseDnMatDescr = std::remove_pointer<hipsparseDnMatDescr_t>::type;
using cusparseDnVecDescr = std::remove_pointer<hipsparseDnVecDescr_t>::type;
using cusparseSpMatDescr = std::remove_pointer<hipsparseSpMatDescr_t>::type;
using cusparseSpMatDescr = std::remove_pointer<hipsparseSpMatDescr_t>::type;
using cusparseSpGEMMDescr = std::remove_pointer<hipsparseSpGEMMDescr_t>::type;
#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
using bsrsv2Info = std::remove_pointer<bsrsv2Info_t>::type;
using bsrsm2Info = std::remove_pointer<bsrsm2Info_t>::type;
#endif
#endif
// NOTE: This is only needed for CUDA 11 and earlier, since CUDA 12 introduced
// API for const descriptors
cusparseStatus_t destroyConstDnMat(const cusparseDnMatDescr* dnMatDescr);
class TORCH_CUDA_CPP_API CuSparseMatDescriptor
: public CuSparseDescriptor<cusparseMatDescr, &cusparseDestroyMatDescr> {
public:
CuSparseMatDescriptor() {
cusparseMatDescr_t raw_descriptor = nullptr;
TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor));
descriptor_.reset(raw_descriptor);
}
CuSparseMatDescriptor(bool upper, bool unit) {
cusparseFillMode_t fill_mode =
upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER;
cusparseDiagType_t diag_type =
unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT;
cusparseMatDescr_t raw_descriptor = nullptr;
TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor));
TORCH_CUDASPARSE_CHECK(cusparseSetMatFillMode(raw_descriptor, fill_mode));
TORCH_CUDASPARSE_CHECK(cusparseSetMatDiagType(raw_descriptor, diag_type));
descriptor_.reset(raw_descriptor);
}
};
#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
class TORCH_CUDA_CPP_API CuSparseBsrsv2Info
: public CuSparseDescriptor<bsrsv2Info, &cusparseDestroyBsrsv2Info> {
public:
CuSparseBsrsv2Info() {
bsrsv2Info_t raw_descriptor = nullptr;
TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsv2Info(&raw_descriptor));
descriptor_.reset(raw_descriptor);
}
};
class TORCH_CUDA_CPP_API CuSparseBsrsm2Info
: public CuSparseDescriptor<bsrsm2Info, &cusparseDestroyBsrsm2Info> {
public:
CuSparseBsrsm2Info() {
bsrsm2Info_t raw_descriptor = nullptr;
TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsm2Info(&raw_descriptor));
descriptor_.reset(raw_descriptor);
}
};
#endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
#if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type);
#if AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS()
class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
: public CuSparseDescriptor<cusparseDnMatDescr, &cusparseDestroyDnMat> {
public:
explicit CuSparseDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1);
};
class TORCH_CUDA_CPP_API CuSparseConstDnMatDescriptor
: public CuSparseDescriptor<const cusparseDnMatDescr, &destroyConstDnMat> {
public:
explicit CuSparseConstDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1);
cusparseDnMatDescr* unsafe_mutable_descriptor() const {
return const_cast<cusparseDnMatDescr*>(descriptor());
}
cusparseDnMatDescr* unsafe_mutable_descriptor() {
return const_cast<cusparseDnMatDescr*>(descriptor());
}
};
class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor
: public CuSparseDescriptor<cusparseDnVecDescr, &cusparseDestroyDnVec> {
public:
explicit CuSparseDnVecDescriptor(const Tensor& input);
};
class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
: public CuSparseDescriptor<cusparseSpMatDescr, &cusparseDestroySpMat> {};
#elif AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
: public ConstCuSparseDescriptor<
cusparseDnMatDescr,
&cusparseDestroyDnMat> {
public:
explicit CuSparseDnMatDescriptor(
const Tensor& input,
int64_t batch_offset = -1);
};
class TORCH_CUDA_CPP_API CuSparseConstDnMatDescriptor
: public ConstCuSparseDescriptor<
const cusparseDnMatDescr,
&destroyConstDnMat> {
public:
explicit CuSparseConstDnMatDescriptor(
const Tensor& input,
int64_t batch_offset = -1);
cusparseDnMatDescr* unsafe_mutable_descriptor() const {
return const_cast<cusparseDnMatDescr*>(descriptor());
}
cusparseDnMatDescr* unsafe_mutable_descriptor() {
return const_cast<cusparseDnMatDescr*>(descriptor());
}
};
class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor
: public ConstCuSparseDescriptor<
cusparseDnVecDescr,
&cusparseDestroyDnVec> {
public:
explicit CuSparseDnVecDescriptor(const Tensor& input);
};
class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
: public ConstCuSparseDescriptor<
cusparseSpMatDescr,
&cusparseDestroySpMat> {};
#endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
class TORCH_CUDA_CPP_API CuSparseSpMatCsrDescriptor
: public CuSparseSpMatDescriptor {
public:
explicit CuSparseSpMatCsrDescriptor(const Tensor& input, int64_t batch_offset = -1);
std::tuple<int64_t, int64_t, int64_t> get_size() {
int64_t rows = 0, cols = 0, nnz = 0;
TORCH_CUDASPARSE_CHECK(cusparseSpMatGetSize(
this->descriptor(),
&rows,
&cols,
&nnz));
return std::make_tuple(rows, cols, nnz);
}
void set_tensor(const Tensor& input) {
auto crow_indices = input.crow_indices();
auto col_indices = input.col_indices();
auto values = input.values();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(crow_indices.is_contiguous());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(col_indices.is_contiguous());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_contiguous());
TORCH_CUDASPARSE_CHECK(cusparseCsrSetPointers(
this->descriptor(),
crow_indices.data_ptr(),
col_indices.data_ptr(),
values.data_ptr()));
}
#if AT_USE_CUSPARSE_GENERIC_SPSV()
void set_mat_fill_mode(bool upper) {
cusparseFillMode_t fill_mode =
upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER;
TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute(
this->descriptor(),
CUSPARSE_SPMAT_FILL_MODE,
&fill_mode,
sizeof(fill_mode)));
}
void set_mat_diag_type(bool unit) {
cusparseDiagType_t diag_type =
unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT;
TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute(
this->descriptor(),
CUSPARSE_SPMAT_DIAG_TYPE,
&diag_type,
sizeof(diag_type)));
}
#endif
};
#if AT_USE_CUSPARSE_GENERIC_SPSV()
class TORCH_CUDA_CPP_API CuSparseSpSVDescriptor
: public CuSparseDescriptor<cusparseSpSVDescr, &cusparseSpSV_destroyDescr> {
public:
CuSparseSpSVDescriptor() {
cusparseSpSVDescr_t raw_descriptor = nullptr;
TORCH_CUDASPARSE_CHECK(cusparseSpSV_createDescr(&raw_descriptor));
descriptor_.reset(raw_descriptor);
}
};
#endif
#if AT_USE_CUSPARSE_GENERIC_SPSM()
class TORCH_CUDA_CPP_API CuSparseSpSMDescriptor
: public CuSparseDescriptor<cusparseSpSMDescr, &cusparseSpSM_destroyDescr> {
public:
CuSparseSpSMDescriptor() {
cusparseSpSMDescr_t raw_descriptor = nullptr;
TORCH_CUDASPARSE_CHECK(cusparseSpSM_createDescr(&raw_descriptor));
descriptor_.reset(raw_descriptor);
}
};
#endif
class TORCH_CUDA_CPP_API CuSparseSpGEMMDescriptor
: public CuSparseDescriptor<cusparseSpGEMMDescr, &cusparseSpGEMM_destroyDescr> {
public:
CuSparseSpGEMMDescriptor() {
cusparseSpGEMMDescr_t raw_descriptor = nullptr;
TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&raw_descriptor));
descriptor_.reset(raw_descriptor);
}
};
#endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
} // namespace at::cuda::sparse

View File

@ -0,0 +1,15 @@
#pragma once
#include <ATen/Tensor.h>
#include <c10/util/Half.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
namespace at {
template <>
inline __half* Tensor::data() const {
return reinterpret_cast<__half*>(data<Half>());
}
} // namespace at

View File

@ -0,0 +1,20 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
namespace at::cuda {
// Check if every tensor in a list of tensors matches the current
// device.
inline bool check_device(ArrayRef<Tensor> ts) {
if (ts.empty()) {
return true;
}
Device curDevice = Device(kCUDA, current_device());
for (const Tensor& t : ts) {
if (t.device() != curDevice) return false;
}
return true;
}
} // namespace at::cuda

View File

@ -0,0 +1,37 @@
#pragma once
#include <ATen/core/CachingHostAllocator.h>
#include <c10/core/Allocator.h>
#include <c10/cuda/CUDAStream.h>
namespace at::cuda {
//
// A caching allocator for CUDA host allocations (pinned memory).
//
// This provides a drop-in replacement for THCudaHostAllocator, which re-uses
// freed pinned (page-locked) memory allocations. This avoids device
// synchronizations due to cudaFreeHost calls.
//
// To ensure correct behavior, THCCachingHostAllocator_recordEvent must be
// called anytime a pointer from this allocator is used in a cudaMemcpyAsync
// call between host and device, and passed the corresponding context from the
// allocation. This is currently invoked by at::native::copy_kernel_cuda.
//
TORCH_CUDA_CPP_API c10::Allocator* getCachingHostAllocator();
// Records an event in the specified stream. The allocation corresponding to the
// input `ptr`/`ctx` will not be re-used until the event has occurred.
TORCH_CUDA_CPP_API bool CachingHostAllocator_recordEvent(
void* ptr,
void* ctx,
c10::cuda::CUDAStream stream);
// Releases cached pinned memory allocations via cudaHostFree
TORCH_CUDA_CPP_API void CachingHostAllocator_emptyCache();
inline TORCH_CUDA_CPP_API at::DataPtr HostAlloc(size_t size) {
return getCachingHostAllocator()->allocate(size);
}
} // namespace at::cuda

View File

@ -0,0 +1,121 @@
#pragma once
#include <cuda.h>
#include <c10/util/complex.h>
#include <c10/util/Half.h>
__device__ __forceinline__ unsigned int ACTIVE_MASK()
{
#if !defined(USE_ROCM)
return __activemask();
#else
// will be ignored anyway
return 0xffffffff;
#endif
}
__device__ __forceinline__ void WARP_SYNC(unsigned mask = 0xffffffff) {
#if !defined(USE_ROCM)
return __syncwarp(mask);
#endif
}
#if defined(USE_ROCM)
__device__ __forceinline__ unsigned long long int WARP_BALLOT(int predicate)
{
return __ballot(predicate);
}
#else
__device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int mask = 0xffffffff)
{
#if !defined(USE_ROCM)
return __ballot_sync(mask, predicate);
#else
return __ballot(predicate);
#endif
}
#endif
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if !defined(USE_ROCM)
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename T>
__device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if !defined(USE_ROCM)
return __shfl_sync(mask, value, srcLane, width);
#else
return __shfl(value, srcLane, width);
#endif
}
template <typename T>
__device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if !defined(USE_ROCM)
return __shfl_up_sync(mask, value, delta, width);
#else
return __shfl_up(value, delta, width);
#endif
}
template <typename T>
__device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if !defined(USE_ROCM)
return __shfl_down_sync(mask, value, delta, width);
#else
return __shfl_down(value, delta, width);
#endif
}
#if defined(USE_ROCM)
template<>
__device__ __forceinline__ int64_t WARP_SHFL_DOWN<int64_t>(int64_t value, unsigned int delta, int width , unsigned int mask)
{
//(HIP doesn't support int64_t). Trick from https://devblogs.nvidia.com/faster-parallel-reductions-kepler/
int2 a = *reinterpret_cast<int2*>(&value);
a.x = __shfl_down(a.x, delta);
a.y = __shfl_down(a.y, delta);
return *reinterpret_cast<int64_t*>(&a);
}
#endif
template<>
__device__ __forceinline__ c10::Half WARP_SHFL_DOWN<c10::Half>(c10::Half value, unsigned int delta, int width, unsigned int mask)
{
return c10::Half(WARP_SHFL_DOWN<unsigned short>(value.x, delta, width, mask), c10::Half::from_bits_t{});
}
template <typename T>
__device__ __forceinline__ c10::complex<T> WARP_SHFL_DOWN(c10::complex<T> value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if !defined(USE_ROCM)
return c10::complex<T>(
__shfl_down_sync(mask, value.real_, delta, width),
__shfl_down_sync(mask, value.imag_, delta, width));
#else
return c10::complex<T>(
__shfl_down(value.real_, delta, width),
__shfl_down(value.imag_, delta, width));
#endif
}
/**
* For CC 3.5+, perform a load using __ldg
*/
template <typename T>
__device__ __forceinline__ T doLdg(const T* p) {
#if __CUDA_ARCH__ >= 350 && !defined(USE_ROCM)
return __ldg(p);
#else
return *p;
#endif
}

View File

@ -0,0 +1,44 @@
#pragma once
#include <ATen/core/TensorBase.h>
namespace at::detail {
TORCH_CUDA_CPP_API TensorBase empty_cuda(
IntArrayRef size,
ScalarType dtype,
std::optional<Device> device_opt,
std::optional<c10::MemoryFormat> memory_format_opt);
TORCH_CUDA_CPP_API TensorBase empty_cuda(
IntArrayRef size,
std::optional<ScalarType> dtype_opt,
std::optional<Layout> layout_opt,
std::optional<Device> device_opt,
std::optional<bool> pin_memory_opt,
std::optional<c10::MemoryFormat> memory_format_opt);
TORCH_CUDA_CPP_API TensorBase empty_cuda(
IntArrayRef size,
const TensorOptions &options);
TORCH_CUDA_CPP_API TensorBase empty_strided_cuda(
IntArrayRef size,
IntArrayRef stride,
ScalarType dtype,
std::optional<Device> device_opt);
TORCH_CUDA_CPP_API TensorBase empty_strided_cuda(
IntArrayRef size,
IntArrayRef stride,
std::optional<ScalarType> dtype_opt,
std::optional<Layout> layout_opt,
std::optional<Device> device_opt,
std::optional<bool> pin_memory_opt);
TORCH_CUDA_CPP_API TensorBase empty_strided_cuda(
IntArrayRef size,
IntArrayRef stride,
const TensorOptions &options);
} // namespace at::detail

View File

@ -0,0 +1,205 @@
#pragma once
#include <cublas_v2.h>
#include <cusparse.h>
#include <c10/macros/Export.h>
#ifdef CUDART_VERSION
#include <cusolver_common.h>
#endif
#if defined(USE_CUDSS)
#include <cudss.h>
#endif
#include <ATen/Context.h>
#include <c10/util/Exception.h>
#include <c10/cuda/CUDAException.h>
namespace c10 {
class CuDNNError : public c10::Error {
using Error::Error;
};
} // namespace c10
#define AT_CUDNN_FRONTEND_CHECK(EXPR, ...) \
do { \
auto error_object = EXPR; \
if (!error_object.is_good()) { \
TORCH_CHECK_WITH(CuDNNError, false, \
"cuDNN Frontend error: ", error_object.get_message()); \
} \
} while (0) \
#define AT_CUDNN_CHECK_WITH_SHAPES(EXPR, ...) AT_CUDNN_CHECK(EXPR, "\n", ##__VA_ARGS__)
// See Note [CHECK macro]
#define AT_CUDNN_CHECK(EXPR, ...) \
do { \
cudnnStatus_t status = EXPR; \
if (status != CUDNN_STATUS_SUCCESS) { \
if (status == CUDNN_STATUS_NOT_SUPPORTED) { \
TORCH_CHECK_WITH(CuDNNError, false, \
"cuDNN error: ", \
cudnnGetErrorString(status), \
". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \
} else { \
TORCH_CHECK_WITH(CuDNNError, false, \
"cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__); \
} \
} \
} while (0)
namespace at::cuda::blas {
C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error);
} // namespace at::cuda::blas
#define TORCH_CUDABLAS_CHECK(EXPR) \
do { \
cublasStatus_t __err = EXPR; \
TORCH_CHECK(__err == CUBLAS_STATUS_SUCCESS, \
"CUDA error: ", \
at::cuda::blas::_cublasGetErrorEnum(__err), \
" when calling `" #EXPR "`"); \
} while (0)
const char *cusparseGetErrorString(cusparseStatus_t status);
#define TORCH_CUDASPARSE_CHECK(EXPR) \
do { \
cusparseStatus_t __err = EXPR; \
TORCH_CHECK(__err == CUSPARSE_STATUS_SUCCESS, \
"CUDA error: ", \
cusparseGetErrorString(__err), \
" when calling `" #EXPR "`"); \
} while (0)
#if defined(USE_CUDSS)
namespace at::cuda::cudss {
C10_EXPORT const char* cudssGetErrorMessage(cudssStatus_t error);
} // namespace at::cuda::solver
#define TORCH_CUDSS_CHECK(EXPR) \
do { \
cudssStatus_t __err = EXPR; \
if (__err == CUDSS_STATUS_EXECUTION_FAILED) { \
TORCH_CHECK_LINALG( \
false, \
"cudss error: ", \
at::cuda::cudss::cudssGetErrorMessage(__err), \
", when calling `" #EXPR "`", \
". This error may appear if the input matrix contains NaN. ");\
} else { \
TORCH_CHECK( \
__err == CUDSS_STATUS_SUCCESS, \
"cudss error: ", \
at::cuda::cudss::cudssGetErrorMessage(__err), \
", when calling `" #EXPR "`. "); \
} \
} while (0)
#else
#define TORCH_CUDSS_CHECK(EXPR) EXPR
#endif
// cusolver related headers are only supported on cuda now
#ifdef CUDART_VERSION
namespace at::cuda::solver {
C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);
constexpr const char* _cusolver_backend_suggestion = \
"If you keep seeing this error, you may use " \
"`torch.backends.cuda.preferred_linalg_library()` to try " \
"linear algebra operators with other supported backends. " \
"See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library";
} // namespace at::cuda::solver
// When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan.
// When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
#define TORCH_CUSOLVER_CHECK(EXPR) \
do { \
cusolverStatus_t __err = EXPR; \
if ((CUDA_VERSION < 11500 && \
__err == CUSOLVER_STATUS_EXECUTION_FAILED) || \
(CUDA_VERSION >= 11500 && \
__err == CUSOLVER_STATUS_INVALID_VALUE)) { \
TORCH_CHECK_LINALG( \
false, \
"cusolver error: ", \
at::cuda::solver::cusolverGetErrorMessage(__err), \
", when calling `" #EXPR "`", \
". This error may appear if the input matrix contains NaN. ", \
at::cuda::solver::_cusolver_backend_suggestion); \
} else { \
TORCH_CHECK( \
__err == CUSOLVER_STATUS_SUCCESS, \
"cusolver error: ", \
at::cuda::solver::cusolverGetErrorMessage(__err), \
", when calling `" #EXPR "`. ", \
at::cuda::solver::_cusolver_backend_suggestion); \
} \
} while (0)
#else
#define TORCH_CUSOLVER_CHECK(EXPR) EXPR
#endif
#define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR)
// For CUDA Driver API
//
// This is here instead of in c10 because NVRTC is loaded dynamically via a stub
// in ATen, and we need to use its nvrtcGetErrorString.
// See NOTE [ USE OF NVRTC AND DRIVER API ].
#if !defined(USE_ROCM)
#define AT_CUDA_DRIVER_CHECK(EXPR) \
do { \
CUresult __err = EXPR; \
if (__err != CUDA_SUCCESS) { \
const char* err_str; \
CUresult get_error_str_err C10_UNUSED = at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \
if (get_error_str_err != CUDA_SUCCESS) { \
AT_ERROR("CUDA driver error: unknown error"); \
} else { \
AT_ERROR("CUDA driver error: ", err_str); \
} \
} \
} while (0)
#else
#define AT_CUDA_DRIVER_CHECK(EXPR) \
do { \
CUresult __err = EXPR; \
if (__err != CUDA_SUCCESS) { \
AT_ERROR("CUDA driver error: ", static_cast<int>(__err)); \
} \
} while (0)
#endif
// For CUDA NVRTC
//
// Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE,
// incorrectly produces the error string "NVRTC unknown error."
// The following maps it correctly.
//
// This is here instead of in c10 because NVRTC is loaded dynamically via a stub
// in ATen, and we need to use its nvrtcGetErrorString.
// See NOTE [ USE OF NVRTC AND DRIVER API ].
#define AT_CUDA_NVRTC_CHECK(EXPR) \
do { \
nvrtcResult __err = EXPR; \
if (__err != NVRTC_SUCCESS) { \
if (static_cast<int>(__err) != 7) { \
AT_ERROR("CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \
} else { \
AT_ERROR("CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \
} \
} \
} while (0)

View File

@ -0,0 +1,121 @@
#pragma once
#include <cuda.h>
#include <limits.h>
#include <math.h>
#include <float.h>
// NumericLimits.cuh is a holder for numeric limits definitions of commonly used
// types. This header is very specific to ROCm HIP and may be removed in the future.
// This header is derived from the legacy THCNumerics.cuh.
// The lower_bound and upper_bound constants are same as lowest and max for
// integral types, but are -inf and +inf for floating point types. They are
// useful in implementing min, max, etc.
namespace at {
template <typename T>
struct numeric_limits {
};
// WARNING: the following at::numeric_limits definitions are there only to support
// HIP compilation for the moment. Use std::numeric_limits if you are not
// compiling for ROCm.
// from @colesbury: "The functions on numeric_limits aren't marked with
// __device__ which is why they don't work with ROCm. CUDA allows them
// because they're constexpr."
namespace {
// ROCm doesn't like INFINITY too.
constexpr double inf = INFINITY;
}
template <>
struct numeric_limits<bool> {
static inline __host__ __device__ bool lowest() { return false; }
static inline __host__ __device__ bool max() { return true; }
static inline __host__ __device__ bool lower_bound() { return false; }
static inline __host__ __device__ bool upper_bound() { return true; }
};
template <>
struct numeric_limits<uint8_t> {
static inline __host__ __device__ uint8_t lowest() { return 0; }
static inline __host__ __device__ uint8_t max() { return UINT8_MAX; }
static inline __host__ __device__ uint8_t lower_bound() { return 0; }
static inline __host__ __device__ uint8_t upper_bound() { return UINT8_MAX; }
};
template <>
struct numeric_limits<int8_t> {
static inline __host__ __device__ int8_t lowest() { return INT8_MIN; }
static inline __host__ __device__ int8_t max() { return INT8_MAX; }
static inline __host__ __device__ int8_t lower_bound() { return INT8_MIN; }
static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; }
};
template <>
struct numeric_limits<int16_t> {
static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
static inline __host__ __device__ int16_t max() { return INT16_MAX; }
static inline __host__ __device__ int16_t lower_bound() { return INT16_MIN; }
static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; }
};
template <>
struct numeric_limits<int32_t> {
static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
static inline __host__ __device__ int32_t max() { return INT32_MAX; }
static inline __host__ __device__ int32_t lower_bound() { return INT32_MIN; }
static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; }
};
template <>
struct numeric_limits<int64_t> {
#ifdef _MSC_VER
static inline __host__ __device__ int64_t lowest() { return _I64_MIN; }
static inline __host__ __device__ int64_t max() { return _I64_MAX; }
static inline __host__ __device__ int64_t lower_bound() { return _I64_MIN; }
static inline __host__ __device__ int64_t upper_bound() { return _I64_MAX; }
#else
static inline __host__ __device__ int64_t lowest() { return INT64_MIN; }
static inline __host__ __device__ int64_t max() { return INT64_MAX; }
static inline __host__ __device__ int64_t lower_bound() { return INT64_MIN; }
static inline __host__ __device__ int64_t upper_bound() { return INT64_MAX; }
#endif
};
template <>
struct numeric_limits<at::Half> {
static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits()); }
static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits()); }
static inline __host__ __device__ at::Half lower_bound() { return at::Half(0xFC00, at::Half::from_bits()); }
static inline __host__ __device__ at::Half upper_bound() { return at::Half(0x7C00, at::Half::from_bits()); }
};
template <>
struct numeric_limits<at::BFloat16> {
static inline __host__ __device__ at::BFloat16 lowest() { return at::BFloat16(0xFF7F, at::BFloat16::from_bits()); }
static inline __host__ __device__ at::BFloat16 max() { return at::BFloat16(0x7F7F, at::BFloat16::from_bits()); }
static inline __host__ __device__ at::BFloat16 lower_bound() { return at::BFloat16(0xFF80, at::BFloat16::from_bits()); }
static inline __host__ __device__ at::BFloat16 upper_bound() { return at::BFloat16(0x7F80, at::BFloat16::from_bits()); }
};
template <>
struct numeric_limits<float> {
static inline __host__ __device__ float lowest() { return -FLT_MAX; }
static inline __host__ __device__ float max() { return FLT_MAX; }
static inline __host__ __device__ float lower_bound() { return -static_cast<float>(inf); }
static inline __host__ __device__ float upper_bound() { return static_cast<float>(inf); }
};
template <>
struct numeric_limits<double> {
static inline __host__ __device__ double lowest() { return -DBL_MAX; }
static inline __host__ __device__ double max() { return DBL_MAX; }
static inline __host__ __device__ double lower_bound() { return -inf; }
static inline __host__ __device__ double upper_bound() { return inf; }
};
} // namespace at

View File

@ -0,0 +1,11 @@
#include <c10/macros/Macros.h>
#include <cstdint>
namespace at::cuda {
namespace detail {
void init_p2p_access_cache(int64_t num_devices);
}
TORCH_CUDA_CPP_API bool get_p2p_access(int source_dev, int dest_dev);
} // namespace at::cuda

View File

@ -0,0 +1,5 @@
#pragma once
#include <cstdint>
#include <ATen/cuda/detail/PhiloxCudaStateRaw.cuh>

View File

@ -0,0 +1,4 @@
#pragma once
#include <ATen/cuda/PhiloxCudaState.h>
#include <ATen/cuda/detail/UnpackRaw.cuh>

View File

@ -0,0 +1,11 @@
#pragma once
#include <c10/core/Allocator.h>
#include <ATen/cuda/CachingHostAllocator.h>
namespace at::cuda {
inline TORCH_CUDA_CPP_API at::Allocator* getPinnedMemoryAllocator() {
return getCachingHostAllocator();
}
} // namespace at::cuda

View File

@ -0,0 +1,78 @@
#pragma once
#include <ATen/ceil_div.h>
#include <ATen/cuda/DeviceUtils.cuh>
#include <ATen/cuda/AsmUtils.cuh>
#include <c10/macros/Macros.h>
// Collection of in-kernel scan / prefix sum utilities
namespace at::cuda {
// Inclusive prefix sum for binary vars using intra-warp voting +
// shared memory
template <typename T, bool KillWARDependency, class BinaryFunction>
__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) {
// Within-warp, we use warp voting.
#if defined (USE_ROCM)
unsigned long long int vote = WARP_BALLOT(in);
T index = __popcll(getLaneMaskLe() & vote);
T carry = __popcll(vote);
#else
T vote = WARP_BALLOT(in);
T index = __popc(getLaneMaskLe() & vote);
T carry = __popc(vote);
#endif
int warp = threadIdx.x / C10_WARP_SIZE;
// Per each warp, write out a value
if (getLaneId() == 0) {
smem[warp] = carry;
}
__syncthreads();
// Sum across warps in one thread. This appears to be faster than a
// warp shuffle scan for CC 3.0+
if (threadIdx.x == 0) {
int current = 0;
for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) {
T v = smem[i];
smem[i] = binop(smem[i], current);
current = binop(current, v);
}
}
__syncthreads();
// load the carry from the preceding warp
if (warp >= 1) {
index = binop(index, smem[warp - 1]);
}
*out = index;
if (KillWARDependency) {
__syncthreads();
}
}
// Exclusive prefix sum for binary vars using intra-warp voting +
// shared memory
template <typename T, bool KillWARDependency, class BinaryFunction>
__device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) {
inclusiveBinaryPrefixScan<T, false, BinaryFunction>(smem, in, out, binop);
// Inclusive to exclusive
*out -= (T) in;
// The outgoing carry for all threads is the last warp's sum
*carry = smem[at::ceil_div<int>(blockDim.x, C10_WARP_SIZE) - 1];
if (KillWARDependency) {
__syncthreads();
}
}
} // namespace at::cuda

View File

@ -0,0 +1,13 @@
#pragma once
#include <c10/macros/Export.h>
#include <cstdint>
namespace at::cuda {
// enqueues a kernel that spins for the specified number of cycles
TORCH_CUDA_CU_API void sleep(int64_t cycles);
// flushes instruction cache for ROCm; no-op for CUDA
TORCH_CUDA_CU_API void flush_icache();
} // namespace at::cuda

View File

@ -0,0 +1,23 @@
#pragma once
#include <cstddef>
#include <c10/cuda/CUDACachingAllocator.h>
namespace at::cuda {
/// Allocator for Thrust to re-route its internal device allocations
/// to the THC allocator
class ThrustAllocator {
public:
typedef char value_type;
char* allocate(std::ptrdiff_t size) {
return static_cast<char*>(c10::cuda::CUDACachingAllocator::raw_alloc(size));
}
void deallocate(char* p, size_t size) {
c10::cuda::CUDACachingAllocator::raw_delete(p);
}
};
} // namespace at::cuda

View File

@ -0,0 +1,405 @@
#pragma once
#include <ATen/cuda/cub.h>
#include <cstddef>
#include <type_traits>
#include <iterator>
#include <limits>
#include <ATen/cuda/cub_definitions.cuh>
#if USE_GLOBAL_CUB_WRAPPED_NAMESPACE()
#include <cub/cub.cuh>
#else
// include cub in a safe manner, see:
// https://github.com/pytorch/pytorch/pull/55292
#undef CUB_NS_POSTFIX //undef to avoid redefinition warnings
#undef CUB_NS_PREFIX
#undef CUB_NS_QUALIFIER
#define CUB_NS_PREFIX namespace at_cuda_detail {
#define CUB_NS_POSTFIX }
#define CUB_NS_QUALIFIER ::at_cuda_detail::cub
#include <cub/cub.cuh>
#undef CUB_NS_POSTFIX
#undef CUB_NS_PREFIX
#undef CUB_NS_QUALIFIER
#endif
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAStream.h>
// handle the temporary storage and 'twice' calls for cub API
#define CUB_WRAPPER(func, ...) do { \
size_t temp_storage_bytes = 0; \
func(nullptr, temp_storage_bytes, __VA_ARGS__); \
auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \
auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \
func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \
AT_CUDA_CHECK(cudaGetLastError()); \
} while (false)
#ifdef USE_ROCM
#define NO_ROCM(x)
#define ROCM_HIPCUB(x) ::hipcub
#else
#define NO_ROCM(x) x
#define ROCM_HIPCUB(x) x
#endif
#if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM)
#if !defined(USE_ROCM)
namespace at_cuda_detail {
#endif
// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16
template <>
struct ROCM_HIPCUB(cub)::FpLimits<c10::BFloat16>
{
static __host__ __device__ __forceinline__ c10::BFloat16 Max() {
unsigned short max_word = 0x7F7F;
return reinterpret_cast<c10::BFloat16&>(max_word);
}
static __host__ __device__ __forceinline__ c10::BFloat16 Lowest() {
unsigned short lowest_word = 0xFF7F;
return reinterpret_cast<c10::BFloat16&>(lowest_word);
}
};
template <>
struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>:
ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
#if !defined(USE_ROCM)
} // namespace at_cuda_detail
#endif
#endif
#if !defined(USE_ROCM)
namespace at::native {
namespace cub = ::at_cuda_detail::cub;
} // namespace at::native
#endif
namespace at::cuda::cub {
namespace detail {
template<typename T>
struct cuda_type {
using type = T;
};
template<>
struct cuda_type<c10::Half> {
using type = __half;
};
#if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16()
template<>
struct cuda_type<c10::BFloat16> {
using type = __nv_bfloat16;
};
#elif defined(USE_ROCM)
template<>
struct cuda_type<c10::BFloat16> {
using type = hip_bfloat16;
};
#endif
} // namespace detail
template<typename key_t, typename value_t, typename OffsetIteratorT>
inline void segmented_sort_pairs(
const key_t *keys_in, key_t *keys_out,
const value_t *values_in, value_t *values_out,
int64_t num_elements, int64_t num_segments,
OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
) {
TORCH_CHECK(num_elements <= std::numeric_limits<int>::max(),
"cub sort does not support sorting more than INT_MAX elements");
TORCH_CHECK(num_segments <= std::numeric_limits<int>::max(),
"cub sort does not support sorting more than INT_MAX elements");
using key_t_ = typename detail::cuda_type<key_t>::type;
auto allocator = c10::cuda::CUDACachingAllocator::get();
c10::DataPtr keys_out_owner;
if (keys_out == nullptr) {
keys_out_owner = allocator->allocate(num_elements * sizeof(key_t));
keys_out = reinterpret_cast<key_t *>(keys_out_owner.get());
}
const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);
if (descending) {
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairsDescending,
keys_in_, keys_out_, values_in, values_out,
num_elements, num_segments, begin_offsets, end_offsets,
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
} else {
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairs,
keys_in_, keys_out_, values_in, values_out,
num_elements, num_segments, begin_offsets, end_offsets,
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
}
}
#if CUB_SUPPORTS_UNIQUE_BY_KEY()
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT>
inline void unique_by_key(
KeysInputIteratorT keys_in, ValuesInputIteratorT values_in,
ValuesOutputIteratorT values_out,
NumSelectedIteratorT num_selected, int64_t num_input_items)
{
// TODO: use thrust::discard_iterator to handle null keys_out when https://github.com/NVIDIA/cub/issues/406 is fixed.
using KeyT = typename std::iterator_traits<KeysInputIteratorT>::value_type;
auto allocator = c10::cuda::CUDACachingAllocator::get();
c10::DataPtr keys_out_owner;
keys_out_owner = allocator->allocate(num_input_items * sizeof(KeyT));
auto keys_out_ = static_cast<KeyT *>(keys_out_owner.get());
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey,
keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream());
}
#endif
namespace impl {
template<typename InputIteratorT1, typename InputIteratorT2, typename OutputIteratorT, class ScanOpT>
C10_LAUNCH_BOUNDS_1(1)
__global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputIteratorT out, ScanOpT scan_op){
// NOTE: out here not the final scan output, but an intermediate of the accumulation type.
using acc_t = typename std::iterator_traits<OutputIteratorT>::value_type;
*out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b));
}
#if !CUB_SUPPORTS_FUTURE_VALUE()
template<typename ValueT, typename InputIteratorT>
struct chained_iterator {
using iterator_category = std::random_access_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = ValueT;
using pointer = ValueT*;
using reference = ValueT&;
InputIteratorT iter;
ValueT *first;
difference_type offset = 0;
__device__ ValueT operator[](difference_type i) {
i += offset;
if (i == 0) {
return *first;
} else {
return ValueT(iter[i - 1]);
}
}
__device__ chained_iterator operator+(difference_type i) {
return chained_iterator{iter, first, i};
}
__device__ ValueT operator*() {
return (*this)[0];
}
};
#endif
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
// so split at int_max/2
constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
}
// non synchronizing cub call
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
// so split at int_max/2
template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, int max_cub_size=impl::max_cub_size>
inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
#if defined(USE_ROCM)
//For ROCm, use hipCUB chained iterators
CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::InclusiveScan,
input,
output,
scan_op,
num_items,
at::cuda::getCurrentCUDAStream());
C10_HIP_KERNEL_LAUNCH_CHECK();
#else
// non synchronizing cub call
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
// so split at int_max/2
int size_cub = std::min<int64_t>(num_items, max_cub_size);
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
input,
output,
scan_op,
size_cub,
at::cuda::getCurrentCUDAStream());
C10_CUDA_KERNEL_LAUNCH_CHECK();
using input_t = typename std::iterator_traits<InputIteratorT>::value_type;
for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) {
auto allocator = c10::cuda::CUDACachingAllocator::get();
c10::DataPtr first_elem = allocator->allocate(sizeof(input_t));
auto first_elem_ptr = reinterpret_cast<input_t *>(first_elem.get());
size_cub = std::min<int64_t>(num_items - i, max_cub_size);
impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
output + i - 1,
input + i,
first_elem_ptr,
scan_op);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#if !CUB_SUPPORTS_FUTURE_VALUE()
using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>;
using tuple = typename ArgIndexInputIterator::value_type;
auto input_iter_transform = [=] __device__ (const tuple &x)->input_t {
if (x.key == 0) {
return *first_elem_ptr;
} else {
return x.value;
}
};
auto input_ = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator<input_t, decltype(input_iter_transform), ArgIndexInputIterator>(
ArgIndexInputIterator(input + i), input_iter_transform);
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
input_,
output + i,
scan_op,
size_cub,
at::cuda::getCurrentCUDAStream());
#else
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
input + i + 1,
output + i,
scan_op,
::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr),
size_cub,
at::cuda::getCurrentCUDAStream());
#endif
}
#endif
}
template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename InitValueT, int max_cub_size=impl::max_cub_size>
inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, InitValueT init_value, int64_t num_items) {
#if defined(USE_ROCM)
//For ROCm, use hipCUB chained iterators
CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::ExclusiveScan,
input,
output,
scan_op,
init_value,
num_items,
at::cuda::getCurrentCUDAStream());
C10_HIP_KERNEL_LAUNCH_CHECK();
#else
// non synchronizing cub call
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
// so split at int_max/2
int size_cub = std::min<int64_t>(num_items, max_cub_size);
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
input,
output,
scan_op,
init_value,
size_cub,
at::cuda::getCurrentCUDAStream());
C10_CUDA_KERNEL_LAUNCH_CHECK();
for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) {
auto allocator = c10::cuda::CUDACachingAllocator::get();
c10::DataPtr first_elem = allocator->allocate(sizeof(InitValueT));
auto first_elem_ptr = reinterpret_cast<InitValueT *>(first_elem.get());
size_cub = std::min<int64_t>(num_items - i, max_cub_size);
impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
output + i - 1,
input + i - 1,
first_elem_ptr,
scan_op);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#if !CUB_SUPPORTS_FUTURE_VALUE()
auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{
input + i, first_elem_ptr};
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
input_,
output + i,
scan_op,
size_cub,
at::cuda::getCurrentCUDAStream());
#else
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
input + i,
output + i,
scan_op,
::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr),
size_cub,
at::cuda::getCurrentCUDAStream());
#endif
}
#endif
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub InclusiveSumByKey does not support more than INT_MAX elements");
CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveSumByKey,
keys, input, output, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream());
}
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename ScanOpT>
inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub InclusiveSumByKey does not support more than INT_MAX elements");
CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveScanByKey,
keys, input, output, scan_op, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream());
}
#endif
template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT>
void unique(InputIteratorT input, OutputIteratorT output,
NumSelectedIteratorT num_selected_out, int64_t num_items) {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub unique does not support more than INT_MAX elements");
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique,
input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream());
}
template <typename InputIteratorT, typename OutputIteratorT, typename CountsOutputIteratorT,
typename LengthOutputIteratorT>
void run_length_encode(InputIteratorT input, OutputIteratorT output, CountsOutputIteratorT counts_out,
LengthOutputIteratorT length_out, int64_t num_items) {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub run_length_encode does not support more than INT_MAX elements");
CUB_WRAPPER(
NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode,
input, output, counts_out, length_out, num_items,
at::cuda::getCurrentCUDAStream());
}
template <typename InputIteratorT, typename OutputIteratorT, typename ReductionOpT, typename T>
void reduce(InputIteratorT input, OutputIteratorT output, int64_t num_items, ReductionOpT op, T init) {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub reduce does not support more than INT_MAX elements");
CUB_WRAPPER(
NO_ROCM(at_cuda_detail)::cub::DeviceReduce::Reduce,
input, output, num_items, op, init,
at::cuda::getCurrentCUDAStream());
}
} // namespace at::cuda::cub

View File

@ -0,0 +1,87 @@
#pragma once
#include <cstdint>
#include <c10/core/ScalarType.h>
#include <ATen/cuda/CUDAConfig.h>
// NOTE: These templates are intentionally not defined in this header,
// which aviods re-compiling them for each translation unit. If you get
// a link error, you need to add an explicit instantiation for your
// types in cub.cu
namespace at::cuda::cub {
inline int get_num_bits(uint64_t max_key) {
int num_bits = 1;
while (max_key > 1) {
max_key >>= 1;
num_bits++;
}
return num_bits;
}
namespace detail {
// radix_sort_pairs doesn't interact with value_t other than to copy
// the data, so we can save template instantiations by reinterpreting
// it as an opaque type.
template <int N> struct alignas(N) OpaqueType { char data[N]; };
template<typename key_t, int value_size>
void radix_sort_pairs_impl(
const key_t *keys_in, key_t *keys_out,
const OpaqueType<value_size> *values_in, OpaqueType<value_size> *values_out,
int64_t n, bool descending, int64_t begin_bit, int64_t end_bit);
} // namespace detail
template<typename key_t, typename value_t>
void radix_sort_pairs(
const key_t *keys_in, key_t *keys_out,
const value_t *values_in, value_t *values_out,
int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8) {
static_assert(std::is_trivially_copyable_v<value_t> ||
AT_ROCM_ENABLED(), // ROCm incorrectly fails this check for vector types
"radix_sort_pairs value type must be trivially copyable");
// Make value type opaque, so all inputs of a certain size use the same template instantiation
using opaque_t = detail::OpaqueType<sizeof(value_t)>;
static_assert(sizeof(value_t) <= 8 && (sizeof(value_t) & (sizeof(value_t) - 1)) == 0,
"This size of value_t is not instantiated. Please instantiate it in cub.cu"
" and modify this check.");
static_assert(sizeof(value_t) == alignof(value_t), "Expected value_t to be size-aligned");
detail::radix_sort_pairs_impl(
keys_in, keys_out,
reinterpret_cast<const opaque_t*>(values_in),
reinterpret_cast<opaque_t*>(values_out),
n, descending, begin_bit, end_bit);
}
template<typename key_t>
void radix_sort_keys(
const key_t *keys_in, key_t *keys_out,
int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8);
// NOTE: Intermediate sums will be truncated to input_t precision
template <typename input_t, typename output_t>
void inclusive_sum_truncating(const input_t *input, output_t *output, int64_t n);
template <typename scalar_t>
void inclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) {
return inclusive_sum_truncating(input, output, n);
}
// NOTE: Sums are done is common_type<input_t, output_t>
template <typename input_t, typename output_t>
void exclusive_sum_in_common_type(const input_t *input, output_t *output, int64_t n);
template <typename scalar_t>
void exclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) {
return exclusive_sum_in_common_type(input, output, n);
}
void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n);
inline void mask_exclusive_sum(const bool *mask, int64_t *output_idx, int64_t n) {
return mask_exclusive_sum(
reinterpret_cast<const uint8_t*>(mask), output_idx, n);
}
} // namespace at::cuda::cub

View File

@ -0,0 +1,53 @@
#pragma once
#if !defined(USE_ROCM)
#include <cuda.h> // for CUDA_VERSION
#endif
#if !defined(USE_ROCM)
#include <cub/version.cuh>
#else
#define CUB_VERSION 0
#endif
// cub sort support for __nv_bfloat16 is added to cub 1.13 in:
// https://github.com/NVIDIA/cub/pull/306
#if CUB_VERSION >= 101300
#define CUB_SUPPORTS_NV_BFLOAT16() true
#else
#define CUB_SUPPORTS_NV_BFLOAT16() false
#endif
// cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in:
// https://github.com/NVIDIA/cub/pull/326
// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake
// starting from CUDA 11.5
#if defined(CUB_WRAPPED_NAMESPACE) || defined(THRUST_CUB_WRAPPED_NAMESPACE)
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() true
#else
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
#endif
// cub support for UniqueByKey is added to cub 1.16 in:
// https://github.com/NVIDIA/cub/pull/405
#if CUB_VERSION >= 101600
#define CUB_SUPPORTS_UNIQUE_BY_KEY() true
#else
#define CUB_SUPPORTS_UNIQUE_BY_KEY() false
#endif
// cub support for scan by key is added to cub 1.15
// in https://github.com/NVIDIA/cub/pull/376
#if CUB_VERSION >= 101500
#define CUB_SUPPORTS_SCAN_BY_KEY() 1
#else
#define CUB_SUPPORTS_SCAN_BY_KEY() 0
#endif
// cub support for cub::FutureValue is added to cub 1.15 in:
// https://github.com/NVIDIA/cub/pull/305
#if CUB_VERSION >= 101500
#define CUB_SUPPORTS_FUTURE_VALUE() true
#else
#define CUB_SUPPORTS_FUTURE_VALUE() false
#endif

View File

@ -0,0 +1,58 @@
#pragma once
#include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/Generator.h>
#include <optional>
// TODO: No need to have this whole header, we can just put it all in
// the cpp file
namespace at::cuda::detail {
// Set the callback to initialize Magma, which is set by
// torch_cuda_cu. This indirection is required so magma_init is called
// in the same library where Magma will be used.
TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)());
// The real implementation of CUDAHooksInterface
struct CUDAHooks : public at::CUDAHooksInterface {
CUDAHooks(at::CUDAHooksArgs) {}
void initCUDA() const override;
Device getDeviceFromPtr(void* data) const override;
bool isPinnedPtr(const void* data) const override;
const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override;
bool hasCUDA() const override;
bool hasMAGMA() const override;
bool hasCuDNN() const override;
bool hasCuSOLVER() const override;
bool hasCuBLASLt() const override;
bool hasROCM() const override;
const at::cuda::NVRTC& nvrtc() const override;
DeviceIndex current_device() const override;
bool hasPrimaryContext(DeviceIndex device_index) const override;
Allocator* getCUDADeviceAllocator() const override;
Allocator* getPinnedMemoryAllocator() const override;
bool compiledWithCuDNN() const override;
bool compiledWithMIOpen() const override;
bool supportsDilatedConvolutionWithCuDNN() const override;
bool supportsDepthwiseConvolutionWithCuDNN() const override;
bool supportsBFloat16ConvolutionWithCuDNNv8() const override;
bool hasCUDART() const override;
long versionCUDART() const override;
long versionCuDNN() const override;
std::string showConfig() const override;
double batchnormMinEpsilonCuDNN() const override;
int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override;
void cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const override;
int64_t cuFFTGetPlanCacheSize(DeviceIndex device_index) const override;
void cuFFTClearPlanCache(DeviceIndex device_index) const override;
int getNumGPUs() const override;
#ifdef USE_ROCM
bool isGPUArch(DeviceIndex device_index, const std::vector<std::string>& archs) const override;
#endif
void deviceSynchronize(DeviceIndex device_index) const override;
};
} // at::cuda::detail

View File

@ -0,0 +1,151 @@
// Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states.
// These handles are tied to device, and these libraries requires/recommends not to
// share handles across host threads.
//
// These libraries recommend using one handle per host thread. We may not want to do
// this because threads are relatively light-weight, but creating and destroying
// handles is expensive (destroying the handle causes synchronizations). DataParallel,
// for example, creates new threads for each forward pass.
//
// This file implements a handle pool mechanism. The handle pool returns handles on
// demand as threads request them. If all existing handles in the pool are in use,
// it creates a new one. As threads terminate, they release handles back into the pool.
// In this way, the handle pool never creates more handles than the high-water mark of
// active threads, so it's efficient with DataParallel.
#pragma once
#include <unordered_map>
#include <vector>
#include <utility>
#include <mutex>
#include <memory>
#include <c10/util/Exception.h>
namespace at::cuda { namespace {
template <typename Handle_t, void Create(Handle_t *), void Destroy(Handle_t)>
struct DeviceThreadHandlePool : public std::enable_shared_from_this<DeviceThreadHandlePool<Handle_t, Create, Destroy>> {
struct Handle {
Handle_t handle;
Handle(bool create = false) : handle(nullptr)
{
if(create) Create(&handle);
}
// std::vector.emplace() and push_back() may route through temporaries and call
// copy/move constructors along the way. If this is the case, we don't want
// the destructors of temporaries to call cudnnDestroy on the handle.
// We can achieve safety (for the narrow case of stashing within std::vectors)
// by making Handle moveable but not copyable, and transferring handle ownership
// to the latest constructed object. This is not a substitute for full-blown
// reference counting, but reference counting may be overkill here.
// Another alternative is to wrap the saved Handles in unique_ptrs, i.e.,
// unordered_map<int, vector<unique_ptr<Handle>>> created_handles;
Handle(const Handle& rhs) = delete;
// Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom
Handle(Handle&& rhs) noexcept : Handle() { std::swap(handle, rhs.handle); }
// operator= takes argument by value
Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; }
~Handle() {
if(handle) Destroy(handle);
}
};
std::mutex mutex;
// Handles are lazily created as different threads request them,
// but are never destroyed until the end of the process.
// The maximum number of handles this process will create for each device is equal
// to the high-water mark of the number of concurrently active threads that request
// handles for that device.
// When threads terminate, they release their handles back into the pool for reuse.
// Otherwise, new handles would be created every time new threads were spawned,
// resulting in poor performance for Python modules that repeatedly or frequently
// spawned new sets of threads (like DataParallel, which creates a new set of threads
// for each forward pass).
//
// To prevent potential deadlocks, we explicitly choose not to cap the number
// of handles that are created per device.
// Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device,
// only 4 can make forward progress at any time. The other 4 will not release their
// handles until they exit, so the fifth cannot make progress until then. This is
// not a problem...UNLESS all 5 threads attempt some sort of synchronization at an
// intermediate point (ie, before any of them have exited). We have no way to anticipate
// or enforce that user threads will not attempt such intermediate synchronization.
// The only way to ensure safety is to avoid imposing a cap on the number of handles.
std::unordered_map<int, std::vector<Handle>> created_handles;
std::unordered_map<int, std::vector<Handle_t>> available_handles;
// PoolWindow lazily creates and caches the handles that a particular thread is using,
// so in the common case handle access doesn't incur either handle creation or a mutex lock.
class PoolWindow
{
public:
PoolWindow(std::shared_ptr<DeviceThreadHandlePool> parent): weak_parent(std::move(parent)) {}
~PoolWindow(){ release(); }
Handle_t reserve(int device)
{
// If this thread already has a handle for this device, return it
if(my_handles.find(device) != my_handles.end())
return my_handles[device];
// otherwise, either grab a handle from the pool if one is available,
// or if not, create a new one.
auto parent = weak_parent.lock();
TORCH_CHECK(parent, "Cannot create handle during program termination");
std::lock_guard<std::mutex> guard(parent->mutex);
if(parent->available_handles[device].size() > 0)
{
my_handles[device] = parent->available_handles[device].back();
parent->available_handles[device].pop_back();
}
else
{
// In local testing, I do observe that emplace_back sometimes routes through temporaries
// that incur move-constructor and destructor calls. See comments in Handle above.
parent->created_handles[device].emplace_back(true /*create*/);
my_handles[device] = parent->created_handles[device].back().handle;
}
return my_handles[device];
}
private:
// Stores the per-device handles currently owned by this thread
std::unordered_map<int, Handle_t> my_handles;
std::weak_ptr<DeviceThreadHandlePool> weak_parent;
// Called by the destructor. Releases this thread's handles back into the pool.
void release() {
if(my_handles.size() > 0) {
auto parent = weak_parent.lock();
if (!parent) {
// If this thread exits after atexit handlers have completed, the
// cuda context itself may be invalid, so we must leak the handles.
return;
}
std::lock_guard<std::mutex> guard(parent->mutex);
for(auto d_h : my_handles)
parent->available_handles[d_h.first].push_back(d_h.second);
}
}
};
// Warning:
// If you want to change this function, be aware that this function will be called
// by multiple threads and there is no mutex guarding the call of this function, so
// make sure your implementation is thread-safe.
PoolWindow *newPoolWindow() {
// The returned pointer will be owned by a thread local variable
// so that different threads does not share the same PoolWindow.
return new PoolWindow(this->shared_from_this());
}
};
}} // namespace at::cuda::detail::<anonymous>

View File

@ -0,0 +1,36 @@
#pragma once
#include <ATen/core/TensorBase.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/native/CanUse32BitIndexMath.h>
namespace at::cuda::detail {
TORCH_CUDA_CU_API bool maybeOverlappingIndices(const at::TensorBase &t);
using at::native::canUse32BitIndexMath;
template <typename scalar, typename IndexType>
TensorInfo<scalar, IndexType>
getTensorInfo(const at::TensorBase &t) {
IndexType sz[MAX_TENSORINFO_DIMS];
IndexType st[MAX_TENSORINFO_DIMS];
int dims = t.dim();
for (int i = 0; i < dims; ++i) {
sz[i] = t.size(i);
st[i] = t.stride(i);
}
scalar* data_ptr = nullptr;
if constexpr (std::is_const<scalar>::value) {
data_ptr = t.const_data_ptr<scalar>();
} else {
data_ptr = t.mutable_data_ptr<scalar>();
}
return TensorInfo<scalar, IndexType>(
data_ptr, dims, sz, st);
}
} // namespace at::cuda::detail

View File

@ -0,0 +1,124 @@
#pragma once
#include <assert.h>
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
#include <cuda_runtime.h>
#endif
namespace at::cuda::detail {
// A utility class to implement integer division by multiplication, given a fixed
// divisor.
//
// WARNING: The fast divider algorithm is only implemented for unsigned int;
// otherwise we default to plain integer division. For unsigned int,
// we further assume that the dividend is at most INT32_MAX. Thus,
// IntDivider must NOT be used for general integer division.
//
// This reduced range is enough for our purpose, and it allows us to
// slightly simplify the computation.
//
// (NOTE: Below, "2^k" denotes exponentiation, i.e., 1<<k.)
//
// For any N-bit unsigned integer d (> 0), we can find a "magic number" m (2^N
// <= m < 2^(N+1)) and shift s such that:
//
// \floor(n / d) = \floor((m * n) / 2^(N+s)).
//
// Given such m and s, the integer division can be then implemented as:
//
// let m' = m - 2^N // 0 <= m' < 2^N
//
// fast_integer_division(n):
// // Multiply two N-bit unsigned integers: the result is a 2N-bit unsigned
// // integer. Then take the higher N bits.
// t = (m' * n) >> N
//
// // Here we use the fact that n is less than 2^(N-1): otherwise the value
// // of (t + n) may not fit in an N-bit integer.
// return (t + n) >> s
//
// Finding such a magic number is surprisingly easy:
//
// s = \ceil(\log_2 d)
// m' = \floor(2^N * (2^s - d) / d) + 1 // Need 2N-bit integer arithmetic.
//
// See also:
// - Division by Invariant Integers Using Multiplication,
// Torbjörn Granlund and Peter L. Montgomery, 1994.
//
// - http://www.hackersdelight.org/magic.htm
//
// - http://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html
// Result of div/mod operation stored together.
template <typename Value>
struct DivMod {
Value div, mod;
C10_HOST_DEVICE DivMod(Value div, Value mod) : div(div), mod(mod) { }
};
// Base case: we only have an implementation for uint32_t for now. For
// everything else, we use plain division.
template <typename Value>
struct IntDivider {
IntDivider() = default;
IntDivider(Value d) : divisor(d) { }
C10_HOST_DEVICE inline Value div(Value n) const { return n / divisor; }
C10_HOST_DEVICE inline Value mod(Value n) const { return n % divisor; }
C10_HOST_DEVICE inline DivMod<Value> divmod(Value n) const {
return DivMod<Value>(n / divisor, n % divisor);
}
Value divisor;
};
// Implement fast integer division.
template <>
struct IntDivider<unsigned int> {
static_assert(sizeof(unsigned int) == 4, "Assumes 32-bit unsigned int.");
IntDivider() = default;
IntDivider(unsigned int d) : divisor(d) {
assert(divisor >= 1 && divisor <= INT32_MAX);
// TODO: gcc/clang has __builtin_clz() but it's not portable.
for (shift = 0; shift < 32; shift++) if ((1U << shift) >= divisor) break;
uint64_t one = 1;
uint64_t magic = ((one << 32) * ((one << shift) - divisor)) / divisor + 1;
m1 = magic;
assert(m1 > 0 && m1 == magic); // m1 must fit in 32 bits.
}
C10_HOST_DEVICE inline unsigned int div(unsigned int n) const {
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
// 't' is the higher 32-bits of unsigned 32-bit multiplication of 'n' and
// 'm1'.
unsigned int t = __umulhi(n, m1);
return (t + n) >> shift;
#else
// Using uint64_t so that the addition does not overflow.
uint64_t t = ((uint64_t) n * m1) >> 32;
return (t + n) >> shift;
#endif
}
C10_HOST_DEVICE inline unsigned int mod(unsigned int n) const {
return n - div(n) * divisor;
}
C10_HOST_DEVICE inline DivMod<unsigned int> divmod(unsigned int n) const {
unsigned int q = div(n);
return DivMod<unsigned int>(q, n - q * divisor);
}
unsigned int divisor; // d above.
unsigned int m1; // Magic number: m' above.
unsigned int shift; // Shift amounts.
};
} // namespace at::cuda::detail

View File

@ -0,0 +1,37 @@
#pragma once
#include <limits>
#include <c10/util/Exception.h>
namespace at::cuda::detail {
// CUDA: grid stride looping
//
// int64_t _i_n_d_e_x specifically prevents overflow in the loop increment.
// If input.numel() < INT_MAX, _i_n_d_e_x < INT_MAX, except after the final
// iteration of the loop where _i_n_d_e_x += blockDim.x * gridDim.x can be
// greater than INT_MAX. But in that case _i_n_d_e_x >= n, so there are no
// further iterations and the overflowed value in i=_i_n_d_e_x is not used.
#define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \
int64_t _i_n_d_e_x = blockIdx.x * blockDim.x + threadIdx.x; \
for (index_type i=_i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x+=blockDim.x * gridDim.x, i=_i_n_d_e_x)
#define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int)
// Use 1024 threads per block, which requires cuda sm_2x or above
constexpr int CUDA_NUM_THREADS = 1024;
// CUDA: number of blocks for threads.
inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block=CUDA_NUM_THREADS) {
TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N);
constexpr int64_t max_int = std::numeric_limits<int>::max();
// Round up division for positive number that cannot cause integer overflow
auto block_num = (N - 1) / max_threads_per_block + 1;
TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on CUDA device");
return static_cast<int>(block_num);
}
} // namespace at::cuda::detail

View File

@ -0,0 +1,11 @@
#pragma once
#include <ATen/detail/CUDAHooksInterface.h>
namespace at::cuda {
// Forward-declares at::cuda::NVRTC
struct NVRTC;
namespace detail {
extern NVRTC lazyNVRTC;
} // namespace detail
} // namespace at::cuda

View File

@ -0,0 +1,119 @@
#pragma once
#include <array>
#include <cstdint>
#include <type_traits>
#include <c10/macros/Macros.h>
#include <ATen/core/Array.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/cuda/detail/IntegerDivider.cuh>
// If element_sizes is nullptr, then the strides will be in bytes, otherwise
// the strides will be in # of elements.
// Operands that share the same shape, but may have different strides.
// OffsetCalculator iterates the tensor in a column-major order
#if defined(USE_ROCM)
constexpr int MAX_DIMS = 16;
#else
constexpr int MAX_DIMS = 25;
#endif
template <int NARGS, typename index_t = uint32_t, bool signed_strides = false>
struct OffsetCalculator {
// We allow having negative strides to implement some operations like torch.flip
using stride_t = std::conditional_t<signed_strides,
std::make_signed_t<index_t>,
index_t>;
// The offset for each argument. Wrapper around fixed-size array.
// On CUDA, zero sized array is not allowed, so when we are handling nullary
// operators, we need to create a size 1 offset to avoid compiler failure.
// This size 1 offset is just a placeholder, and we will not use it.
using offset_type = at::detail::Array<stride_t, std::max<int>(NARGS, 1)>;
// if element_sizes is nullptr, then the strides will be in bytes, otherwise
// the strides will be in # of elements.
OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides, const int64_t* element_sizes=nullptr) : dims(dims) {
TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims");
for (int i=0; i < dims; i++){
sizes_[i] = at::cuda::detail::IntDivider<index_t>(sizes[i]);
for (int arg = 0; arg < NARGS; arg++) {
int64_t element_size = (element_sizes == nullptr ? 1LL : element_sizes[arg]);
strides_[i][arg] = strides[arg][i] / element_size;
}
}
}
C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
offset_type offsets;
#pragma unroll
for (int arg = 0; arg < NARGS; arg++) {
offsets[arg] = 0;
}
#pragma unroll
for (int dim = 0; dim < MAX_DIMS; ++dim) {
if (dim == dims) {
break;
}
auto divmod = sizes_[dim].divmod(linear_idx);
linear_idx = divmod.div;
#pragma unroll
for (int arg = 0; arg < NARGS; arg++) {
offsets[arg] += divmod.mod * strides_[dim][arg];
}
}
return offsets;
}
int dims;
at::cuda::detail::IntDivider<index_t> sizes_[MAX_DIMS];
stride_t strides_[MAX_DIMS][std::max<int>(NARGS, 1)];
};
template <int NARGS, typename index_t = uint32_t>
struct TrivialOffsetCalculator {
// The offset for each argument. Wrapper around fixed-size array.
// The offsets are in # of elements, not in bytes.
// On CUDA, zero sized array is not allowed, so when we are handling nullary
// operators, we need to create a size 1 offset to avoid compiler failure.
// This size 1 offset is just a placeholder, and we will not use it.
using offset_type = at::detail::Array<index_t, std::max<int>(NARGS, 1)>;
C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
offset_type offsets;
#pragma unroll
for (int arg = 0; arg < NARGS; arg++) {
offsets[arg] = linear_idx;
}
return offsets;
}
};
// Make an OffsetCalculator with byte offsets
template<int N, bool signed_strides = false>
static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(const at::TensorIteratorBase& iter) {
TORCH_INTERNAL_ASSERT(N <= iter.ntensors());
std::array<const int64_t*, N> strides;
for (int i = 0; i < N; i++) {
strides[i] = iter.strides(i).data();
}
return OffsetCalculator<N, uint32_t, signed_strides>(iter.ndim(), iter.shape().data(), strides.data());
}
// Make an OffsetCalculator with element offsets
template<int N, bool signed_strides = false>
static OffsetCalculator<N, uint32_t, signed_strides> make_element_offset_calculator(
const at::TensorIteratorBase& iter) {
TORCH_INTERNAL_ASSERT(N <= iter.ntensors());
std::array<const int64_t*, N> strides;
std::array<int64_t, N> element_sizes;
for (int i = 0; i < N; i++) {
strides[i] = iter.strides(i).data();
element_sizes[i] = iter.element_size(i);
}
return OffsetCalculator<N, uint32_t, signed_strides>(
iter.ndim(), iter.shape().data(), strides.data(), element_sizes.data());
}

View File

@ -0,0 +1,43 @@
// No "#pragma once" because this is a raw definition that can be copied by jit codegen.
// Eager mode clients should not include this file directly, instead,
// they should #include <ATen/cuda/PhiloxCudaState.h>, which has a #pragma once.
// Stores RNG state values. Passed as a kernel argument.
// See Note [CUDA Graph-safe RNG states].
//
// The raw definition lives in its own file so jit codegen can easily copy it.
namespace at {
struct PhiloxCudaState {
PhiloxCudaState() = default;
// Called if graph capture is not underway
PhiloxCudaState(uint64_t seed,
uint64_t offset) {
seed_.val = seed;
offset_.val = offset;
}
// Called if graph capture is underway
PhiloxCudaState(int64_t* seed,
int64_t* offset_extragraph,
uint32_t offset_intragraph) {
seed_.ptr = seed;
offset_.ptr = offset_extragraph;
offset_intragraph_ = offset_intragraph;
captured_ = true;
}
// Public members, directly accessible by at::cuda::philox::unpack.
// If we made them private with getters/setters, the getters/setters
// would have to be __device__, and we can't declare __device__ in ATen.
union Payload {
uint64_t val;
int64_t* ptr;
};
Payload seed_{};
Payload offset_{};
uint32_t offset_intragraph_ = 0;
bool captured_ = false;
};
} // namespace at

View File

@ -0,0 +1,116 @@
#pragma once
#include <ATen/CollapseDims.h>
namespace at::cuda::detail {
#define MAX_TENSORINFO_DIMS 25
// CUDA kernel argument that defines tensor layout
template <typename T, typename IndexType>
struct TensorInfo {
TensorInfo();
TensorInfo(T* p,
int dim,
IndexType sz[MAX_TENSORINFO_DIMS],
IndexType st[MAX_TENSORINFO_DIMS]);
// Set the size of the given dimension to 1, as if it were a
// reduction dim (allows you to calculate offsets of the reduction
// slice)
void reduceDim(int dim);
// See note on [collapse dims].
int collapseDims(const int excludeDim = -1);
// Contiguous tensors of more than one dimension are collapsed down
// to one tensor
__host__ __device__ inline bool isContiguous() const {
return (dims == 1 && strides[0] == 1);
}
T* data;
IndexType sizes[MAX_TENSORINFO_DIMS];
IndexType strides[MAX_TENSORINFO_DIMS];
int dims;
};
template <typename T, typename IndexType>
TensorInfo<T, IndexType>::TensorInfo() {
data = nullptr;
dims = 0;
}
template <typename T, typename IndexType>
TensorInfo<T, IndexType>::TensorInfo(T* p,
int dim,
IndexType sz[MAX_TENSORINFO_DIMS],
IndexType st[MAX_TENSORINFO_DIMS]) {
data = p;
dims = dim;
TORCH_CHECK(dims < MAX_TENSORINFO_DIMS, "CUDA Tensors cannot have more than 25 dimensions");
for (int i = 0; i < dim; ++i) {
sizes[i] = sz[i];
strides[i] = st[i];
}
}
template <typename T, typename IndexType>
void
TensorInfo<T, IndexType>::reduceDim(int dim) {
TORCH_CHECK(dim < dims && dim >= 0, "expected dim between 0 and dims - 1");
sizes[dim] = 1;
}
template <typename T, typename IndexType>
int
TensorInfo<T, IndexType>::collapseDims(const int excludeDim) {
auto result = at::collapse_dims(sizes, strides, dims, excludeDim);
dims = std::get<1>(result);
return std::get<0>(result);
}
// Translate a linear index for the apply to a T* offset;
// specialized on `Dims` to reduce nvcc compilation time
template <typename T, typename IndexType, int Dims>
struct IndexToOffset {
static __host__ __device__ IndexType get(
IndexType linearId,
const TensorInfo<T, IndexType>& info) {
IndexType offset = 0;
// Uses static dims
for (int i = Dims - 1; i > 0; --i) {
IndexType curDimIndex = linearId % info.sizes[i];
IndexType curDimOffset = curDimIndex * info.strides[i];
offset += curDimOffset;
linearId /= info.sizes[i];
}
return offset + linearId * info.strides[0];
}
};
// Uses dynamic (runtime) instead of static (compiletime) dims
template <typename T, typename IndexType>
struct IndexToOffset<T, IndexType, -1> {
static inline __host__ __device__ IndexType get(
IndexType linearId,
const TensorInfo<T, IndexType>& info) {
IndexType offset = 0;
for (int i = info.dims - 1; i > 0; --i) {
IndexType curDimIndex = linearId % info.sizes[i];
IndexType curDimOffset = curDimIndex * info.strides[i];
offset += curDimOffset;
linearId /= info.sizes[i];
}
return offset + linearId * info.strides[0];
}
};
} // namespace at::cuda::detail

View File

@ -0,0 +1,28 @@
// No "#pragma once" because this is a raw definition that can be copied by jit codegen.
// Eager mode clients should not include this file directly, instead,
// they should #include <ATen/cuda/PhiloxUtils.cuh>, which has a #pragma once.
namespace at::cuda::philox {
// In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether
// that instance was created with graph capture underway or not.
// See Note [CUDA Graph-safe RNG states].
//
// We can't write a __device__ function in CUDAGeneratorImpl.h, because it's in ATen.
// Also, whatever call unpacks PhiloxCudaState in consumer kernels must be inlineable.
// Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda.
//
// The raw definition lives in its own file so jit codegen can easily copy it.
__host__ __device__ __forceinline__ std::tuple<uint64_t, uint64_t>
unpack(at::PhiloxCudaState arg) {
if (arg.captured_) {
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
// For most threads' reads it will hit in cache, so it shouldn't hurt performance.
return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
} else {
return std::make_tuple(arg.seed_.val, arg.offset_.val);
}
}
} // namespace at::cuda::philox

View File

@ -0,0 +1,40 @@
#pragma once
#include <ATen/jit_macros.h>
#if AT_USE_JITERATOR()
#include <c10/macros/Export.h>
#include <c10/util/SmallVector.h>
#include <ATen/core/Tensor.h>
#include <string>
#include <vector>
namespace at::cuda {
TORCH_CUDA_CPP_API c10::SmallVector<at::Tensor> CompileAndLaunchKernel(
const std::string& code_string,
const std::string& kernel_name,
const int num_outputs,
const c10::SmallVector<at::Tensor>& tensors,
const c10::SmallVector<at::Scalar>& extra_args,
bool return_by_ref);
} // namespace at::cuda
#else
namespace at::cuda {
TORCH_CUDA_CPP_API c10::SmallVector<at::Tensor> CompileAndLaunchKernel(
const std::string& code_string,
const std::string& kernel_name,
const int num_outputs,
const c10::SmallVector<at::Tensor>& tensors,
const c10::SmallVector<at::Scalar>& extra_args,
bool return_by_ref) {
TORCH_CHECK(false, "Jiterator is not supported");
}
} // namespace at::cuda
#endif // AT_USE_JITERATOR()

View File

@ -0,0 +1,249 @@
#pragma once
#include <ATen/jit_macros.h>
#if AT_USE_JITERATOR()
#include <ATen/native/TensorIterator.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <ATen/native/cuda/JitLoops.cuh>
#include <string>
#include <variant>
#include <vector>
namespace at::native {
#define AT_FOR_8_CASES(_) \
_(1) \
_(2) \
_(3) \
_(4) \
_(5) \
_(6) \
_(7) \
_(8)
#define AT_FOR_8_CASES_WITH_COMMA(_) \
_(1) , \
_(2) , \
_(3) , \
_(4) , \
_(5) , \
_(6) , \
_(7) , \
_(8)
c10::SmallVector<std::string> get_extra_args_typenames(const c10::SmallVector<at::Scalar>& extra_args) {
c10::SmallVector<std::string> args_typenames(extra_args.size());
for (const auto i : c10::irange(extra_args.size())) {
args_typenames[i] = at::cuda::jit::typeName(extra_args[i].type());
}
return args_typenames;
}
int can_vectorize_up_to(at::ScalarType type, char* pointer) {
switch(type) {
#define DEFINE_CASE(ctype, scalartype) \
case ScalarType::scalartype : return memory::can_vectorize_up_to<ctype>(pointer);
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
#undef DEFINE_CASE
default: TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
}
}
// jitted version of the above
// See Note [Jiterator], this relies on the assumptions enumerated there
int jitted_can_vectorize_up_to(const TensorIteratorBase& iter) {
const at::ScalarType common_dtype = iter.common_dtype();
const at::ScalarType result_dtype = common_dtype;
// Deals with output
int result = can_vectorize_up_to(result_dtype, static_cast<char*>(iter.data_ptr(0)));
// Incorporates input(s)
for (auto i = 1; i < iter.ntensors(); ++i) {
result = std::min<int>(result, can_vectorize_up_to(common_dtype, static_cast<char*>(iter.data_ptr(i))));
}
return result;
}
template<bool IS_INPUT, int N>
static std::unique_ptr<OffsetCalculator<N>> make_unique_offset_calculator(
const TensorIteratorBase& iter) {
// array size can not be 0, this happens when N == 0
constexpr int array_size = std::max<int>(N, 1);
TORCH_INTERNAL_ASSERT(N == (IS_INPUT ? iter.ninputs() : iter.noutputs()));
std::array<const int64_t*, array_size> strides;
int64_t element_sizes[array_size];
for (int i = 0; i < N; i++) {
int index = IS_INPUT ? i + iter.noutputs() : i;
strides[i] = iter.strides(index).data();
element_sizes[i] = iter.element_size(index);
}
return std::make_unique<OffsetCalculator<N>>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
}
template <bool IS_INPUT>
struct OffsetCalculatorVariant {
#define DEFINE_CASE(index) std::unique_ptr<OffsetCalculator<index>>
using OffsetCalculatorTypes = std::variant<
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
>;
#undef DEFINE_CASE
OffsetCalculatorVariant(const TensorIteratorBase& iter) {
int num = IS_INPUT ? iter.ninputs() : iter.noutputs();
switch(num) {
#define DEFINE_CASE(index) \
case index : v = make_unique_offset_calculator<IS_INPUT, index>(iter); break;
AT_FOR_8_CASES(DEFINE_CASE)
#undef DEFINE_CASE
default:
TORCH_CHECK(false, "OffsetCalculatorVariant is not implemented for num_tensor = ", num);
}
}
void* data_ptr() {
return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
}
private:
OffsetCalculatorTypes v{};
};
struct ArrayVariant {
// works for up to 8 input + 8 outputs
#define DEFINE_CASE(index) at::detail::Array<char*, index>, at::detail::Array<char*, index+8>
using ArrayTypes = std::variant<
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
>;
#undef DEFINE_CASE
ArrayVariant(const TensorIteratorBase& iter) {
int ntensors = iter.ntensors();
switch(ntensors) {
#define DEFINE_CASE(index) \
case index: array = at::detail::Array<char*, index>{}; break; \
case index+8: array = at::detail::Array<char*, index+8>{}; break;
AT_FOR_8_CASES(DEFINE_CASE)
#undef DEFINE_CASE
default:
TORCH_CHECK(false, "ArrayVariant is not implemented for ntensors = ", ntensors);
}
std::visit([&](auto& a) {
for (auto i = 0; i < ntensors; ++i) {
a[i] = (char*)iter.data_ptr(i);
}
}, array);
}
void* data_ptr() {
return std::visit([](auto & a){ return static_cast<void*>(&a); }, array);
}
private:
ArrayTypes array;
};
struct TrivialOffsetCalculatorVariant {
#define DEFINE_CASE(index) TrivialOffsetCalculator<index>
using TrivialOffsetCalculatorTypes = std::variant<
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
>;
#undef DEFINE_CASE
TrivialOffsetCalculatorVariant(int num) {
switch(num) {
#define DEFINE_CASE(index) \
case index: v = TrivialOffsetCalculator<index>(); break;
AT_FOR_8_CASES(DEFINE_CASE)
#undef DEFINE_CASE
default:
TORCH_CHECK(false, "TrivialOffsetCalculatorVariant is not implemented for num_tensors = ", num);
}
}
void* data_ptr() {
return std::visit([](auto & v){ return static_cast<void*>(&v); }, v);
}
private:
TrivialOffsetCalculatorTypes v{};
};
struct LoadWithCastVariant {
#define DEFINE_CASE(index) std::unique_ptr<memory::LoadWithCast<index>>
using LoadWithCastPtr = std::variant<
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
>;
#undef DEFINE_CASE
LoadWithCastVariant(const TensorIteratorBase& iter) {
int arity = iter.ninputs();
switch(arity) {
#define DEFINE_CASE(index) \
case index: v = std::make_unique<memory::LoadWithCast<index>>(iter); break;
AT_FOR_8_CASES(DEFINE_CASE)
#undef DEFINE_CASE
default:
TORCH_CHECK(false, "LoadWithCastVariant is not implemented for ninputs = ", arity);
}
}
void* data_ptr() {
return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
}
private:
LoadWithCastPtr v{};
};
struct StoreWithCastVariant {
#define DEFINE_CASE(index) std::unique_ptr<memory::StoreWithCast<index>>
using StoreWithCastPtr = std::variant<
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
>;
#undef DEFINE_CASE
StoreWithCastVariant(const TensorIteratorBase& iter) {
int num = iter.noutputs();
switch(num) {
#define DEFINE_CASE(index) \
case index: v = std::make_unique<memory::StoreWithCast<index>>(iter); break;
AT_FOR_8_CASES(DEFINE_CASE)
#undef DEFINE_CASE
default:
TORCH_CHECK(false, "StoreWithCastVariant is not implemented for noutputs = ", num);
}
}
void* data_ptr() {
return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
}
private:
StoreWithCastPtr v{};
};
} // namespace at::native
#endif // AT_USE_JITERATOR()

View File

@ -0,0 +1,14 @@
#pragma once
#include <string>
#include <c10/macros/Export.h>
namespace at::cuda {
TORCH_CUDA_CPP_API const std::string &get_traits_string();
TORCH_CUDA_CPP_API const std::string &get_cmath_string();
TORCH_CUDA_CPP_API const std::string &get_complex_body_string();
TORCH_CUDA_CPP_API const std::string &get_complex_half_body_string();
TORCH_CUDA_CPP_API const std::string &get_complex_math_string();
} // namespace at::cuda

View File

@ -0,0 +1,397 @@
// Original TunableOp is from onnxruntime.
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
// Adapting TunableOp into PyTorch
// Copyright (c) Advanced Micro Devices, Inc.
//
#pragma once
#include <string>
#include <ATen/cuda/tunable/TunableOp.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/util/StringUtil.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/allclose.h>
#include <ATen/ops/from_blob.h>
#endif
namespace at::cuda::tunable {
enum class BlasOp {
N = 0,
T = 1
};
inline std::string BlasOpToString(BlasOp op) {
switch (op) {
case BlasOp::N:
return "N";
case BlasOp::T:
return "T";
}
TORCH_CHECK(false, "unrecognized BlasOp");
return "N";
}
namespace detail {
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
// comparison done as 1D tensor
at::Tensor ref = at::from_blob(c, {size}, options);
at::Tensor oth = at::from_blob(other_c, {size}, options);
at::Tensor ref_float = ref.to(at::kFloat);
at::Tensor oth_float = oth.to(at::kFloat);
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
double last_succeed_atol = 1;
double last_succeed_rtol = 1;
for (auto& atol : atols) {
for (auto& rtol : rtols) {
if (at::allclose(ref_float, oth_float, rtol, atol)) {
last_succeed_atol = atol;
last_succeed_rtol = rtol;
}
}
}
if (last_succeed_atol == 1) {
return false;
}
else {
TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
}
return true;
}
}
template <typename T>
struct GemmParams : OpParams {
GemmParams() {
duplicate_inputs_ = false;
}
std::string Signature() const override {
return c10::str(transa, transb, "_", m, "_", n, "_", k);
}
size_t GetSizeA() const {
return sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
}
size_t GetSizeB() const {
return sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
}
size_t GetSizeC() const {
return sizeof(T) * ldc * n;
}
size_t GetSize(bool duplicate_inputs) const {
size_t size = GetSizeC();
if (duplicate_inputs) {
size += GetSizeA();
size += GetSizeB();
}
return size;
}
GemmParams* DeepCopy(bool duplicate_inputs) const {
GemmParams* copy = new GemmParams;
*copy = *this;
c10::DeviceIndex device = 0;
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
size_t c_size = GetSizeC();
copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
if (duplicate_inputs) {
size_t a_size = GetSizeA();
size_t b_size = GetSizeB();
copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
copy->duplicate_inputs_ = true;
}
return copy;
}
// only call on object returned by DeepCopy
void Delete() {
c10::cuda::CUDACachingAllocator::raw_delete(c);
if (duplicate_inputs_) {
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
}
}
TuningStatus NumericalCheck(GemmParams<T> *other) {
auto c_dtype = c10::CppTypeToScalarType<T>::value;
return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL;
}
char transa;
char transb;
int64_t m;
int64_t n;
int64_t k;
at::opmath_type<T> alpha;
const T* a;
int64_t lda;
const T* b;
int64_t ldb;
at::opmath_type<T> beta;
T* c;
int64_t ldc;
private:
bool duplicate_inputs_;
};
template <typename T>
struct GemmAndBiasParams : OpParams {
std::string Signature() const override {
return c10::str(transa, transb, "_", m, "_", n, "_", k);
}
size_t GetSize(bool duplicate_inputs) const {
size_t size = sizeof(T) * ldc * n;
if (duplicate_inputs) {
size += sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
size += sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
}
return size;
}
GemmAndBiasParams* DeepCopy(bool duplicate_inputs) const {
GemmAndBiasParams* copy = new GemmAndBiasParams;
*copy = *this;
c10::DeviceIndex device = 0;
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
size_t c_size = ldc * n * sizeof(T);
copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
if (duplicate_inputs) {
size_t a_size = sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
size_t b_size = sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
copy->duplicate_inputs_ = true;
}
return copy;
}
// only call on object returned by DeepCopy
void Delete() {
c10::cuda::CUDACachingAllocator::raw_delete(c);
if (duplicate_inputs_) {
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
}
}
TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
auto c_dtype = c10::CppTypeToScalarType<T>::value;
return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL;
}
char transa;
char transb;
int64_t m;
int64_t n;
int64_t k;
at::opmath_type<T> alpha;
const T* a;
int64_t lda;
const T* b;
int64_t ldb;
T* c;
int64_t ldc;
const T* bias;
at::cuda::blas::GEMMAndBiasActivationEpilogue activation;
private:
bool duplicate_inputs_;
};
template <typename T>
struct GemmStridedBatchedParams : OpParams {
GemmStridedBatchedParams() {
duplicate_inputs_ = false;
}
std::string Signature() const override {
return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch);
}
size_t GetSizeA() const {
return sizeof(T) * std::min(lda, stride_a) * ((transa == 'n' || transa == 'N') ? k : m) * batch;
}
size_t GetSizeB() const {
return sizeof(T) * std::min(ldb, stride_b) * ((transb == 'n' || transb == 'N') ? n : k) * batch;
}
size_t GetSizeC() const {
return sizeof(T) * std::min(ldc, stride_c) * n * batch;
}
size_t GetSize(bool duplicate_inputs) const {
size_t size = GetSizeC();
if (duplicate_inputs) {
size += GetSizeA();
size += GetSizeB();
}
return size;
}
GemmStridedBatchedParams* DeepCopy(bool duplicate_inputs) const {
GemmStridedBatchedParams* copy = new GemmStridedBatchedParams;
*copy = *this;
c10::DeviceIndex device = 0;
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
size_t c_size = GetSizeC();
copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
if (duplicate_inputs) {
size_t a_size = GetSizeA();
size_t b_size = GetSizeB();
copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
copy->duplicate_inputs_ = true;
}
return copy;
}
// only call on object returned by DeepCopy
void Delete() {
c10::cuda::CUDACachingAllocator::raw_delete(c);
if (duplicate_inputs_) {
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
}
}
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
auto c_dtype = c10::CppTypeToScalarType<T>::value;
return detail::NumericalCheck(c_dtype, c, other->c, batch*stride_c) ? OK : FAIL;
}
char transa;
char transb;
int64_t m;
int64_t n;
int64_t k;
at::opmath_type<T> alpha;
const T* a;
int64_t lda;
int64_t stride_a;
const T* b;
int64_t ldb;
int64_t stride_b;
at::opmath_type<T> beta;
T* c;
int64_t ldc;
int64_t stride_c;
int64_t batch;
private:
bool duplicate_inputs_;
};
template <typename T>
struct ScaledGemmParams : OpParams {
ScaledGemmParams() {
duplicate_inputs_ = false;
}
std::string Signature() const override {
return c10::str(transa, transb, "_", m, "_", n, "_", k);
}
size_t GetSizeA() const {
return sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
}
size_t GetSizeB() const {
return sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
}
size_t GetSizeC() const {
return sizeof(T) * ldc * n;
}
size_t GetSize(bool duplicate_inputs) const {
size_t size = GetSizeC();
if (duplicate_inputs) {
size += GetSizeA();
size += GetSizeB();
}
return size;
}
ScaledGemmParams* DeepCopy(bool duplicate_inputs) const {
ScaledGemmParams* copy = new ScaledGemmParams;
*copy = *this;
c10::DeviceIndex device = 0;
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
size_t c_size = GetSizeC();
copy->c = c10::cuda::CUDACachingAllocator::raw_alloc(c_size);
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
if (duplicate_inputs) {
size_t a_size = GetSizeA();
size_t b_size = GetSizeB();
copy->a = c10::cuda::CUDACachingAllocator::raw_alloc(a_size);
copy->b = c10::cuda::CUDACachingAllocator::raw_alloc(b_size);
copy->duplicate_inputs_ = true;
}
return copy;
}
// only call on object returned by DeepCopy
void Delete() {
c10::cuda::CUDACachingAllocator::raw_delete(c);
if (duplicate_inputs_) {
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<void*>(a));
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<void*>(b));
}
}
TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL;
}
char transa;
char transb;
int64_t m;
int64_t n;
int64_t k;
const void* a;
const void* a_scale_ptr;
int64_t lda;
ScalarType a_dtype;
const void* b;
const void* b_scale_ptr;
int64_t ldb;
ScalarType b_dtype;
const void* bias_ptr;
ScalarType bias_dtype;
void* c;
const void* c_scale_ptr;
int64_t ldc;
ScalarType c_dtype;
void* amax_ptr;
bool use_fast_accum;
private:
bool duplicate_inputs_;
};
} // namespace at::cuda::tunable

View File

@ -0,0 +1,611 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDADataType.h>
#include <ATen/cuda/tunable/TunableOp.h>
#include <ATen/cuda/tunable/GemmCommon.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/util/StringUtil.h>
#include <hipblaslt/hipblaslt.h>
#include <hipblaslt/hipblaslt-ext.hpp>
#define TORCH_HIPBLASLT_CHECK(EXPR) \
do { \
hipblasStatus_t __err = EXPR; \
TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS, \
"hipblaslt error: ", \
hipblasStatusToString(__err), \
" when calling `" #EXPR "`"); \
} while (0)
namespace at::cuda::tunable {
template <typename T>
constexpr hipblasDatatype_t HipDataTypeFor();
template <>
constexpr hipblasDatatype_t HipDataTypeFor<float>() {
return HIP_R_32F;
}
template <>
constexpr hipblasDatatype_t HipDataTypeFor<Half>() {
return HIP_R_16F;
}
template <>
constexpr hipblasDatatype_t HipDataTypeFor<BFloat16>() {
return HIP_R_16BF;
}
template <>
constexpr hipblasDatatype_t HipDataTypeFor<double>() {
return HIP_R_64F;
}
template <>
constexpr hipblasDatatype_t HipDataTypeFor<c10::Float8_e4m3fnuz>() {
return HIP_R_8F_E4M3_FNUZ;
}
template <>
constexpr hipblasDatatype_t HipDataTypeFor<c10::Float8_e5m2fnuz>() {
return HIP_R_8F_E5M2_FNUZ;
}
template <typename T>
int GetBatchFromParams(const GemmParams<T>* params) {
return 1;
}
template <typename T>
int GetBatchFromParams(const GemmAndBiasParams<T>* params) {
return 1;
}
template <typename T>
int GetBatchFromParams(const GemmStridedBatchedParams<T>* params) {
return params->batch;
}
template <typename T>
int GetBatchFromParams(const ScaledGemmParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideAFromParams(const GemmParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideAFromParams(const GemmAndBiasParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideAFromParams(const GemmStridedBatchedParams<T>* params) {
return params->stride_a;
}
template <typename T>
int GetStrideAFromParams(const ScaledGemmParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideBFromParams(const GemmParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideBFromParams(const GemmAndBiasParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideBFromParams(const GemmStridedBatchedParams<T>* params) {
return params->stride_b;
}
template <typename T>
int GetStrideBFromParams(const ScaledGemmParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideCFromParams(const GemmParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideCFromParams(const GemmAndBiasParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideCFromParams(const GemmStridedBatchedParams<T>* params) {
return params->stride_c;
}
template <typename T>
int GetStrideCFromParams(const ScaledGemmParams<T>* params) {
return 1;
}
template <typename T>
float GetAlphaFromParams(const GemmParams<T>* params) {
return params->alpha;
}
template <typename T>
float GetAlphaFromParams(const GemmAndBiasParams<T>* params) {
return params->alpha;
}
template <typename T>
float GetAlphaFromParams(const GemmStridedBatchedParams<T>* params) {
return params->alpha;
}
template <typename T>
float GetAlphaFromParams(const ScaledGemmParams<T>* params) {
return 1.0;
}
template <typename T>
float GetBetaFromParams(const GemmParams<T>* params) {
return params->beta;
}
template <typename T>
float GetBetaFromParams(const GemmAndBiasParams<T>* params) {
return 0.0;
}
template <typename T>
float GetBetaFromParams(const GemmStridedBatchedParams<T>* params) {
return params->beta;
}
template <typename T>
float GetBetaFromParams(const ScaledGemmParams<T>* params) {
return 0.0;
}
template <typename T>
const void* GetAScalePointerFromParams(const GemmParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetAScalePointerFromParams(const GemmAndBiasParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetAScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetAScalePointerFromParams(const ScaledGemmParams<T>* params) {
return params->a_scale_ptr;
}
template <typename T>
const void* GetBScalePointerFromParams(const GemmParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetBScalePointerFromParams(const GemmAndBiasParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetBScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetBScalePointerFromParams(const ScaledGemmParams<T>* params) {
return params->b_scale_ptr;
}
template <typename T>
const void* GetDScalePointerFromParams(const GemmParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetDScalePointerFromParams(const GemmAndBiasParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetDScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetDScalePointerFromParams(const ScaledGemmParams<T>* params) {
return params->c_scale_ptr;
}
template <typename T>
const void* GetBiasPointerFromParams(const GemmParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetBiasPointerFromParams(const GemmAndBiasParams<T>* params) {
return params->bias;
}
template <typename T>
const void* GetBiasPointerFromParams(const GemmStridedBatchedParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetBiasPointerFromParams(const ScaledGemmParams<T>* params) {
return params->bias_ptr;
}
template <typename T>
hipDataType GetBiasTypeFromParams(const GemmParams<T>* params) {
return HIP_R_32F;
}
template <typename T>
hipDataType GetBiasTypeFromParams(const GemmAndBiasParams<T>* params) {
return HipDataTypeFor<T>();
}
template <typename T>
hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams<T>* params) {
return HIP_R_32F;
}
template <typename T>
hipDataType GetBiasTypeFromParams(const ScaledGemmParams<T>* params) {
return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype);
}
template <typename T>
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmParams<T>* params) {
return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
}
template <typename T>
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmAndBiasParams<T>* params) {
return params->activation;
}
template <typename T>
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmStridedBatchedParams<T>* params) {
return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
}
template <typename T>
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const ScaledGemmParams<T>* params) {
return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
}
static hipblasOperation_t _hipblasOpFromChar(char op) {
switch (op) {
case 'n':
case 'N':
return HIPBLAS_OP_N;
case 't':
case 'T':
return HIPBLAS_OP_T;
case 'c':
case 'C':
return HIPBLAS_OP_C;
}
AT_ERROR(
"_hipblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
}
static char _charFromhipblasOp(hipblasOperation_t op) {
switch (op) {
case HIPBLAS_OP_N:
return 'N';
case HIPBLAS_OP_T:
return 'T';
case HIPBLAS_OP_C:
return 'C';
}
AT_ERROR(
"_charFromhipblasOp input should be HIPBLAS_OP_N/T/C but got `", op, "`");
}
static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) {
if (layout == BlasOp::N) {
return HIPBLAS_OP_N;
}
return HIPBLAS_OP_T;
}
static size_t GetHipblasltWorkspaceSize() {
static const char * env = getenv("HIPBLASLT_WORKSPACE_SIZE");
// 256MB is max workspace size allowed for hipblaslt
// hipblaslt-bench uses 32MB
// recommendation from hipblaslt author was 76MB
size_t workspace_size = 32*1024; // going with 32MB
if (env) {
try {
workspace_size = std::stoi(env);
} catch(std::invalid_argument const& e) {
TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,",
" using default workspace size of ", workspace_size, " KiB.");
} catch(std::out_of_range const& e) {
TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,",
" using default workspace size of ", workspace_size, " KiB.");
}
}
return workspace_size * 1024;
}
template <typename T, cublasStatus_t (*destructor)(T*)>
struct HipBlasLtDeleter {
void operator()(T* x) {
if (x != nullptr) {
TORCH_CUDABLAS_CHECK(destructor(x));
}
}
};
template <typename T, hipblasStatus_t (*destructor)(T*)>
class HipBlasLtDescriptor {
public:
T* descriptor() const {
return descriptor_.get();
}
T* descriptor() {
return descriptor_.get();
}
protected:
std::unique_ptr<T, HipBlasLtDeleter<T, destructor>> descriptor_;
};
class HipBlasLtMatmulDescriptor : public HipBlasLtDescriptor<
hipblasLtMatmulDescOpaque_t,
&hipblasLtMatmulDescDestroy> {
public:
HipBlasLtMatmulDescriptor(
hipblasComputeType_t compute_type,
hipDataType scale_type) {
hipblasLtMatmulDesc_t raw_descriptor = nullptr;
TORCH_HIPBLASLT_CHECK(
hipblasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(hipblasLtMatmulDescAttributes_t attr, const T value) {
TORCH_HIPBLASLT_CHECK(::hipblasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
class HipblasltGemmOp : public Callable<ParamsT> {
public:
HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {}
TuningStatus Call(const ParamsT* params) override {
hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
auto a_datatype = HipDataTypeFor<AT>();
auto b_datatype = HipDataTypeFor<BT>();
auto in_out_datatype = HipDataTypeFor<CT>();
auto opa = _hipblasOpFromChar(params->transa);
auto opb = _hipblasOpFromChar(params->transb);
TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen");
float alpha = GetAlphaFromParams<CT>(params);
float beta = GetBetaFromParams<CT>(params);
hipblasLtMatrixLayout_t mat_a, mat_b, mat_c;
if (opa == HIPBLAS_OP_N) {
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->m, params->k, params->lda));
}
else {
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->k, params->m, params->lda));
}
if (opb == HIPBLAS_OP_N) {
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->k, params->n, params->ldb));
}
else {
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->n, params->k, params->ldb));
}
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc));
// specific to batched gemmm
int batch = GetBatchFromParams<CT>(params);
if (batch > 1) {
int64_t stride_a = GetStrideAFromParams<CT>(params);
int64_t stride_b = GetStrideBFromParams<CT>(params);
int64_t stride_c = GetStrideCFromParams<CT>(params);
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
}
HipBlasLtMatmulDescriptor matmul(HIPBLAS_COMPUTE_32F, HIP_R_32F);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb);
// specific to scaled gemm
const void* mat1_scale_ptr = GetAScalePointerFromParams<CT>(params);
const void* mat2_scale_ptr = GetBScalePointerFromParams<CT>(params);
const void* result_scale_ptr = GetDScalePointerFromParams<CT>(params);
if (mat1_scale_ptr && mat2_scale_ptr) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
}
if (result_scale_ptr) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
}
const void* bias_ptr = GetBiasPointerFromParams<CT>(params);
auto bias_datatype = GetBiasTypeFromParams<CT>(params);
if (bias_ptr) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype);
auto activation = GetActivationFromParams<CT>(params);
if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS);
}
else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS);
}
else {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS);
}
}
size_t workspace_size = GetHipblasltWorkspaceSize();
auto op_handle = at::cuda::getCurrentCUDABlasLtHandle();
size_t ret_workspace_size = 0;
auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle,
matmul.descriptor(),
&alpha,
mat_a,
mat_b,
&beta,
mat_c,
mat_c,
algo_,
ret_workspace_size);
if (status == HIPBLAS_STATUS_SUCCESS) {
if (ret_workspace_size >= workspace_size) {
return FAIL;
}
}
else {
return FAIL;
}
void* workspace_buffer = nullptr;
if (workspace_size > 0) {
workspace_buffer = c10::cuda::CUDACachingAllocator::raw_alloc(workspace_size);
}
TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle,
matmul.descriptor(),
&alpha,
params->a,
mat_a,
params->b,
mat_b,
&beta,
params->c,
mat_c,
params->c,
mat_c,
&algo_,
workspace_buffer,
workspace_size,
at::cuda::getCurrentCUDAStream()));
//TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul));
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a));
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b));
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c));
if (workspace_size > 0) {
c10::cuda::CUDACachingAllocator::raw_delete(workspace_buffer);
}
return OK;
}
private:
hipblasLtMatmulAlgo_t algo_;
};
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
auto GetHipBlasLtTypeStringAndOps() {
hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
auto a_datatype = HipDataTypeFor<AT>();
auto b_datatype = HipDataTypeFor<BT>();
auto in_out_datatype = HipDataTypeFor<CT>();
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
hipblasLtHandle_t handle;
TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle));
TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle,
hipblaslt_ext::GemmType::HIPBLASLT_GEMM,
transa_outer,
transb_outer,
a_datatype,
b_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
heuristic_result));
TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));
// Sort heuristic_result by algo index to make sure the order of returned algos is deterministic.
std::sort(heuristic_result.begin(),
heuristic_result.end(),
[](hipblasLtMatmulHeuristicResult_t& a, hipblasLtMatmulHeuristicResult_t& b) {
return hipblaslt_ext::getIndexFromAlgo(a.algo) < hipblaslt_ext::getIndexFromAlgo(b.algo);
});
int returned_algo_count = heuristic_result.size();
std::vector<std::pair<std::string, std::unique_ptr<Callable<ParamsT>>>> ret;
for (int i = 0; i < returned_algo_count; i++) {
auto algo = heuristic_result[i].algo;
int algo_index = hipblaslt_ext::getIndexFromAlgo(algo);
auto callable = std::make_unique<HipblasltGemmOp<AT, BT, CT, ALayout, BLayout, ParamsT>>(algo);
std::string type_string = c10::str(
"Gemm_Hipblaslt_", _charFromhipblasOp(transa_outer), _charFromhipblasOp(transb_outer), "_", algo_index);
ret.emplace_back(type_string, std::move(callable));
}
return ret;
}
template <typename T, BlasOp ALayout, BlasOp BLayout>
auto GetHipBlasLtGemmTypeStringAndOps() {
return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmParams<T>>();
}
template <typename T, BlasOp ALayout, BlasOp BLayout>
auto GetHipBlasLtGemmAndBiasTypeStringAndOps() {
return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmAndBiasParams<T>>();
}
template <typename T, BlasOp ALayout, BlasOp BLayout>
auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() {
return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmStridedBatchedParams<T>>();
}
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
auto GetHipBlasLtScaledGemmTypeStringAndOps() {
return GetHipBlasLtTypeStringAndOps<AT, BT, CT, ALayout, BLayout, ScaledGemmParams<CT>>();
}
#undef TORCH_HIPBLASLT_CHECK
} // namespace at::cuda::tunable

View File

@ -0,0 +1,275 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/tunable/TunableOp.h>
#include <ATen/cuda/tunable/GemmCommon.h>
#include <c10/util/StringUtil.h>
#define ROCBLAS_BETA_FEATURES_API
#include <rocblas/rocblas.h>
#define TORCH_ROCBLAS_CHECK(EXPR) \
do { \
rocblas_status __err = EXPR; \
TORCH_CHECK(__err == rocblas_status_success, \
"rocblas error: ", \
rocblas_status_to_string(__err), \
" when calling `" #EXPR "`"); \
} while (0)
namespace at::cuda::tunable {
template <typename T>
constexpr rocblas_datatype RocBlasDataTypeFor();
template <>
constexpr rocblas_datatype RocBlasDataTypeFor<float>() {
return rocblas_datatype_f32_r;
}
template <>
constexpr rocblas_datatype RocBlasDataTypeFor<double>() {
return rocblas_datatype_f64_r;
}
template <>
constexpr rocblas_datatype RocBlasDataTypeFor<Half>() {
return rocblas_datatype_f16_r;
}
template <>
constexpr rocblas_datatype RocBlasDataTypeFor<BFloat16>() {
return rocblas_datatype_bf16_r;
}
template <>
constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<float>>() {
return rocblas_datatype_f32_c;
}
template <>
constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<double>>() {
return rocblas_datatype_f64_c;
}
template <typename T>
constexpr rocblas_datatype RocBlasComputeTypeFor();
template <>
constexpr rocblas_datatype RocBlasComputeTypeFor<float>() {
return rocblas_datatype_f32_r;
}
template <>
constexpr rocblas_datatype RocBlasComputeTypeFor<double>() {
return rocblas_datatype_f64_r;
}
template <>
constexpr rocblas_datatype RocBlasComputeTypeFor<Half>() {
// Note that we're returning the _compute_ type for a given datatype.
// As of 12/2022, using compute type FP16 for 16-bit floats was much
// slower than using compute type FP32. So we use FP32 compute even for
// FP16 datatypes. This is how GEMM is implemented even in the function
// rocblasGemmHelper (see fpgeneric.h)
return rocblas_datatype_f32_r;
}
template <>
constexpr rocblas_datatype RocBlasComputeTypeFor<BFloat16>() {
// Note that we're returning the _compute_ type for a given datatype.
// As of 12/2022, using compute type FP16 for 16-bit floats was much
// slower than using compute type FP32. So we use FP32 compute even for
// BF16 datatypes. This is how GEMM is implemented even in the function
// rocblasGemmHelper (see fpgeneric.h)
return rocblas_datatype_f32_r;
}
template <>
constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<float>>() {
return rocblas_datatype_f32_c;
}
template <>
constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<double>>() {
return rocblas_datatype_f64_c;
}
template <typename T>
auto DoCastForHalfOrBfloat16(const T fp) {
return fp;
}
template <>
inline auto DoCastForHalfOrBfloat16<Half>(const Half fp) {
// alpha and beta should be the same as compute_type, in Half case it is float.
float h = fp;
return h;
}
template <>
inline auto DoCastForHalfOrBfloat16<BFloat16>(const BFloat16 fp) {
// alpha and beta should be the same as compute_type, in bfloat16 case it is float.
float h = fp;
return h;
}
static rocblas_operation _rocblasOpFromChar(char op) {
switch (op) {
case 'n':
case 'N':
return rocblas_operation_none;
case 't':
case 'T':
return rocblas_operation_transpose;
case 'c':
case 'C':
return rocblas_operation_conjugate_transpose;
}
AT_ERROR(
"_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
}
template <typename T>
class RocblasGemmOp : public Callable<GemmParams<T>> {
public:
RocblasGemmOp(int solution) : solution_{solution} {}
TuningStatus Call(const GemmParams<T>* params) override {
auto input_output_type = RocBlasDataTypeFor<T>();
auto compute_type = RocBlasComputeTypeFor<T>();
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
auto h_b = DoCastForHalfOrBfloat16(params->beta);
auto status = rocblas_gemm_ex(
(rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
_rocblasOpFromChar(params->transa),
_rocblasOpFromChar(params->transb),
params->m, params->n, params->k,
&h_a,
params->a, input_output_type, params->lda,
params->b, input_output_type, params->ldb,
&h_b,
params->c, input_output_type, params->ldc,
params->c, input_output_type, params->ldc,
compute_type,
rocblas_gemm_algo_solution_index,
solution_,
rocblas_gemm_flags_none);
if (status != rocblas_status_success) {
return FAIL;
}
return OK;
}
private:
int solution_;
};
template <typename T>
auto GetRocBlasGemmTypeStringAndOps() {
rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
int solution_size;
auto input_output_type = RocBlasDataTypeFor<T>();
auto compute_type = RocBlasComputeTypeFor<T>();
// Get the number of available solutions
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
input_output_type,
input_output_type,
compute_type,
rocblas_gemm_flags_none,
nullptr,
&solution_size));
std::vector<int> solutions(solution_size);
// Get the list of available solutions
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
input_output_type,
input_output_type,
compute_type,
rocblas_gemm_flags_none,
solutions.data(),
&solution_size));
// Sort the solutions in ascending order to make the solution vector deterministic across runs
std::sort(solutions.begin(), solutions.end());
std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmParams<T>>>>> ret;
for (size_t i = 0; i < solutions.size(); ++i) {
auto callable = std::make_unique<RocblasGemmOp<T>>(solutions[i]);
ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
}
return ret;
}
template <typename T>
class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
public:
RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {}
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
auto input_output_type = RocBlasDataTypeFor<T>();
auto compute_type = RocBlasComputeTypeFor<T>();
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
auto h_b = DoCastForHalfOrBfloat16(params->beta);
auto status = rocblas_gemm_strided_batched_ex(
(rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
_rocblasOpFromChar(params->transa),
_rocblasOpFromChar(params->transb),
params->m, params->n, params->k,
&h_a,
params->a, input_output_type, params->lda, params->stride_a,
params->b, input_output_type, params->ldb, params->stride_b,
&h_b,
params->c, input_output_type, params->ldc, params->stride_c,
params->c, input_output_type, params->ldc, params->stride_c,
params->batch,
compute_type,
rocblas_gemm_algo_solution_index,
solution_,
rocblas_gemm_flags_none);
if (status != rocblas_status_success) {
return FAIL;
}
return OK;
}
private:
int solution_;
};
template <typename T>
auto GetRocBlasGemmStridedBatchedTypeStringAndOps() {
rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
int solution_size;
auto input_output_type = RocBlasDataTypeFor<T>();
auto compute_type = RocBlasComputeTypeFor<T>();
// Get the number of available solutions
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
input_output_type,
input_output_type,
compute_type,
rocblas_gemm_flags_none,
nullptr,
&solution_size));
std::vector<int> solutions(solution_size);
// Get the list of available solutions
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
input_output_type,
input_output_type,
compute_type,
rocblas_gemm_flags_none,
solutions.data(),
&solution_size));
// Sort the solutions in ascending order to make the solution vector deterministic across runs
std::sort(solutions.begin(), solutions.end());
std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmStridedBatchedParams<T>>>>> ret;
for (size_t i = 0; i < solutions.size(); ++i) {
auto callable = std::make_unique<RocblasGemmStridedBatchedOp<T>>(solutions[i]);
ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
}
return ret;
}
} // namespace at::cuda::tunable

View File

@ -0,0 +1,34 @@
// Original TunableOp is from onnxruntime.
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
// Adapting TunableOp into PyTorch
// Copyright (c) Advanced Micro Devices, Inc.
//
#pragma once
#include <cuda_runtime.h>
#include <ATen/cuda/tunable/Tunable.h>
namespace at::cuda::tunable {
class StreamTimer : public ITimer {
public:
StreamTimer();
virtual ~StreamTimer() override;
void Start() override;
void End() override;
float Duration() override;
private:
cudaEvent_t start_;
cudaEvent_t end_;
};
} // namespace at::cuda::tunable

View File

@ -0,0 +1,246 @@
// Original TunableOp is from onnxruntime.
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
// Adapting TunableOp into PyTorch
// Copyright (c) Advanced Micro Devices, Inc.
//
#pragma once
#include <c10/util/CallOnce.h>
#include <fstream>
#include <functional>
#include <iostream>
#include <memory>
#include <mutex>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>
namespace at::cuda::tunable {
namespace detail {
struct MaybeDelete {
bool owns_pointer;
void operator()(std::ostream* os) const { if (owns_pointer) delete os; }
};
using OstreamPtr = std::unique_ptr<std::ostream, MaybeDelete>;
static OstreamPtr get_stream(std::string filename) {
if (filename.compare("out") == 0) {
return OstreamPtr { &std::cout, MaybeDelete {false} };
}
else if (filename.compare("err") == 0) {
return OstreamPtr { &std::cerr, MaybeDelete {false} };
}
else {
return OstreamPtr { new std::ofstream {filename.c_str()}, MaybeDelete {true} };
}
}
}
static void TunableLog(int level, const std::string& msg) {
static const char *env_file = getenv("PYTORCH_TUNABLEOP_VERBOSE_FILENAME");
static const char *env_verbose = getenv("PYTORCH_TUNABLEOP_VERBOSE");
static int level_user = env_verbose ? atoi(env_verbose) : 0;
static auto streamptr = detail::get_stream(env_file ? env_file : "err");
if (level_user >= level) {
(*streamptr) << msg <<std::endl;
}
}
#define TUNABLE_LOGV(LEVEL, ...) TunableLog(LEVEL, c10::str(__VA_ARGS__))
#define TUNABLE_LOG1(...) TUNABLE_LOGV(1, __VA_ARGS__)
#define TUNABLE_LOG2(...) TUNABLE_LOGV(2, __VA_ARGS__)
#define TUNABLE_LOG3(...) TUNABLE_LOGV(3, __VA_ARGS__)
enum TORCH_CUDA_CPP_API TuningStatus {
OK = 0,
FAIL = 1,
UNSUPPORTED = 2,
};
// Mapping from params signature to kernel id
class TORCH_CUDA_CPP_API ResultEntry {
public:
explicit ResultEntry(const std::string& key, double time) : key_(key), time_(time) {}
bool operator==(const ResultEntry& other) { return key_ == other.key_; }
bool operator!=(const ResultEntry& other) { return key_ != other.key_; }
operator std::string () { return key_; }
std::string GetKey() const { return key_; }
double GetTime() const { return time_; }
friend std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry);
static ResultEntry Null() { return ResultEntry("Null", 0.0); }
static ResultEntry Default() { return ResultEntry("Default", 0.0); }
private:
std::string key_;
double time_;
};
typedef std::unordered_map<std::string, ResultEntry> KernelMap;
typedef std::unordered_map<std::string, KernelMap> ResultsMap;
struct TORCH_CUDA_CPP_API TuningResults {
// Validates if these results are compatible with the libraries
std::unordered_map<std::string, std::string> validators;
// Mapping from Callable signature to Callable's tuning result
ResultsMap results;
};
class TORCH_CUDA_CPP_API TuningResultsManager {
public:
TuningResultsManager() = default;
~TuningResultsManager() = default;
KernelMap Lookup(const std::string& op_signature);
ResultEntry Lookup(const std::string& op_signature, const std::string& params_signature);
inline void AddImpl(const std::string& op_signature,
const std::string& params_signature,
ResultEntry best,
KernelMap& kernel_map);
void Add(const std::string& op_signature,
const std::string& params_signature,
ResultEntry best);
void Delete(const std::string& op_signature, const std::string& params_signature);
inline void DisjointMergeImpl(
const std::string& op_signature,
const KernelMap& kernel_map,
/*out*/ ResultsMap& results);
void Load(const ResultsMap& results_to_load);
ResultsMap Dump();
void DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map);
size_t GetSize();
private:
std::mutex lock_;
ResultsMap results_;
};
class TORCH_CUDA_CPP_API TuningResultsValidator {
public:
using GetFunc = std::function<std::string()>;
using ValidateFunc = std::function<TuningStatus(const std::string&)>;
using GetValidateFuncs = std::unordered_map<std::string, std::pair<GetFunc, ValidateFunc>>;
TuningResultsValidator();
~TuningResultsValidator() = default;
std::unordered_map<std::string, std::string> GetAllValidators() const;
TuningStatus ValidateAll(const std::unordered_map<std::string, std::string>& to_validate) const;
void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf);
protected:
std::string GetPyTorchVersion() const;
TuningStatus ValidatePyTorchVersion(const std::string& value) const;
public:
static constexpr const std::array mandatory_keys{"PT_VERSION"};
private:
GetValidateFuncs validators_;
};
class TORCH_CUDA_CPP_API TuningContext {
public:
TuningContext();
~TuningContext();
TuningContext(TuningContext &) = delete;
TuningContext(TuningContext &&) = delete;
TuningContext &operator=(TuningContext &) = delete;
TuningContext &operator=(TuningContext &&) = delete;
void EnableTunableOp(bool value);
bool IsTunableOpEnabled() const;
void EnableTuning(bool value);
bool IsTuningEnabled() const;
void EnableNumericsCheck(bool value);
bool IsNumericsCheckEnabled() const;
void SetMaxTuningDurationMs(int max_duration_ms);
int GetMaxTuningDurationMs() const;
void SetMaxTuningIterations(int max_iter);
int GetMaxTuningIterations() const;
void SetMaxWarmupDurationMs(int max_duration_ms);
int GetMaxWarmupDurationMs() const;
void SetMaxWarmupIterations(int max_iter);
int GetMaxWarmupIterations() const;
void EnableICacheFlush(bool value);
bool IsICacheFlushEnabled() const;
void SetRotatingBufferSize(int size);
int GetRotatingBufferSize() const;
TuningResultsManager& GetTuningResultsManager();
TuningResultsValidator& GetTuningResultsValidator();
TuningResults GetTuningResults();
TuningStatus LoadTuningResults(const TuningResults& tr);
void SetFilename(const std::string& filename, bool insert_device_ordinal=false);
std::string GetFilename() const;
void WriteFileOnExit(bool value);
bool ReadFile(const std::string& filename={});
bool WriteFile(const std::string& filename={});
private:
bool enable_;
bool tuning_enable_;
bool manager_initialized_;
bool write_file_on_exit_;
bool numerics_check_enable_;
int max_tuning_duration_ms_;
int max_tuning_iterations_;
int max_warmup_duration_ms_;
int max_warmup_iterations_;
bool icache_flush_;
int rotating_buffer_size_;
mutable TuningResultsManager manager_;
mutable c10::once_flag manager_init_once_;
TuningResultsValidator validator_;
std::string filename_;
size_t results_count_from_input_file_;
};
TORCH_CUDA_CPP_API TuningContext* getTuningContext();
class ITimer {
public:
ITimer() = default;
virtual ~ITimer() = default;
virtual void Start() = 0;
virtual void End() = 0;
/// Computes the elapsed time in milliseconds between Start() and End()
virtual float Duration() = 0;
};
} // namespace at::cuda::tunable

View File

@ -0,0 +1,307 @@
// Original TunableOp is from onnxruntime.
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
// Adapting TunableOp into PyTorch
// Copyright (c) Advanced Micro Devices, Inc.
//
#pragma once
#include <ATen/cuda/tunable/GemmCommon.h>
#ifdef USE_ROCM
#include <ATen/cuda/tunable/GemmHipblaslt.h>
#include <ATen/cuda/tunable/GemmRocblas.h>
#endif
#include <ATen/cuda/tunable/StreamTimer.h>
#include <ATen/cuda/tunable/TunableOp.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/StringUtil.h>
namespace at::cuda::tunable {
template <typename T>
class DefaultGemmOp : public Callable<GemmParams<T>> {
public:
TuningStatus Call(const GemmParams<T>* params) override {
at::cuda::blas::gemm_internal<T>(
params->transa, params->transb,
params->m, params->n, params->k,
params->alpha,
params->a, params->lda,
params->b, params->ldb,
params->beta,
params->c, params->ldc);
return OK;
}
};
static bool _transposeBoolFromChar(char op) {
return op == 't' || op == 'T';
}
template <typename T>
class DefaultGemmAndBiasOp : public Callable<GemmAndBiasParams<T>> {
public:
TuningStatus Call(const GemmAndBiasParams<T>* params) override {
at::cuda::blas::gemm_and_bias<T>(
_transposeBoolFromChar(params->transa),
_transposeBoolFromChar(params->transb),
params->m, params->n, params->k,
params->alpha,
params->a, params->lda,
params->b, params->ldb,
params->bias,
params->c, params->ldc,
params->activation);
return OK;
}
};
template <typename T>
class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
public:
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
at::cuda::blas::bgemm_internal<T>(
params->transa, params->transb,
params->m, params->n, params->k,
params->alpha,
params->a, params->lda, params->stride_a,
params->b, params->ldb, params->stride_b,
params->beta,
params->c, params->ldc, params->stride_c,
params->batch);
return OK;
}
};
template <typename T>
class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
public:
TuningStatus Call(const ScaledGemmParams<T>* params) override {
at::cuda::blas::scaled_gemm(
params->transa,
params->transb,
params->m,
params->n,
params->k,
params->a,
params->a_scale_ptr,
params->lda,
params->a_dtype,
params->b,
params->b_scale_ptr,
params->ldb,
params->b_dtype,
params->bias_ptr,
params->bias_dtype,
params->c,
params->c_scale_ptr,
params->ldc,
params->c_dtype,
params->amax_ptr,
params->use_fast_accum);
return OK;
}
};
template <typename T>
inline bool IsZero(T v) {
return v == 0.0f;
}
template <>
inline bool IsZero(BFloat16 v) {
return v.x == 0;
}
template <>
inline bool IsZero(Half v) {
return float(v) == 0.0f;
}
template <>
inline bool IsZero(c10::complex<double> v) {
return v == 0.0;
}
template <>
inline bool IsZero(c10::complex<float> v) {
return v == 0.0f;
}
template <typename T>
inline std::string TypeName(T v) {
return "unknown";
}
template <>
inline std::string TypeName(float v) {
return "float";
}
template <>
inline std::string TypeName(double v) {
return "double";
}
template <>
inline std::string TypeName(BFloat16 v) {
return "BFloat16";
}
template <>
inline std::string TypeName(Half v) {
return "Half";
}
template <>
inline std::string TypeName(Float8_e4m3fn v) {
return "Float8_e4m3fn";
}
template <>
inline std::string TypeName(Float8_e5m2 v) {
return "Float8_e5m2";
}
template <>
inline std::string TypeName(Float8_e4m3fnuz v) {
return "Float8_e4m3fnuz";
}
template <>
inline std::string TypeName(Float8_e5m2fnuz v) {
return "Float8_e5m2fnuz";
}
template <>
inline std::string TypeName(c10::complex<double> v) {
return "c10::complex<double>";
}
template <>
inline std::string TypeName(c10::complex<float> v) {
return "c10::complex<float>";
}
template <typename T, BlasOp ALayout, BlasOp BLayout>
class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
public:
GemmTunableOp() {
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
#ifdef USE_ROCM
static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) {
for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps<T>()) {
this->RegisterOp(std::move(name), std::move(op));
}
}
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
// disallow tuning of hipblaslt with c10::complex
if constexpr (
!std::is_same_v<T, c10::complex<float>> &&
!std::is_same_v<T, c10::complex<double>>) {
for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps<T, ALayout, BLayout>()) {
this->RegisterOp(std::move(name), std::move(op));
}
}
}
#endif
}
std::string Signature() override {
return c10::str("GemmTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
}
};
template <typename T, BlasOp ALayout, BlasOp BLayout>
class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer> {
public:
GemmAndBiasTunableOp() {
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
#ifdef USE_ROCM
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
// disallow tuning of hipblaslt with c10::complex
if constexpr (
!std::is_same_v<T, c10::complex<float>> &&
!std::is_same_v<T, c10::complex<double>>) {
for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps<T, ALayout, BLayout>()) {
this->RegisterOp(std::move(name), std::move(op));
}
}
}
#endif
}
std::string Signature() override {
return c10::str("GemmAndBiasTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
}
};
template <typename T, BlasOp ALayout, BlasOp BLayout>
class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>, StreamTimer> {
public:
GemmStridedBatchedTunableOp() {
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
#ifdef USE_ROCM
static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) {
for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps<T>()) {
this->RegisterOp(std::move(name), std::move(op));
}
}
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
// disallow tuning of hipblaslt with c10::complex
if constexpr (
!std::is_same_v<T, c10::complex<float>> &&
!std::is_same_v<T, c10::complex<double>>) {
for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps<T, ALayout, BLayout>()) {
this->RegisterOp(std::move(name), std::move(op));
}
}
}
#endif
}
std::string Signature() override {
return c10::str("GemmStridedBatchedTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
}
};
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer> {
public:
ScaledGemmTunableOp() {
this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
#ifdef USE_ROCM
for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
this->RegisterOp(std::move(name), std::move(op));
}
#endif
}
std::string Signature() override {
return c10::str("ScaledGemmTunableOp",
"_", TypeName<AT>(AT{}),
"_", TypeName<BT>(BT{}),
"_", TypeName<CT>(CT{}),
"_", BlasOpToString(ALayout), BlasOpToString(BLayout));
}
};
} // namespace at::cuda::tunable

View File

@ -0,0 +1,286 @@
// Original TunableOp is from onnxruntime.
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
// Adapting TunableOp into PyTorch
// Copyright (c) Advanced Micro Devices, Inc.
//
#pragma once
#include <ATen/cuda/tunable/Tunable.h>
#include <ATen/cuda/Sleep.h>
#include <c10/cuda/CUDACachingAllocator.h>
#ifndef _WIN32
#include <cxxabi.h>
#endif
#include <string>
#include <type_traits>
#include <unordered_map>
#include <vector>
namespace at::cuda::tunable {
template <typename ParamsT>
class Callable {
public:
Callable() = default;
Callable(Callable&&) = default;
virtual ~Callable() = default;
virtual TuningStatus Call(const ParamsT*) {
return FAIL;
}
virtual TuningStatus IsSupported(const ParamsT* params) {
return Call(params);
}
};
template <typename ParamsT, typename TimerT>
class TunableOp {
public:
TunableOp() = default;
TunableOp(TunableOp&&) = default;
virtual ~TunableOp() = default;
TuningStatus operator()(const ParamsT* params) {
ResultEntry result = ResultEntry::Null();
TuningContext* ctx = getTuningContext();
if (ctx->IsTunableOpEnabled()) {
auto& mgr = ctx->GetTuningResultsManager();
auto op_sig = Signature();
auto params_sig = params->Signature();
result = mgr.Lookup(op_sig, params_sig);
// If there is not previous tuning result been found, we do the tuning iff tuning is enabled
if (result == ResultEntry::Null() && ctx->IsTuningEnabled()) {
result = FindFastest(params);
mgr.Add(op_sig, params_sig, result);
}
}
else {
result = ResultEntry::Default();
}
if (result == ResultEntry::Null()) {
TUNABLE_LOG2("no result, using default");
result = ResultEntry::Default();
}
auto iter = ops_.find(result);
TORCH_CHECK(iter != ops_.end());
return iter->second->Call(params);
}
virtual std::string Signature() {
// According to C++17 standard https://wg21.link/n4659 section 15.7.4
// > if the operand of typeid refers to the
// > object under construction or destruction, typeid yields the std::type_info object representing the constructor
// > or destructors class.
// So delay the op signature generation.
c10::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); });
return signature_;
}
protected:
void RegisterOp(const std::string& name, std::unique_ptr<Callable<ParamsT>> op) {
this->op_names_.emplace_back(name);
this->ops_.emplace(name, std::move(op));
}
private:
static void WarmUp(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
TuningContext* ctx = getTuningContext();
bool do_flush = ctx->IsICacheFlushEnabled();
for (size_t i = 0; i < num_iter; i++) {
if (do_flush) {
at::cuda::flush_icache();
}
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
}
}
static double Profile(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
TuningContext* ctx = getTuningContext();
bool do_flush = ctx->IsICacheFlushEnabled();
TimerT timer{};
timer.Start();
for (size_t i = 0; i < num_iter; i++) {
if (do_flush) {
at::cuda::flush_icache();
}
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
}
timer.End();
return timer.Duration() / num_iter;
}
protected:
virtual ResultEntry FindFastest(const ParamsT* params) {
TuningContext* ctx = getTuningContext();
auto op_sig = Signature();
auto params_sig = params->Signature();
TUNABLE_LOG2("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates");
auto min_duration_ms = std::numeric_limits<double>::infinity();
std::string id_name = "Default";
ParamsT* reference_params = nullptr;
// numeric check option is controlled by non-static env var, so check it once per tuned operator
bool do_numerics_check = ctx->IsNumericsCheckEnabled();
// calcaulte a reference answer for numerical check
if (do_numerics_check) {
reference_params = params->DeepCopy(false);
TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK);
}
// need copies of params to reuse
// make as many copies as will fill the requested rotating buffer size, if requested
// rotating_size guaranteed to be >= 0 even though GetRotatingBufferSize() returns int
size_t rotating_size = ctx->GetRotatingBufferSize();
bool use_buffer_rotation = (rotating_size > 0);
size_t param_size = params->GetSize(use_buffer_rotation);
size_t param_count = (rotating_size / param_size) + 1;
constexpr size_t MB = 1024*1024;
if (use_buffer_rotation) {
TUNABLE_LOG2("Rotating buffer ", rotating_size/MB, " MiB. ",
"Needed Size: ", param_size/MB, " MiB. ",
"Needed number of param copies: ", param_count);
}
TORCH_CHECK(param_count > 0);
std::vector<ParamsT*> reusable_params(param_count);
for (size_t i = 0; i < param_count; i++) {
reusable_params[i] = params->DeepCopy(use_buffer_rotation);
}
// for rotating buffer
size_t offset = 0;
for (size_t i = 0; i < op_names_.size(); i++) {
auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
if (do_numerics_check) {
ParamsT* numerical_params = params->DeepCopy(false);
auto status = candidate->Call(numerical_params);
if (status != OK) {
numerical_params->Delete();
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
status = reference_params->NumericalCheck(numerical_params);
numerical_params->Delete();
if (status != OK) {
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
}
else {
auto status = candidate->Call(reusable_params[0]);
if (status != OK) {
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
}
// collect a small profile
constexpr const int approx_num_iter = 3;
auto approx_duration = Profile(candidate, reusable_params, approx_num_iter, offset);
// bail if too slow
if (approx_duration > 2 * min_duration_ms) {
TUNABLE_LOG3("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
// for warmup does user set max duration, max iters, or both?
// warmup is allowed to be skipped by setting either iterations or duration to 0
double max_warmup_duration = ctx->GetMaxWarmupDurationMs();
int max_warmup_iter = ctx->GetMaxWarmupIterations();
int warmup_iter = 1; // default
if (max_warmup_duration >= 0) {
int duration_iters = max_warmup_duration / approx_duration;
if (max_warmup_iter >= 0) {
warmup_iter = std::min(max_warmup_iter, duration_iters);
}
else {
warmup_iter = duration_iters;
}
}
else if (max_warmup_iter >= 0) {
warmup_iter = max_warmup_iter;
}
// for tuning does user set max duration, max iters, or both?
double max_tuning_duration = ctx->GetMaxTuningDurationMs();
int max_tuning_iter = ctx->GetMaxTuningIterations();
int tuning_iter = 100; // default
if (max_tuning_duration > 0) {
int duration_iters = max_tuning_duration / approx_duration;
if (max_tuning_iter > 0) {
tuning_iter = std::min(max_tuning_iter, duration_iters);
}
else {
tuning_iter = duration_iters;
}
}
else if (max_tuning_iter > 0) {
tuning_iter = max_tuning_iter;
}
// tuning must run at least 1 iteration
tuning_iter = std::max(1, tuning_iter);
// do the full warmup followed by tuning
double warmup_ms = warmup_iter * approx_duration;
double tuning_ms = tuning_iter * approx_duration;
TUNABLE_LOG3("├──tuning using "
"warmup iters ", warmup_iter, " [", warmup_ms, " ms] "
"and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ",
"instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]);
TUNABLE_LOG3("├──offset at ", offset);
WarmUp(candidate, reusable_params, warmup_iter, offset);
auto duration_ms = Profile(candidate, reusable_params, tuning_iter, offset);
if (duration_ms < min_duration_ms) {
TUNABLE_LOG3("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]);
min_duration_ms = duration_ms;
id_name = op_names_[i];
}
}
for (size_t i = 0; i < reusable_params.size(); i++) {
reusable_params[i]->Delete();
}
if (reference_params) {
reference_params->Delete();
}
TUNABLE_LOG2("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name);
return ResultEntry(id_name, min_duration_ms);
}
private:
std::string CreateSignature() {
#ifndef _WIN32
const auto* name = typeid(*this).name();
char buf[256];
size_t buf_len = 256;
abi::__cxa_demangle(name, buf, &buf_len, nullptr);
buf[255] = '\0';
return buf;
#else
return typeid(*this).name();
#endif
}
mutable c10::once_flag signature_init_once_;
std::string signature_;
std::unordered_map<std::string, std::unique_ptr<Callable<ParamsT>>> ops_;
std::vector<std::string> op_names_;
};
struct OpParams {
OpParams() {}
virtual ~OpParams() = default;
virtual std::string Signature() const = 0;
};
} // namespace at::cuda::tunable