I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,29 @@
// Copyright © 2022 Apple Inc.
#pragma once
#include <ATen/core/TensorBase.h>
namespace at::detail {
C10_EXPORT TensorBase empty_mps(
IntArrayRef size,
std::optional<ScalarType> dtype_opt,
std::optional<Layout> layout_opt,
std::optional<Device> device_opt,
std::optional<bool> pin_memory_opt,
std::optional<c10::MemoryFormat> memory_format_opt);
C10_EXPORT TensorBase empty_mps(
IntArrayRef size, const TensorOptions &options);
C10_EXPORT TensorBase empty_strided_mps(
IntArrayRef size,
IntArrayRef stride,
ScalarType dtype,
std::optional<Device> device_opt);
C10_EXPORT TensorBase empty_strided_mps(
IntArrayRef size,
IntArrayRef stride,
const TensorOptions &options);
} // namespace at::detail

View File

@ -0,0 +1,535 @@
#pragma once
namespace at::mps {
static const char * indexing_metal_shaders = R"INDEX_METAL(
#include <metal_stdlib>
#include <metal_atomic>
using namespace metal;
struct IndexAB {
constant int64_t* indexArray;
};
template<typename T, typename OffsetsT>
kernel void index_select(
constant IndexAB * indexAB [[buffer(0)]],
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant OffsetsT * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
constant int64_t* indexArray = indexAB[i].indexArray;
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
offset += index * index_strides[i];
}
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x);
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y + offset);
*out = *in;
}
template<typename T, typename OffsetsT>
void index_put_impl(
constant IndexAB * indexAB,
constant int64_t * index_sizes,
constant int64_t * index_strides,
constant OffsetsT * offsets,
constant void * inputData,
device void * outputData,
constant uint32_t & num_indices,
uint thread_index) {
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
constant int64_t* indexArray = indexAB[i].indexArray;
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
offset += index * index_strides[i];
}
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
*out = *in;
}
template<typename T, typename OffsetsT>
kernel void index_put_serial(
constant IndexAB * indexAB [[buffer(0)]],
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant OffsetsT * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
constant uint * numIters [[buffer(7)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
for (uint iter_i = 0; iter_i < *numIters; iter_i++) {
index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, iter_i);
}
}
template<typename T, typename OffsetsT>
kernel void index_put(
constant IndexAB * indexAB [[buffer(0)]],
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant OffsetsT * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, thread_index);
}
#define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
template \
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
constant IndexAB * indexAB [[buffer(0)]], \
constant void * indexSizes [[buffer(1)]], \
constant void * indexStrides [[buffer(2)]], \
constant IDX_DTYPE * offsets [[buffer(3)]], \
constant void * inputData [[buffer(4)]], \
device void * outputData [[buffer(5)]], \
constant uint32_t & num_indices [[buffer(6)]], \
uint thread_index [[thread_position_in_grid]]);
#define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
REGISTER_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
REGISTER_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
REGISTER_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
REGISTER_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
REGISTER_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
REGISTER_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
REGISTER_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
REGISTER_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
REGISTER_INDEX_OP_ALL_DTYPES(select);
REGISTER_INDEX_OP_ALL_DTYPES(put);
#define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
template \
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
constant IndexAB * indexAB [[buffer(0)]], \
constant void * indexSizes [[buffer(1)]], \
constant void * indexStrides [[buffer(2)]], \
constant IDX_DTYPE * offsets [[buffer(3)]], \
constant void * inputData [[buffer(4)]], \
device void * outputData [[buffer(5)]], \
constant uint32_t & num_indices [[buffer(6)]], \
constant uint * numIters [[buffer(7)]], \
uint thread_index [[thread_position_in_grid]]);
#define REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(put_serial);
template<typename StridesT, typename DataT>
kernel void kernel_index_offsets(constant StridesT * strides [[buffer(0)]],
device DataT * data_offsets [[buffer(1)]],
constant uint * iter_shape [[buffer(2)]],
constant uint & num_dimensions [[buffer(3)]],
uint thread_index [[thread_position_in_grid]]) {
data_offsets[thread_index] = 0;
uint32_t idx = thread_index;
for (uint32_t dim = 0; dim < num_dimensions; dim++) {
uint32_t remainder = idx % iter_shape[dim];
idx /= iter_shape[dim];
data_offsets[thread_index] += remainder * DataT(strides[dim]);
}
}
template
[[host_name("kernel_index_offsets_32")]]
kernel void kernel_index_offsets<packed_uint3, uint3>(
constant packed_uint3 * strides [[buffer(0)]],
device uint3 * data_offsets [[buffer(1)]],
constant uint * iter_shape [[buffer(2)]],
constant uint & num_dimensions [[buffer(3)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("kernel_index_offsets_64")]]
kernel void kernel_index_offsets<packed_uint3, ulong3>(
constant packed_uint3 * strides [[buffer(0)]],
device ulong3 * data_offsets [[buffer(1)]],
constant uint * iter_shape [[buffer(2)]],
constant uint & num_dimensions [[buffer(3)]],
uint thread_index [[thread_position_in_grid]]);
template<typename T, typename E, typename OffsetsT>
kernel void index_put_accumulate_native_dtypes(
constant IndexAB * indexAB [[buffer(0)]],
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant OffsetsT * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
constant int64_t* indexArray = indexAB[i].indexArray;
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
offset += index * index_strides[i];
}
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
constant E * in = (constant E*)((constant char*)inputData + offsets[thread_index].y);
atomic_fetch_add_explicit(out, *in, memory_order_relaxed);
}
template<typename T>
__attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * addr, T value) {
device atomic_uint* uintAddr = (device atomic_uint*)addr;
uint expected = atomic_load_explicit(uintAddr, memory_order_relaxed);
T updated = as_type<T>(expected) + value;
while (!atomic_compare_exchange_weak_explicit(uintAddr, &expected, as_type<uint>(updated), memory_order_relaxed, memory_order_relaxed)) {
updated = as_type<T>(expected) + value;
}
}
template<typename T, typename OffsetsT>
kernel void atomic_index_put_accumulate(
constant IndexAB * indexAB [[buffer(0)]],
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant OffsetsT * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
constant int64_t * index_strides = (constant int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
constant int64_t* indexArray = indexAB[i].indexArray;
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
offset += index * index_strides[i];
}
device void * out = (device void*)((device char*)outputData + offsets[thread_index].x + offset);
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
atomic_fetch_add_relaxed<T>(out, *in);
}
template
[[host_name("index_put_accumulate_32bit_float_idx32")]]
kernel void atomic_index_put_accumulate<float, uint3>(
constant IndexAB * indexAB [[buffer(0)]],
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_put_accumulate_32bit_float_idx64")]]
kernel void atomic_index_put_accumulate<float, ulong3>(
constant IndexAB * indexAB [[buffer(0)]],
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant ulong3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_put_accumulate_32bit_int_idx32")]]
kernel void index_put_accumulate_native_dtypes<atomic_int, int, uint3>(
constant IndexAB * indexAB [[buffer(0)]],
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_put_accumulate_32bit_int_idx64")]]
kernel void index_put_accumulate_native_dtypes<atomic_int, int, ulong3>(
constant IndexAB * indexAB [[buffer(0)]],
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant ulong3 * offsets [[buffer(3)]],
constant void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
constant uint32_t & num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]);
)INDEX_METAL";
static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
struct __attribute__ ((packed)) packed_uint5{{
uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
}};
template<typename Y, typename X>
Y cast(const X x);
template<>
{1} cast<{1}, {0}>(const {0} x) {{
return {2};
}}
kernel void scatter_kernel_5(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint5 & size [[buffer(2)]],
constant packed_uint5 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint5 local_index;
local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
local_index.z = linear_index / (size.u * size.w) % size.z;
local_index.w = linear_index / size.u % size.w;
local_index.u = linear_index % size.u;
packed_uint5 strided_index;
strided_index.x = local_index.x * stride.x;
strided_index.y = local_index.y * stride.y;
strided_index.z = local_index.z * stride.z;
strided_index.w = local_index.w * stride.w;
strided_index.u = local_index.u * stride.u;
dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u] = cast<{1}>(src[linear_index]);
}}
kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint4 & size [[buffer(2)]],
constant packed_uint4 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint4 local_index;
local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
local_index.y = linear_index / (size[3] * size[2]) % size[1];
local_index.z = linear_index / size[3] % size[2];
local_index.w = linear_index % size[3];
const packed_uint4 strided_index = local_index * stride;
dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]);
}}
kernel void scatter_kernel_3(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint3 & size [[buffer(2)]],
constant packed_uint3 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint3 local_index;
local_index.x = linear_index / (size[2] * size[1]) % size[0];
local_index.y = linear_index / size[2] % size[1];
local_index.z = linear_index % size[2];
const packed_uint3 strided_index = local_index * stride;
dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]);
}}
kernel void scatter_kernel_2(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint2 & size [[buffer(2)]],
constant packed_uint2 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint2 local_index;
local_index.x = linear_index / size[1] % size[0];
local_index.y = linear_index % size[1];
const packed_uint2 strided_index = local_index * stride;
dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]);
}}
kernel void scatter_kernel_1(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant int & size [[buffer(2)]],
constant int & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
const int local_index = linear_index % size;
const int strided_index = local_index * stride;
dst[strided_index] = cast<{1}>(src[linear_index]);
}}
)METAL_SCATTER";
static const char *GATHER_OPS_TEMPLATE = R"METAL_GATHER(
struct __attribute__ ((packed)) packed_uint5{{
uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
}};
template<typename Y, typename X>
Y cast(const X x);
template<>
{1} cast<{1}, {0}>(const {0} x) {{
return {2};
}}
kernel void gather_kernel_5(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint5 & size [[buffer(2)]],
constant packed_uint5 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint5 local_index;
local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
local_index.z = linear_index / (size.u * size.w) % size.z;
local_index.w = linear_index / size.u % size.w;
local_index.u = linear_index % size.u;
packed_uint5 strided_index;
strided_index.x = local_index.x * stride.x;
strided_index.y = local_index.y * stride.y;
strided_index.z = local_index.z * stride.z;
strided_index.w = local_index.w * stride.w;
strided_index.u = local_index.u * stride.u;
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u]);
}}
kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint4 & size [[buffer(2)]],
constant packed_uint4 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint4 local_index;
local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
local_index.y = linear_index / (size[3] * size[2]) % size[1];
local_index.z = linear_index / size[3] % size[2];
local_index.w = linear_index % size[3];
const packed_uint4 strided_index = local_index * stride;
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]);
}}
kernel void gather_kernel_3(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint3 & size [[buffer(2)]],
constant packed_uint3 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint3 local_index;
local_index.x = linear_index / (size[2] * size[1]) % size[0];
local_index.y = linear_index / size[2] % size[1];
local_index.z = linear_index % size[2];
const packed_uint3 strided_index = local_index * stride;
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]);
}}
kernel void gather_kernel_2(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint2 & size [[buffer(2)]],
constant packed_uint2 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint2 local_index;
local_index.x = linear_index / size[1] % size[0];
local_index.y = linear_index % size[1];
const packed_uint2 strided_index = local_index * stride;
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]);
}}
kernel void gather_kernel_1(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant int & size [[buffer(2)]],
constant int & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
const int local_index = linear_index % size;
const int strided_index = local_index * stride;
dst[linear_index] = cast<{1}>(src[strided_index]);
}}
)METAL_GATHER";
} // namespace at::mps

View File

@ -0,0 +1,403 @@
// Copyright © 2022 Apple Inc.
#pragma once
#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/mps/MPSEvent.h>
#include <ATen/mps/MPSStream.h>
#include <cstdio>
#include <mutex>
#include <set>
#include <unordered_set>
#include <mach/vm_page_size.h>
#include <c10/util/flat_hash_map.h>
// this implementation is based on CUDACachingAllocator.
// It utilizes Metal Heaps to improve the performance with buffer allocation.
// Do not include this header. Use MPSAllocatorInterface.h instead.
// TODO: Unify the logic with CUDACachingAllocator and remove redundant code.
namespace at::mps::HeapAllocator {
static const size_t kMaxSmallAlloc = MB(1); // largest "small" allocation is 1 MiB
static const size_t kMinLargeAlloc = MB(10); // allocations between 1 and 10 MiB may use kLargeHeap
static const size_t kRoundLarge = MB(2); // round up large allocations to 2 MiB
static const size_t kSmallHeap = MB(8); // "small" allocations are packed in 8 MiB heaps
static const size_t kLargeHeap = MB(32); // "large" allocations may be packed in 32 MiB heaps
static const size_t kXLargeHeapD = MB(128); // "extra large" allocations on Discrete devices may be packed in 128 MiB heaps
static const size_t kXLargeHeapU = MB(1024); // "extra large" allocations on Unified devices may be packed in 1 GiB heaps
static const size_t kMaxScalarAlloc = (sizeof(int64_t)); // largest "scalar" allocation
// buffer pools could be customized with a combination of usage flags
enum UsageFlags : uint32_t {
PRIVATE = 0,
SMALL = (1 << 0), // small heaps have sizes of kSmallHeap, and large ones kLargeHeap
SHARED = (1 << 1), // shared pools allocated on devices with unified memory; otherwise, private between host/device
MANAGED = (1 << 2), // managed storage mode
HAZARD = (1 << 3), // enables Automatic Hazard Tracking for the resources allocated on the pool
SCALAR = (1 << 4), // used to import CPU scalar values to GPU and use them in MPS Stream
};
// debug verbosity flags
enum DebugVerbosity : uint32_t {
SILENT = 0,
PROFILING = (1 << 0), // print generic profiling data for total system memory usage
ALLOCATIONS = (1 << 1), // print buffer allocations
RECYCLES = (1 << 2), // print buffer recycling
RELEASES = (1 << 3), // print buffer releases
LARGE_ONLY = (1 << 4), // only log large buffer pool transactions
};
struct HeapBlock;
struct BufferBlock {
id<MTLBuffer> buffer;
void* cpu_ptr = nullptr; // stores the pointer to CPU mapping of a Shared MTLBuffer
size_t size; // size after alignment
size_t requested_size; // requested size (before alignment)
// buffer shape is used for retrieving base of views in cached graphs
std::vector<int64_t> shape;
bool in_use = false;
HeapBlock* heap;
id_t buf_id;
// counter to candidate least recently used buffers for garbage collection
uint32_t gc_count = 0;
uint32_t use_count = 0;
// counter to assign unique ids to buffer blocks
static uint64_t buffer_counter;
// Metal events used to sync GPU/CPU operations on the shared-storage buffers
MPSEventPtr event;
BufferBlock(size_t Size, size_t RequestedSize = 0, const id<MTLBuffer> Buffer = nullptr,
HeapBlock* Heap = nullptr) :
buffer(Buffer), size(Size), requested_size(RequestedSize),
heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) { }
static bool Comparator(const BufferBlock* a, const BufferBlock* b) {
return (a->size != b->size) ? a->size < b->size : (uintptr_t)a->buffer < (uintptr_t)b->buffer;
}
static size_t alignUp(size_t Size, size_t Alignment) {
assert(((Alignment - 1) & Alignment) == 0);
return ((Size + Alignment - 1) & ~(Alignment - 1));
}
uint32_t retainCount() const { return [buffer retainCount]; }
};
typedef bool (*BufferComparison)(const BufferBlock*, const BufferBlock*);
struct BufferPool;
struct AllocParams {
AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool) :
search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) { }
size_t size() const { return search_key.size; }
BufferBlock search_key;
BufferPool* pool;
BufferBlock* buffer_block = nullptr;
size_t requested_size;
// true if we exceed the low watermark limit. In this case
// we apply strategies to relieve the pressure before allocation.
bool has_memory_pressure = false;
// true if we're allocating on a unified memory device
bool has_unified_memory = true;
};
struct HeapBlock {
id<MTLHeap> heap;
struct { size_t total, available; } size;
BufferPool* pool;
unsigned int n_buffers = 0;
id_t heap_id;
// indicates if we split this heap to sub-allocate 'several' buffers (otherwise single buffer)
bool is_split;
// counter to assign unique ids to heap blocks
static uint64_t heap_counter;
HeapBlock(size_t Size, const id<MTLHeap> Heap = nullptr, BufferPool *Pool = nullptr) :
heap(Heap), size({.total = Size, .available = Size}), pool(Pool),
heap_id(Heap ? ++heap_counter : 0), is_split(true) { }
static MTLResourceOptions getOptions(uint32_t usage) {
// TODO: check the caching performance of write-combined mode
MTLResourceOptions options = MTLResourceCPUCacheModeDefaultCache;
if (usage & UsageFlags::MANAGED)
options |= MTLResourceStorageModeManaged;
else if (usage & UsageFlags::SHARED)
options |= MTLResourceStorageModeShared;
else
options |= MTLResourceStorageModePrivate;
options |= (usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked;
return options;
}
static HeapBlock* createHeapBlock(AllocParams& params, id<MTLDevice> device, uint32_t usage) {
HeapBlock *heapBlock = nullptr;
bool is_split = true;
const size_t size = params.size();
MTLHeapDescriptor *d = [MTLHeapDescriptor new];
if (d) {
const size_t kXLargeHeap = params.has_unified_memory ? kXLargeHeapU : kXLargeHeapD;
if (size <= kMaxSmallAlloc) {
d.size = kSmallHeap;
} else if (size < kMinLargeAlloc) {
d.size = kLargeHeap;
} else if (size < kXLargeHeap / 2 && !params.has_memory_pressure) {
d.size = kXLargeHeap;
} else {
d.size = kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
is_split = false;
}
d.storageMode = (usage & UsageFlags::SHARED) ? MTLStorageModeShared : MTLStorageModePrivate;
d.cpuCacheMode = MTLCPUCacheModeDefaultCache;
// this automatically handles Metal buffer access synchronizations at the
// cost of slightly lower performance.
d.hazardTrackingMode = (usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked;
d.resourceOptions = getOptions(usage);
d.type = MTLHeapTypeAutomatic;
id<MTLHeap> heap = [device newHeapWithDescriptor: d];
if (heap) {
[heap setPurgeableState:MTLPurgeableStateNonVolatile];
const size_t heap_size = heapAvailableSize(heap);
heapBlock = new HeapBlock(heap_size, heap, params.pool);
if (heapBlock) {
heapBlock->is_split = is_split;
}
}
[d release];
}
return heapBlock;
}
static bool Comparator(const HeapBlock* a, const HeapBlock* b) {
return (a->size.available != b->size.available) ? a->size.available < b->size.available :
(uintptr_t)a->heap < (uintptr_t)b->heap;
}
static NSUInteger heapAvailableSize(id<MTLHeap> heap, size_t Alignment = vm_page_size) {
return [heap maxAvailableSizeWithAlignment:Alignment];
}
NSUInteger Size() {
return [heap size];
}
id<MTLBuffer> newMTLBuffer(size_t length, uint32_t usage) {
id<MTLBuffer> buf = [heap newBufferWithLength:length options:getOptions(usage)];
if (buf) {
updateAvailableSize();
n_buffers++;
}
return buf;
}
// returns the retainCount before releasing the buffer
uint32_t releaseMTLBuffer(id<MTLBuffer>& buffer) {
const uint32_t retainCount = [buffer retainCount];
[buffer release];
buffer = nil;
updateAvailableSize();
n_buffers--;
return retainCount;
}
// returns the retainCount before releasing the heap
uint32_t releaseMTLHeap() {
const uint32_t retainCount = [heap retainCount];
TORCH_INTERNAL_ASSERT(!n_buffers); // assert if heap isn't empty
[heap setPurgeableState:MTLPurgeableStateEmpty];
[heap release];
heap = nil;
size.available = 0;
return retainCount;
}
uint32_t retainCount() const { return [heap retainCount]; }
void updateAvailableSize() { size.available = heapAvailableSize(heap); }
};
typedef bool (*HeapComparison)(const HeapBlock*, const HeapBlock*);
struct BufferPool {
enum class Kind {
PRIVATE_SMALL,
PRIVATE_LARGE,
SHARED_SMALL,
SHARED_LARGE,
SCALAR,
};
BufferPool(const id<MTLDevice> Device, uint32_t Usage) :
device(Device), usage(Usage),
heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) { }
const id<MTLDevice> device;
// usage flags to customize the pool for various purposes (see UsageFlags enum)
const uint32_t usage;
// total number of buffers in the pool
uint32_t n_buffers = 0;
// total allocations size on this pool
size_t allocated_size = 0;
// total memory available in the pool
size_t available_size = 0;
// list of heaps ordered by their "available" (not total) memory size
std::set<HeapBlock*, HeapComparison> heaps;
// list of only "available" buffers in the pool (i.e., buffers not in-use)
std::set<BufferBlock*, BufferComparison> available_buffers;
// list of buffers that are in a state of "limbo" where they've already been freed
// from PyTorch-side, but were not returned to pool due to still being
// in-use by command buffers with retainCount > 1. In this state, the buffer is
// neither ready to be recycled, nor could be returned to pool as available.
// These buffers will be returned to pool once the command buffer's
// completionHandler callbacks are called.
std::unordered_set<BufferBlock*> buffers_pending_free;
// list of heaps pending size update
std::unordered_set<HeapBlock*> heaps_pending_update;
};
class MPSHeapAllocatorImpl {
public:
explicit MPSHeapAllocatorImpl() :
m_device(at::mps::MPSDevice::getInstance()->device()),
m_max_buffer_size([m_device maxBufferLength]),
m_stream(getDefaultMPSStream()),
m_event_pool(getMPSEventPool()) {
init_allocator();
}
~MPSHeapAllocatorImpl() {
emptyCache();
}
// interface exposed to at::Allocator
id<MTLBuffer> malloc(size_t size, uint32_t usage);
// frees a buffer and returns it into buffer pool
void free(void* ptr);
// releases all the cached buffers and their associated heaps
void emptyCache();
// free inactive buffers that are pending to be freed
void freeInactiveBuffers();
// returns true if buffer was allocated from the shared pool
bool isSharedBuffer(const void* ptr);
// get the requested unaligned size of an MTLBuffer
ssize_t getUnalignedBufferSize(const void* ptr);
// set the shape of a base tensor from a view tensor
void setBufferShape(const void* ptr, const IntArrayRef& shape);
// retrieve the shape of a base tensor from a view tensor
IntArrayRef getBufferShape(const void* ptr);
// get the unique ID of the buffer
id_t getBufferId(const void* ptr);
// allocate a buffer from a specialized pool to import CPU scalars into GPU
id<MTLBuffer> allocScalarBufferWithValue(void* value, size_t size);
// returns a CPU-mapping of the input buffer and its retainCount,
// if only it has Shared storage-mode and allocated on MPSAllocator
std::pair<const void*, uint32_t> getSharedBufferPtr(const void* buffer);
// records events for a list of MTLBuffers (list is used to lock the mutex once)
// returns true if records any event (given if passed buffers exist and are shared-storage)
bool recordEvents(c10::ArrayRef<const void*> buffers);
// waits for the event to signal the completion of GPU execution
// on the passed shared buffers (list is used to lock the mutex once)
// returns true if actually waited on any event
bool waitForEvents(c10::ArrayRef<const void*> buffers);
// this indicates how far (in Megabytes) the current total allocations are from the
// low watermark limit which is used to detect if we're under memory pressure
// This returns zero if we've reached the low watermark limit
ssize_t getLowWatermarkValue();
// (see m_low_watermark_ratio for description)
void setLowWatermarkRatio(double ratio);
// (see m_high_watermark_ratio for description)
void setHighWatermarkRatio(double ratio);
// (see m_low_watermark_limit for description)
size_t getLowWatermarkLimit() const { return m_low_watermark_limit; }
// (see m_max_total_allowed_size for description)
size_t getHighWatermarkLimit() const { return m_max_total_allowed_size; }
// (see m_total_allocated_memory for description)
size_t getTotalAllocatedMemory() const { return m_total_allocated_memory; }
// (see m_current_allocated_memory for description)
size_t getCurrentAllocatedMemory() const { return m_current_allocated_memory; }
// total GPU memory allocated in the process by Metal driver; including
// implicit allocations from MPS/MPSGraph frameworks and MPSHeapAllocatorImpl.
size_t getDriverAllocatedMemory() const { return current_allocated_size(); }
// recommended Max memory for Metal
size_t getRecommendedMaxMemory() const { return max_device_size(); }
// (see enum DebugVerbosity for description)
uint32_t getDebugVerbosity() const { return m_debug_verbosity; }
// returns the device that we allocate from
inline id<MTLDevice> Device() const { return m_device; }
// TODO: make a common function to do size unit conversions in PyTorch.
inline std::string format_size(uint64_t size) const;
private:
// (see m_high_watermark_ratio for description)
constexpr static double default_high_watermark_ratio = 1.7;
// we set the allowed upper bound to twice the size of recommendedMaxWorkingSetSize.
constexpr static double default_high_watermark_upper_bound = 2.0;
// (see m_low_watermark_ratio for description)
// on unified memory, we could allocate beyond the recommendedMaxWorkingSetSize
constexpr static double default_low_watermark_ratio_unified = 1.4;
constexpr static double default_low_watermark_ratio_discrete = 1.0;
const id<MTLDevice> m_device;
std::recursive_mutex m_mutex;
// allocated buffers by device pointer
ska::flat_hash_map<const void*, BufferBlock*> m_allocated_buffers;
// using a container for pools to simplify iterating them
ska::flat_hash_map<BufferPool::Kind, std::unique_ptr<BufferPool>> m_pools;
// total memory allocated by HeapAllocator (including blocks in pools)
size_t m_total_allocated_memory = 0;
// currently active memory allocations in use (i.e., blocks not in pools)
size_t m_current_allocated_memory = 0;
// max buffer size allowed by Metal
size_t m_max_buffer_size = 0;
// maximum total size allowed to be allocated
size_t m_max_total_allowed_size = 0;
// high watermark ratio is a hard limit for the total allowed allocations
// 0. : disables high watermark limit (may cause system failure if system-wide OOM occurs)
// 1. : recommended maximum allocation size (i.e., device.recommendedMaxWorkingSetSize)
// >1.: allows limits beyond the device.recommendedMaxWorkingSetSize
// e.g., value 0.95 means we allocate up to 95% of recommended maximum
// allocation size; beyond that, the allocations would fail with OOM error.
double m_high_watermark_ratio;
// low watermark ratio is a soft limit to attempt limiting memory allocations up to the lower watermark
// level by garbage collection or committing command buffers more frequently (a.k.a, adaptive commit).
// Value between 0 to m_high_watermark_ratio (setting 0.0 disables adaptive commit and garbage collection)
// e.g., value 0.9 means we 'attempt' to limit allocations up to 90% of recommended maximum
// allocation size.
double m_low_watermark_ratio;
// low watermark size limit (in Bytes) at the time we initialize the allocator
size_t m_low_watermark_limit;
// use "PYTORCH_DEBUG_MPS_ALLOCATOR" env-var to set debug verbosity
uint32_t m_debug_verbosity;
// default MPS stream
MPSStream* m_stream;
// we hold a reference to MPSEventPool so it could get destroyed after MPSAllocator
std::shared_ptr<MPSEventPool> m_event_pool;
void init_allocator();
void init_buffer_pools();
HeapBlock* get_free_heap(AllocParams& params);
bool get_free_buffer(AllocParams& params);
BufferBlock* get_allocated_buffer_block(const void* ptr);
BufferBlock* alloc_buffer_block(size_t size, uint32_t usage);
bool alloc_buffer(AllocParams& params);
void free_buffer(BufferBlock* buffer_block);
// returns true if the container heap is also released
bool release_buffer(BufferBlock* buffer_block, bool remove_empty_heap = true);
void release_buffers(BufferPool& pool);
bool release_available_cached_buffers(AllocParams& params);
bool release_cached_buffers();
// free unused cached blocks to reclaim GPU memory if memory pressure is high
void garbage_collect_cached_buffers(AllocParams& params);
// returns the suitable buffer pool type for the usage or
// requested/allocated sizes
BufferPool& get_pool(size_t requested_size, size_t aligned_size, uint32_t usage);
// returns the aligned allocation size that is optimized
// for the buffers to get reused frequently
size_t get_allocation_size(size_t size, uint32_t usage) const;
// maximum size of device memory available for allocation in current process
// Note: the recommendedMaxWorkingSetSize is typically 75% of the total system memory.
size_t max_device_size() const { return [m_device recommendedMaxWorkingSetSize]; }
// there are implicit allocations from MPS backend, so we need to query the 'device' for
// total allocated size instead of manually tracking in MPSAllocator
size_t current_allocated_size() const { return [m_device currentAllocatedSize]; }
bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) const {
for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) {
MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block ? buffer_block->buffer : nullptr, event);
}
return true;
}
};
} // namespace at::mps::HeapAllocator

View File

@ -0,0 +1,64 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <c10/core/Allocator.h>
#include <c10/util/Registry.h>
#include <ATen/core/ATen_fwd.h>
#define MB(x) (x * 1048576UL)
namespace at::mps {
// this is a public interface to access MPSAllocator.
// Do not declare methods that would depend on MPS or Metal frameworks.
class IMPSAllocator : public c10::Allocator {
public:
// see the comments in MPSAllocator.h for the description of these methods.
virtual void emptyCache() const = 0;
virtual void freeInactiveBuffers() const = 0;
virtual ssize_t getUnalignedBufferSize(const void* ptr) const = 0;
virtual IntArrayRef getBufferShape(const void* ptr) const = 0;
virtual id_t getBufferId(const void* ptr) const = 0;
virtual void setBufferShape(const void* ptr, const IntArrayRef& shape) const = 0;
virtual bool isSharedBuffer(const void* ptr) const = 0;
virtual bool isSharedStorageSupported() const = 0;
virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size) const = 0;
virtual std::string formatSize(size_t size) const = 0;
virtual void setLowWatermarkRatio(double ratio) const = 0;
virtual void setHighWatermarkRatio(double ratio) const = 0;
virtual ssize_t getLowWatermarkValue() const = 0;
virtual size_t getLowWatermarkLimit() const = 0;
virtual size_t getHighWatermarkLimit() const = 0;
virtual size_t getTotalAllocatedMemory() const = 0;
virtual size_t getCurrentAllocatedMemory() const = 0;
virtual size_t getDriverAllocatedMemory() const = 0;
virtual size_t getRecommendedMaxMemory() const = 0;
virtual std::pair<const void*, uint32_t> getSharedBufferPtr(const void* ptr) const = 0;
virtual bool recordEvents(c10::ArrayRef<const void*> buffers) const = 0;
virtual bool waitForEvents(c10::ArrayRef<const void*> buffers) const = 0;
};
class IMpsAllocatorCallback {
public:
enum class EventType {
ALLOCATED, // buffer got allocated to be used immediately
RECYCLED, // buffer pulled from free list to be reused
FREED, // buffer put to free list for future recycling
RELEASED, // buffer memory released
ALLOCATION_FAILED // buffer allocation failed
};
virtual ~IMpsAllocatorCallback() = default;
virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
};
// MPS allocator will execute every registered callback when a block of memory is freed.
C10_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__);
IMPSAllocator* getIMPSAllocator(bool sharedAllocator = false);
bool isMPSPinnedPtr(const void* data);
} // namespace at::mps

View File

@ -0,0 +1,84 @@
// Copyright © 2022 Apple Inc.
#pragma once
#include <c10/core/Allocator.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#ifdef __OBJC__
#include <Foundation/Foundation.h>
#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
typedef id<MTLDevice> MTLDevice_t;
typedef id<MTLLibrary> MTLLibrary_t;
typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
typedef id<MTLLibrary> MTLLibrary_t;
#else
typedef void* MTLDevice;
typedef void* MTLDevice_t;
typedef void* MTLLibrary_t;
typedef void* MTLComputePipelineState_t;
typedef void* MTLLibrary_t;
#endif
namespace at::mps {
// Helper enum to check if a MPSGraph op is supported in a given macOS version
enum class MacOSVersion : uint32_t {
MACOS_VER_13_1_PLUS = 0,
MACOS_VER_13_2_PLUS,
MACOS_VER_13_3_PLUS,
MACOS_VER_14_0_PLUS,
MACOS_VER_14_4_PLUS,
MACOS_VER_15_0_PLUS,
};
//-----------------------------------------------------------------
// MPSDevice
//
// MPSDevice is a singleton class that returns the default device
//-----------------------------------------------------------------
class TORCH_API MPSDevice {
public:
/**
* MPSDevice should not be cloneable.
*/
MPSDevice(MPSDevice& other) = delete;
/**
* MPSDevice should not be assignable.
*/
void operator=(const MPSDevice&) = delete;
/**
* Gets single instance of the Device.
*/
static MPSDevice* getInstance();
/**
* Returns the single device.
*/
MTLDevice_t device() {
return _mtl_device;
}
/**
* Returns whether running on Ventura or newer
*/
bool isMacOS13Plus(MacOSVersion version) const;
MTLComputePipelineState_t metalIndexingPSO(const std::string &kernel);
MTLLibrary_t getMetalIndexingLibrary();
~MPSDevice();
private:
static MPSDevice* _device;
MTLDevice_t _mtl_device;
MTLLibrary_t _mtl_indexing_library;
MPSDevice();
};
TORCH_API bool is_available();
TORCH_API bool is_macos_13_or_newer(MacOSVersion version);
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
} // namespace at::mps

View File

@ -0,0 +1,100 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <ATen/mps/MPSStream.h>
#include <ctime>
#include <stack>
namespace at::mps {
// NOTE: don't create instances of this class directly.
// Use MPSEventPool to acquire instances of MPSEvent.
class MPSEvent {
public:
explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
~MPSEvent();
// records an event on the stream
void record(bool needsLock, bool syncEvent = false);
// makes all future work submitted to the stream wait for this event.
bool wait(bool needsLock, bool syncEvent = false);
// schedules a notifyListener callback for the event.
bool notify(bool needsLock, MTLSharedEventNotificationBlock block);
// checks if events are already signaled.
bool query() const;
// blocks the CPU thread until all the GPU work that were scheduled
// prior to recording this event are completed.
bool synchronize();
// resets this event with new parameters in case it gets reused from the event pool
void reset(MPSStream* stream, bool enable_timing);
// returns the unique ID of the event instance
id_t getID() const { return m_id; }
// returns the completion timestamp of the event
uint64_t getCompletionTime() const { return m_completion_time; }
// if already recorded, waits for cpu_sync_cv to be signaled
void waitForCpuSync();
private:
id_t m_id;
// enables measuring the completion time of the notifyListener of this event
bool m_enable_timing;
uint64_t m_signalCounter = 0;
MPSStream* m_stream = nullptr;
MTLSharedEvent_t m_event = nullptr;
MTLSharedEventListener* m_listener = nullptr;
// used to sync the events created on this Stream with CPU
std::mutex m_cpu_sync_mutex{};
std::condition_variable m_cpu_sync_cv{};
// CondVar predicate to sync the events created on this Stream with CPU
bool m_cpu_sync_completed = false;
// used to compute elapsed time
uint64_t m_completion_time = 0;
void recordLocked(bool syncEvent);
bool waitLocked(bool syncEvent);
bool notifyLocked(MTLSharedEventNotificationBlock block);
void notifyCpuSync();
static uint64_t getTime() {
return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
}
};
typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr;
class MPSEventPool {
public:
explicit MPSEventPool(MPSStream* default_stream);
~MPSEventPool();
MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream);
void emptyCache();
// these are mainly used for MPSHooks and torch.mps.Event() bindings
id_t acquireEvent(bool enable_timing);
void releaseEvent(id_t event_id);
void recordEvent(id_t event_id, bool syncEvent);
void waitForEvent(id_t event_id, bool syncEvent);
void synchronizeEvent(id_t event_id);
bool queryEvent(id_t event_id);
// returns elapsed time between two recorded events in milliseconds
double elapsedTime(id_t start_event_id, id_t end_event_id);
private:
MPSStream* m_default_stream = nullptr;
std::recursive_mutex m_mutex;
std::stack<std::unique_ptr<MPSEvent>> m_pool{};
// dictionary to associate event IDs with event objects
// used to retain in-use events out of the pool
// for torch.mps.Event() bindings.
std::unordered_map<id_t, MPSEventPtr> m_in_use_events{};
uint64_t m_event_counter = 0;
std::function<void(MPSEvent*)> m_default_deleter;
MPSEvent* getInUseEvent(id_t event_id, bool locked = true);
};
// shared_ptr is used to get MPSEventPool destroyed after dependent instances
std::shared_ptr<MPSEventPool> getMPSEventPool();
} // namespace at::mps

View File

@ -0,0 +1,52 @@
// Copyright © 2022 Apple Inc.
#pragma once
#include <ATen/core/Generator.h>
#include <ATen/core/PhiloxRNGEngine.h>
#include <c10/core/GeneratorImpl.h>
#include <optional>
namespace at {
namespace mps::detail {
constexpr uint32_t PHILOX_STATE_N = 7;
struct rng_data_pod {
std::array<uint32_t, PHILOX_STATE_N> state{1};
uint64_t seed = default_rng_seed_val;
};
TORCH_API const Generator& getDefaultMPSGenerator();
TORCH_API Generator createMPSGenerator(uint64_t seed_val = default_rng_seed_val);
} // namespace mps::detail
struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl {
// Constructors
MPSGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
~MPSGeneratorImpl() override = default;
// MPSGeneratorImpl methods
std::shared_ptr<MPSGeneratorImpl> clone() const;
void set_current_seed(uint64_t seed) override;
void set_offset(uint64_t offset) override;
uint64_t get_offset() const override;
uint64_t current_seed() const override;
uint64_t seed() override;
void set_state(const c10::TensorImpl& new_state) override;
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
void update_philox_counters();
void set_engine(at::Philox4_32 engine) { engine_ = engine; };
at::Philox4_32 engine() { return engine_; };
uint32_t* state_data() { return data_.state.data(); }
static DeviceType device_type() { return DeviceType::MPS; };
private:
mps::detail::rng_data_pod data_;
at::Philox4_32 engine_;
MPSGeneratorImpl* clone_impl() const override;
};
} // namespace at

View File

@ -0,0 +1,179 @@
// Copyright © 2022 Apple Inc.
#pragma once
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <ATen/Context.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/mps/MPSEvent.h>
#ifdef __OBJC__
#include <Foundation/Foundation.h>
#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#endif
#include <ATen/Tensor.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/Storage.h>
#include <c10/core/TensorImpl.h>
#include <sys/_types/_size_t.h>
#include <memory>
#include <c10/core/UndefinedTensorImpl.h>
#include <c10/util/intrusive_ptr.h>
namespace at::mps {
typedef MPSEvent* mpsEvent_t;
// TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
// https://github.com/pytorch/pytorch/issues/77170
struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr c10::DeviceType static_type = c10::DeviceType::MPS;
// constructor
MPSGuardImpl() {}
explicit MPSGuardImpl(c10::DeviceType t) {
TORCH_INTERNAL_ASSERT(t == c10::DeviceType::MPS);
}
// returns the type
c10::DeviceType type() const override {
return c10::DeviceType::MPS;
}
Device exchangeDevice(Device d) const override {
return Device(c10::DeviceType::MPS, 0);
}
Device getDevice() const override {
return Device(c10::DeviceType::MPS, 0);
}
std::optional<Device> uncheckedGetDevice() const noexcept {
return Device(c10::DeviceType::MPS, 0);
}
void setDevice(Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_mps());
}
void uncheckedSetDevice(Device d) const noexcept override {
// TODO: Currently setting only device 0
}
Stream getStream(Device d) const noexcept override {
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
}
Stream getNewStream(Device, int priority = 0) const override {
(void)priority;
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
}
Stream getDefaultStream(Device d) const override {
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
}
// NB: These do NOT set the current device
Stream exchangeStream(Stream s) const noexcept override {
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
}
DeviceIndex deviceCount() const noexcept override {
if (at::hasMPS()) {
//TODO: extend it for multi-device case
return 1;
} else {
return 0;
}
}
// Event-related functions
void createEvent(
mpsEvent_t* event,
const EventFlag flag) const;
void destroyEvent(
void* event,
const DeviceIndex device_index) const noexcept override;
void record(
void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override;
void block(
void* event,
const Stream& stream) const override;
bool queryEvent(void* event) const override;
};
/// A variant of OptionalDeviceGuard that is specialized for MPS.
struct OptionalMPSGuard {
explicit OptionalMPSGuard() : guard_() {}
explicit OptionalMPSGuard(std::optional<Device> device_opt)
: guard_(device_opt) {}
/// Set the current MPS device to the passed device index, if it is not
/// nullopt
explicit OptionalMPSGuard(std::optional<DeviceIndex> device_index_opt)
: guard_(device_index_opt) {}
// Copy is not allowed
OptionalMPSGuard(const OptionalMPSGuard&) = delete;
OptionalMPSGuard& operator=(const OptionalMPSGuard&) = delete;
OptionalMPSGuard(OptionalMPSGuard&& other) = delete;
OptionalMPSGuard& operator=(OptionalMPSGuard&& other) = delete;
/// Sets the MPS device to the given device, initializing the guard if it
/// is not already initialized. Errors if the given device is not a MPS
/// device.
void set_device(Device device) {
guard_.set_device(device);
}
/// Sets the MPS device to the given device, initializing the guard if it is
/// not already initialized. Errors if the given device is not a MPS device.
void reset_device(Device device) {
guard_.reset_device(device);
}
/// Sets the MPS device to the given device index, initializing the guard if
/// it is not already initialized.
void set_index(DeviceIndex device_index) {
guard_.set_index(device_index);
}
/// Returns the device that was set immediately prior to initialization of the
/// guard, or nullopt if the guard is uninitialized.
std::optional<Device> original_device() const {
return guard_.original_device();
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device, if the guard is initialized,
/// or nullopt if the guard is uninitialized.
std::optional<Device> current_device() const {
return guard_.current_device();
}
/// Restore the original MPS device, resetting this guard to uninitialized
/// state.
void reset() {
guard_.reset();
}
private:
c10::impl::InlineOptionalDeviceGuard<MPSGuardImpl> guard_;
};
C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl);
} // namespace at::mps

View File

@ -0,0 +1,60 @@
// Copyright © 2022 Apple Inc.
#pragma once
#include <ATen/detail/MPSHooksInterface.h>
#include <ATen/Generator.h>
#include <ATen/mps/MPSEvent.h>
#include <optional>
namespace at::mps {
// The real implementation of MPSHooksInterface
struct MPSHooks : public at::MPSHooksInterface {
MPSHooks(at::MPSHooksArgs) {}
void initMPS() const override;
// MPSDevice interface
bool hasMPS() const override;
bool isOnMacOSorNewer(unsigned major, unsigned minor) const override;
// MPSGeneratorImpl interface
const Generator& getDefaultMPSGenerator() const override;
// MPSStream interface
void deviceSynchronize() const override;
void commitStream() const override;
void* getCommandBuffer() const override;
void* getDispatchQueue() const override;
// MPSAllocator interface
Allocator* getMPSDeviceAllocator() const override;
void emptyCache() const override;
size_t getCurrentAllocatedMemory() const override;
size_t getDriverAllocatedMemory() const override;
size_t getRecommendedMaxMemory() const override;
void setMemoryFraction(double ratio) const override;
bool isPinnedPtr(const void* data) const override;
Allocator* getPinnedMemoryAllocator() const override;
// MPSProfiler interface
void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const override;
void profilerStopTrace() const override;
// MPSEvent interface
uint32_t acquireEvent(bool enable_timing) const override;
void releaseEvent(uint32_t event_id) const override;
void recordEvent(uint32_t event_id) const override;
void waitForEvent(uint32_t event_id) const override;
void synchronizeEvent(uint32_t event_id) const override;
bool queryEvent(uint32_t event_id) const override;
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const override;
// Compatibility with Accelerator API
bool hasPrimaryContext(DeviceIndex device_index) const override {
// When MPS is available, it is always in use for the one device.
return true;
}
};
} // namespace at::mps

View File

@ -0,0 +1,402 @@
// Copyright © 2022 Apple Inc.
#pragma once
#include <ATen/Tensor.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/mps/MPSAllocatorInterface.h>
#include <os/signpost.h>
#include <os/log.h>
#include <atomic>
#include <ctime>
#include <sstream>
#include <string>
#include <unordered_map>
#include <utility>
namespace at::mps {
namespace Profiler {
struct BaseInfo {
// profiling info types
enum class Type {
GRAPH,
KERNEL,
COPY,
CPU_FALLBACK,
};
BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle) :
type(infoType), profileId(Id), handle(Handle) { }
virtual ~BaseInfo() = default;
// type of profiling info
Type type;
// unique profile ID for execution instances of operations or copies
uint64_t profileId;
// ID generated by os_signpost
// since it's possible to use event and interval-based signposts at the
// same time, we need separate IDs for each.
os_signpost_id_t eventSignpostId = 0, intervalSignpostId = 0;
// accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime - GPUStartTime")
std::atomic<double> totalGpuTime{0.0};
// accumulated Scheduling time in ms (obtained from CompletionHandler's "KernelEndTime - KernelStartTime")
std::atomic<double> totalSchedulingTime{0.0};
// indicates if the operation or copy execution has completed
std::atomic_bool completed{false};
// handle used to identify the profile info's instance (usually the pointer)
const uintptr_t handle;
virtual const std::string toString(double gpuTime = 0, double schedulingTime = 0) const;
// builds a string for a tensor (format: Device:ScalarType[tensor.sizes()])
static std::string buildTensorString(const Tensor& tensor, bool includeBufferId = false) {
if (tensor.defined()) {
std::stringstream tensorStr;
auto deviceType = tensor.device().type();
tensorStr << c10::DeviceTypeName(deviceType);
// see comments for INCLUDE_BUFFER_ID
if (includeBufferId && deviceType == at::kMPS) {
id<MTLBuffer> buffer = __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer))
<< ":" << buffer.retainCount << ")";
}
tensorStr << ":"
<< tensor.scalar_type() << tensor.sizes();
return tensorStr.str();
} else {
return "undefined";
}
}
static uint64_t getTime() {
return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
}
};
struct OperationInfo : BaseInfo {
OperationInfo(const void* Handle, bool IsGraph, uint64_t Id, const std::string& StrKey) :
BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)), strKey(StrKey) { }
uint64_t runCount = 0;
std::string strKey;
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
// builds a string for a kernel
static std::string buildKernelString(const std::string& kernelName,
const TensorList& tensors,
bool includeBufferId = false) {
std::stringstream kernelStr;
kernelStr << kernelName;
for (const Tensor& tensor: tensors) {
kernelStr << ":" << BaseInfo::buildTensorString(tensor, includeBufferId);
}
return kernelStr.str();
}
};
struct CpuFbInfo : BaseInfo {
CpuFbInfo(uint64_t Id, const std::string& OpName) :
BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) { }
uint64_t runCount = 0;
// the current and total overhead of copies in bytes required to convert the Op's
// input tensors from MPS to CPU and then output from CPU back to MPS
size_t currentCopyOverhead = 0;
size_t totalCopyOverhead = 0;
std::string opName;
std::string strKey;
uint64_t startTime = 0;
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
void updateCopyOverhead(const TensorList& tensors) {
currentCopyOverhead = 0;
for (const Tensor& tensor: tensors) {
if (tensor.defined()) {
currentCopyOverhead += tensor.nbytes();
}
}
totalCopyOverhead += currentCopyOverhead;
}
};
struct CopyInfo : BaseInfo {
enum class Kind {
MPS_TO_MPS,
MPS_TO_CPU,
CPU_TO_MPS,
};
CopyInfo(const void* Handle, size_t Length, uint64_t Id, bool IsNonBlocking, bool UsesBlitter) :
BaseInfo(Type::COPY, Id, uintptr_t(Handle)), kind(Kind::MPS_TO_MPS),
length(Length), isNonBlocking(IsNonBlocking), usesBlitter(UsesBlitter) { }
Kind kind;
size_t length;
bool isNonBlocking;
bool usesBlitter;
std::string srcStrKey;
std::string dstStrKey;
// for copies that don't use blitters, we measure CPU time
uint64_t startTime = 0;
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
static std::string buildTensorString(const void* buffer, const OptionalTensorRef tensor, bool includeBufferId = false);
static bool isStorageOnMPS(const void* buffer, const OptionalTensorRef tensor) {
if (tensor.has_value()) {
return tensor->device().type() == at::kMPS;
}
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(buffer);
// getUnalignedBufferSize() returns -1 if input buffer is not on MPS device
return getIMPSAllocator()->getUnalignedBufferSize(buffer) >= 0;
}
static Kind getCopyKind(const void* srcBuffer, const void* dstBuffer,
const OptionalTensorRef srcTensor, const OptionalTensorRef dstTensor) {
const bool isSrcOnMPS = isStorageOnMPS(srcBuffer, srcTensor);
const bool isDstOnMPS = isStorageOnMPS(dstBuffer, dstTensor);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isSrcOnMPS || isDstOnMPS);
if (isSrcOnMPS && !isDstOnMPS) {
return Kind::MPS_TO_CPU;
} else if (!isSrcOnMPS && isDstOnMPS) {
return Kind::CPU_TO_MPS;
}
return Kind::MPS_TO_MPS;
}
};
struct CopyStat : CopyInfo {
explicit CopyStat(std::string CopyKindStr) :
CopyInfo(nullptr, 0, 0, false, false), kindStr(std::move(CopyKindStr)) {}
// total number of copies
size_t totalCount = 0;
// number of Scalar copies (i.e., less than sizeof(int64))
size_t scalarsCount = 0;
// number of blocking copies (i.e., require syncing to GPU)
size_t blockingCount = 0;
// number of copies that used memcpy(), instead of Metal Blit Encoder
size_t memcpyCount = 0;
// accumulated GPU time in ms for the scalar copies
std::atomic<double> scalarsGpuTime{0.0};
// copy kind in string type
std::string kindStr;
};
class MPSProfiler {
public:
// lower 16 bits used for profiler options
enum ProfileOptions : uint32_t {
OPTIONS_NONE = 0,
// ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK, etc.)
// (used for convenience to not compute bit flags by OR-ing manually)
// trace all signpost types using events
ALL_SIGNPOST_EVENTS = (1 << 0),
// trace all signpost types using intervals
ALL_SIGNPOST_INTERVALS = (1 << 1),
// always wait for command buffer to finish executing after each commit
WAIT_UNTIL_COMPLETED = (1 << 2),
// for interval-based signposts, include the scheduling portion of
// Graph/Kernel/Copy executions as well.
// if flag is disable, only "GPU run time" is included in interval,
// and not schedule time.
INCLUDE_SCHEDULE_INTERVAL = (1 << 3),
// use these if you need to trace signposts types individually (rarely required)
// trace signpost using intervals
USE_INTERVALS = (1 << 4),
// trace signpost by emitting events
USE_EVENTS = (1 << 5),
// used for sanity check (Change this when new option added)
OPTIONS_COUNT = (USE_EVENTS << 1) - 1,
};
// when adding new types, #define the type string in MPSProfiler.mm as well.
// upper 16 bits used for event types
enum SignpostTypes : uint32_t {
SIGNPOST_NONE = 0,
// trace signposts for PyTorch operation executions
RUN_OPERATION = (1 << 16),
// trace signposts for blitter copies
BLIT_COPY = (1 << 17),
// trace signposts for ops that fall back on CPU
CPU_FALLBACK = (1 << 18),
// used for sanity check (Change this when new type added)
SIGNPOST_COUNT = (CPU_FALLBACK << 1) - 1,
};
enum LogOptions : uint32_t {
LOG_NONE = 0,
// Info logging options during execution
// -------------------------------------
// prints operation info (id/key/run_count) during execution
OPERATION_INFO = (1 << 0),
// prints copy info (src/dst tensors/buffers, size, etc.) during execution
COPY_INFO = (1 << 1),
// prints CPU Fallback info (id/runCount/opName/copyOverhead) during execution
CPU_FALLBACK_INFO = (1 << 2),
// Profiling Statistics logging options when process terminates
// ------------------------------------------------------------
// prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before process terminates
// this is convenient to not combine following stats bit flags manually
ALL_STATS = (1 << 3),
// prints operation stats (GPU times, run count, etc.) before process terminates
OPERATION_STATS = (1 << 4),
// prints copies stats (GPU times, copy kinds, sizes, etc.) before process terminates
COPY_STATS = (1 << 5),
// prints CPU Fallback stats (CPU times, run times, size of MPS<->CPU copies
// for tensors, etc.) before process terminates
CPU_FALLBACK_STATS = (1 << 6),
// Metadata format options when logging the info
// ---------------------------------------------
// if enabled, includes GPU run time in metadata (i.e., GPUEndTime-GPUStartTime
// from Metal Command Buffers) (e.g., [GPU=0.324 ms])
INCLUDE_GPU_TIME = (1 << 7),
// if enabled, includes GPU scheduling time in metadata separately
// (i.e., KernelEndTime-KernelStartTime from Metal Command Buffers)
// e.g., [GPU=0.324 ms, KRNL=0.036 ms]
INCLUDE_KERNEL_TIME = (1 << 8),
// if enabled, includes the unique buffer ID in metadata for the storage
// of a tensor that was allocated on MPSAllocator. This is useful (along with
// the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are involved
// with various operations.
INCLUDE_BUFFER_ID = (1 << 9),
// used for sanity check (Change this when new option added)
LOG_COUNT = (INCLUDE_BUFFER_ID << 1) - 1,
};
explicit MPSProfiler();
~MPSProfiler();
// the handle is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
// the beginProfile*() functions return a profileId which is unique per graph/kernel/copy
uint64_t beginProfileKernel(const void* handle, const std::string& strKey, bool isGraph);
uint64_t beginProfileKernel(const void* handle, const std::string& kernelName, const TensorList& tensors);
uint64_t beginProfileCopy(const void* srcBuffer, const void* dstBuffer,
const OptionalTensorRef srcTensor,
const OptionalTensorRef dstTensor,
size_t length, bool isNonBlocking, bool usesBlitter = true);
uint64_t beginProfileCPUFallback(const std::string& opName, const TensorList& tensors);
void beginProfileGPUInterval(const void* handle);
void endProfileCopy(uint64_t profileId, SyncType syncType);
void endProfileKernel(const void* handle, SyncType syncType = SyncType::NONE);
void endProfileCPUFallback(const std::string& opName);
// these are used to hook into Python bindings for torch.mps.profiler module.
// this enables generating OS Signpost traces from MPSProfiler on-demand
// during runtime (instead of environment variables).
// The "mode" could be either "interval", "event", or both "interval,event"
// for interval-based and/or event-based signpost tracing.
void StartTrace(const std::string& mode, bool waitUntilCompleted);
void StopTrace();
// Abstractions for GPU trace capturing
bool isCaptureEnabled() const;
bool isCapturing() const;
void startCapture(const std::string& name, MPSStream* stream = nullptr);
void stopCapture(MPSStream* stream = nullptr);
// convenience functions to indicate whether signpost tracing or
// logging are enabled for the SignpostTypes
bool isOperationProfilingEnabled() const {
return (m_signpost_types & SignpostTypes::RUN_OPERATION) ||
(m_log_options & (LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS));
}
bool isCopyProfilingEnabled() const {
return (m_signpost_types & SignpostTypes::BLIT_COPY) ||
(m_log_options & (LogOptions::COPY_INFO | LogOptions::COPY_STATS));
}
bool isCPUFallbackProfilingEnabled() const {
return (m_signpost_types & SignpostTypes::CPU_FALLBACK) ||
(m_log_options & (LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS));
}
bool isSignpostTracingEnabled() const {
return (m_signpost_types != SignpostTypes::SIGNPOST_NONE);
}
private:
// indicates what type of signpost types are enabled and traced by MPS profiler.
uint32_t m_signpost_types = 0;
uint32_t m_profile_options = 0;
uint32_t m_log_options = 0;
uint64_t m_kernel_counter = 0;
uint64_t m_graph_counter = 0;
uint64_t m_cpu_fb_counter = 0;
uint64_t m_copy_counter = 0;
// technically, it's possible to trace both events and intervals at the same time
// so we use separate os_log categories for them
os_log_t m_os_log_events;
os_log_t m_os_log_intervals;
// stats logging could run either from destructor or signal handler
// so this is used to check if logging has already started.
std::atomic_bool hasLoggedStats{false};
// indicates there are pending completionHandler callbacks that haven't been called yet.
std::atomic_bool hasPendingCompletionHandlers{false};
// used to capture sigint signal to log profiling stats
static struct sigaction currentSigint, previousSigint;
// We use the following lists for two reasons:
// 1- for interval-based signposts the "begin" point won't be in same function
// as the "end" point where we need to be able to retrieve signpost's info
// 2- if Operations info need to be logged when process ends using LogOptions::OPERATION_INFO.
// the pointer key for this map is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
// this list is retained and could be logged along with aggregate profiling numbers when the process ends.
std::unordered_map<uintptr_t, std::unique_ptr<OperationInfo>> m_op_info_list{};
// the string key for this map is the op name that we fall back to execute on CPU
// this list is retained and could be logged along with aggregate profiling numbers when the process ends.
std::unordered_map<std::string, std::unique_ptr<CpuFbInfo>> m_cpu_fb_info_list{};
// this list contains the info for copies, and its key is the unique profileId
// which is generated from m_copy_counter
// The copyInfo list is not retained.
std::unordered_map<uint64_t, std::unique_ptr<CopyInfo>> m_copy_info_list{};
// a short list that contains copy stats
std::unordered_map<CopyInfo::Kind, std::unique_ptr<CopyStat>> m_copy_stat_list{};
mutable MTLCaptureManager *captureManager = nil;
unsigned captureCount = 0;
void initialize();
void beginProfileExecution(BaseInfo& info, bool cpuExecution = false);
void endProfileExecution(BaseInfo& info, os_signpost_id_t event_signpost_id,
os_signpost_id_t interval_signpost_id,
double gpuTime, double schedulingTime);
void addProfilerScheduledHandler(BaseInfo& info);
void addProfilerCompletedHandler(BaseInfo& info, SyncType syncType);
void emitSignpostEvent(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
const std::string& msg) const;
void beginSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
const std::string& msg) const;
void endSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id) const;
void updateCopyStats(const CopyInfo& copyInfo, double gpuTime, double schedulingTime);
// returns true if logging the profiling info "during the execution" is enabled
bool isProfileInfoLoggingEnabled(BaseInfo::Type infoType, bool isExecutionEnded);
// logs all the profiling stats that are enabled
void logProfilingStats();
// logs kernel profiling stats when the process ends.
void logOperationsProfilingStats(std::FILE* f) const;
// logs CPU Fallback profiling stats when the process ends.
void logCPUFallbackProfilingStats(std::FILE* f) const;
// logs copy profiling stats when the process ends.
void logCopyProfilingStats(std::FILE* f) const;
os_signpost_id_t generateSignpostId(os_signpost_type_t signpostType, const void* ptr = nullptr);
static SignpostTypes getSignpostType(BaseInfo::Type infoType);
static void handleIntSignal(int signal);
};
} // namespace Profiler
Profiler::MPSProfiler& getMPSProfiler();
} // namespace at::mps

View File

@ -0,0 +1,133 @@
// Copyright © 2022 Apple Inc.
#pragma once
#include <cstdint>
#include <utility>
#include <c10/core/DeviceGuard.h>
#include <c10/util/Exception.h>
#include <c10/core/Stream.h>
#include <ATen/mps/MPSDevice.h>
#ifdef __OBJC__
#include <Foundation/Foundation.h>
#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
typedef id<MTLCommandQueue> MTLCommandQueue_t;
typedef id<MTLCommandBuffer> MTLCommandBuffer_t;
typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
typedef id<MTLSharedEvent> MTLSharedEvent_t;
typedef id<MTLDevice> MTLDevice_t;
#else
typedef void* MTLCommandQueue_t;
typedef void* MTLCommandQueue;
typedef void* MTLCommandBuffer_t;
typedef void* MTLCommandBuffer;
typedef void* MTLComputeCommandEncoder_t;
typedef void* MTLSharedEvent_t;
typedef void* dispatch_queue_t;
typedef void* MTLDevice_t;
#define nil NULL;
#endif
namespace at::mps {
//-----------------------------------------------------------------
// MPSStream
//-----------------------------------------------------------------
enum class SyncType {
NONE, // no commit to command buffer
COMMIT, // commit and flush the command buffer
COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish
COMMIT_AND_CONTINUE,// commit and continue with a new underlying command buffer
COMMIT_ADAPTIVE, // commit adaptively based on available memory
};
class TORCH_API MPSStream
{
public:
enum Unchecked { UNCHECKED };
/// Construct a MPSStream from a Stream. This construction is checked,
/// and will raise an error if the Stream is not, in fact, a MPS stream.
explicit MPSStream(Stream stream);
~MPSStream();
MTLCommandQueue_t commandQueue() const { return _commandQueue; };
dispatch_queue_t queue() const { return _serialQueue; }
MPSCommandBuffer* commandBuffer();
MTLComputeCommandEncoder_t commandEncoder();
void endKernelCoalescing();
void synchronize(SyncType syncType);
void fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
void copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
size_t length, size_t srcOffset, size_t dstOffset,
uint64_t profileId, SyncType syncType = SyncType::NONE);
void copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
size_t length, size_t srcOffset, size_t dstOffset,
bool non_blocking, uint64_t profileId);
void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType = SyncType::NONE);
void addCompletedHandler(MTLCommandBufferHandler block);
/// Get the MPS device index that this stream is associated with.
c10::DeviceIndex device_index() const { return _stream.device_index(); }
MTLCommandQueue_t stream() const { return _commandQueue; };
MTLDevice_t device() const { return [_commandQueue device];}
/// Explicit conversion to Stream.
Stream unwrap() const { return _stream; }
private:
Stream _stream;
MTLCommandQueue_t _commandQueue = nil;
MPSCommandBuffer* _commandBuffer = nil;
MPSCommandBuffer* _prevCommandBuffer = nil;
MTLComputeCommandEncoder_t _commandEncoder = nil;
MPSGraphExecutionDescriptor *_executionDescriptor = nil;
MPSGraphCompilationDescriptor *_compilationDescriptor = nil;
dispatch_queue_t _serialQueue = nullptr;
// CommitAndContinue is enabled by default
bool _enableCommitAndContinue = true;
// use synchronize() to access any of these commit functions outside MPSStream
void commit();
void commitAndWait();
void commitAndContinue();
void flush();
};
/**
* Get the current MPS stream
*/
TORCH_API MPSStream* getCurrentMPSStream();
/**
* Get the default MPS stream
*/
TORCH_API MPSStream* getDefaultMPSStream();
//-----------------------------------------------------------------
// MPSStreamImpl
//-----------------------------------------------------------------
class TORCH_API MPSStreamImpl
{
public:
/**
* Gets single instance of the MPSStream.
*/
static MPSStream* getInstance();
private:
static MPSStream* _stream;
MPSStreamImpl();
};
} // namespace at::mps