I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,37 @@
#pragma once
#if !defined(_MSC_VER) && __cplusplus < 201703L
#error C++17 or later compatible compiler is required to use ATen.
#endif
#include <ATen/Context.h>
#include <ATen/Device.h>
#include <ATen/DeviceGuard.h>
#include <ATen/DimVector.h>
#include <ATen/Dispatch.h>
#include <ATen/Formatting.h>
#include <ATen/Functions.h>
#include <ATen/NamedTensor.h>
#include <ATen/ScalarOps.h>
#include <ATen/Tensor.h>
#include <ATen/TensorGeometry.h>
#include <ATen/TensorIndexing.h>
#include <ATen/TensorOperators.h>
#include <ATen/Version.h>
#include <ATen/core/ATenGeneral.h>
#include <ATen/core/Generator.h>
#include <ATen/core/Reduction.h>
#include <ATen/core/Scalar.h>
#include <ATen/core/UnsafeFromTH.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <c10/core/Allocator.h>
#include <c10/core/InferenceMode.h>
#include <c10/core/Layout.h>
#include <c10/core/Storage.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Exception.h>
// TODO: try to remove this
// There is some back story, see https://github.com/pytorch/pytorch/issues/48684
#include <ATen/NativeFunctions.h>

View File

@ -0,0 +1,173 @@
#pragma once
#include <ATen/Config.h>
#include <c10/core/DeviceType.h>
#include <c10/core/ScalarType.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Half.h>
// Defines the accumulation type for a scalar type.
// Example:
// using accscalar_t = acc_type<scalar_t, /*is_cuda*/true>;
//
// Accumulation types are an important concept in numeric computing
// because you frequently want to perform intermediate computations
// at a higher precision than the input and output precision, to avoid
// compounding internal rounding errors. Accumulation is the most
// well-known intermediate computation (it is of great importance for
// sum reduction and matrix multiply, for example), but in PyTorch
// acc_type ends up getting used for all sorts of other intermediate
// computations, so it perhaps would be more accurately (ahem) called an
// "accurate" type. acc_type is especially important for reduced
// precision operations like float16 and bfloat16, where relatively
// benign looking inputs can easily end up overflowing/underflowing.
//
// acc_type is parametrized by whether or not you are running on CUDA
// or not, because on CUDA double precision operations are expensive
// and so by default, we don't actually want to use double as an
// acc_type on CUDA. A lot of things are typed out below, but
// basically, the table is generated by a few rules:
//
// If bool:
// Use 'bool' as acc_type.
// If floating point:
// If CUDA, use 'float' as acc_type (unless scalar_t is double),
// otherwise (CPU) use 'double'
// If integral:
// Use 'int64_t' as acc_type
//
// You're not forced to use this template; if you happen to know
// something specific about your use case, you can specify your own
// desired behavior. This template, however, will give you a reasonable
// default that will work for all dtypes supported in PyTorch.
#if defined(__CUDACC__)
#include <cuda.h>
#include <cuda_fp16.h>
#elif defined(__HIPCC__)
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#endif
namespace at {
template <typename T, c10::DeviceType D>
struct AccumulateTypeDevice {};
template <typename T, bool>
struct AccumulateType {};
template <typename T>
struct AccumulateType<T, false> {
using type = typename AccumulateTypeDevice<T, c10::DeviceType::CPU>::type;
};
template <typename T>
struct AccumulateType<T, true> {
using type = typename AccumulateTypeDevice<T, c10::DeviceType::CUDA>::type;
};
template <typename T, c10::DeviceType device>
using acc_type_device = typename AccumulateTypeDevice<T, device>::type;
template <typename T, bool is_cuda>
using acc_type = typename AccumulateType<T, is_cuda>::type;
#define ACC_TYPE(t, acc_t, device_type) \
template <> \
struct AccumulateTypeDevice<t, device_type> { \
using type = acc_t; \
};
#define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS)
#define XPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::XPU)
#define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA)
#define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU)
MPS_ACC_TYPE(BFloat16, float);
MPS_ACC_TYPE(Half, float);
MPS_ACC_TYPE(Float8_e5m2, float);
MPS_ACC_TYPE(Float8_e4m3fn, float);
MPS_ACC_TYPE(Float8_e5m2fnuz, float);
MPS_ACC_TYPE(Float8_e4m3fnuz, float);
MPS_ACC_TYPE(float, float);
MPS_ACC_TYPE(double, float);
MPS_ACC_TYPE(int8_t, int64_t);
MPS_ACC_TYPE(uint8_t, int64_t);
MPS_ACC_TYPE(char, int64_t);
MPS_ACC_TYPE(int16_t, int64_t);
MPS_ACC_TYPE(int32_t, int64_t);
MPS_ACC_TYPE(int64_t, int64_t);
MPS_ACC_TYPE(bool, bool);
MPS_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
MPS_ACC_TYPE(c10::complex<float>, c10::complex<float>);
MPS_ACC_TYPE(c10::complex<double>, c10::complex<float>);
XPU_ACC_TYPE(BFloat16, float);
XPU_ACC_TYPE(Half, float);
XPU_ACC_TYPE(Float8_e5m2, float);
XPU_ACC_TYPE(Float8_e4m3fn, float);
XPU_ACC_TYPE(Float8_e5m2fnuz, float);
XPU_ACC_TYPE(Float8_e4m3fnuz, float);
XPU_ACC_TYPE(float, float);
XPU_ACC_TYPE(double, double);
XPU_ACC_TYPE(int8_t, int64_t);
XPU_ACC_TYPE(uint8_t, int64_t);
XPU_ACC_TYPE(char, int64_t);
XPU_ACC_TYPE(int16_t, int64_t);
XPU_ACC_TYPE(int32_t, int64_t);
XPU_ACC_TYPE(int64_t, int64_t);
XPU_ACC_TYPE(bool, bool);
XPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
XPU_ACC_TYPE(c10::complex<float>, c10::complex<float>);
XPU_ACC_TYPE(c10::complex<double>, c10::complex<double>);
#if defined(__CUDACC__) || defined(__HIPCC__)
CUDA_ACC_TYPE(half, float);
#endif
CUDA_ACC_TYPE(BFloat16, float);
CUDA_ACC_TYPE(Half, float);
CUDA_ACC_TYPE(Float8_e5m2, float);
CUDA_ACC_TYPE(Float8_e4m3fn, float);
CUDA_ACC_TYPE(Float8_e5m2fnuz, float);
CUDA_ACC_TYPE(Float8_e4m3fnuz, float);
CUDA_ACC_TYPE(float, float);
CUDA_ACC_TYPE(double, double);
CUDA_ACC_TYPE(int8_t, int64_t);
CUDA_ACC_TYPE(uint8_t, int64_t);
CUDA_ACC_TYPE(char, int64_t);
CUDA_ACC_TYPE(int16_t, int64_t);
CUDA_ACC_TYPE(int32_t, int64_t);
CUDA_ACC_TYPE(int64_t, int64_t);
CUDA_ACC_TYPE(bool, bool);
CUDA_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
CUDA_ACC_TYPE(c10::complex<float>, c10::complex<float>);
CUDA_ACC_TYPE(c10::complex<double>, c10::complex<double>);
CPU_ACC_TYPE(BFloat16, float);
CPU_ACC_TYPE(Half, float);
CPU_ACC_TYPE(Float8_e5m2, float);
CPU_ACC_TYPE(Float8_e4m3fn, float);
CPU_ACC_TYPE(Float8_e5m2fnuz, float);
CPU_ACC_TYPE(Float8_e4m3fnuz, float);
CPU_ACC_TYPE(float, double);
CPU_ACC_TYPE(double, double);
CPU_ACC_TYPE(int8_t, int64_t);
CPU_ACC_TYPE(uint8_t, int64_t);
CPU_ACC_TYPE(char, int64_t);
CPU_ACC_TYPE(int16_t, int64_t);
CPU_ACC_TYPE(int32_t, int64_t);
CPU_ACC_TYPE(int64_t, int64_t);
CPU_ACC_TYPE(bool, bool);
CPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
CPU_ACC_TYPE(c10::complex<float>, c10::complex<double>);
CPU_ACC_TYPE(c10::complex<double>, c10::complex<double>);
TORCH_API c10::ScalarType toAccumulateType(
c10::ScalarType type,
c10::DeviceType device);
TORCH_API c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda);
} // namespace at

View File

@ -0,0 +1,2 @@
#pragma once
#include <c10/util/ArrayRef.h>

View File

@ -0,0 +1,2 @@
#pragma once
#include <c10/core/Backend.h>

View File

@ -0,0 +1,2 @@
#pragma once
#include <ATen/core/Backtrace.h>

View File

@ -0,0 +1,27 @@
#pragma once
#include <c10/util/Exception.h>
#include <ostream>
#include <string>
namespace at {
enum class BlasBackend : int8_t { Cublas, Cublaslt };
inline std::string BlasBackendToString(at::BlasBackend backend) {
switch (backend) {
case BlasBackend::Cublas:
return "at::BlasBackend::Cublas";
case BlasBackend::Cublaslt:
return "at::BlasBackend::Cublaslt";
default:
TORCH_CHECK(false, "Unknown blas backend");
}
}
inline std::ostream& operator<<(std::ostream& stream, at::BlasBackend backend) {
return stream << BlasBackendToString(backend);
}
} // namespace at

View File

@ -0,0 +1,343 @@
#pragma once
#include <ATen/CollapseDims.h>
#include <ATen/Parallel.h>
#include <ATen/TensorUtils.h>
#include <c10/util/irange.h>
#include <cstring>
#include <limits>
namespace at {
/*
* The basic strategy for apply is as follows:
*
* 1. Starting with the outermost index, loop until we reach a dimension where
* the data is no longer contiguous, i.e. the stride at that dimension is not
* equal to the size of the tensor defined by the outer dimensions. Let's call
* this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
* A is equal to the entire Tensor. Let's call the inner tensor B.
*
* 2. We loop through the indices in B, starting at its outermost dimension. For
* example, if B is a 2x2 matrix, then we do:
*
* B[0][0]
* B[0][1]
* B[1][0]
* B[1][1]
*
* We set the offset into the underlying storage as (storageOffset + stride_B *
* index_B), i.e. basically we compute the offset into the storage as we would
* normally for a Tensor. But because we are guaranteed the subsequent data is
* contiguous in memory, we can simply loop for sizeof(A) iterations and perform
* the operation, without having to follow the order described by the strides of
* A.
*
* 3. As an optimization, we merge dimensions of A that are contiguous in
* memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
* then the first two dimensions can be merged for the purposes of APPLY,
* reducing the number of nested loops.
*/
inline Tensor sort_strides(Tensor& tensor_) {
IntArrayRef strides = tensor_.strides();
std::vector<int64_t> indices;
indices.reserve(tensor_.ndimension());
for (const auto i : c10::irange(tensor_.ndimension())) {
indices.push_back(i);
}
std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
return strides[i1] > strides[i2];
});
Tensor tensor = tensor_.permute(indices);
return tensor;
}
template <typename T, int N>
struct strided_tensor_iter_fixed {
public:
T* data_ = NULL;
int64_t dim_ = 0;
int64_t counter_[N] = {0};
int64_t sizes_[N] = {0};
int64_t strides_[N] = {0};
strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
void operator=(strided_tensor_iter_fixed const& x) = delete;
strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default;
strided_tensor_iter_fixed(
Tensor& tensor,
C10_UNUSED bool sort_strides = false)
: data_(tensor.data_ptr<T>()) {
std::memset(counter_, 0, sizeof(int64_t) * N);
if (tensor.dim() > 0) {
std::memcpy(
sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t));
std::memcpy(
strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t));
}
dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension()));
}
};
template <typename T>
struct strided_tensor_iter {
private:
public:
T* data_ = NULL;
int64_t dim_;
std::vector<int64_t> counter_;
std::vector<int64_t> sizes_;
std::vector<int64_t> strides_;
strided_tensor_iter(strided_tensor_iter const&) = delete;
void operator=(strided_tensor_iter const& x) = delete;
strided_tensor_iter(strided_tensor_iter&&) = default;
strided_tensor_iter(Tensor& tensor)
: data_(tensor.data_ptr<T>()),
dim_(tensor.ndimension()),
counter_(dim_, 0),
sizes_(tensor.sizes().vec()),
strides_(tensor.strides().vec()) {
dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_));
}
};
inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) {
if (tensors.empty())
return true;
int64_t all_numel = tensors[0].numel();
for (const auto i : c10::irange(1, tensors.size())) {
if (tensors[i].numel() != all_numel)
return false;
}
return true;
}
inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
std::ostringstream oss;
oss << "inconsistent tensor size, expected ";
for (size_t i = 0; i < tensors.size() - 1; i++) {
oss << tensors[i].sizes() << ", ";
}
oss << "and " << tensors[tensors.size() - 1].sizes()
<< " to have the same number of elements, but got ";
for (size_t i = 0; i < tensors.size() - 1; i++) {
oss << tensors[i].numel() << ", ";
}
oss << "and " << tensors[tensors.size() - 1].numel()
<< " elements respectively";
return oss.str();
}
inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
checkDeviceType("CPU_tensor_apply", tensors, kCPU);
checkLayout("CPU_tensor_apply", tensors, kStrided);
if (!_all_equal_numel(tensors))
AT_ERROR(_all_equal_numel_error(tensors));
// An empty tensor has no elements
for (auto& t : tensors)
if (t.numel() == 0)
return false;
return true;
}
inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) {
int64_t dim = 0;
for (auto& t : tensors)
dim = std::max(dim, t.ndimension());
return dim;
}
inline void iterate(int64_t /*size*/){};
template <typename Arg, typename... Args>
inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) {
iter.counter_[iter.dim_ - 1] += size;
iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1];
iterate(size, iter_tail...);
}
inline bool iterate_continue() {
return true;
};
template <typename Arg, typename... Args>
inline bool iterate_continue(Arg& iter, Args&... iter_tail) {
return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] &&
iterate_continue(iter_tail...);
}
inline int64_t max_iterate_size() {
return std::numeric_limits<int64_t>::max();
};
template <typename Arg, typename... Args>
inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) {
return std::min(
(iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]),
max_iterate_size(iter_tail...));
}
inline void iterate_overflow(){};
template <typename Arg, typename... Args>
inline void iterate_overflow(Arg& iter, Args&... iter_tail) {
if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
for (int64_t i = iter.dim_ - 1; i > 0; i--) {
if (iter.counter_[i] == iter.sizes_[i]) {
iter.counter_[i] = 0;
iter.counter_[i - 1]++;
iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
iter.strides_[i - 1];
}
}
}
iterate_overflow(iter_tail...);
}
inline void forward(int64_t /*offset*/){};
template <typename Arg, typename... Args>
inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
int64_t multi = offset;
for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
int64_t inc = multi % iter.sizes_[i];
multi = multi / iter.sizes_[i];
iter.data_ = iter.data_ + inc * iter.strides_[i];
iter.counter_[i] += inc;
}
forward(offset, iter_tail...);
}
inline int64_t max_dim() {
return 0;
}
template <typename Arg, typename... Args>
inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
return std::max(iter.dim_, max_dim(iter_tail...));
}
inline void apply_op(){};
template <typename Op, typename... Args>
inline void apply_op(
int64_t numel,
int64_t offset,
const Op& op,
Args... iters) {
// For 0-dim tensors
if (numel == 1 && max_dim(iters...) == 0) {
op(*iters.data_...);
return;
}
if (offset > 0)
forward(offset, iters...);
// Splitting this into chunks helps the compiler create faster assembly
for (int64_t i = 0; i < numel;) {
for (; iterate_continue(iters...) && i < numel;) {
op(*iters.data_...);
iterate(1, iters...);
i++;
}
iterate_overflow(iters...);
}
}
/*
Apply a pointwise operator to sequence of tensors
The calling convention for op is a function/functor that takes the same
number of pointers of type scalar as the number of given tensors. For example,
to compute a = b * c, op would be of the form:
[](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
b_val[0] * c_val[0]; };
*/
template <typename scalar1, typename scalar2, typename Op>
inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
if (!_apply_preamble({tensor1, tensor2}))
return;
if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
strided_tensor_iter_fixed<scalar2, 8>(tensor2));
} else {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter<scalar1>(tensor1),
strided_tensor_iter<scalar2>(tensor2));
}
}
template <typename scalar1, typename scalar2, typename scalar3, typename Op>
inline void CPU_tensor_apply3(
Tensor tensor1,
Tensor tensor2,
Tensor tensor3,
const Op op) {
if (!_apply_preamble({tensor1, tensor2, tensor3}))
return;
if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
strided_tensor_iter_fixed<scalar2, 8>(tensor2),
strided_tensor_iter_fixed<scalar3, 8>(tensor3));
} else {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter<scalar1>(tensor1),
strided_tensor_iter<scalar2>(tensor2),
strided_tensor_iter<scalar3>(tensor3));
}
}
template <
typename scalar1,
typename scalar2,
typename scalar3,
typename scalar4,
typename Op>
inline void CPU_tensor_apply4(
Tensor tensor1,
Tensor tensor2,
Tensor tensor3,
Tensor tensor4,
const Op op) {
if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
return;
if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
strided_tensor_iter_fixed<scalar2, 8>(tensor2),
strided_tensor_iter_fixed<scalar3, 8>(tensor3),
strided_tensor_iter_fixed<scalar4, 8>(tensor4));
} else {
apply_op(
tensor1.numel(),
0,
op,
strided_tensor_iter<scalar1>(tensor1),
strided_tensor_iter<scalar2>(tensor2),
strided_tensor_iter<scalar3>(tensor3),
strided_tensor_iter<scalar4>(tensor4));
}
}
} // namespace at

View File

@ -0,0 +1,33 @@
#pragma once
#include <c10/core/Allocator.h>
#include <c10/util/Exception.h>
// This file creates a fake allocator that just throws exceptions if
// it is actually used.
// state passed to the allocator is the std::function<void(void*)> called
// when the blob is release by ATen
namespace at {
static cpu_fixed_malloc(void*, ptrdiff_t) {
AT_ERROR("attempting to resize a tensor view of an external blob");
}
static cpu_fixed_realloc(void*, void*, ptrdiff_t) {
AT_ERROR("attempting to resize a tensor view of an external blob");
}
static cpu_fixed_free(void* state, void* allocation) {
auto on_release = static_cast<std::function<void(void*)>*>(state);
(*on_release)(allocation);
delete on_release;
}
static Allocator CPU_fixed_allocator = {
cpu_fixed_malloc,
cpu_fixed_realloc,
cpu_fixed_free};
} // namespace at

View File

@ -0,0 +1,29 @@
#include <ATen/core/TensorBody.h>
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
// Code introduced to avoid cyclic dependency in static dispatch is no longer
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
// to Operators.cpp for supporting multiple backends with multiple kernels.
//
// Note [Avoiding Include Cycles In Static Dispatch]
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
//
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
// directly inlined into TensorBody.h.
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
// which include functions that have defaultable std::optional<Tensor> arguments.
// That requires knowing the full Tensor class definition.
//
// We break the cycle by doing the following:
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
// - CPUFunctions_inl.h includes everything else
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
// and then it includes CPUFunctions_inl.h.
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
// - This also means that static dispatch build, CPUFunctions.h only needs to
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
#include <ATen/CPUFunctions_inl.h>

View File

@ -0,0 +1,540 @@
#pragma once
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
// NB: The implementing C++ file is RegisterDispatchKey.cpp
// The only #includes we need are for custom classes that have defaults in the C++ API
#include <c10/core/MemoryFormat.h>
#include <c10/core/Scalar.h>
#include <ATen/core/Reduction.h>
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
#error This change adds a dependency on all pytorch operators, meaning the \
file will need to be re-compiled every time an operator is changed or added. \
Consider including a specific operator from \
<ATen/ops/{my_operator}_cpu_dispatch.h>. \
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
#endif
#include <ATen/ops/_adaptive_avg_pool2d_cpu_dispatch.h>
#include <ATen/ops/_adaptive_avg_pool2d_backward_cpu_dispatch.h>
#include <ATen/ops/_adaptive_avg_pool3d_cpu_dispatch.h>
#include <ATen/ops/_adaptive_avg_pool3d_backward_cpu_dispatch.h>
#include <ATen/ops/_add_relu_cpu_dispatch.h>
#include <ATen/ops/_addmm_activation_cpu_dispatch.h>
#include <ATen/ops/_aminmax_cpu_dispatch.h>
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_cpu_dispatch.h>
#include <ATen/ops/_amp_update_scale_cpu_dispatch.h>
#include <ATen/ops/_assert_async_cpu_dispatch.h>
#include <ATen/ops/_batch_norm_with_update_cpu_dispatch.h>
#include <ATen/ops/_cdist_backward_cpu_dispatch.h>
#include <ATen/ops/_cdist_forward_cpu_dispatch.h>
#include <ATen/ops/_cholesky_solve_helper_cpu_dispatch.h>
#include <ATen/ops/_compute_linear_combination_cpu_dispatch.h>
#include <ATen/ops/_convert_indices_from_coo_to_csr_cpu_dispatch.h>
#include <ATen/ops/_convert_indices_from_csr_to_coo_cpu_dispatch.h>
#include <ATen/ops/_convert_weight_to_int4pack_cpu_dispatch.h>
#include <ATen/ops/_ctc_loss_cpu_dispatch.h>
#include <ATen/ops/_ctc_loss_backward_cpu_dispatch.h>
#include <ATen/ops/_cummax_helper_cpu_dispatch.h>
#include <ATen/ops/_cummin_helper_cpu_dispatch.h>
#include <ATen/ops/_dirichlet_grad_cpu_dispatch.h>
#include <ATen/ops/_efficientzerotensor_cpu_dispatch.h>
#include <ATen/ops/_embedding_bag_cpu_dispatch.h>
#include <ATen/ops/_embedding_bag_backward_cpu_dispatch.h>
#include <ATen/ops/_embedding_bag_dense_backward_cpu_dispatch.h>
#include <ATen/ops/_embedding_bag_forward_only_cpu_dispatch.h>
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_cpu_dispatch.h>
#include <ATen/ops/_empty_affine_quantized_cpu_dispatch.h>
#include <ATen/ops/_empty_per_channel_affine_quantized_cpu_dispatch.h>
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_cpu_dispatch.h>
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_cpu_dispatch.h>
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_cpu_dispatch.h>
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_cpu_dispatch.h>
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_cpu_dispatch.h>
#include <ATen/ops/_fft_c2c_cpu_dispatch.h>
#include <ATen/ops/_fft_c2r_cpu_dispatch.h>
#include <ATen/ops/_fft_r2c_cpu_dispatch.h>
#include <ATen/ops/_foobar_cpu_dispatch.h>
#include <ATen/ops/_functional_assert_async_cpu_dispatch.h>
#include <ATen/ops/_fused_adagrad_cpu_dispatch.h>
#include <ATen/ops/_fused_adam_cpu_dispatch.h>
#include <ATen/ops/_fused_adamw_cpu_dispatch.h>
#include <ATen/ops/_fused_moving_avg_obs_fq_helper_cpu_dispatch.h>
#include <ATen/ops/_fused_sdp_choice_cpu_dispatch.h>
#include <ATen/ops/_fused_sgd_cpu_dispatch.h>
#include <ATen/ops/_histogramdd_bin_edges_cpu_dispatch.h>
#include <ATen/ops/_histogramdd_from_bin_cts_cpu_dispatch.h>
#include <ATen/ops/_histogramdd_from_bin_tensors_cpu_dispatch.h>
#include <ATen/ops/_index_put_impl_cpu_dispatch.h>
#include <ATen/ops/_int_mm_cpu_dispatch.h>
#include <ATen/ops/_jagged_to_padded_dense_forward_cpu_dispatch.h>
#include <ATen/ops/_linalg_det_cpu_dispatch.h>
#include <ATen/ops/_linalg_eigh_cpu_dispatch.h>
#include <ATen/ops/_linalg_eigvals_cpu_dispatch.h>
#include <ATen/ops/_linalg_slogdet_cpu_dispatch.h>
#include <ATen/ops/_linalg_solve_ex_cpu_dispatch.h>
#include <ATen/ops/_linalg_svd_cpu_dispatch.h>
#include <ATen/ops/_local_scalar_dense_cpu_dispatch.h>
#include <ATen/ops/_log_softmax_cpu_dispatch.h>
#include <ATen/ops/_log_softmax_backward_data_cpu_dispatch.h>
#include <ATen/ops/_logcumsumexp_cpu_dispatch.h>
#include <ATen/ops/_make_dep_token_cpu_dispatch.h>
#include <ATen/ops/_make_per_channel_quantized_tensor_cpu_dispatch.h>
#include <ATen/ops/_make_per_tensor_quantized_tensor_cpu_dispatch.h>
#include <ATen/ops/_masked_softmax_cpu_dispatch.h>
#include <ATen/ops/_masked_softmax_backward_cpu_dispatch.h>
#include <ATen/ops/_native_batch_norm_legit_cpu_dispatch.h>
#include <ATen/ops/_native_multi_head_attention_cpu_dispatch.h>
#include <ATen/ops/_nested_compute_contiguous_strides_offsets_cpu_dispatch.h>
#include <ATen/ops/_nested_from_padded_cpu_dispatch.h>
#include <ATen/ops/_nested_tensor_from_mask_cpu_dispatch.h>
#include <ATen/ops/_nested_tensor_from_mask_left_aligned_cpu_dispatch.h>
#include <ATen/ops/_nested_view_from_buffer_cpu_dispatch.h>
#include <ATen/ops/_padded_dense_to_jagged_forward_cpu_dispatch.h>
#include <ATen/ops/_pdist_backward_cpu_dispatch.h>
#include <ATen/ops/_pdist_forward_cpu_dispatch.h>
#include <ATen/ops/_prelu_kernel_cpu_dispatch.h>
#include <ATen/ops/_prelu_kernel_backward_cpu_dispatch.h>
#include <ATen/ops/_reshape_alias_cpu_dispatch.h>
#include <ATen/ops/_sample_dirichlet_cpu_dispatch.h>
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_cpu_dispatch.h>
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_cpu_dispatch.h>
#include <ATen/ops/_segment_reduce_backward_cpu_dispatch.h>
#include <ATen/ops/_slow_conv2d_backward_cpu_dispatch.h>
#include <ATen/ops/_slow_conv2d_forward_cpu_dispatch.h>
#include <ATen/ops/_softmax_cpu_dispatch.h>
#include <ATen/ops/_softmax_backward_data_cpu_dispatch.h>
#include <ATen/ops/_spdiags_cpu_dispatch.h>
#include <ATen/ops/_stack_cpu_dispatch.h>
#include <ATen/ops/_standard_gamma_cpu_dispatch.h>
#include <ATen/ops/_standard_gamma_grad_cpu_dispatch.h>
#include <ATen/ops/_test_functorch_fallback_cpu_dispatch.h>
#include <ATen/ops/_test_optional_filled_intlist_cpu_dispatch.h>
#include <ATen/ops/_test_optional_floatlist_cpu_dispatch.h>
#include <ATen/ops/_test_optional_intlist_cpu_dispatch.h>
#include <ATen/ops/_to_sparse_cpu_dispatch.h>
#include <ATen/ops/_to_sparse_bsc_cpu_dispatch.h>
#include <ATen/ops/_to_sparse_bsr_cpu_dispatch.h>
#include <ATen/ops/_to_sparse_csc_cpu_dispatch.h>
#include <ATen/ops/_to_sparse_csr_cpu_dispatch.h>
#include <ATen/ops/_transform_bias_rescale_qkv_cpu_dispatch.h>
#include <ATen/ops/_transformer_encoder_layer_fwd_cpu_dispatch.h>
#include <ATen/ops/_unique_cpu_dispatch.h>
#include <ATen/ops/_unique2_cpu_dispatch.h>
#include <ATen/ops/_upsample_bicubic2d_aa_cpu_dispatch.h>
#include <ATen/ops/_upsample_bicubic2d_aa_backward_cpu_dispatch.h>
#include <ATen/ops/_upsample_bilinear2d_aa_cpu_dispatch.h>
#include <ATen/ops/_upsample_bilinear2d_aa_backward_cpu_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact1d_cpu_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact1d_backward_cpu_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact2d_cpu_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact2d_backward_cpu_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact3d_cpu_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact3d_backward_cpu_dispatch.h>
#include <ATen/ops/_validate_compressed_sparse_indices_cpu_dispatch.h>
#include <ATen/ops/_weight_int4pack_mm_cpu_dispatch.h>
#include <ATen/ops/_weight_int8pack_mm_cpu_dispatch.h>
#include <ATen/ops/_weight_norm_interface_cpu_dispatch.h>
#include <ATen/ops/_weight_norm_interface_backward_cpu_dispatch.h>
#include <ATen/ops/abs_cpu_dispatch.h>
#include <ATen/ops/acos_cpu_dispatch.h>
#include <ATen/ops/acosh_cpu_dispatch.h>
#include <ATen/ops/adaptive_avg_pool2d_cpu_dispatch.h>
#include <ATen/ops/adaptive_avg_pool3d_cpu_dispatch.h>
#include <ATen/ops/adaptive_avg_pool3d_backward_cpu_dispatch.h>
#include <ATen/ops/adaptive_max_pool2d_cpu_dispatch.h>
#include <ATen/ops/adaptive_max_pool2d_backward_cpu_dispatch.h>
#include <ATen/ops/adaptive_max_pool3d_cpu_dispatch.h>
#include <ATen/ops/adaptive_max_pool3d_backward_cpu_dispatch.h>
#include <ATen/ops/add_cpu_dispatch.h>
#include <ATen/ops/addbmm_cpu_dispatch.h>
#include <ATen/ops/addcdiv_cpu_dispatch.h>
#include <ATen/ops/addcmul_cpu_dispatch.h>
#include <ATen/ops/addmm_cpu_dispatch.h>
#include <ATen/ops/addmv_cpu_dispatch.h>
#include <ATen/ops/addr_cpu_dispatch.h>
#include <ATen/ops/all_cpu_dispatch.h>
#include <ATen/ops/amax_cpu_dispatch.h>
#include <ATen/ops/amin_cpu_dispatch.h>
#include <ATen/ops/aminmax_cpu_dispatch.h>
#include <ATen/ops/angle_cpu_dispatch.h>
#include <ATen/ops/any_cpu_dispatch.h>
#include <ATen/ops/arange_cpu_dispatch.h>
#include <ATen/ops/argmax_cpu_dispatch.h>
#include <ATen/ops/argmin_cpu_dispatch.h>
#include <ATen/ops/as_strided_cpu_dispatch.h>
#include <ATen/ops/asin_cpu_dispatch.h>
#include <ATen/ops/asinh_cpu_dispatch.h>
#include <ATen/ops/atan_cpu_dispatch.h>
#include <ATen/ops/atan2_cpu_dispatch.h>
#include <ATen/ops/atanh_cpu_dispatch.h>
#include <ATen/ops/avg_pool2d_cpu_dispatch.h>
#include <ATen/ops/avg_pool2d_backward_cpu_dispatch.h>
#include <ATen/ops/avg_pool3d_cpu_dispatch.h>
#include <ATen/ops/avg_pool3d_backward_cpu_dispatch.h>
#include <ATen/ops/baddbmm_cpu_dispatch.h>
#include <ATen/ops/batch_norm_backward_cpu_dispatch.h>
#include <ATen/ops/batch_norm_update_stats_cpu_dispatch.h>
#include <ATen/ops/bernoulli_cpu_dispatch.h>
#include <ATen/ops/binary_cross_entropy_cpu_dispatch.h>
#include <ATen/ops/binary_cross_entropy_backward_cpu_dispatch.h>
#include <ATen/ops/bincount_cpu_dispatch.h>
#include <ATen/ops/binomial_cpu_dispatch.h>
#include <ATen/ops/bitwise_and_cpu_dispatch.h>
#include <ATen/ops/bitwise_left_shift_cpu_dispatch.h>
#include <ATen/ops/bitwise_not_cpu_dispatch.h>
#include <ATen/ops/bitwise_or_cpu_dispatch.h>
#include <ATen/ops/bitwise_right_shift_cpu_dispatch.h>
#include <ATen/ops/bitwise_xor_cpu_dispatch.h>
#include <ATen/ops/bmm_cpu_dispatch.h>
#include <ATen/ops/bucketize_cpu_dispatch.h>
#include <ATen/ops/cat_cpu_dispatch.h>
#include <ATen/ops/cauchy_cpu_dispatch.h>
#include <ATen/ops/ceil_cpu_dispatch.h>
#include <ATen/ops/channel_shuffle_cpu_dispatch.h>
#include <ATen/ops/cholesky_cpu_dispatch.h>
#include <ATen/ops/cholesky_inverse_cpu_dispatch.h>
#include <ATen/ops/clamp_cpu_dispatch.h>
#include <ATen/ops/clamp_max_cpu_dispatch.h>
#include <ATen/ops/clamp_min_cpu_dispatch.h>
#include <ATen/ops/col2im_cpu_dispatch.h>
#include <ATen/ops/complex_cpu_dispatch.h>
#include <ATen/ops/conj_physical_cpu_dispatch.h>
#include <ATen/ops/copysign_cpu_dispatch.h>
#include <ATen/ops/cos_cpu_dispatch.h>
#include <ATen/ops/cosh_cpu_dispatch.h>
#include <ATen/ops/count_nonzero_cpu_dispatch.h>
#include <ATen/ops/cumprod_cpu_dispatch.h>
#include <ATen/ops/cumsum_cpu_dispatch.h>
#include <ATen/ops/dequantize_cpu_dispatch.h>
#include <ATen/ops/digamma_cpu_dispatch.h>
#include <ATen/ops/div_cpu_dispatch.h>
#include <ATen/ops/dot_cpu_dispatch.h>
#include <ATen/ops/elu_cpu_dispatch.h>
#include <ATen/ops/elu_backward_cpu_dispatch.h>
#include <ATen/ops/embedding_dense_backward_cpu_dispatch.h>
#include <ATen/ops/embedding_renorm_cpu_dispatch.h>
#include <ATen/ops/empty_cpu_dispatch.h>
#include <ATen/ops/empty_strided_cpu_dispatch.h>
#include <ATen/ops/eq_cpu_dispatch.h>
#include <ATen/ops/equal_cpu_dispatch.h>
#include <ATen/ops/erf_cpu_dispatch.h>
#include <ATen/ops/erfc_cpu_dispatch.h>
#include <ATen/ops/erfinv_cpu_dispatch.h>
#include <ATen/ops/exp_cpu_dispatch.h>
#include <ATen/ops/exp2_cpu_dispatch.h>
#include <ATen/ops/expm1_cpu_dispatch.h>
#include <ATen/ops/exponential_cpu_dispatch.h>
#include <ATen/ops/eye_cpu_dispatch.h>
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_cpu_dispatch.h>
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_cpu_dispatch.h>
#include <ATen/ops/fill_cpu_dispatch.h>
#include <ATen/ops/flip_cpu_dispatch.h>
#include <ATen/ops/floor_cpu_dispatch.h>
#include <ATen/ops/floor_divide_cpu_dispatch.h>
#include <ATen/ops/fmax_cpu_dispatch.h>
#include <ATen/ops/fmin_cpu_dispatch.h>
#include <ATen/ops/fmod_cpu_dispatch.h>
#include <ATen/ops/frac_cpu_dispatch.h>
#include <ATen/ops/fractional_max_pool2d_cpu_dispatch.h>
#include <ATen/ops/fractional_max_pool2d_backward_cpu_dispatch.h>
#include <ATen/ops/fractional_max_pool3d_cpu_dispatch.h>
#include <ATen/ops/fractional_max_pool3d_backward_cpu_dispatch.h>
#include <ATen/ops/frexp_cpu_dispatch.h>
#include <ATen/ops/from_file_cpu_dispatch.h>
#include <ATen/ops/gather_cpu_dispatch.h>
#include <ATen/ops/gcd_cpu_dispatch.h>
#include <ATen/ops/ge_cpu_dispatch.h>
#include <ATen/ops/gelu_cpu_dispatch.h>
#include <ATen/ops/gelu_backward_cpu_dispatch.h>
#include <ATen/ops/geometric_cpu_dispatch.h>
#include <ATen/ops/geqrf_cpu_dispatch.h>
#include <ATen/ops/glu_cpu_dispatch.h>
#include <ATen/ops/glu_backward_cpu_dispatch.h>
#include <ATen/ops/glu_backward_jvp_cpu_dispatch.h>
#include <ATen/ops/glu_jvp_cpu_dispatch.h>
#include <ATen/ops/grid_sampler_2d_cpu_dispatch.h>
#include <ATen/ops/grid_sampler_2d_backward_cpu_dispatch.h>
#include <ATen/ops/grid_sampler_3d_cpu_dispatch.h>
#include <ATen/ops/grid_sampler_3d_backward_cpu_dispatch.h>
#include <ATen/ops/gt_cpu_dispatch.h>
#include <ATen/ops/hardshrink_cpu_dispatch.h>
#include <ATen/ops/hardshrink_backward_cpu_dispatch.h>
#include <ATen/ops/hardsigmoid_cpu_dispatch.h>
#include <ATen/ops/hardsigmoid_backward_cpu_dispatch.h>
#include <ATen/ops/hardswish_cpu_dispatch.h>
#include <ATen/ops/hardswish_backward_cpu_dispatch.h>
#include <ATen/ops/hardtanh_cpu_dispatch.h>
#include <ATen/ops/hardtanh_backward_cpu_dispatch.h>
#include <ATen/ops/heaviside_cpu_dispatch.h>
#include <ATen/ops/histc_cpu_dispatch.h>
#include <ATen/ops/histogram_cpu_dispatch.h>
#include <ATen/ops/huber_loss_cpu_dispatch.h>
#include <ATen/ops/huber_loss_backward_cpu_dispatch.h>
#include <ATen/ops/hypot_cpu_dispatch.h>
#include <ATen/ops/i0_cpu_dispatch.h>
#include <ATen/ops/igamma_cpu_dispatch.h>
#include <ATen/ops/igammac_cpu_dispatch.h>
#include <ATen/ops/im2col_cpu_dispatch.h>
#include <ATen/ops/index_cpu_dispatch.h>
#include <ATen/ops/index_add_cpu_dispatch.h>
#include <ATen/ops/index_copy_cpu_dispatch.h>
#include <ATen/ops/index_fill_cpu_dispatch.h>
#include <ATen/ops/index_reduce_cpu_dispatch.h>
#include <ATen/ops/index_select_cpu_dispatch.h>
#include <ATen/ops/is_set_to_cpu_dispatch.h>
#include <ATen/ops/isin_cpu_dispatch.h>
#include <ATen/ops/isnan_cpu_dispatch.h>
#include <ATen/ops/isneginf_cpu_dispatch.h>
#include <ATen/ops/isposinf_cpu_dispatch.h>
#include <ATen/ops/kthvalue_cpu_dispatch.h>
#include <ATen/ops/lcm_cpu_dispatch.h>
#include <ATen/ops/le_cpu_dispatch.h>
#include <ATen/ops/leaky_relu_cpu_dispatch.h>
#include <ATen/ops/leaky_relu_backward_cpu_dispatch.h>
#include <ATen/ops/lerp_cpu_dispatch.h>
#include <ATen/ops/lgamma_cpu_dispatch.h>
#include <ATen/ops/linalg_cholesky_ex_cpu_dispatch.h>
#include <ATen/ops/linalg_cross_cpu_dispatch.h>
#include <ATen/ops/linalg_eig_cpu_dispatch.h>
#include <ATen/ops/linalg_eigvals_cpu_dispatch.h>
#include <ATen/ops/linalg_householder_product_cpu_dispatch.h>
#include <ATen/ops/linalg_inv_ex_cpu_dispatch.h>
#include <ATen/ops/linalg_ldl_factor_ex_cpu_dispatch.h>
#include <ATen/ops/linalg_ldl_solve_cpu_dispatch.h>
#include <ATen/ops/linalg_lstsq_cpu_dispatch.h>
#include <ATen/ops/linalg_lu_cpu_dispatch.h>
#include <ATen/ops/linalg_lu_factor_ex_cpu_dispatch.h>
#include <ATen/ops/linalg_lu_solve_cpu_dispatch.h>
#include <ATen/ops/linalg_matrix_exp_cpu_dispatch.h>
#include <ATen/ops/linalg_qr_cpu_dispatch.h>
#include <ATen/ops/linalg_solve_triangular_cpu_dispatch.h>
#include <ATen/ops/linalg_vector_norm_cpu_dispatch.h>
#include <ATen/ops/linspace_cpu_dispatch.h>
#include <ATen/ops/log_cpu_dispatch.h>
#include <ATen/ops/log10_cpu_dispatch.h>
#include <ATen/ops/log1p_cpu_dispatch.h>
#include <ATen/ops/log2_cpu_dispatch.h>
#include <ATen/ops/log_normal_cpu_dispatch.h>
#include <ATen/ops/log_sigmoid_backward_cpu_dispatch.h>
#include <ATen/ops/log_sigmoid_forward_cpu_dispatch.h>
#include <ATen/ops/logaddexp_cpu_dispatch.h>
#include <ATen/ops/logaddexp2_cpu_dispatch.h>
#include <ATen/ops/logical_and_cpu_dispatch.h>
#include <ATen/ops/logical_not_cpu_dispatch.h>
#include <ATen/ops/logical_or_cpu_dispatch.h>
#include <ATen/ops/logical_xor_cpu_dispatch.h>
#include <ATen/ops/logit_cpu_dispatch.h>
#include <ATen/ops/logit_backward_cpu_dispatch.h>
#include <ATen/ops/logspace_cpu_dispatch.h>
#include <ATen/ops/lshift_cpu_dispatch.h>
#include <ATen/ops/lt_cpu_dispatch.h>
#include <ATen/ops/lu_unpack_cpu_dispatch.h>
#include <ATen/ops/masked_fill_cpu_dispatch.h>
#include <ATen/ops/masked_scatter_cpu_dispatch.h>
#include <ATen/ops/masked_select_cpu_dispatch.h>
#include <ATen/ops/max_cpu_dispatch.h>
#include <ATen/ops/max_pool2d_with_indices_cpu_dispatch.h>
#include <ATen/ops/max_pool2d_with_indices_backward_cpu_dispatch.h>
#include <ATen/ops/max_pool3d_with_indices_cpu_dispatch.h>
#include <ATen/ops/max_pool3d_with_indices_backward_cpu_dispatch.h>
#include <ATen/ops/max_unpool2d_cpu_dispatch.h>
#include <ATen/ops/max_unpool3d_cpu_dispatch.h>
#include <ATen/ops/maximum_cpu_dispatch.h>
#include <ATen/ops/mean_cpu_dispatch.h>
#include <ATen/ops/median_cpu_dispatch.h>
#include <ATen/ops/min_cpu_dispatch.h>
#include <ATen/ops/minimum_cpu_dispatch.h>
#include <ATen/ops/mish_cpu_dispatch.h>
#include <ATen/ops/mish_backward_cpu_dispatch.h>
#include <ATen/ops/mkldnn_rnn_layer_cpu_dispatch.h>
#include <ATen/ops/mkldnn_rnn_layer_backward_cpu_dispatch.h>
#include <ATen/ops/mm_cpu_dispatch.h>
#include <ATen/ops/mode_cpu_dispatch.h>
#include <ATen/ops/mse_loss_cpu_dispatch.h>
#include <ATen/ops/mse_loss_backward_cpu_dispatch.h>
#include <ATen/ops/mul_cpu_dispatch.h>
#include <ATen/ops/multi_margin_loss_cpu_dispatch.h>
#include <ATen/ops/multi_margin_loss_backward_cpu_dispatch.h>
#include <ATen/ops/multilabel_margin_loss_backward_cpu_dispatch.h>
#include <ATen/ops/multilabel_margin_loss_forward_cpu_dispatch.h>
#include <ATen/ops/multinomial_cpu_dispatch.h>
#include <ATen/ops/mvlgamma_cpu_dispatch.h>
#include <ATen/ops/nan_to_num_cpu_dispatch.h>
#include <ATen/ops/nanmedian_cpu_dispatch.h>
#include <ATen/ops/nansum_cpu_dispatch.h>
#include <ATen/ops/narrow_copy_cpu_dispatch.h>
#include <ATen/ops/native_batch_norm_cpu_dispatch.h>
#include <ATen/ops/native_batch_norm_backward_cpu_dispatch.h>
#include <ATen/ops/native_channel_shuffle_cpu_dispatch.h>
#include <ATen/ops/native_dropout_cpu_dispatch.h>
#include <ATen/ops/native_dropout_backward_cpu_dispatch.h>
#include <ATen/ops/native_group_norm_cpu_dispatch.h>
#include <ATen/ops/native_group_norm_backward_cpu_dispatch.h>
#include <ATen/ops/native_layer_norm_cpu_dispatch.h>
#include <ATen/ops/native_layer_norm_backward_cpu_dispatch.h>
#include <ATen/ops/ne_cpu_dispatch.h>
#include <ATen/ops/neg_cpu_dispatch.h>
#include <ATen/ops/nextafter_cpu_dispatch.h>
#include <ATen/ops/nll_loss2d_backward_cpu_dispatch.h>
#include <ATen/ops/nll_loss2d_forward_cpu_dispatch.h>
#include <ATen/ops/nll_loss_backward_cpu_dispatch.h>
#include <ATen/ops/nll_loss_forward_cpu_dispatch.h>
#include <ATen/ops/nonzero_cpu_dispatch.h>
#include <ATen/ops/nonzero_static_cpu_dispatch.h>
#include <ATen/ops/norm_cpu_dispatch.h>
#include <ATen/ops/normal_cpu_dispatch.h>
#include <ATen/ops/ormqr_cpu_dispatch.h>
#include <ATen/ops/pixel_shuffle_cpu_dispatch.h>
#include <ATen/ops/pixel_unshuffle_cpu_dispatch.h>
#include <ATen/ops/poisson_cpu_dispatch.h>
#include <ATen/ops/polar_cpu_dispatch.h>
#include <ATen/ops/polygamma_cpu_dispatch.h>
#include <ATen/ops/pow_cpu_dispatch.h>
#include <ATen/ops/prod_cpu_dispatch.h>
#include <ATen/ops/put_cpu_dispatch.h>
#include <ATen/ops/quantize_per_channel_cpu_dispatch.h>
#include <ATen/ops/quantize_per_tensor_cpu_dispatch.h>
#include <ATen/ops/quantize_per_tensor_dynamic_cpu_dispatch.h>
#include <ATen/ops/random_cpu_dispatch.h>
#include <ATen/ops/randperm_cpu_dispatch.h>
#include <ATen/ops/range_cpu_dispatch.h>
#include <ATen/ops/reciprocal_cpu_dispatch.h>
#include <ATen/ops/reflection_pad1d_cpu_dispatch.h>
#include <ATen/ops/reflection_pad1d_backward_cpu_dispatch.h>
#include <ATen/ops/reflection_pad2d_cpu_dispatch.h>
#include <ATen/ops/reflection_pad2d_backward_cpu_dispatch.h>
#include <ATen/ops/reflection_pad3d_cpu_dispatch.h>
#include <ATen/ops/reflection_pad3d_backward_cpu_dispatch.h>
#include <ATen/ops/relu_cpu_dispatch.h>
#include <ATen/ops/remainder_cpu_dispatch.h>
#include <ATen/ops/renorm_cpu_dispatch.h>
#include <ATen/ops/repeat_interleave_cpu_dispatch.h>
#include <ATen/ops/replication_pad1d_cpu_dispatch.h>
#include <ATen/ops/replication_pad1d_backward_cpu_dispatch.h>
#include <ATen/ops/replication_pad2d_cpu_dispatch.h>
#include <ATen/ops/replication_pad2d_backward_cpu_dispatch.h>
#include <ATen/ops/replication_pad3d_cpu_dispatch.h>
#include <ATen/ops/replication_pad3d_backward_cpu_dispatch.h>
#include <ATen/ops/resize_cpu_dispatch.h>
#include <ATen/ops/roll_cpu_dispatch.h>
#include <ATen/ops/round_cpu_dispatch.h>
#include <ATen/ops/rrelu_with_noise_cpu_dispatch.h>
#include <ATen/ops/rshift_cpu_dispatch.h>
#include <ATen/ops/rsqrt_cpu_dispatch.h>
#include <ATen/ops/rsub_cpu_dispatch.h>
#include <ATen/ops/scatter_cpu_dispatch.h>
#include <ATen/ops/scatter_add_cpu_dispatch.h>
#include <ATen/ops/scatter_reduce_cpu_dispatch.h>
#include <ATen/ops/searchsorted_cpu_dispatch.h>
#include <ATen/ops/segment_reduce_cpu_dispatch.h>
#include <ATen/ops/set_cpu_dispatch.h>
#include <ATen/ops/sgn_cpu_dispatch.h>
#include <ATen/ops/sigmoid_cpu_dispatch.h>
#include <ATen/ops/sigmoid_backward_cpu_dispatch.h>
#include <ATen/ops/sign_cpu_dispatch.h>
#include <ATen/ops/signbit_cpu_dispatch.h>
#include <ATen/ops/silu_cpu_dispatch.h>
#include <ATen/ops/silu_backward_cpu_dispatch.h>
#include <ATen/ops/sin_cpu_dispatch.h>
#include <ATen/ops/sinc_cpu_dispatch.h>
#include <ATen/ops/sinh_cpu_dispatch.h>
#include <ATen/ops/slow_conv3d_forward_cpu_dispatch.h>
#include <ATen/ops/slow_conv_dilated2d_cpu_dispatch.h>
#include <ATen/ops/slow_conv_dilated3d_cpu_dispatch.h>
#include <ATen/ops/slow_conv_transpose2d_cpu_dispatch.h>
#include <ATen/ops/slow_conv_transpose3d_cpu_dispatch.h>
#include <ATen/ops/smooth_l1_loss_cpu_dispatch.h>
#include <ATen/ops/smooth_l1_loss_backward_cpu_dispatch.h>
#include <ATen/ops/softplus_cpu_dispatch.h>
#include <ATen/ops/softplus_backward_cpu_dispatch.h>
#include <ATen/ops/softshrink_cpu_dispatch.h>
#include <ATen/ops/softshrink_backward_cpu_dispatch.h>
#include <ATen/ops/sort_cpu_dispatch.h>
#include <ATen/ops/special_airy_ai_cpu_dispatch.h>
#include <ATen/ops/special_bessel_j0_cpu_dispatch.h>
#include <ATen/ops/special_bessel_j1_cpu_dispatch.h>
#include <ATen/ops/special_bessel_y0_cpu_dispatch.h>
#include <ATen/ops/special_bessel_y1_cpu_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_t_cpu_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_u_cpu_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_v_cpu_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_w_cpu_dispatch.h>
#include <ATen/ops/special_entr_cpu_dispatch.h>
#include <ATen/ops/special_erfcx_cpu_dispatch.h>
#include <ATen/ops/special_hermite_polynomial_h_cpu_dispatch.h>
#include <ATen/ops/special_hermite_polynomial_he_cpu_dispatch.h>
#include <ATen/ops/special_i0e_cpu_dispatch.h>
#include <ATen/ops/special_i1_cpu_dispatch.h>
#include <ATen/ops/special_i1e_cpu_dispatch.h>
#include <ATen/ops/special_laguerre_polynomial_l_cpu_dispatch.h>
#include <ATen/ops/special_legendre_polynomial_p_cpu_dispatch.h>
#include <ATen/ops/special_log_ndtr_cpu_dispatch.h>
#include <ATen/ops/special_modified_bessel_i0_cpu_dispatch.h>
#include <ATen/ops/special_modified_bessel_i1_cpu_dispatch.h>
#include <ATen/ops/special_modified_bessel_k0_cpu_dispatch.h>
#include <ATen/ops/special_modified_bessel_k1_cpu_dispatch.h>
#include <ATen/ops/special_ndtri_cpu_dispatch.h>
#include <ATen/ops/special_scaled_modified_bessel_k0_cpu_dispatch.h>
#include <ATen/ops/special_scaled_modified_bessel_k1_cpu_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_cpu_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_cpu_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_cpu_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_cpu_dispatch.h>
#include <ATen/ops/special_spherical_bessel_j0_cpu_dispatch.h>
#include <ATen/ops/special_xlog1py_cpu_dispatch.h>
#include <ATen/ops/special_zeta_cpu_dispatch.h>
#include <ATen/ops/sqrt_cpu_dispatch.h>
#include <ATen/ops/sspaddmm_cpu_dispatch.h>
#include <ATen/ops/std_cpu_dispatch.h>
#include <ATen/ops/std_mean_cpu_dispatch.h>
#include <ATen/ops/sub_cpu_dispatch.h>
#include <ATen/ops/sum_cpu_dispatch.h>
#include <ATen/ops/take_cpu_dispatch.h>
#include <ATen/ops/tan_cpu_dispatch.h>
#include <ATen/ops/tanh_cpu_dispatch.h>
#include <ATen/ops/tanh_backward_cpu_dispatch.h>
#include <ATen/ops/threshold_cpu_dispatch.h>
#include <ATen/ops/threshold_backward_cpu_dispatch.h>
#include <ATen/ops/to_mkldnn_cpu_dispatch.h>
#include <ATen/ops/topk_cpu_dispatch.h>
#include <ATen/ops/trace_cpu_dispatch.h>
#include <ATen/ops/triangular_solve_cpu_dispatch.h>
#include <ATen/ops/tril_cpu_dispatch.h>
#include <ATen/ops/tril_indices_cpu_dispatch.h>
#include <ATen/ops/triu_cpu_dispatch.h>
#include <ATen/ops/triu_indices_cpu_dispatch.h>
#include <ATen/ops/trunc_cpu_dispatch.h>
#include <ATen/ops/unfold_cpu_dispatch.h>
#include <ATen/ops/unfold_backward_cpu_dispatch.h>
#include <ATen/ops/uniform_cpu_dispatch.h>
#include <ATen/ops/unique_consecutive_cpu_dispatch.h>
#include <ATen/ops/unique_dim_cpu_dispatch.h>
#include <ATen/ops/unique_dim_consecutive_cpu_dispatch.h>
#include <ATen/ops/upsample_bicubic2d_cpu_dispatch.h>
#include <ATen/ops/upsample_bicubic2d_backward_cpu_dispatch.h>
#include <ATen/ops/upsample_bilinear2d_cpu_dispatch.h>
#include <ATen/ops/upsample_bilinear2d_backward_cpu_dispatch.h>
#include <ATen/ops/upsample_linear1d_cpu_dispatch.h>
#include <ATen/ops/upsample_linear1d_backward_cpu_dispatch.h>
#include <ATen/ops/upsample_nearest1d_cpu_dispatch.h>
#include <ATen/ops/upsample_nearest1d_backward_cpu_dispatch.h>
#include <ATen/ops/upsample_nearest2d_cpu_dispatch.h>
#include <ATen/ops/upsample_nearest2d_backward_cpu_dispatch.h>
#include <ATen/ops/upsample_nearest3d_cpu_dispatch.h>
#include <ATen/ops/upsample_nearest3d_backward_cpu_dispatch.h>
#include <ATen/ops/upsample_trilinear3d_cpu_dispatch.h>
#include <ATen/ops/upsample_trilinear3d_backward_cpu_dispatch.h>
#include <ATen/ops/var_cpu_dispatch.h>
#include <ATen/ops/var_mean_cpu_dispatch.h>
#include <ATen/ops/vdot_cpu_dispatch.h>
#include <ATen/ops/view_cpu_dispatch.h>
#include <ATen/ops/view_as_complex_cpu_dispatch.h>
#include <ATen/ops/view_as_real_cpu_dispatch.h>
#include <ATen/ops/where_cpu_dispatch.h>
#include <ATen/ops/xlogy_cpu_dispatch.h>
#include <ATen/ops/zero_cpu_dispatch.h>

View File

@ -0,0 +1,49 @@
#pragma once
#include <ATen/core/Generator.h>
#include <ATen/core/MT19937RNGEngine.h>
#include <c10/core/GeneratorImpl.h>
#include <optional>
namespace at {
struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl {
// Constructors
CPUGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
~CPUGeneratorImpl() override = default;
// CPUGeneratorImpl methods
std::shared_ptr<CPUGeneratorImpl> clone() const;
void set_current_seed(uint64_t seed) override;
void set_offset(uint64_t offset) override;
uint64_t get_offset() const override;
uint64_t current_seed() const override;
uint64_t seed() override;
void set_state(const c10::TensorImpl& new_state) override;
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
static c10::DeviceType device_type();
uint32_t random();
uint64_t random64();
std::optional<float> next_float_normal_sample();
std::optional<double> next_double_normal_sample();
void set_next_float_normal_sample(std::optional<float> randn);
void set_next_double_normal_sample(std::optional<double> randn);
at::mt19937 engine();
void set_engine(at::mt19937 engine);
private:
CPUGeneratorImpl* clone_impl() const override;
at::mt19937 engine_;
std::optional<float> next_float_normal_sample_;
std::optional<double> next_double_normal_sample_;
};
namespace detail {
TORCH_API const Generator& getDefaultCPUGenerator();
TORCH_API Generator
createCPUGenerator(uint64_t seed_val = default_rng_seed_val);
} // namespace detail
} // namespace at

View File

@ -0,0 +1,29 @@
#include <ATen/core/TensorBody.h>
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
// Code introduced to avoid cyclic dependency in static dispatch is no longer
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
// to Operators.cpp for supporting multiple backends with multiple kernels.
//
// Note [Avoiding Include Cycles In Static Dispatch]
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
//
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
// directly inlined into TensorBody.h.
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
// which include functions that have defaultable std::optional<Tensor> arguments.
// That requires knowing the full Tensor class definition.
//
// We break the cycle by doing the following:
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
// - CPUFunctions_inl.h includes everything else
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
// and then it includes CPUFunctions_inl.h.
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
// - This also means that static dispatch build, CPUFunctions.h only needs to
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
#include <ATen/CUDAFunctions_inl.h>

View File

@ -0,0 +1,623 @@
#pragma once
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
// NB: The implementing C++ file is RegisterDispatchKey.cpp
// The only #includes we need are for custom classes that have defaults in the C++ API
#include <c10/core/MemoryFormat.h>
#include <c10/core/Scalar.h>
#include <ATen/core/Reduction.h>
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
#error This change adds a dependency on all pytorch operators, meaning the \
file will need to be re-compiled every time an operator is changed or added. \
Consider including a specific operator from \
<ATen/ops/{my_operator}_cuda_dispatch.h>. \
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
#endif
#include <ATen/ops/_adaptive_avg_pool2d_cuda_dispatch.h>
#include <ATen/ops/_adaptive_avg_pool2d_backward_cuda_dispatch.h>
#include <ATen/ops/_adaptive_avg_pool3d_cuda_dispatch.h>
#include <ATen/ops/_adaptive_avg_pool3d_backward_cuda_dispatch.h>
#include <ATen/ops/_addmm_activation_cuda_dispatch.h>
#include <ATen/ops/_aminmax_cuda_dispatch.h>
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_cuda_dispatch.h>
#include <ATen/ops/_amp_update_scale_cuda_dispatch.h>
#include <ATen/ops/_assert_async_cuda_dispatch.h>
#include <ATen/ops/_batch_norm_with_update_cuda_dispatch.h>
#include <ATen/ops/_cdist_backward_cuda_dispatch.h>
#include <ATen/ops/_cdist_forward_cuda_dispatch.h>
#include <ATen/ops/_cholesky_solve_helper_cuda_dispatch.h>
#include <ATen/ops/_chunk_cat_cuda_dispatch.h>
#include <ATen/ops/_compute_linear_combination_cuda_dispatch.h>
#include <ATen/ops/_conv_depthwise2d_cuda_dispatch.h>
#include <ATen/ops/_convert_indices_from_coo_to_csr_cuda_dispatch.h>
#include <ATen/ops/_convert_indices_from_csr_to_coo_cuda_dispatch.h>
#include <ATen/ops/_convert_weight_to_int4pack_cuda_dispatch.h>
#include <ATen/ops/_cslt_compress_cuda_dispatch.h>
#include <ATen/ops/_cslt_sparse_mm_cuda_dispatch.h>
#include <ATen/ops/_cslt_sparse_mm_search_cuda_dispatch.h>
#include <ATen/ops/_ctc_loss_cuda_dispatch.h>
#include <ATen/ops/_ctc_loss_backward_cuda_dispatch.h>
#include <ATen/ops/_cudnn_ctc_loss_cuda_dispatch.h>
#include <ATen/ops/_cudnn_init_dropout_state_cuda_dispatch.h>
#include <ATen/ops/_cudnn_rnn_cuda_dispatch.h>
#include <ATen/ops/_cudnn_rnn_backward_cuda_dispatch.h>
#include <ATen/ops/_cudnn_rnn_flatten_weight_cuda_dispatch.h>
#include <ATen/ops/_cummax_helper_cuda_dispatch.h>
#include <ATen/ops/_cummin_helper_cuda_dispatch.h>
#include <ATen/ops/_dirichlet_grad_cuda_dispatch.h>
#include <ATen/ops/_efficient_attention_backward_cuda_dispatch.h>
#include <ATen/ops/_efficient_attention_forward_cuda_dispatch.h>
#include <ATen/ops/_efficientzerotensor_cuda_dispatch.h>
#include <ATen/ops/_embedding_bag_cuda_dispatch.h>
#include <ATen/ops/_embedding_bag_backward_cuda_dispatch.h>
#include <ATen/ops/_embedding_bag_dense_backward_cuda_dispatch.h>
#include <ATen/ops/_embedding_bag_forward_only_cuda_dispatch.h>
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_cuda_dispatch.h>
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_cuda_dispatch.h>
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_cuda_dispatch.h>
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_cuda_dispatch.h>
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_cuda_dispatch.h>
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_cuda_dispatch.h>
#include <ATen/ops/_fft_c2c_cuda_dispatch.h>
#include <ATen/ops/_fft_c2r_cuda_dispatch.h>
#include <ATen/ops/_fft_r2c_cuda_dispatch.h>
#include <ATen/ops/_fill_mem_eff_dropout_mask_cuda_dispatch.h>
#include <ATen/ops/_flash_attention_backward_cuda_dispatch.h>
#include <ATen/ops/_flash_attention_forward_cuda_dispatch.h>
#include <ATen/ops/_foreach_abs_cuda_dispatch.h>
#include <ATen/ops/_foreach_acos_cuda_dispatch.h>
#include <ATen/ops/_foreach_add_cuda_dispatch.h>
#include <ATen/ops/_foreach_addcdiv_cuda_dispatch.h>
#include <ATen/ops/_foreach_addcmul_cuda_dispatch.h>
#include <ATen/ops/_foreach_asin_cuda_dispatch.h>
#include <ATen/ops/_foreach_atan_cuda_dispatch.h>
#include <ATen/ops/_foreach_ceil_cuda_dispatch.h>
#include <ATen/ops/_foreach_clamp_max_cuda_dispatch.h>
#include <ATen/ops/_foreach_clamp_min_cuda_dispatch.h>
#include <ATen/ops/_foreach_copy_cuda_dispatch.h>
#include <ATen/ops/_foreach_cos_cuda_dispatch.h>
#include <ATen/ops/_foreach_cosh_cuda_dispatch.h>
#include <ATen/ops/_foreach_div_cuda_dispatch.h>
#include <ATen/ops/_foreach_erf_cuda_dispatch.h>
#include <ATen/ops/_foreach_erfc_cuda_dispatch.h>
#include <ATen/ops/_foreach_exp_cuda_dispatch.h>
#include <ATen/ops/_foreach_expm1_cuda_dispatch.h>
#include <ATen/ops/_foreach_floor_cuda_dispatch.h>
#include <ATen/ops/_foreach_frac_cuda_dispatch.h>
#include <ATen/ops/_foreach_lerp_cuda_dispatch.h>
#include <ATen/ops/_foreach_lgamma_cuda_dispatch.h>
#include <ATen/ops/_foreach_log_cuda_dispatch.h>
#include <ATen/ops/_foreach_log10_cuda_dispatch.h>
#include <ATen/ops/_foreach_log1p_cuda_dispatch.h>
#include <ATen/ops/_foreach_log2_cuda_dispatch.h>
#include <ATen/ops/_foreach_max_cuda_dispatch.h>
#include <ATen/ops/_foreach_maximum_cuda_dispatch.h>
#include <ATen/ops/_foreach_minimum_cuda_dispatch.h>
#include <ATen/ops/_foreach_mul_cuda_dispatch.h>
#include <ATen/ops/_foreach_neg_cuda_dispatch.h>
#include <ATen/ops/_foreach_norm_cuda_dispatch.h>
#include <ATen/ops/_foreach_pow_cuda_dispatch.h>
#include <ATen/ops/_foreach_reciprocal_cuda_dispatch.h>
#include <ATen/ops/_foreach_round_cuda_dispatch.h>
#include <ATen/ops/_foreach_sigmoid_cuda_dispatch.h>
#include <ATen/ops/_foreach_sign_cuda_dispatch.h>
#include <ATen/ops/_foreach_sin_cuda_dispatch.h>
#include <ATen/ops/_foreach_sinh_cuda_dispatch.h>
#include <ATen/ops/_foreach_sqrt_cuda_dispatch.h>
#include <ATen/ops/_foreach_sub_cuda_dispatch.h>
#include <ATen/ops/_foreach_tan_cuda_dispatch.h>
#include <ATen/ops/_foreach_tanh_cuda_dispatch.h>
#include <ATen/ops/_foreach_trunc_cuda_dispatch.h>
#include <ATen/ops/_foreach_zero_cuda_dispatch.h>
#include <ATen/ops/_fused_adam_cuda_dispatch.h>
#include <ATen/ops/_fused_adamw_cuda_dispatch.h>
#include <ATen/ops/_fused_dropout_cuda_dispatch.h>
#include <ATen/ops/_fused_moving_avg_obs_fq_helper_cuda_dispatch.h>
#include <ATen/ops/_fused_sdp_choice_cuda_dispatch.h>
#include <ATen/ops/_fused_sgd_cuda_dispatch.h>
#include <ATen/ops/_index_put_impl_cuda_dispatch.h>
#include <ATen/ops/_int_mm_cuda_dispatch.h>
#include <ATen/ops/_jagged_to_padded_dense_forward_cuda_dispatch.h>
#include <ATen/ops/_linalg_det_cuda_dispatch.h>
#include <ATen/ops/_linalg_eigh_cuda_dispatch.h>
#include <ATen/ops/_linalg_eigvals_cuda_dispatch.h>
#include <ATen/ops/_linalg_slogdet_cuda_dispatch.h>
#include <ATen/ops/_linalg_solve_ex_cuda_dispatch.h>
#include <ATen/ops/_linalg_svd_cuda_dispatch.h>
#include <ATen/ops/_local_scalar_dense_cuda_dispatch.h>
#include <ATen/ops/_log_softmax_cuda_dispatch.h>
#include <ATen/ops/_log_softmax_backward_data_cuda_dispatch.h>
#include <ATen/ops/_logcumsumexp_cuda_dispatch.h>
#include <ATen/ops/_make_per_channel_quantized_tensor_cuda_dispatch.h>
#include <ATen/ops/_make_per_tensor_quantized_tensor_cuda_dispatch.h>
#include <ATen/ops/_masked_scale_cuda_dispatch.h>
#include <ATen/ops/_masked_softmax_cuda_dispatch.h>
#include <ATen/ops/_masked_softmax_backward_cuda_dispatch.h>
#include <ATen/ops/_mixed_dtypes_linear_cuda_dispatch.h>
#include <ATen/ops/_native_batch_norm_legit_cuda_dispatch.h>
#include <ATen/ops/_native_multi_head_attention_cuda_dispatch.h>
#include <ATen/ops/_nested_compute_contiguous_strides_offsets_cuda_dispatch.h>
#include <ATen/ops/_nested_from_padded_cuda_dispatch.h>
#include <ATen/ops/_nested_tensor_from_mask_cuda_dispatch.h>
#include <ATen/ops/_nested_tensor_from_mask_left_aligned_cuda_dispatch.h>
#include <ATen/ops/_nested_view_from_buffer_cuda_dispatch.h>
#include <ATen/ops/_padded_dense_to_jagged_forward_cuda_dispatch.h>
#include <ATen/ops/_pdist_backward_cuda_dispatch.h>
#include <ATen/ops/_pdist_forward_cuda_dispatch.h>
#include <ATen/ops/_prelu_kernel_cuda_dispatch.h>
#include <ATen/ops/_prelu_kernel_backward_cuda_dispatch.h>
#include <ATen/ops/_reshape_alias_cuda_dispatch.h>
#include <ATen/ops/_sample_dirichlet_cuda_dispatch.h>
#include <ATen/ops/_scaled_dot_product_cudnn_attention_cuda_dispatch.h>
#include <ATen/ops/_scaled_dot_product_cudnn_attention_backward_cuda_dispatch.h>
#include <ATen/ops/_scaled_dot_product_efficient_attention_cuda_dispatch.h>
#include <ATen/ops/_scaled_dot_product_efficient_attention_backward_cuda_dispatch.h>
#include <ATen/ops/_scaled_dot_product_flash_attention_cuda_dispatch.h>
#include <ATen/ops/_scaled_dot_product_flash_attention_backward_cuda_dispatch.h>
#include <ATen/ops/_scaled_mm_cuda_dispatch.h>
#include <ATen/ops/_segment_reduce_backward_cuda_dispatch.h>
#include <ATen/ops/_slow_conv2d_backward_cuda_dispatch.h>
#include <ATen/ops/_slow_conv2d_forward_cuda_dispatch.h>
#include <ATen/ops/_softmax_cuda_dispatch.h>
#include <ATen/ops/_softmax_backward_data_cuda_dispatch.h>
#include <ATen/ops/_sparse_semi_structured_addmm_cuda_dispatch.h>
#include <ATen/ops/_sparse_semi_structured_apply_cuda_dispatch.h>
#include <ATen/ops/_sparse_semi_structured_apply_dense_cuda_dispatch.h>
#include <ATen/ops/_sparse_semi_structured_linear_cuda_dispatch.h>
#include <ATen/ops/_sparse_semi_structured_mm_cuda_dispatch.h>
#include <ATen/ops/_sparse_semi_structured_tile_cuda_dispatch.h>
#include <ATen/ops/_standard_gamma_cuda_dispatch.h>
#include <ATen/ops/_standard_gamma_grad_cuda_dispatch.h>
#include <ATen/ops/_thnn_fused_gru_cell_cuda_dispatch.h>
#include <ATen/ops/_thnn_fused_gru_cell_backward_cuda_dispatch.h>
#include <ATen/ops/_thnn_fused_lstm_cell_cuda_dispatch.h>
#include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_cuda_dispatch.h>
#include <ATen/ops/_to_sparse_cuda_dispatch.h>
#include <ATen/ops/_to_sparse_bsc_cuda_dispatch.h>
#include <ATen/ops/_to_sparse_bsr_cuda_dispatch.h>
#include <ATen/ops/_to_sparse_csc_cuda_dispatch.h>
#include <ATen/ops/_to_sparse_csr_cuda_dispatch.h>
#include <ATen/ops/_to_sparse_semi_structured_cuda_dispatch.h>
#include <ATen/ops/_transform_bias_rescale_qkv_cuda_dispatch.h>
#include <ATen/ops/_transformer_encoder_layer_fwd_cuda_dispatch.h>
#include <ATen/ops/_triton_multi_head_attention_cuda_dispatch.h>
#include <ATen/ops/_triton_scaled_dot_attention_cuda_dispatch.h>
#include <ATen/ops/_unique_cuda_dispatch.h>
#include <ATen/ops/_unique2_cuda_dispatch.h>
#include <ATen/ops/_upsample_bicubic2d_aa_cuda_dispatch.h>
#include <ATen/ops/_upsample_bicubic2d_aa_backward_cuda_dispatch.h>
#include <ATen/ops/_upsample_bilinear2d_aa_cuda_dispatch.h>
#include <ATen/ops/_upsample_bilinear2d_aa_backward_cuda_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact1d_cuda_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact1d_backward_cuda_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact2d_cuda_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact2d_backward_cuda_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact3d_cuda_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact3d_backward_cuda_dispatch.h>
#include <ATen/ops/_use_cudnn_ctc_loss_cuda_dispatch.h>
#include <ATen/ops/_validate_compressed_sparse_indices_cuda_dispatch.h>
#include <ATen/ops/_weight_int4pack_mm_cuda_dispatch.h>
#include <ATen/ops/_weight_norm_interface_cuda_dispatch.h>
#include <ATen/ops/_weight_norm_interface_backward_cuda_dispatch.h>
#include <ATen/ops/abs_cuda_dispatch.h>
#include <ATen/ops/acos_cuda_dispatch.h>
#include <ATen/ops/acosh_cuda_dispatch.h>
#include <ATen/ops/adaptive_avg_pool2d_cuda_dispatch.h>
#include <ATen/ops/adaptive_avg_pool3d_cuda_dispatch.h>
#include <ATen/ops/adaptive_avg_pool3d_backward_cuda_dispatch.h>
#include <ATen/ops/adaptive_max_pool2d_cuda_dispatch.h>
#include <ATen/ops/adaptive_max_pool2d_backward_cuda_dispatch.h>
#include <ATen/ops/adaptive_max_pool3d_cuda_dispatch.h>
#include <ATen/ops/adaptive_max_pool3d_backward_cuda_dispatch.h>
#include <ATen/ops/add_cuda_dispatch.h>
#include <ATen/ops/addbmm_cuda_dispatch.h>
#include <ATen/ops/addcdiv_cuda_dispatch.h>
#include <ATen/ops/addcmul_cuda_dispatch.h>
#include <ATen/ops/addmm_cuda_dispatch.h>
#include <ATen/ops/addmv_cuda_dispatch.h>
#include <ATen/ops/addr_cuda_dispatch.h>
#include <ATen/ops/all_cuda_dispatch.h>
#include <ATen/ops/amax_cuda_dispatch.h>
#include <ATen/ops/amin_cuda_dispatch.h>
#include <ATen/ops/aminmax_cuda_dispatch.h>
#include <ATen/ops/angle_cuda_dispatch.h>
#include <ATen/ops/any_cuda_dispatch.h>
#include <ATen/ops/arange_cuda_dispatch.h>
#include <ATen/ops/argmax_cuda_dispatch.h>
#include <ATen/ops/argmin_cuda_dispatch.h>
#include <ATen/ops/as_strided_cuda_dispatch.h>
#include <ATen/ops/asin_cuda_dispatch.h>
#include <ATen/ops/asinh_cuda_dispatch.h>
#include <ATen/ops/atan_cuda_dispatch.h>
#include <ATen/ops/atan2_cuda_dispatch.h>
#include <ATen/ops/atanh_cuda_dispatch.h>
#include <ATen/ops/avg_pool2d_cuda_dispatch.h>
#include <ATen/ops/avg_pool2d_backward_cuda_dispatch.h>
#include <ATen/ops/avg_pool3d_cuda_dispatch.h>
#include <ATen/ops/avg_pool3d_backward_cuda_dispatch.h>
#include <ATen/ops/baddbmm_cuda_dispatch.h>
#include <ATen/ops/batch_norm_backward_cuda_dispatch.h>
#include <ATen/ops/batch_norm_backward_elemt_cuda_dispatch.h>
#include <ATen/ops/batch_norm_backward_reduce_cuda_dispatch.h>
#include <ATen/ops/batch_norm_elemt_cuda_dispatch.h>
#include <ATen/ops/batch_norm_gather_stats_cuda_dispatch.h>
#include <ATen/ops/batch_norm_gather_stats_with_counts_cuda_dispatch.h>
#include <ATen/ops/batch_norm_stats_cuda_dispatch.h>
#include <ATen/ops/batch_norm_update_stats_cuda_dispatch.h>
#include <ATen/ops/bernoulli_cuda_dispatch.h>
#include <ATen/ops/binary_cross_entropy_cuda_dispatch.h>
#include <ATen/ops/binary_cross_entropy_backward_cuda_dispatch.h>
#include <ATen/ops/bincount_cuda_dispatch.h>
#include <ATen/ops/binomial_cuda_dispatch.h>
#include <ATen/ops/bitwise_and_cuda_dispatch.h>
#include <ATen/ops/bitwise_left_shift_cuda_dispatch.h>
#include <ATen/ops/bitwise_not_cuda_dispatch.h>
#include <ATen/ops/bitwise_or_cuda_dispatch.h>
#include <ATen/ops/bitwise_right_shift_cuda_dispatch.h>
#include <ATen/ops/bitwise_xor_cuda_dispatch.h>
#include <ATen/ops/bmm_cuda_dispatch.h>
#include <ATen/ops/bucketize_cuda_dispatch.h>
#include <ATen/ops/cat_cuda_dispatch.h>
#include <ATen/ops/cauchy_cuda_dispatch.h>
#include <ATen/ops/ceil_cuda_dispatch.h>
#include <ATen/ops/channel_shuffle_cuda_dispatch.h>
#include <ATen/ops/cholesky_cuda_dispatch.h>
#include <ATen/ops/cholesky_inverse_cuda_dispatch.h>
#include <ATen/ops/clamp_cuda_dispatch.h>
#include <ATen/ops/clamp_max_cuda_dispatch.h>
#include <ATen/ops/clamp_min_cuda_dispatch.h>
#include <ATen/ops/col2im_cuda_dispatch.h>
#include <ATen/ops/complex_cuda_dispatch.h>
#include <ATen/ops/conj_physical_cuda_dispatch.h>
#include <ATen/ops/conv_depthwise3d_cuda_dispatch.h>
#include <ATen/ops/convolution_backward_cuda_dispatch.h>
#include <ATen/ops/copysign_cuda_dispatch.h>
#include <ATen/ops/cos_cuda_dispatch.h>
#include <ATen/ops/cosh_cuda_dispatch.h>
#include <ATen/ops/count_nonzero_cuda_dispatch.h>
#include <ATen/ops/cudnn_affine_grid_generator_cuda_dispatch.h>
#include <ATen/ops/cudnn_affine_grid_generator_backward_cuda_dispatch.h>
#include <ATen/ops/cudnn_batch_norm_cuda_dispatch.h>
#include <ATen/ops/cudnn_batch_norm_backward_cuda_dispatch.h>
#include <ATen/ops/cudnn_convolution_cuda_dispatch.h>
#include <ATen/ops/cudnn_convolution_add_relu_cuda_dispatch.h>
#include <ATen/ops/cudnn_convolution_relu_cuda_dispatch.h>
#include <ATen/ops/cudnn_convolution_transpose_cuda_dispatch.h>
#include <ATen/ops/cudnn_grid_sampler_cuda_dispatch.h>
#include <ATen/ops/cudnn_grid_sampler_backward_cuda_dispatch.h>
#include <ATen/ops/cumprod_cuda_dispatch.h>
#include <ATen/ops/cumsum_cuda_dispatch.h>
#include <ATen/ops/dequantize_cuda_dispatch.h>
#include <ATen/ops/digamma_cuda_dispatch.h>
#include <ATen/ops/div_cuda_dispatch.h>
#include <ATen/ops/dot_cuda_dispatch.h>
#include <ATen/ops/elu_cuda_dispatch.h>
#include <ATen/ops/elu_backward_cuda_dispatch.h>
#include <ATen/ops/embedding_dense_backward_cuda_dispatch.h>
#include <ATen/ops/embedding_renorm_cuda_dispatch.h>
#include <ATen/ops/empty_cuda_dispatch.h>
#include <ATen/ops/empty_strided_cuda_dispatch.h>
#include <ATen/ops/eq_cuda_dispatch.h>
#include <ATen/ops/equal_cuda_dispatch.h>
#include <ATen/ops/erf_cuda_dispatch.h>
#include <ATen/ops/erfc_cuda_dispatch.h>
#include <ATen/ops/erfinv_cuda_dispatch.h>
#include <ATen/ops/exp_cuda_dispatch.h>
#include <ATen/ops/exp2_cuda_dispatch.h>
#include <ATen/ops/expm1_cuda_dispatch.h>
#include <ATen/ops/exponential_cuda_dispatch.h>
#include <ATen/ops/eye_cuda_dispatch.h>
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_cuda_dispatch.h>
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_cuda_dispatch.h>
#include <ATen/ops/fill_cuda_dispatch.h>
#include <ATen/ops/flip_cuda_dispatch.h>
#include <ATen/ops/floor_cuda_dispatch.h>
#include <ATen/ops/floor_divide_cuda_dispatch.h>
#include <ATen/ops/fmax_cuda_dispatch.h>
#include <ATen/ops/fmin_cuda_dispatch.h>
#include <ATen/ops/fmod_cuda_dispatch.h>
#include <ATen/ops/frac_cuda_dispatch.h>
#include <ATen/ops/fractional_max_pool2d_cuda_dispatch.h>
#include <ATen/ops/fractional_max_pool2d_backward_cuda_dispatch.h>
#include <ATen/ops/fractional_max_pool3d_cuda_dispatch.h>
#include <ATen/ops/fractional_max_pool3d_backward_cuda_dispatch.h>
#include <ATen/ops/frexp_cuda_dispatch.h>
#include <ATen/ops/gather_cuda_dispatch.h>
#include <ATen/ops/gcd_cuda_dispatch.h>
#include <ATen/ops/ge_cuda_dispatch.h>
#include <ATen/ops/gelu_cuda_dispatch.h>
#include <ATen/ops/gelu_backward_cuda_dispatch.h>
#include <ATen/ops/geometric_cuda_dispatch.h>
#include <ATen/ops/geqrf_cuda_dispatch.h>
#include <ATen/ops/glu_cuda_dispatch.h>
#include <ATen/ops/glu_backward_cuda_dispatch.h>
#include <ATen/ops/glu_backward_jvp_cuda_dispatch.h>
#include <ATen/ops/glu_jvp_cuda_dispatch.h>
#include <ATen/ops/grid_sampler_2d_cuda_dispatch.h>
#include <ATen/ops/grid_sampler_2d_backward_cuda_dispatch.h>
#include <ATen/ops/grid_sampler_3d_cuda_dispatch.h>
#include <ATen/ops/grid_sampler_3d_backward_cuda_dispatch.h>
#include <ATen/ops/gt_cuda_dispatch.h>
#include <ATen/ops/hardshrink_cuda_dispatch.h>
#include <ATen/ops/hardshrink_backward_cuda_dispatch.h>
#include <ATen/ops/hardsigmoid_cuda_dispatch.h>
#include <ATen/ops/hardsigmoid_backward_cuda_dispatch.h>
#include <ATen/ops/hardswish_cuda_dispatch.h>
#include <ATen/ops/hardswish_backward_cuda_dispatch.h>
#include <ATen/ops/hardtanh_cuda_dispatch.h>
#include <ATen/ops/hardtanh_backward_cuda_dispatch.h>
#include <ATen/ops/heaviside_cuda_dispatch.h>
#include <ATen/ops/histc_cuda_dispatch.h>
#include <ATen/ops/huber_loss_cuda_dispatch.h>
#include <ATen/ops/huber_loss_backward_cuda_dispatch.h>
#include <ATen/ops/hypot_cuda_dispatch.h>
#include <ATen/ops/i0_cuda_dispatch.h>
#include <ATen/ops/igamma_cuda_dispatch.h>
#include <ATen/ops/igammac_cuda_dispatch.h>
#include <ATen/ops/im2col_cuda_dispatch.h>
#include <ATen/ops/index_cuda_dispatch.h>
#include <ATen/ops/index_add_cuda_dispatch.h>
#include <ATen/ops/index_copy_cuda_dispatch.h>
#include <ATen/ops/index_fill_cuda_dispatch.h>
#include <ATen/ops/index_reduce_cuda_dispatch.h>
#include <ATen/ops/index_select_cuda_dispatch.h>
#include <ATen/ops/is_set_to_cuda_dispatch.h>
#include <ATen/ops/isin_cuda_dispatch.h>
#include <ATen/ops/isnan_cuda_dispatch.h>
#include <ATen/ops/isneginf_cuda_dispatch.h>
#include <ATen/ops/isposinf_cuda_dispatch.h>
#include <ATen/ops/kthvalue_cuda_dispatch.h>
#include <ATen/ops/lcm_cuda_dispatch.h>
#include <ATen/ops/le_cuda_dispatch.h>
#include <ATen/ops/leaky_relu_cuda_dispatch.h>
#include <ATen/ops/leaky_relu_backward_cuda_dispatch.h>
#include <ATen/ops/lerp_cuda_dispatch.h>
#include <ATen/ops/lgamma_cuda_dispatch.h>
#include <ATen/ops/linalg_cholesky_ex_cuda_dispatch.h>
#include <ATen/ops/linalg_cross_cuda_dispatch.h>
#include <ATen/ops/linalg_eig_cuda_dispatch.h>
#include <ATen/ops/linalg_eigvals_cuda_dispatch.h>
#include <ATen/ops/linalg_householder_product_cuda_dispatch.h>
#include <ATen/ops/linalg_inv_ex_cuda_dispatch.h>
#include <ATen/ops/linalg_ldl_factor_ex_cuda_dispatch.h>
#include <ATen/ops/linalg_ldl_solve_cuda_dispatch.h>
#include <ATen/ops/linalg_lstsq_cuda_dispatch.h>
#include <ATen/ops/linalg_lu_cuda_dispatch.h>
#include <ATen/ops/linalg_lu_factor_ex_cuda_dispatch.h>
#include <ATen/ops/linalg_lu_solve_cuda_dispatch.h>
#include <ATen/ops/linalg_matrix_exp_cuda_dispatch.h>
#include <ATen/ops/linalg_qr_cuda_dispatch.h>
#include <ATen/ops/linalg_solve_triangular_cuda_dispatch.h>
#include <ATen/ops/linalg_vector_norm_cuda_dispatch.h>
#include <ATen/ops/linspace_cuda_dispatch.h>
#include <ATen/ops/log_cuda_dispatch.h>
#include <ATen/ops/log10_cuda_dispatch.h>
#include <ATen/ops/log1p_cuda_dispatch.h>
#include <ATen/ops/log2_cuda_dispatch.h>
#include <ATen/ops/log_normal_cuda_dispatch.h>
#include <ATen/ops/log_sigmoid_backward_cuda_dispatch.h>
#include <ATen/ops/log_sigmoid_forward_cuda_dispatch.h>
#include <ATen/ops/logaddexp_cuda_dispatch.h>
#include <ATen/ops/logaddexp2_cuda_dispatch.h>
#include <ATen/ops/logical_and_cuda_dispatch.h>
#include <ATen/ops/logical_not_cuda_dispatch.h>
#include <ATen/ops/logical_or_cuda_dispatch.h>
#include <ATen/ops/logical_xor_cuda_dispatch.h>
#include <ATen/ops/logit_cuda_dispatch.h>
#include <ATen/ops/logit_backward_cuda_dispatch.h>
#include <ATen/ops/logspace_cuda_dispatch.h>
#include <ATen/ops/lshift_cuda_dispatch.h>
#include <ATen/ops/lt_cuda_dispatch.h>
#include <ATen/ops/lu_unpack_cuda_dispatch.h>
#include <ATen/ops/masked_fill_cuda_dispatch.h>
#include <ATen/ops/masked_scatter_cuda_dispatch.h>
#include <ATen/ops/masked_select_cuda_dispatch.h>
#include <ATen/ops/max_cuda_dispatch.h>
#include <ATen/ops/max_pool2d_with_indices_cuda_dispatch.h>
#include <ATen/ops/max_pool2d_with_indices_backward_cuda_dispatch.h>
#include <ATen/ops/max_pool3d_with_indices_cuda_dispatch.h>
#include <ATen/ops/max_pool3d_with_indices_backward_cuda_dispatch.h>
#include <ATen/ops/max_unpool2d_cuda_dispatch.h>
#include <ATen/ops/max_unpool3d_cuda_dispatch.h>
#include <ATen/ops/maximum_cuda_dispatch.h>
#include <ATen/ops/mean_cuda_dispatch.h>
#include <ATen/ops/median_cuda_dispatch.h>
#include <ATen/ops/min_cuda_dispatch.h>
#include <ATen/ops/minimum_cuda_dispatch.h>
#include <ATen/ops/miopen_batch_norm_cuda_dispatch.h>
#include <ATen/ops/miopen_batch_norm_backward_cuda_dispatch.h>
#include <ATen/ops/miopen_convolution_cuda_dispatch.h>
#include <ATen/ops/miopen_convolution_add_relu_cuda_dispatch.h>
#include <ATen/ops/miopen_convolution_relu_cuda_dispatch.h>
#include <ATen/ops/miopen_convolution_transpose_cuda_dispatch.h>
#include <ATen/ops/miopen_depthwise_convolution_cuda_dispatch.h>
#include <ATen/ops/miopen_rnn_cuda_dispatch.h>
#include <ATen/ops/miopen_rnn_backward_cuda_dispatch.h>
#include <ATen/ops/mish_cuda_dispatch.h>
#include <ATen/ops/mish_backward_cuda_dispatch.h>
#include <ATen/ops/mm_cuda_dispatch.h>
#include <ATen/ops/mode_cuda_dispatch.h>
#include <ATen/ops/mse_loss_cuda_dispatch.h>
#include <ATen/ops/mse_loss_backward_cuda_dispatch.h>
#include <ATen/ops/mul_cuda_dispatch.h>
#include <ATen/ops/multi_margin_loss_cuda_dispatch.h>
#include <ATen/ops/multi_margin_loss_backward_cuda_dispatch.h>
#include <ATen/ops/multilabel_margin_loss_backward_cuda_dispatch.h>
#include <ATen/ops/multilabel_margin_loss_forward_cuda_dispatch.h>
#include <ATen/ops/multinomial_cuda_dispatch.h>
#include <ATen/ops/mvlgamma_cuda_dispatch.h>
#include <ATen/ops/nan_to_num_cuda_dispatch.h>
#include <ATen/ops/nanmedian_cuda_dispatch.h>
#include <ATen/ops/nansum_cuda_dispatch.h>
#include <ATen/ops/native_batch_norm_cuda_dispatch.h>
#include <ATen/ops/native_batch_norm_backward_cuda_dispatch.h>
#include <ATen/ops/native_dropout_cuda_dispatch.h>
#include <ATen/ops/native_dropout_backward_cuda_dispatch.h>
#include <ATen/ops/native_group_norm_cuda_dispatch.h>
#include <ATen/ops/native_group_norm_backward_cuda_dispatch.h>
#include <ATen/ops/native_layer_norm_cuda_dispatch.h>
#include <ATen/ops/native_layer_norm_backward_cuda_dispatch.h>
#include <ATen/ops/ne_cuda_dispatch.h>
#include <ATen/ops/neg_cuda_dispatch.h>
#include <ATen/ops/nextafter_cuda_dispatch.h>
#include <ATen/ops/nll_loss2d_backward_cuda_dispatch.h>
#include <ATen/ops/nll_loss2d_forward_cuda_dispatch.h>
#include <ATen/ops/nll_loss_backward_cuda_dispatch.h>
#include <ATen/ops/nll_loss_forward_cuda_dispatch.h>
#include <ATen/ops/nonzero_cuda_dispatch.h>
#include <ATen/ops/norm_cuda_dispatch.h>
#include <ATen/ops/normal_cuda_dispatch.h>
#include <ATen/ops/ormqr_cuda_dispatch.h>
#include <ATen/ops/poisson_cuda_dispatch.h>
#include <ATen/ops/polar_cuda_dispatch.h>
#include <ATen/ops/polygamma_cuda_dispatch.h>
#include <ATen/ops/pow_cuda_dispatch.h>
#include <ATen/ops/prod_cuda_dispatch.h>
#include <ATen/ops/put_cuda_dispatch.h>
#include <ATen/ops/quantize_per_channel_cuda_dispatch.h>
#include <ATen/ops/quantize_per_tensor_cuda_dispatch.h>
#include <ATen/ops/quantize_per_tensor_dynamic_cuda_dispatch.h>
#include <ATen/ops/random_cuda_dispatch.h>
#include <ATen/ops/randperm_cuda_dispatch.h>
#include <ATen/ops/range_cuda_dispatch.h>
#include <ATen/ops/reciprocal_cuda_dispatch.h>
#include <ATen/ops/record_stream_cuda_dispatch.h>
#include <ATen/ops/reflection_pad1d_cuda_dispatch.h>
#include <ATen/ops/reflection_pad1d_backward_cuda_dispatch.h>
#include <ATen/ops/reflection_pad2d_cuda_dispatch.h>
#include <ATen/ops/reflection_pad2d_backward_cuda_dispatch.h>
#include <ATen/ops/reflection_pad3d_cuda_dispatch.h>
#include <ATen/ops/reflection_pad3d_backward_cuda_dispatch.h>
#include <ATen/ops/relu_cuda_dispatch.h>
#include <ATen/ops/remainder_cuda_dispatch.h>
#include <ATen/ops/renorm_cuda_dispatch.h>
#include <ATen/ops/repeat_interleave_cuda_dispatch.h>
#include <ATen/ops/replication_pad1d_cuda_dispatch.h>
#include <ATen/ops/replication_pad1d_backward_cuda_dispatch.h>
#include <ATen/ops/replication_pad2d_cuda_dispatch.h>
#include <ATen/ops/replication_pad2d_backward_cuda_dispatch.h>
#include <ATen/ops/replication_pad3d_cuda_dispatch.h>
#include <ATen/ops/replication_pad3d_backward_cuda_dispatch.h>
#include <ATen/ops/resize_cuda_dispatch.h>
#include <ATen/ops/roll_cuda_dispatch.h>
#include <ATen/ops/round_cuda_dispatch.h>
#include <ATen/ops/rrelu_with_noise_cuda_dispatch.h>
#include <ATen/ops/rshift_cuda_dispatch.h>
#include <ATen/ops/rsqrt_cuda_dispatch.h>
#include <ATen/ops/rsub_cuda_dispatch.h>
#include <ATen/ops/scatter_cuda_dispatch.h>
#include <ATen/ops/scatter_add_cuda_dispatch.h>
#include <ATen/ops/scatter_reduce_cuda_dispatch.h>
#include <ATen/ops/searchsorted_cuda_dispatch.h>
#include <ATen/ops/segment_reduce_cuda_dispatch.h>
#include <ATen/ops/set_cuda_dispatch.h>
#include <ATen/ops/sgn_cuda_dispatch.h>
#include <ATen/ops/sigmoid_cuda_dispatch.h>
#include <ATen/ops/sigmoid_backward_cuda_dispatch.h>
#include <ATen/ops/sign_cuda_dispatch.h>
#include <ATen/ops/signbit_cuda_dispatch.h>
#include <ATen/ops/silu_cuda_dispatch.h>
#include <ATen/ops/silu_backward_cuda_dispatch.h>
#include <ATen/ops/sin_cuda_dispatch.h>
#include <ATen/ops/sinc_cuda_dispatch.h>
#include <ATen/ops/sinh_cuda_dispatch.h>
#include <ATen/ops/slow_conv_dilated2d_cuda_dispatch.h>
#include <ATen/ops/slow_conv_dilated3d_cuda_dispatch.h>
#include <ATen/ops/slow_conv_transpose2d_cuda_dispatch.h>
#include <ATen/ops/slow_conv_transpose3d_cuda_dispatch.h>
#include <ATen/ops/smooth_l1_loss_cuda_dispatch.h>
#include <ATen/ops/smooth_l1_loss_backward_cuda_dispatch.h>
#include <ATen/ops/softplus_cuda_dispatch.h>
#include <ATen/ops/softplus_backward_cuda_dispatch.h>
#include <ATen/ops/softshrink_cuda_dispatch.h>
#include <ATen/ops/softshrink_backward_cuda_dispatch.h>
#include <ATen/ops/sort_cuda_dispatch.h>
#include <ATen/ops/special_airy_ai_cuda_dispatch.h>
#include <ATen/ops/special_bessel_j0_cuda_dispatch.h>
#include <ATen/ops/special_bessel_j1_cuda_dispatch.h>
#include <ATen/ops/special_bessel_y0_cuda_dispatch.h>
#include <ATen/ops/special_bessel_y1_cuda_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_t_cuda_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_u_cuda_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_v_cuda_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_w_cuda_dispatch.h>
#include <ATen/ops/special_entr_cuda_dispatch.h>
#include <ATen/ops/special_erfcx_cuda_dispatch.h>
#include <ATen/ops/special_hermite_polynomial_h_cuda_dispatch.h>
#include <ATen/ops/special_hermite_polynomial_he_cuda_dispatch.h>
#include <ATen/ops/special_i0e_cuda_dispatch.h>
#include <ATen/ops/special_i1_cuda_dispatch.h>
#include <ATen/ops/special_i1e_cuda_dispatch.h>
#include <ATen/ops/special_laguerre_polynomial_l_cuda_dispatch.h>
#include <ATen/ops/special_legendre_polynomial_p_cuda_dispatch.h>
#include <ATen/ops/special_log_ndtr_cuda_dispatch.h>
#include <ATen/ops/special_modified_bessel_i0_cuda_dispatch.h>
#include <ATen/ops/special_modified_bessel_i1_cuda_dispatch.h>
#include <ATen/ops/special_modified_bessel_k0_cuda_dispatch.h>
#include <ATen/ops/special_modified_bessel_k1_cuda_dispatch.h>
#include <ATen/ops/special_ndtri_cuda_dispatch.h>
#include <ATen/ops/special_scaled_modified_bessel_k0_cuda_dispatch.h>
#include <ATen/ops/special_scaled_modified_bessel_k1_cuda_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_cuda_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_cuda_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_cuda_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_cuda_dispatch.h>
#include <ATen/ops/special_spherical_bessel_j0_cuda_dispatch.h>
#include <ATen/ops/special_xlog1py_cuda_dispatch.h>
#include <ATen/ops/special_zeta_cuda_dispatch.h>
#include <ATen/ops/split_with_sizes_copy_cuda_dispatch.h>
#include <ATen/ops/sqrt_cuda_dispatch.h>
#include <ATen/ops/sspaddmm_cuda_dispatch.h>
#include <ATen/ops/std_cuda_dispatch.h>
#include <ATen/ops/std_mean_cuda_dispatch.h>
#include <ATen/ops/sub_cuda_dispatch.h>
#include <ATen/ops/sum_cuda_dispatch.h>
#include <ATen/ops/take_cuda_dispatch.h>
#include <ATen/ops/tan_cuda_dispatch.h>
#include <ATen/ops/tanh_cuda_dispatch.h>
#include <ATen/ops/tanh_backward_cuda_dispatch.h>
#include <ATen/ops/threshold_cuda_dispatch.h>
#include <ATen/ops/threshold_backward_cuda_dispatch.h>
#include <ATen/ops/topk_cuda_dispatch.h>
#include <ATen/ops/trace_cuda_dispatch.h>
#include <ATen/ops/triangular_solve_cuda_dispatch.h>
#include <ATen/ops/tril_cuda_dispatch.h>
#include <ATen/ops/tril_indices_cuda_dispatch.h>
#include <ATen/ops/triu_cuda_dispatch.h>
#include <ATen/ops/triu_indices_cuda_dispatch.h>
#include <ATen/ops/trunc_cuda_dispatch.h>
#include <ATen/ops/unfold_cuda_dispatch.h>
#include <ATen/ops/unfold_backward_cuda_dispatch.h>
#include <ATen/ops/uniform_cuda_dispatch.h>
#include <ATen/ops/unique_consecutive_cuda_dispatch.h>
#include <ATen/ops/unique_dim_cuda_dispatch.h>
#include <ATen/ops/unique_dim_consecutive_cuda_dispatch.h>
#include <ATen/ops/upsample_bicubic2d_cuda_dispatch.h>
#include <ATen/ops/upsample_bicubic2d_backward_cuda_dispatch.h>
#include <ATen/ops/upsample_bilinear2d_cuda_dispatch.h>
#include <ATen/ops/upsample_bilinear2d_backward_cuda_dispatch.h>
#include <ATen/ops/upsample_linear1d_cuda_dispatch.h>
#include <ATen/ops/upsample_linear1d_backward_cuda_dispatch.h>
#include <ATen/ops/upsample_nearest1d_cuda_dispatch.h>
#include <ATen/ops/upsample_nearest1d_backward_cuda_dispatch.h>
#include <ATen/ops/upsample_nearest2d_cuda_dispatch.h>
#include <ATen/ops/upsample_nearest2d_backward_cuda_dispatch.h>
#include <ATen/ops/upsample_nearest3d_cuda_dispatch.h>
#include <ATen/ops/upsample_nearest3d_backward_cuda_dispatch.h>
#include <ATen/ops/upsample_trilinear3d_cuda_dispatch.h>
#include <ATen/ops/upsample_trilinear3d_backward_cuda_dispatch.h>
#include <ATen/ops/var_cuda_dispatch.h>
#include <ATen/ops/var_mean_cuda_dispatch.h>
#include <ATen/ops/vdot_cuda_dispatch.h>
#include <ATen/ops/view_cuda_dispatch.h>
#include <ATen/ops/view_as_complex_cuda_dispatch.h>
#include <ATen/ops/view_as_real_cuda_dispatch.h>
#include <ATen/ops/where_cuda_dispatch.h>
#include <ATen/ops/xlogy_cuda_dispatch.h>
#include <ATen/ops/zero_cuda_dispatch.h>

View File

@ -0,0 +1,24 @@
#pragma once
#include <ATen/ATen.h>
namespace at::caching {
// Some systems (just cudagraphs currently) will persist a static tensor output
// whose TensorImpl does not change across iterations. For these tensors caching
// dtype conversions is invalid. Additionally, there will be an extra reference
// count to these cached tensors that would prevent buffer inplacing and other
// checks on tensor uniqueness. If we are not using these systems the enabled
// flag will be false and we will avoid the hash lookup.
TORCH_API bool is_cached_tensor(const at::Tensor& t);
TORCH_API void add_cached_tensor(const at::Tensor& t);
TORCH_API void remove_cached_tensor(const at::Tensor& t);
TORCH_API void set_cached_tensors_enabled(bool enable);
// For gradient buffer stealing we will adjust the use count of tensors
// which are persisted by cudagraphs, just as we need to adjust reference
// count of tensors with hooks.
TORCH_API size_t adjusted_use_count(const at::Tensor& t);
} // namespace at::caching

View File

@ -0,0 +1,94 @@
#include <c10/util/Exception.h>
#include <utility>
namespace at {
/*
[collapse dims] Updates sizes, and strides to reflect a "collapse" of
the info, possibly excluding the optional excludeDim. A "collapsed" version
of the info is the fewest dims that order the tensor's elements in the same
way as the original info. If excludeDim is specified, the collapse is the
fewest dims that order the tensor's elements as the original and preserve the
excluded dimension, unless the tensor collapses to a point.
This function returns a pair of values.
1) The (new) index of the preserved dimension if excludeDim is
specified. 0 if the tensor is collapsed to a point. -1
otherwise.
2) The new number of dimensions.
*/
template <typename T>
inline std::pair<int64_t, int64_t> collapse_dims(
T* sizes,
T* strides,
int64_t dims,
const int excludeDim = -1) {
TORCH_CHECK(
excludeDim >= -1 && excludeDim < dims,
"expected excluded dim between -1 and dims - 1");
int64_t stopDim = (excludeDim == -1) ? dims : excludeDim;
int64_t newIndex = -1;
int64_t oldIndex = 0;
int64_t remappedExcludedDim = -1;
while (oldIndex < dims) {
// Finds a dimension to collapse into
for (; oldIndex < stopDim; ++oldIndex) {
if (sizes[oldIndex] == 1) {
continue;
}
++newIndex;
sizes[newIndex] = sizes[oldIndex];
strides[newIndex] = strides[oldIndex];
++oldIndex;
break;
}
// Collapses dims
for (; oldIndex < stopDim; ++oldIndex) {
if (sizes[oldIndex] == 1) {
continue;
}
if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) {
sizes[newIndex] *= sizes[oldIndex];
strides[newIndex] = strides[oldIndex];
} else {
++newIndex;
sizes[newIndex] = sizes[oldIndex];
strides[newIndex] = strides[oldIndex];
}
}
// Handles excludeDim being set (oldIndex == excludeDim)
if (oldIndex != dims) {
// Preserves excluded dimension
++newIndex;
sizes[newIndex] = sizes[oldIndex];
strides[newIndex] = strides[oldIndex];
remappedExcludedDim = newIndex;
// Restarts iteration after excludeDim
++oldIndex;
stopDim = dims;
}
}
// Handles special case of all dims size 1
if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) {
dims = 1;
sizes[0] = 1;
strides[0] = 1;
return std::pair<int64_t, int64_t>(0, 1);
}
dims = newIndex + 1;
return std::pair<int64_t, int64_t>(remappedExcludedDim, dims);
}
} // namespace at

View File

@ -0,0 +1,29 @@
#include <ATen/core/TensorBody.h>
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
// Code introduced to avoid cyclic dependency in static dispatch is no longer
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
// to Operators.cpp for supporting multiple backends with multiple kernels.
//
// Note [Avoiding Include Cycles In Static Dispatch]
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
//
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
// directly inlined into TensorBody.h.
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
// which include functions that have defaultable std::optional<Tensor> arguments.
// That requires knowing the full Tensor class definition.
//
// We break the cycle by doing the following:
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
// - CPUFunctions_inl.h includes everything else
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
// and then it includes CPUFunctions_inl.h.
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
// - This also means that static dispatch build, CPUFunctions.h only needs to
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
#include <ATen/CompositeExplicitAutogradFunctions_inl.h>

View File

@ -0,0 +1,553 @@
#pragma once
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
// NB: The implementing C++ file is RegisterDispatchKey.cpp
// The only #includes we need are for custom classes that have defaults in the C++ API
#include <c10/core/MemoryFormat.h>
#include <c10/core/Scalar.h>
#include <ATen/core/Reduction.h>
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
#error This change adds a dependency on all pytorch operators, meaning the \
file will need to be re-compiled every time an operator is changed or added. \
Consider including a specific operator from \
<ATen/ops/{my_operator}_compositeexplicitautograd_dispatch.h>. \
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
#endif
#include <ATen/ops/_adaptive_avg_pool2d_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_adaptive_avg_pool2d_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_adaptive_avg_pool3d_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_adaptive_avg_pool3d_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_add_relu_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_aminmax_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_amp_update_scale_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_assert_scalar_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_batch_norm_no_update_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_batch_norm_with_update_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_cdist_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_cdist_forward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_cholesky_solve_helper_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_chunk_cat_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_coalesce_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_coalesced_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_conj_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_conj_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_conj_physical_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_convolution_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_copy_from_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_copy_from_and_resize_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_ctc_loss_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_ctc_loss_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_cudnn_ctc_loss_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_cudnn_init_dropout_state_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_cudnn_rnn_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_cudnn_rnn_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_cudnn_rnn_flatten_weight_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_dirichlet_grad_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_efficientzerotensor_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_embedding_bag_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_embedding_bag_dense_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_embedding_bag_forward_only_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_empty_affine_quantized_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_empty_per_channel_affine_quantized_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_euclidean_dist_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foobar_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_abs_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_acos_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_add_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_addcdiv_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_addcmul_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_asin_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_atan_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_ceil_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_clamp_max_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_clamp_min_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_cos_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_cosh_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_div_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_erf_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_erfc_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_exp_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_expm1_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_floor_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_frac_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_lerp_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_lgamma_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_log_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_log10_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_log1p_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_log2_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_max_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_maximum_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_minimum_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_mul_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_neg_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_norm_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_pow_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_reciprocal_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_round_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_sigmoid_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_sign_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_sin_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_sinh_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_sqrt_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_sub_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_tan_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_tanh_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_trunc_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_foreach_zero_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_functional_assert_scalar_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_functional_sym_constrain_range_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_functional_sym_constrain_range_for_size_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_fused_adagrad_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_fused_adam_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_fused_adamw_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_fused_dropout_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_fused_moving_avg_obs_fq_helper_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_fused_sgd_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_fw_primal_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_fw_primal_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_has_same_storage_numel_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_histogramdd_bin_edges_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_histogramdd_from_bin_cts_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_histogramdd_from_bin_tensors_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_index_put_impl_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_indices_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_is_all_true_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_is_any_true_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_lazy_clone_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_linalg_check_errors_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_lstm_mps_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_make_dual_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_make_dual_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_make_per_channel_quantized_tensor_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_make_per_tensor_quantized_tensor_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_masked_scale_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_masked_softmax_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_masked_softmax_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_mkldnn_reshape_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_mkldnn_transpose_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_mps_convolution_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_mps_convolution_transpose_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_native_batch_norm_legit_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_native_batch_norm_legit_no_training_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_native_multi_head_attention_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_neg_view_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_neg_view_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_nested_from_padded_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_nested_from_padded_and_nested_example_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_nested_get_values_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_nested_tensor_from_mask_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_nested_tensor_from_tensor_list_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_nested_tensor_size_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_nested_tensor_storage_offsets_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_nested_tensor_strides_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_nested_view_from_buffer_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_nested_view_from_jagged_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_new_zeros_with_same_feature_meta_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_nnpack_spatial_convolution_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_pack_padded_sequence_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_pdist_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_pdist_forward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_pin_memory_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_print_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_reshape_alias_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_reshape_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_resize_output_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_safe_softmax_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_sample_dirichlet_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_segment_reduce_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_slow_conv2d_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_addmm_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_broadcast_to_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_compressed_tensor_with_dims_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_coo_tensor_with_dims_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_csr_prod_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_csr_sum_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_log_softmax_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_log_softmax_backward_data_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_mask_projection_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_softmax_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_softmax_backward_data_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_sparse_matmul_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_sum_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_sum_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_spdiags_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_stack_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_standard_gamma_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_standard_gamma_grad_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_test_autograd_multiple_dispatch_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_test_autograd_multiple_dispatch_view_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_test_functorch_fallback_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_test_optional_filled_intlist_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_test_optional_floatlist_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_test_optional_intlist_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_test_parallel_materialize_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_test_warn_in_autograd_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_thnn_fused_gru_cell_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_thnn_fused_gru_cell_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_thnn_fused_lstm_cell_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_to_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_to_dense_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_to_sparse_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_to_sparse_bsc_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_to_sparse_bsr_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_to_sparse_csc_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_to_sparse_csr_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_transform_bias_rescale_qkv_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_transformer_encoder_layer_fwd_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_trilinear_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_triton_multi_head_attention_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_triton_scaled_dot_attention_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_unique_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_unique2_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_unsafe_index_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_unsafe_index_put_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_unsafe_masked_index_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_unsafe_masked_index_put_accumulate_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_unsafe_view_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_values_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_weight_norm_interface_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/_weight_norm_interface_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/abs_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/add_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/addr_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/affine_grid_generator_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/alias_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/alias_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/all_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/allclose_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/any_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/arange_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/as_strided_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/as_strided_scatter_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/bartlett_window_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/batch_norm_backward_elemt_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/batch_norm_backward_reduce_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/batch_norm_gather_stats_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/batch_norm_gather_stats_with_counts_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/batch_norm_stats_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/batch_norm_update_stats_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/bernoulli_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/binary_cross_entropy_with_logits_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/bincount_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/binomial_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/bitwise_and_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/bitwise_left_shift_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/bitwise_or_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/bitwise_right_shift_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/bitwise_xor_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/blackman_window_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/block_diag_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/bucketize_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/cauchy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/ccol_indices_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/ccol_indices_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/celu_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/channel_shuffle_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/cholesky_solve_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/clone_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/col_indices_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/col_indices_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/complex_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/conj_physical_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/constant_pad_nd_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/conv_depthwise3d_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/conv_tbc_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/convolution_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/convolution_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/convolution_backward_overrideable_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/convolution_overrideable_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/copy_sparse_to_sparse_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/copysign_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/count_nonzero_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/crow_indices_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/crow_indices_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/cudnn_affine_grid_generator_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/cudnn_affine_grid_generator_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/cudnn_batch_norm_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/cudnn_batch_norm_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/cudnn_convolution_add_relu_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/cudnn_convolution_relu_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/cudnn_convolution_transpose_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/cudnn_grid_sampler_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/cudnn_grid_sampler_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/cummax_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/cummin_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/deg2rad_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/dense_dim_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/dequantize_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/detach_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/detach_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/diag_embed_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/diagonal_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/diagonal_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/diagonal_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/diagonal_scatter_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/dist_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/div_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/dot_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/embedding_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/embedding_dense_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/embedding_renorm_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/empty_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/empty_like_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/empty_permuted_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/empty_quantized_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/empty_strided_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/expand_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/expand_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/exponential_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/eye_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/fft_fftfreq_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/fft_rfftfreq_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/fill_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/flip_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/floor_divide_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/fmod_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/frexp_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/from_file_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/full_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/full_like_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/geometric_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/glu_backward_jvp_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/glu_jvp_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/grid_sampler_2d_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/grid_sampler_2d_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/grid_sampler_3d_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/grid_sampler_3d_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/hamming_window_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/hann_window_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/hardswish_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/huber_loss_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/index_fill_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/index_put_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/indices_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/indices_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/int_repr_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/is_coalesced_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/is_pinned_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/is_same_size_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/isinf_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/isnan_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/kaiser_window_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/kthvalue_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/lift_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/lift_fresh_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/lift_fresh_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/linalg_lstsq_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/linalg_matrix_exp_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/linalg_pinv_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/linear_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/linear_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/linspace_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/log_normal_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/log_softmax_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/logcumsumexp_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/logical_and_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/logical_not_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/logical_or_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/logical_xor_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/logspace_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/logsumexp_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/lshift_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/lstm_mps_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/masked_fill_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/masked_scatter_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/masked_scatter_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/matmul_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/max_pool2d_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mean_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/median_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/miopen_batch_norm_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/miopen_batch_norm_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/miopen_convolution_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/miopen_convolution_transpose_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/miopen_depthwise_convolution_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/miopen_rnn_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/miopen_rnn_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mkldnn_convolution_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mkldnn_linear_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mkldnn_linear_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mkldnn_linear_backward_input_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mkldnn_linear_backward_weights_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mkldnn_max_pool2d_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mkldnn_max_pool2d_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mkldnn_max_pool3d_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mkldnn_max_pool3d_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mkldnn_reorder_conv2d_weight_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mkldnn_reorder_conv3d_weight_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mkldnn_rnn_layer_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mkldnn_rnn_layer_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mode_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mps_convolution_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mps_convolution_transpose_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mul_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mv_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/mvlgamma_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/nan_to_num_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/nanmedian_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/native_batch_norm_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/native_dropout_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/native_dropout_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/native_group_norm_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/native_group_norm_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/native_layer_norm_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/native_layer_norm_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/native_norm_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/new_empty_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/new_empty_strided_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/new_full_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/new_ones_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/new_zeros_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/norm_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/normal_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/ones_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/ones_like_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/permute_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/permute_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/pixel_shuffle_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/pixel_unshuffle_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/poisson_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/polar_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/polygamma_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/prod_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/put_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/q_per_channel_scales_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/q_per_channel_zero_points_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/quantize_per_channel_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/quantize_per_tensor_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/quantize_per_tensor_dynamic_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/quantized_batch_norm_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/quantized_max_pool1d_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/quantized_max_pool2d_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/quantized_max_pool3d_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/rad2deg_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/rand_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/rand_like_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/randint_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/randint_like_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/randn_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/randn_like_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/random_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/randperm_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/range_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/relu_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/remainder_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/repeat_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/repeat_interleave_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/resize_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/resize_as_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/resize_as_sparse_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/roll_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/rot90_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/row_indices_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/row_indices_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/rrelu_with_noise_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/rshift_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/rsub_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/scalar_tensor_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/segment_reduce_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/select_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/select_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/select_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/select_scatter_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/set_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/slice_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/slice_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/slice_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/slice_inverse_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/slice_scatter_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/slow_conv_dilated2d_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/slow_conv_dilated3d_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/smooth_l1_loss_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/soft_margin_loss_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/soft_margin_loss_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/softmax_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/sort_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/sparse_compressed_tensor_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/sparse_coo_tensor_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/sparse_dim_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/sparse_mask_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/sparse_resize_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/sparse_resize_and_clear_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_t_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_u_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_v_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_w_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/special_hermite_polynomial_h_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/special_hermite_polynomial_he_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/special_laguerre_polynomial_l_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/special_legendre_polynomial_p_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/special_xlog1py_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/special_zeta_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/split_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/split_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/split_with_sizes_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/split_with_sizes_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/squeeze_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/squeeze_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/stack_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/std_mean_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/sub_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/sum_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/sym_constrain_range_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/sym_constrain_range_for_size_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/t_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/t_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/to_mkldnn_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/to_padded_tensor_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/trace_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/transpose_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/transpose_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/tril_indices_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/triu_indices_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/unbind_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/unbind_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/unfold_backward_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/unfold_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/uniform_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/unique_consecutive_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/unique_dim_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/unique_dim_consecutive_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/unsafe_split_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/unsafe_split_with_sizes_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/unsqueeze_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/unsqueeze_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/values_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/values_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/var_mean_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/vdot_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/view_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/view_as_complex_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/view_as_real_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/view_copy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/xlogy_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/zero_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/zeros_compositeexplicitautograd_dispatch.h>
#include <ATen/ops/zeros_like_compositeexplicitautograd_dispatch.h>

View File

@ -0,0 +1,29 @@
#include <ATen/core/TensorBody.h>
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
// Code introduced to avoid cyclic dependency in static dispatch is no longer
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
// to Operators.cpp for supporting multiple backends with multiple kernels.
//
// Note [Avoiding Include Cycles In Static Dispatch]
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
//
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
// directly inlined into TensorBody.h.
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
// which include functions that have defaultable std::optional<Tensor> arguments.
// That requires knowing the full Tensor class definition.
//
// We break the cycle by doing the following:
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
// - CPUFunctions_inl.h includes everything else
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
// and then it includes CPUFunctions_inl.h.
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
// - This also means that static dispatch build, CPUFunctions.h only needs to
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h>

View File

@ -0,0 +1,323 @@
#pragma once
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
// NB: The implementing C++ file is RegisterDispatchKey.cpp
// The only #includes we need are for custom classes that have defaults in the C++ API
#include <c10/core/MemoryFormat.h>
#include <c10/core/Scalar.h>
#include <ATen/core/Reduction.h>
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
#error This change adds a dependency on all pytorch operators, meaning the \
file will need to be re-compiled every time an operator is changed or added. \
Consider including a specific operator from \
<ATen/ops/{my_operator}_compositeexplicitautogradnonfunctional_dispatch.h>. \
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
#endif
#include <ATen/ops/_addmm_activation_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_conj_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_convert_indices_from_coo_to_csr_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_convert_indices_from_csr_to_coo_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_fw_primal_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_linalg_det_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_linalg_eigh_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_linalg_slogdet_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_linalg_solve_ex_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_linalg_svd_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_log_softmax_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_log_softmax_backward_data_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_make_dual_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_neg_view_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_nested_get_values_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_nested_view_from_buffer_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_nested_view_from_jagged_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_reshape_alias_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_softmax_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_softmax_backward_data_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_sparse_broadcast_to_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_trilinear_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_upsample_bicubic2d_aa_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_upsample_bicubic2d_aa_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_upsample_bilinear2d_aa_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_upsample_bilinear2d_aa_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact1d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact2d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact3d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/_values_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/acos_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/acosh_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/adaptive_max_pool2d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/adaptive_max_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/adaptive_max_pool3d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/adaptive_max_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/add_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/addcdiv_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/addcmul_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/addmm_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/addmv_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/alias_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/all_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/amax_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/amin_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/aminmax_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/any_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/argmax_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/argmin_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/as_strided_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/as_strided_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/as_strided_scatter_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/asin_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/asinh_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/atan_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/atan2_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/atanh_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/avg_pool2d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/avg_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/avg_pool3d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/avg_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/baddbmm_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/bernoulli_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/bitwise_and_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/bitwise_left_shift_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/bitwise_not_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/bitwise_or_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/bitwise_right_shift_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/bitwise_xor_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/bmm_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/cat_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/ccol_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/ceil_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/clamp_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/clamp_max_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/clamp_min_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/col_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/copysign_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/cos_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/cosh_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/crow_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/cumprod_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/cumsum_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/detach_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/diag_embed_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/diagonal_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/diagonal_scatter_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/digamma_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/div_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/elu_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/elu_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/eq_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/erf_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/erfc_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/erfinv_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/exp_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/exp2_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/expand_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/expm1_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/floor_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/fmax_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/fmin_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/fmod_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/frac_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/fractional_max_pool2d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/fractional_max_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/fractional_max_pool3d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/gather_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/gcd_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/ge_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/gelu_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/gelu_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/glu_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/gt_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/hardshrink_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/hardshrink_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/hardsigmoid_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/hardsigmoid_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/heaviside_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/hypot_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/i0_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/igamma_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/igammac_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/index_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/index_add_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/index_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/index_reduce_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/isin_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/isneginf_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/isposinf_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/lcm_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/le_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/leaky_relu_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/leaky_relu_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/lerp_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/lgamma_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/lift_fresh_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/linalg_cholesky_ex_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/linalg_cross_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/linalg_inv_ex_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/linalg_ldl_factor_ex_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/linalg_ldl_solve_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/linalg_lu_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/linalg_lu_factor_ex_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/linalg_lu_solve_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/linalg_pinv_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/linalg_qr_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/linalg_vector_norm_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/log_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/log10_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/log1p_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/log2_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/logaddexp_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/logaddexp2_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/logit_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/logsumexp_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/lt_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/lu_unpack_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/max_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/max_pool2d_with_indices_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/max_pool2d_with_indices_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/maximum_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/mean_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/min_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/minimum_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/mish_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/mm_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/mse_loss_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/mul_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/narrow_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/ne_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/neg_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/new_empty_strided_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/nextafter_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/nll_loss_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/nll_loss_forward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/norm_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/permute_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/pixel_shuffle_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/pixel_unshuffle_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/polygamma_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/pow_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/prod_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/reciprocal_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/reflection_pad1d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/reflection_pad1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/reflection_pad3d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/reflection_pad3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/remainder_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/renorm_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/replication_pad1d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/replication_pad1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/replication_pad2d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/replication_pad3d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/round_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/row_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/rsqrt_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/scatter_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/scatter_add_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/scatter_reduce_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/select_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/select_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/select_scatter_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/sgn_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/sigmoid_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/sigmoid_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/sign_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/signbit_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/silu_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/silu_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/sin_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/sinc_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/sinh_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/slice_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/slice_scatter_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/slow_conv_transpose2d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/smooth_l1_loss_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/softplus_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/softplus_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/softshrink_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/softshrink_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/sort_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_airy_ai_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_bessel_j0_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_bessel_j1_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_bessel_y0_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_bessel_y1_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_t_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_u_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_v_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_w_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_entr_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_erfcx_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_hermite_polynomial_h_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_hermite_polynomial_he_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_i0e_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_i1_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_i1e_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_laguerre_polynomial_l_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_legendre_polynomial_p_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_log_ndtr_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_modified_bessel_i0_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_modified_bessel_i1_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_modified_bessel_k0_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_modified_bessel_k1_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_ndtri_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_scaled_modified_bessel_k0_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_scaled_modified_bessel_k1_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_spherical_bessel_j0_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_xlog1py_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/special_zeta_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/split_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/split_with_sizes_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/sqrt_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/squeeze_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/sub_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/sum_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/t_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/tan_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/tanh_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/tanh_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/threshold_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/threshold_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/topk_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/transpose_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/triangular_solve_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/tril_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/triu_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/trunc_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/unbind_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/unfold_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/unsqueeze_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/upsample_bicubic2d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/upsample_bicubic2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/upsample_bilinear2d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/upsample_bilinear2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/upsample_linear1d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/upsample_linear1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/upsample_nearest1d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/upsample_nearest1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/upsample_nearest2d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/upsample_nearest2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/upsample_nearest3d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/upsample_nearest3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/upsample_trilinear3d_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/upsample_trilinear3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/values_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/view_as_complex_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/view_as_real_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/view_copy_compositeexplicitautogradnonfunctional_dispatch.h>
#include <ATen/ops/xlogy_compositeexplicitautogradnonfunctional_dispatch.h>

View File

@ -0,0 +1,29 @@
#include <ATen/core/TensorBody.h>
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
// Code introduced to avoid cyclic dependency in static dispatch is no longer
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
// to Operators.cpp for supporting multiple backends with multiple kernels.
//
// Note [Avoiding Include Cycles In Static Dispatch]
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
//
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
// directly inlined into TensorBody.h.
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
// which include functions that have defaultable std::optional<Tensor> arguments.
// That requires knowing the full Tensor class definition.
//
// We break the cycle by doing the following:
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
// - CPUFunctions_inl.h includes everything else
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
// and then it includes CPUFunctions_inl.h.
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
// - This also means that static dispatch build, CPUFunctions.h only needs to
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
#include <ATen/CompositeImplicitAutogradFunctions_inl.h>

View File

@ -0,0 +1,502 @@
#pragma once
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
// NB: The implementing C++ file is RegisterDispatchKey.cpp
// The only #includes we need are for custom classes that have defaults in the C++ API
#include <c10/core/MemoryFormat.h>
#include <c10/core/Scalar.h>
#include <ATen/core/Reduction.h>
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
#error This change adds a dependency on all pytorch operators, meaning the \
file will need to be re-compiled every time an operator is changed or added. \
Consider including a specific operator from \
<ATen/ops/{my_operator}_compositeimplicitautograd_dispatch.h>. \
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
#endif
#include <ATen/ops/_add_batch_dim_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_assert_tensor_metadata_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_autocast_to_full_precision_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_autocast_to_reduced_precision_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_batch_norm_impl_index_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_batch_norm_impl_index_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_cast_Byte_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_cast_Char_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_cast_Double_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_cast_Float_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_cast_Half_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_cast_Int_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_cast_Long_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_cast_Short_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_choose_qparams_per_tensor_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_convolution_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_convolution_double_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_convolution_mode_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_cufft_clear_plan_cache_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_cufft_get_plan_cache_max_size_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_cufft_get_plan_cache_size_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_cufft_set_plan_cache_max_size_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_debug_has_internal_overlap_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_dim_arange_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_embedding_bag_sparse_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_gather_sparse_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_has_compatible_shallow_copy_type_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_is_zerotensor_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_lu_with_info_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_nnpack_available_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_pack_padded_sequence_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_pad_circular_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_pad_enum_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_pad_packed_sequence_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_propagate_xla_data_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_remove_batch_dim_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_reshape_from_tensor_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_rowwise_prune_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_saturate_weight_to_fp16_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_scaled_dot_product_attention_math_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_shape_as_tensor_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_sobol_engine_draw_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_sobol_engine_ff_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_sobol_engine_initialize_state_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_sobol_engine_scramble_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_bsc_tensor_unsafe_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_bsr_tensor_unsafe_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_compressed_tensor_unsafe_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_csc_tensor_unsafe_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_csr_tensor_unsafe_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_log_softmax_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_mm_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_softmax_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_sparse_sum_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_test_ambiguous_defaults_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_test_autograd_multiple_dispatch_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_test_check_tensor_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_test_serialization_subcmul_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_test_string_default_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_thnn_differentiable_gru_cell_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_thnn_differentiable_lstm_cell_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_thnn_fused_lstm_cell_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_to_cpu_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_unpack_dual_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_upsample_bicubic2d_aa_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_upsample_bilinear2d_aa_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact1d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact2d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact3d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_use_cudnn_rnn_flatten_weight_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_validate_sparse_bsc_tensor_args_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_validate_sparse_bsr_tensor_args_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_validate_sparse_compressed_tensor_args_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_validate_sparse_coo_tensor_args_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_validate_sparse_csc_tensor_args_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_validate_sparse_csr_tensor_args_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_version_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_weight_norm_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_weight_norm_differentiable_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_wrapped_linear_prepack_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/_wrapped_quantized_linear_prepacked_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/absolute_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/adaptive_avg_pool1d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/adaptive_avg_pool2d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/adaptive_avg_pool3d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/adaptive_max_pool1d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/adjoint_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/affine_grid_generator_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/align_as_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/align_tensors_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/align_to_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/all_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/alpha_dropout_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/and_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/any_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/arccos_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/arccosh_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/arcsin_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/arcsinh_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/arctan_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/arctan2_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/arctanh_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/argsort_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/argwhere_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/atleast_1d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/atleast_2d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/atleast_3d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/avg_pool1d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/batch_norm_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/bilinear_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/broadcast_tensors_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/broadcast_to_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/can_cast_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/cartesian_prod_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/cat_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/cdist_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/chain_matmul_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/chalf_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/choose_qparams_optimized_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/chunk_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/clip_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/coalesce_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/column_stack_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/combinations_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/concat_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/concatenate_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/conj_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/conj_physical_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/contiguous_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/conv1d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/conv2d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/conv3d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/conv_tbc_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/conv_transpose1d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/conv_transpose2d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/conv_transpose3d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/corrcoef_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/cosine_embedding_loss_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/cosine_similarity_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/cov_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/cross_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/cross_entropy_loss_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/ctc_loss_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/cudnn_is_acceptable_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/cummax_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/cummaxmin_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/cummin_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/cumprod_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/cumprod_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/cumsum_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/cumulative_trapezoid_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/data_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/det_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/diag_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/diagflat_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/diagonal_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/diff_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/divide_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/dropout_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/dsplit_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/dstack_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/einsum_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/embedding_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/embedding_bag_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/embedding_sparse_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/empty_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/expand_as_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fake_quantize_per_channel_affine_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fake_quantize_per_tensor_affine_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fbgemm_linear_fp16_weight_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fbgemm_linear_int8_weight_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fbgemm_linear_quantize_weight_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fbgemm_pack_quantized_matrix_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/feature_alpha_dropout_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/feature_dropout_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_fft_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_fft2_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_fftn_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_fftshift_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_hfft_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_hfft2_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_hfftn_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_ifft_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_ifft2_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_ifftn_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_ifftshift_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_ihfft_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_ihfft2_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_ihfftn_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_irfft_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_irfft2_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_irfftn_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_rfft_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_rfft2_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fft_rfftn_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fill_diagonal_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fix_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/flatten_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/flatten_dense_tensors_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fliplr_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/flipud_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/float_power_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/frobenius_norm_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/fused_moving_avg_obs_fake_quant_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/gather_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/gather_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/ger_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/gradient_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/greater_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/greater_equal_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/grid_sampler_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/group_norm_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/gru_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/gru_cell_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/hinge_embedding_loss_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/histogramdd_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/hsplit_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/hstack_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/imag_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/index_add_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/index_copy_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/index_fill_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/index_select_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/index_select_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/infinitely_differentiable_gelu_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/inner_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/instance_norm_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/inverse_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/is_complex_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/is_conj_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/is_distributed_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/is_floating_point_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/is_inference_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/is_leaf_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/is_neg_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/is_nonzero_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/is_signed_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/is_vulkan_available_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/isclose_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/isfinite_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/isreal_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/istft_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/item_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/kl_div_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/kron_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/kthvalue_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/l1_loss_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/layer_norm_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/ldexp_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/less_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/less_equal_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_cholesky_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_cond_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_det_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_diagonal_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_eigh_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_eigvals_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_eigvalsh_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_inv_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_ldl_factor_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_lu_factor_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_matmul_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_matrix_norm_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_matrix_power_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_matrix_rank_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_multi_dot_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_norm_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_pinv_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_slogdet_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_solve_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_solve_ex_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_svd_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_svdvals_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_tensorinv_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_tensorsolve_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_vander_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linalg_vecdot_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/linear_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/log_sigmoid_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/log_softmax_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/logcumsumexp_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/logdet_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/logsumexp_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/lstm_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/lstm_cell_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/lu_solve_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/mH_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/mT_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/margin_ranking_loss_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/masked_select_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/matmul_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/matrix_H_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/matrix_exp_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/matrix_exp_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/matrix_power_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/max_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/max_pool1d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/max_pool1d_with_indices_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/max_pool2d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/max_pool3d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/mean_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/median_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/meshgrid_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/min_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/mish_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/mode_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/moveaxis_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/movedim_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/msort_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/multilabel_margin_loss_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/multiply_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/nanmean_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/nanmedian_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/nanquantile_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/narrow_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/native_channel_shuffle_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/negative_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/nested_to_padded_tensor_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/nll_loss_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/nll_loss2d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/nll_loss_nd_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/nonzero_numpy_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/norm_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/norm_except_dim_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/not_equal_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/nuclear_norm_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/numpy_T_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/one_hot_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/or_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/orgqr_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/outer_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/output_nr_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/pad_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/pad_sequence_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/pairwise_distance_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/pdist_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/pin_memory_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/pinverse_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/poisson_nll_loss_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/positive_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/prelu_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/prod_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/promote_types_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/qr_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/quantile_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/quantized_gru_cell_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/quantized_lstm_cell_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/quantized_rnn_relu_cell_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/quantized_rnn_tanh_cell_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/rand_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/randn_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/ravel_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/real_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/refine_names_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/relu6_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/rename_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/repeat_interleave_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/requires_grad_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/reshape_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/reshape_as_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/resolve_conj_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/resolve_neg_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/result_type_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/retain_grad_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/retains_grad_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/rms_norm_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/rnn_relu_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/rnn_relu_cell_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/rnn_tanh_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/rnn_tanh_cell_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/row_stack_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/rrelu_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/scaled_dot_product_attention_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/scatter_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/scatter_add_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/select_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/selu_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/set_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/set_data_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/silu_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/size_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/slogdet_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/slow_conv3d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/smm_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/softmax_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/sort_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/sparse_bsc_tensor_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/sparse_bsr_tensor_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/sparse_coo_tensor_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/sparse_csc_tensor_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/sparse_csr_tensor_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_digamma_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_erf_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_erfc_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_erfinv_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_exp2_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_expit_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_expm1_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_gammainc_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_gammaincc_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_gammaln_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_i0_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_log1p_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_log_softmax_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_logit_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_logsumexp_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_multigammaln_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_ndtr_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_polygamma_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_psi_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_round_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_sinc_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_softmax_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/special_xlogy_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/split_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/square_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/squeeze_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/sspaddmm_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/std_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/std_mean_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/stft_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/stride_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/subtract_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/sum_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/sum_to_size_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/svd_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/swapaxes_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/swapdims_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/sym_numel_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/sym_size_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/sym_storage_offset_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/sym_stride_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/take_along_dim_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/tensor_split_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/tensordot_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/thnn_conv2d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/tile_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/to_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/to_dense_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/to_dense_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/to_mkldnn_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/to_sparse_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/to_sparse_bsc_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/to_sparse_bsr_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/to_sparse_csc_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/to_sparse_csr_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/trace_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/transpose_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/trapezoid_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/trapz_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/triplet_margin_loss_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/true_divide_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/type_as_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/unbind_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/unflatten_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/unflatten_dense_tensors_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/unsafe_chunk_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/upsample_bicubic2d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/upsample_bilinear2d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/upsample_linear1d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/upsample_nearest1d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/upsample_nearest2d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/upsample_nearest3d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/upsample_trilinear3d_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/value_selecting_reduction_backward_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/vander_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/var_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/var_mean_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/view_as_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/vsplit_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/vstack_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/where_compositeimplicitautograd_dispatch.h>
#include <ATen/ops/xor_compositeimplicitautograd_dispatch.h>

View File

@ -0,0 +1,29 @@
#include <ATen/core/TensorBody.h>
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
// Code introduced to avoid cyclic dependency in static dispatch is no longer
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
// to Operators.cpp for supporting multiple backends with multiple kernels.
//
// Note [Avoiding Include Cycles In Static Dispatch]
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
//
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
// directly inlined into TensorBody.h.
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
// which include functions that have defaultable std::optional<Tensor> arguments.
// That requires knowing the full Tensor class definition.
//
// We break the cycle by doing the following:
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
// - CPUFunctions_inl.h includes everything else
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
// and then it includes CPUFunctions_inl.h.
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
// - This also means that static dispatch build, CPUFunctions.h only needs to
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
#include <ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h>

View File

@ -0,0 +1,25 @@
#pragma once
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
// NB: The implementing C++ file is RegisterDispatchKey.cpp
// The only #includes we need are for custom classes that have defaults in the C++ API
#include <c10/core/MemoryFormat.h>
#include <c10/core/Scalar.h>
#include <ATen/core/Reduction.h>
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
#error This change adds a dependency on all pytorch operators, meaning the \
file will need to be re-compiled every time an operator is changed or added. \
Consider including a specific operator from \
<ATen/ops/{my_operator}_compositeimplicitautogradnestedtensor_dispatch.h>. \
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
#endif
#include <ATen/ops/randn_like_compositeimplicitautogradnestedtensor_dispatch.h>
#include <ATen/ops/reshape_compositeimplicitautogradnestedtensor_dispatch.h>
#include <ATen/ops/reshape_as_compositeimplicitautogradnestedtensor_dispatch.h>
#include <ATen/ops/zeros_like_compositeimplicitautogradnestedtensor_dispatch.h>

View File

@ -0,0 +1,21 @@
#pragma once
// Test these using #if AT_MKL_ENABLED(), not #ifdef, so that it's
// obvious if you forgot to include Config.h
// c.f. https://stackoverflow.com/questions/33759787/generating-an-error-if-checked-boolean-macro-is-not-defined
//
// DO NOT put the macros for CUDA libraries in this file; they belong in cuda/CUDAConfig.h
#define AT_MKLDNN_ENABLED() 1
#define AT_MKLDNN_ACL_ENABLED() 0
#define AT_MKL_ENABLED() 1
#define AT_MKL_SEQUENTIAL() 0
#define AT_POCKETFFT_ENABLED() 0
#define AT_NNPACK_ENABLED() 0
#define CAFFE2_STATIC_LINK_CUDA() 0
#define AT_BUILD_WITH_BLAS() 1
#define AT_BUILD_WITH_LAPACK() 1
#define AT_PARALLEL_OPENMP 1
#define AT_PARALLEL_NATIVE 0
#define AT_BLAS_F2C() 0
#define AT_BLAS_USE_CBLAS_DOT() 0

View File

@ -0,0 +1,610 @@
#pragma once
#include <ATen/BlasBackend.h>
#include <ATen/CPUGeneratorImpl.h>
#include <ATen/DeviceAccelerator.h>
#include <ATen/LinalgBackend.h>
#include <ATen/core/ATenGeneral.h>
#include <ATen/core/DeprecatedTypeProperties.h>
#include <ATen/core/Generator.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/detail/HIPHooksInterface.h>
#include <ATen/detail/IPUHooksInterface.h>
#include <ATen/detail/MAIAHooksInterface.h>
#include <ATen/detail/MPSHooksInterface.h>
#include <ATen/detail/MTIAHooksInterface.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>
#include <ATen/detail/XPUHooksInterface.h>
#include <c10/core/QEngine.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/util/CallOnce.h>
#include <c10/util/Exception.h>
#include <c10/util/env.h>
#include <c10/util/irange.h>
#include <cstdint>
#include <mutex>
namespace at {
class Tensor;
enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
class TORCH_API Context {
public:
Context();
const Generator& defaultGenerator(Device device) {
c10::DeviceType device_type = device.type();
initCUDAIfNeeded(device_type);
initHIPIfNeeded(device_type);
if (device_type == at::kCPU) {
return at::detail::getDefaultCPUGenerator();
} else if (device_type == at::kCUDA) {
return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
} else if (device_type == at::kMPS) {
return at::detail::getMPSHooks().getDefaultMPSGenerator();
} else if (device_type == at::kXPU) {
return at::detail::getXPUHooks().getDefaultXPUGenerator(device.index());
} else if (device_type == at::kIPU) {
return at::detail::getIPUHooks().getDefaultIPUGenerator(device.index());
} else if (device_type == at::kPrivateUse1) {
return at::detail::getPrivateUse1Hooks().getDefaultGenerator(
device.index());
} else {
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
}
}
const AcceleratorHooksInterface& getAcceleratorHooksInterface(
std::optional<c10::DeviceType> opt_device_type = std::nullopt) {
c10::DeviceType device_type = opt_device_type.has_value()
? opt_device_type.value()
: at::getAccelerator(true).value();
if (device_type == at::kCUDA) {
return at::detail::getCUDAHooks();
} else if (device_type == at::kXPU) {
return at::detail::getXPUHooks();
} else if (device_type == at::kMPS) {
return at::detail::getMPSHooks();
} else if (device_type == at::kPrivateUse1) {
return at::detail::getPrivateUse1Hooks();
} else if (device_type == at::kMTIA) {
return at::detail::getMTIAHooks();
} else if (device_type == at::kHIP) {
return at::detail::getHIPHooks();
} else {
AT_ERROR(
c10::DeviceTypeName(device_type), " device type not an accelerator.");
}
}
Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
initCUDAIfNeeded(device_type);
initHIPIfNeeded(device_type);
initXPUIfNeeded(device_type);
if (device_type == at::kCPU) {
return c10::DeviceType::CPU;
} else if (device_type == at::kCUDA) {
return at::detail::getCUDAHooks().getDeviceFromPtr(data);
} else if (device_type == at::kXPU) {
return at::detail::getXPUHooks().getDeviceFromPtr(data);
} else if (device_type == at::kPrivateUse1) {
return at::detail::getPrivateUse1Hooks().getDeviceFromPtr(data);
} else {
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
}
}
bool isPinnedPtr(
const void* data,
std::optional<c10::DeviceType> device_type = std::nullopt) {
auto opt_device_type =
device_type.has_value() ? device_type : at::getAccelerator();
if (!opt_device_type.has_value() || // there is no accelerator
!at::isAccelerator(
opt_device_type.value())) { // passed device not an accelerator
return false;
}
return getAcceleratorHooksInterface(opt_device_type.value())
.isPinnedPtr(data);
}
Allocator* getPinnedMemoryAllocator(
std::optional<c10::DeviceType> device_type = std::nullopt) {
return getAcceleratorHooksInterface(device_type).getPinnedMemoryAllocator();
}
static bool hasOpenMP();
static bool hasMKL();
static bool hasLAPACK();
static bool hasMKLDNN();
static bool hasMAGMA() {
return detail::getCUDAHooks().hasMAGMA();
}
static bool hasCUDA() {
return detail::getCUDAHooks().hasCUDA();
}
static bool hasMTIA() {
return detail::getMTIAHooks().hasMTIA();
}
static bool hasCUDART() {
return detail::getCUDAHooks().hasCUDART();
}
static long versionCUDART() {
return detail::getCUDAHooks().versionCUDART();
}
static bool hasCuDNN() {
return detail::getCUDAHooks().hasCuDNN();
}
static long versionCuDNN() {
return detail::getCUDAHooks().versionCuDNN();
}
static bool hasCuSOLVER() {
return detail::getCUDAHooks().hasCuSOLVER();
}
static bool hasCuBLASLt() {
return detail::getCUDAHooks().hasCuBLASLt();
}
static bool hasHIP() {
return detail::getHIPHooks().hasHIP();
}
static bool hasMPS() {
return detail::getMPSHooks().hasMPS();
}
static bool hasIPU() {
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
}
static bool hasXLA() {
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
}
static bool hasXPU() {
return detail::getXPUHooks().hasXPU();
}
static bool hasLazy() {
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy);
}
static bool hasMAIA() {
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA);
}
// defined in header so that getNonVariableType has ability to inline
// call_once check. getNonVariableType is called fairly frequently
void lazyInitCUDA() {
c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); });
}
void lazyInitHIP() {
c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); });
}
void lazyInitXPU() {
c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); });
}
void lazyInitMTIA() {
c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); });
}
void lazyInitPrivateUse1() {
c10::call_once(thp_init, [&] {
if (isPrivateUse1HooksRegistered()) {
at::detail::getPrivateUse1Hooks().initPrivateUse1();
}
});
}
static const at::cuda::NVRTC& getNVRTC() {
return detail::getCUDAHooks().nvrtc();
}
static bool setFlushDenormal(bool on);
// NB: This method is *purely* whether or not a user requested
// that CuDNN was enabled, it doesn't actually say anything about
// whether or not CuDNN is actually usable. Use cudnn_is_acceptable
// to test this instead
bool userEnabledCuDNN() const;
void setUserEnabledCuDNN(bool e);
bool userEnabledMkldnn() const;
void setUserEnabledMkldnn(bool e);
bool benchmarkCuDNN() const;
void setBenchmarkCuDNN(bool);
int benchmarkLimitCuDNN() const;
void setBenchmarkLimitCuDNN(int);
bool deterministicCuDNN() const;
void setDeterministicCuDNN(bool);
bool deterministicMkldnn() const;
void setDeterministicMkldnn(bool);
bool userEnabledNNPACK() const;
void setUserEnabledNNPACK(bool e);
// Note [Disabling Fused SDP Kernels]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Flash and Memory Efficient SDP kernels are enabled by default.
// However, they can be disabled by setting
// at::globalContext().setUserEnabledFlashSDP(false) flag.
// This is useful for debugging purposes. For example, if you want to
// compare the performance of the flash SDP kernels with the unfused
// kernel, you can disable the flash SDP kernels. By disabling
// the math SDP kernel, you can force your code to use flash kernels.
// The math SDP kernel can be disabled by setting
// at::globalContext().setUserEnabledMathSDP(false) flag.
void setSDPUseFlash(bool);
bool userEnabledFlashSDP() const;
void setSDPUseMemEfficient(bool);
bool userEnabledMemEfficientSDP() const;
void setSDPUseMath(bool);
bool userEnabledMathSDP() const;
void setSDPUseCuDNN(bool);
bool userEnabledCuDNNSDP() const;
void setAllowFP16BF16ReductionMathSDP(bool);
bool allowFP16BF16ReductionMathSDP() const;
void setSDPUseOverrideable(bool);
bool userEnabledOverrideableSDP() const;
at::LinalgBackend linalgPreferredBackend() const;
void setLinalgPreferredBackend(at::LinalgBackend);
at::BlasBackend blasPreferredBackend();
void setBlasPreferredBackend(at::BlasBackend);
// Note [Enabling Deterministic Operations]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Operations in PyTorch that normally act nondeterministically, but have an
// alternate deterministic implementation, should satisfy the following
// requirements:
//
// * Include this comment: "See Note [Enabling Deterministic Operations]"
//
// * Check the value of `at::globalContext().deterministicAlgorithms()` to
// toggle
// between nondeterministic and deterministic implementations.
//
// * Have an entry in the list of PyTorch operations that toggle between
// nondeterministic
// and deterministic implementations, in the docstring of
// `use_deterministic_algorithms()` in torch/__init__.py
//
// `example_func()` below shows an example of toggling between
// nondeterministic and deterministic implementations:
//
// void example_func() {
// // See Note [Enabling Deterministic Operations]
// if (at::globalContext().deterministicAlgorithms()) {
// example_func_deterministic();
// } else {
// example_func_nondeterministic();
// }
// }
bool deterministicAlgorithms() const;
bool deterministicAlgorithmsWarnOnly() const;
void setDeterministicAlgorithms(bool, bool);
bool deterministicFillUninitializedMemory() const;
void setDeterministicFillUninitializedMemory(bool);
// Note [Writing Nondeterministic Operations]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Operations in PyTorch that act nondeterministically and do not have an
// alternate deterministic implementation should satisfy the following
// requirements:
//
// * Include this comment: "See Note [Writing Nondeterministic Operations]"
//
// * Include a comment explaining why the operation is nondeterministic.
//
// * Throw an error when `Context::deterministicAlgorithms()` is true. Most
// of the time, this should be accomplished by calling
// `at::globalContext().alertNotDeterminstic()`. However, if the
// nondeterministic behavior is caused by the CuBLAS workspace
// configuration in CUDA >= 10.2,
// `at::globalContext().alertCuBLASConfigNotDeterministic()` should be
// called instead (in this case, a comment explaining why the operation is
// nondeterministic is not necessary). See below for details on these
// methods.
//
// * Have an entry in the list of nondeterministic PyTorch operations in the
// docstring of `use_deterministic_algorithms()` in torch/__init__.py
//
// * Have a test function in `test/test_torch.py` whose name begins with
// `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace
// configuration is the reason for nondeterminism, the operation should be
// included in the `test_cublas_config_nondeterministic_alert` test. Any new
// tests should ideally follow a pattern similar to the existing ones.
//
// `example_func()` below shows an example of the comments and error-throwing
// code for a nondeterministic operation:
//
// void example_func() {
// // See Note [Writing Nondeterministic Operations]
// // Nondeterministic because <reason>
// at::globalContext().alertNondeterministic("example_func");
// ...
// }
// Throws an error if `Context::deterministicAlgorithms()` is true
static void alertNotDeterministic(c10::string_view const& caller);
// Throws an error if `Context::deterministicAlgorithms()` is true, CUDA
// >= 10.2, and CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or
// ":4096:8". For more details:
// https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
void alertCuBLASConfigNotDeterministic() const;
void setFloat32MatmulPrecision(const std::string& s);
bool allowTF32CuDNN() const;
void setAllowTF32CuDNN(bool);
bool allowTF32CuBLAS() const;
void setAllowTF32CuBLAS(bool);
Float32MatmulPrecision float32MatmulPrecision() const;
void setFloat32MatmulPrecision(Float32MatmulPrecision p);
bool allowFP16ReductionCuBLAS() const;
void setAllowFP16ReductionCuBLAS(bool);
bool allowBF16ReductionCuBLAS() const;
void setAllowBF16ReductionCuBLAS(bool);
at::QEngine qEngine() const;
void setQEngine(at::QEngine e);
static const std::vector<at::QEngine>& supportedQEngines();
static bool isXNNPACKAvailable();
void setCheckSparseTensorInvariants(bool e);
bool checkSparseTensorInvariants() const;
// This method is used to release the original weight after pre-packing.
// It should be called once before loading/running the model.
// NB: By default it is set to true for mobile builds.
void setReleaseWeightsWhenPrepacking(bool e);
bool releaseWeightsWhenPrepacking() const;
void setDisplayVmapFallbackWarnings(bool enabled);
bool areVmapFallbackWarningsEnabled() const;
void setDefaultMobileCPUAllocator();
void unsetDefaultMobileCPUAllocator();
bool allowFP16ReductionCPU() const;
void setAllowFP16ReductionCPU(bool);
private:
void initCUDAIfNeeded(c10::DeviceType p) {
if (p == c10::DeviceType::CUDA) {
lazyInitCUDA();
}
}
void initHIPIfNeeded(c10::DeviceType p) {
if (p == c10::DeviceType::HIP) {
lazyInitHIP();
}
}
void initXPUIfNeeded(c10::DeviceType p) {
if (p == c10::DeviceType::XPU) {
lazyInitXPU();
}
}
static bool checkCuBLASConfigDeterministic();
c10::once_flag thc_init;
c10::once_flag thh_init;
c10::once_flag thx_init;
c10::once_flag th_mtia_init;
c10::once_flag thp_init;
bool enabled_cudnn = true;
bool deterministic_cudnn = false;
bool deterministic_mkldnn = false;
bool _deterministic_algorithms = false;
bool _deterministic_algorithms_warn_only = false;
bool _deterministic_fill_uninitialized_memory = true;
bool enabled_flashSDP = true;
bool enabled_mem_efficientSDP = true;
bool enabled_mathSDP = true;
bool enabled_cudnnSDP = true;
bool enabled_overrideable = true;
bool allow_fp16_bf16_reduction_mathSDP = false;
#ifdef USE_ROCM
bool benchmark_cudnn = true;
#else
bool benchmark_cudnn = false;
#endif
Float32MatmulPrecision float32_matmul_precision =
c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
? at::Float32MatmulPrecision::HIGH
: at::Float32MatmulPrecision::HIGHEST;
int benchmark_limit_cudnn = 10;
bool allow_tf32_cudnn = true;
bool allow_fp16_reduction_cublas = true;
bool allow_bf16_reduction_cublas = true;
bool enabled_mkldnn = true;
bool enabled_nnpack = true;
at::LinalgBackend linalg_preferred_backend =
c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true
? at::LinalgBackend::Cusolver
: at::LinalgBackend::Default;
at::BlasBackend blas_preferred_backend =
#ifdef USE_ROCM
(c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") != false)
#else
(c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true)
#endif
? at::BlasBackend::Cublaslt
: at::BlasBackend::Cublas;
#ifdef C10_MOBILE
bool release_original_weights = true;
#else
bool release_original_weights = false;
#endif
bool display_vmap_fallback_warnings_ = false;
std::optional<at::QEngine> quantized_engine = std::nullopt;
bool enable_sparse_tensor_invariant_checks = false;
bool allow_fp16_reduction_cpu = false;
Allocator* prev_allocator_ptr_{nullptr};
};
TORCH_API Context& globalContext();
inline void init() {
globalContext();
}
TORCH_API Allocator* getCPUAllocator();
inline DeprecatedTypeProperties& getDeprecatedTypeProperties(
Backend p,
ScalarType s) {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
p, s);
}
inline DeprecatedTypeProperties& CPU(ScalarType s) {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
Backend::CPU, s);
}
inline DeprecatedTypeProperties& CUDA(ScalarType s) {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
Backend::CUDA, s);
}
inline DeprecatedTypeProperties& HIP(ScalarType s) {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
Backend::HIP, s);
}
inline DeprecatedTypeProperties& MPS(ScalarType s) {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
Backend::MPS, s);
}
inline bool hasCUDA() {
return globalContext().hasCUDA();
}
inline bool hasMTIA() {
return globalContext().hasMTIA();
}
inline bool hasHIP() {
return globalContext().hasHIP();
}
inline bool hasIPU() {
return globalContext().hasIPU();
}
inline bool hasXLA() {
return globalContext().hasXLA();
}
inline bool hasMPS() {
return globalContext().hasMPS();
}
inline bool hasMAIA() {
return globalContext().hasMAIA();
}
inline bool hasXPU() {
return globalContext().hasXPU();
}
// Despite its name, this function returns the number of *CUDA* GPUs.
inline size_t getNumGPUs() {
// WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
// FUNCTION. If you are interested in interrogating the number of
// devices for a specific device type, add that function to the
// relevant library (e.g., similar to at::cuda::device_count())
if (hasCUDA() && hasHIP()) {
throw std::runtime_error(
"Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades "
"to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually "
"means HIP. Rebuild PyTorch with one or the other disabled.");
} else if (hasCUDA()) {
return detail::getCUDAHooks().getNumGPUs();
} else if (hasHIP()) {
return detail::getHIPHooks().getNumGPUs();
} else {
return 0;
}
}
inline bool hasOpenMP() {
return globalContext().hasOpenMP();
}
inline bool hasMKL() {
return globalContext().hasMKL();
}
inline bool hasLAPACK() {
return globalContext().hasLAPACK();
}
inline bool hasMAGMA() {
return globalContext().hasMAGMA();
}
inline bool hasMKLDNN() {
return globalContext().hasMKLDNN();
}
inline void manual_seed(uint64_t seed) {
auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU);
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
gen.set_current_seed(seed);
}
// NB: Sometimes we build with CUDA, but we don't have any GPUs
// available. In that case, we must not seed CUDA; it will fail!
const auto cuda_num_gpus = detail::getCUDAHooks().getNumGPUs();
if (hasCUDA() && cuda_num_gpus > 0) {
for (const auto i : c10::irange(cuda_num_gpus)) {
auto cuda_gen = globalContext().defaultGenerator(
Device(at::kCUDA, static_cast<c10::DeviceIndex>(i)));
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(cuda_gen.mutex());
cuda_gen.set_current_seed(seed);
}
}
}
const auto xpu_num_gpus = detail::getXPUHooks().getNumGPUs();
if (hasXPU() && xpu_num_gpus) {
for (const auto i : c10::irange(xpu_num_gpus)) {
auto xpu_gen = globalContext().defaultGenerator(
Device(at::kXPU, static_cast<c10::DeviceIndex>(i)));
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(xpu_gen.mutex());
xpu_gen.set_current_seed(seed);
}
}
}
if (hasMPS()) {
auto mps_gen = globalContext().defaultGenerator(c10::DeviceType::MPS);
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(mps_gen.mutex());
mps_gen.set_current_seed(seed);
}
}
// When the global flag `allow_tf32` is set to true, cuBLAS handles are
// automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
// For some operators, such as addmv, TF32 offers no performance improvement
// but causes precision loss. To help this case, this class implements
// a RAII guard that can be used to quickly disable TF32 within its scope.
//
// Usage:
// NoTF32Guard disable_tf32;
struct TORCH_API NoTF32Guard {
NoTF32Guard();
~NoTF32Guard();
static bool should_disable_tf32();
private:
bool changed = false;
};
struct TORCH_API ROCmBackwardPassGuard {
ROCmBackwardPassGuard();
~ROCmBackwardPassGuard();
static bool is_backward_pass();
};
} // namespace at

View File

@ -0,0 +1,25 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <ATen/dlpack.h>
// this convertor will:
// 1) take a Tensor object and wrap it in the DLPack tensor
// 2) take a dlpack tensor and convert it to the ATen Tensor
namespace at {
TORCH_API ScalarType toScalarType(const DLDataType& dtype);
TORCH_API DLManagedTensor* toDLPack(const Tensor& src);
TORCH_API Tensor fromDLPack(DLManagedTensor* src);
C10_DEPRECATED_MESSAGE("Please migrate to a non-const variant")
inline Tensor fromDLPack(const DLManagedTensor* src) {
return fromDLPack(const_cast<DLManagedTensor*>(src));
}
TORCH_API Tensor
fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter);
TORCH_API DLDataType getDLDataType(const Tensor& t);
TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id);
} // namespace at

View File

@ -0,0 +1,2 @@
#pragma once
#include <c10/core/Device.h>

View File

@ -0,0 +1,27 @@
#pragma once
#include <c10/core/DeviceType.h>
#include <c10/macros/Macros.h>
#include <ATen/detail/MTIAHooksInterface.h>
#include <optional>
// This file defines the top level Accelerator concept for PyTorch.
// A device is an accelerator per the definition here if:
// - It is mutually exclusive with all other accelerators
// - It performs asynchronous compute via a Stream/Event system
// - It provides a set of common APIs as defined by AcceleratorHooksInterface
//
// As of today, accelerator devices are (in no particular order):
// CUDA, MTIA, XPU, HIP, MPS, PrivateUse1
namespace at {
// Ensures that only one accelerator is available (at
// compile time if possible) and return it.
// When checked is true, the returned optional always has a value.
TORCH_API std::optional<c10::DeviceType> getAccelerator(bool checked = false);
TORCH_API bool isAccelerator(c10::DeviceType d);
} // namespace at

View File

@ -0,0 +1,41 @@
#pragma once
#include <ATen/core/IListRef.h>
#include <ATen/core/Tensor.h>
#include <c10/core/DeviceGuard.h>
#include <c10/core/ScalarType.h> // TensorList whyyyyy
namespace at {
// Are you here because you're wondering why DeviceGuard(tensor) no
// longer works? For code organization reasons, we have temporarily(?)
// removed this constructor from DeviceGuard. The new way to
// spell it is:
//
// OptionalDeviceGuard guard(device_of(tensor));
/// Return the Device of a Tensor, if the Tensor is defined.
inline std::optional<Device> device_of(const Tensor& t) {
if (t.defined()) {
return std::make_optional(t.device());
} else {
return std::nullopt;
}
}
inline std::optional<Device> device_of(const std::optional<Tensor>& t) {
return t.has_value() ? device_of(t.value()) : std::nullopt;
}
/// Return the Device of a TensorList, if the list is non-empty and
/// the first Tensor is defined. (This function implicitly assumes
/// that all tensors in the list have the same device.)
inline std::optional<Device> device_of(ITensorListRef t) {
if (!t.empty()) {
return device_of(t.front());
} else {
return std::nullopt;
}
}
} // namespace at

View File

@ -0,0 +1,2 @@
#pragma once
#include <ATen/core/DimVector.h>

View File

@ -0,0 +1 @@
#include <ATen/core/Dimname.h>

View File

@ -0,0 +1,808 @@
#pragma once
#include <ATen/core/DeprecatedTypeProperties.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/Half.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/complex.h>
#include <c10/util/string_view.h>
#ifdef __CUDACC__
#include <cuda.h> // For CUDA_VERSION
#endif
#ifdef TEMPLATE_SELECTIVE_BUILD
#include <ATen/selected_mobile_ops.h>
#else
namespace at {
/**
* The method should_include_kernel_dtype() returns true/false
* based on whether the switching code for a specific dtype should be
* included based on build time constants generated from tracing model
* execution. This method will be implemented via code-generation and
* included in this file when code-gen is ready.
*/
inline constexpr bool should_include_kernel_dtype(
const char* /*kernel_tag_str*/,
at::ScalarType /*scalar_type*/
) {
return true;
}
} // namespace at
#endif
/**
* In the Facebook internal build (using BUCK), this macro is enabled by
* passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
* binary.
*/
#if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
namespace at {
namespace detail {
TORCH_API void record_kernel_function_dtype(std::string name);
}
} // namespace at
#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \
at::detail::record_kernel_function_dtype( \
std::string(NAME) + "$" + toString(enum_type));
#else
#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type)
#endif
#define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \
do { \
if constexpr (!at::should_include_kernel_dtype( \
at_dispatch_name, enum_type)) { \
AT_ERROR( \
"dtype '", \
toString(enum_type), \
"' not selected for kernel tag ", \
at_dispatch_name); \
} \
} while (0)
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
case enum_type: { \
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
using HINT C10_UNUSED = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
return __VA_ARGS__(); \
}
#define AT_DISPATCH_CASE(enum_type, ...) \
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
#define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...) \
case enum_type: { \
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
using scalar_t = scalar_type; \
using underlying_t C10_UNUSED = typename scalar_t::underlying; \
const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
return __VA_ARGS__(); \
}
#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
enum_type, scalar_type, bitwidth, qmin, qmax, ...) \
case enum_type: { \
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
using scalar_t = scalar_type; \
using underlying_t C10_UNUSED = typename scalar_t::underlying; \
const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
C10_UNUSED int bit_width = bitwidth; \
C10_UNUSED int64_t quant_min = qmin; \
C10_UNUSED int64_t quant_max = qmax; \
return __VA_ARGS__(); \
}
namespace detail {
inline at::ScalarType scalar_type(at::ScalarType s) {
return s;
}
C10_DEPRECATED_MESSAGE(
"passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, "
"pass an at::ScalarType instead")
inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties& t) {
return t.scalarType();
}
C10_DEPRECATED_MESSAGE(
"AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, "
"use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead")
inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {}
C10_DEPRECATED_MESSAGE(
"AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, "
"use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) "
"instead")
inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
} // namespace detail
// The AT_DISPATCH_* family of macros provides the ability to
// conveniently generate specializations of a kernel over all of the
// dtypes we care about in PyTorch. We call it "dispatch" because
// we are "dispatching" to the correct, dtype-specific kernel.
//
// A standard usage looks like:
//
// AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] {
// // Your code here, with 'scalar_t' now defined to
// // be the dtype in question
// });
//
// There are many variations of this macro, so it's important to
// understand exactly /which/ dtypes you want to get instantiated, as
// well as what the "default" set is.
//
// The default set of dtypes that are instantiated (e.g., by
// AT_DISPATCH_ALL_TYPES) are floating point types (float, double),
// and integral types (int32_t, int64_t, int16_t, int8_t, uint8_t),
// but NOT booleans (bool), half-precision floats (Half) or
// complex number (c10::complex<float>, c10::complex<double>).
// This "cut" is somewhat historical (the default types are the
// ones that TH historically supported), but it also reflects the
// fact that the non-default types are "poorly" behaved (booleans
// are NOT integers mod 2, half precision operations ~essentially
// don't exist on CPU, complex numbers are an experimental application).
//
// Here are the questions you should generally ask to decide which
// dispatch you want:
//
// 1. Is this an integral or floating point specific operation?
// (If so, you'll want one of the FLOATING or INTEGRAL macros.)
//
// 2. Should half be supported? (If you're on CPU, the answer is almost
// definitely no. If you do want support, use one of the AND_HALF
// macros)
//
// Much rarer situations:
//
// 3. Should bool be supported? (You often have to write your kernel
// differently if arithmetic operations are involved.) If so,
// Use AT_DISPATCH_ALL_TYPES_AND along with ScalarType::Bool
//
// 4. Should complex be supported? The answer is almost always no,
// unless you are working on "generic" code that should work on
// all dtypes.
//
// Parameters:
// -----------
//
// 1. The NAME argument is a "tag" that is used to trace and then
// conditionally compile fragments of the case statements such
// that the kernel functions are specialized only for the dtypes
// that are needed. The NAME parameter *must* be a build time
// const char* (can't be std::string, etc...)
//
// Please ensure that the NAME is unique for every implementation
// or you run the risk of over-including code for the kernel
// functions. There is no risk of missing out on any code, so
// it's mostly a risk of a Type-2 error, and not a Type-1 error.
//
// Switch-like syntax:
// -------------------
// There is also a switch-case like syntax which is useful if a kernel
// needs to be specialized for particular scalar types
//
// AT_DISPATCH_SWITCH(self.scalar_type(), "op_name",
// AT_DISPATCH_CASE_INTEGRAL_TYPES([&] {
// op_integral<scalar_t>(iter);
// })
// AT_DISPATCH_CASE_FLOATING_TYPES([&] {
// op_floating<scalar_t>(iter);
// })
// AT_DISPATCH_CASE(kBool, [&] {
// op_bool(iter);
// })
// );
//
// For each AT_DISPATCH_FOO macro, there is a corresponding
// AT_DISPATCH_CASE_FOO macro which can be used inside of an
// AT_DISPATCH_SWITCH block.
// NB: the the_type variable is not used, but we have kept it for
// backwards compatibility. It's probably not used by anyone though;
// but we're just being safe (and it doesn't hurt.) Note we must
// use it to shut up warnings about unused store.
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
constexpr const char* at_dispatch_name = NAME; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
switch (_st) { \
__VA_ARGS__ \
default: \
AT_ERROR( \
'"', \
at_dispatch_name, \
"\" not implemented for '", \
toString(_st), \
"'"); \
} \
}()
#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))
#define AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define AT_DISPATCH_REDUCED_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__))
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
#define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, __VA_ARGS__))
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
#define AT_DISPATCH_FLOATING_TYPES_AND2( \
SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_FLOATING_TYPES_AND2( \
SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
#define AT_DISPATCH_FLOATING_TYPES_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
#define AT_DISPATCH_FLOATING_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
#define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)
#define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__))
#define AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, ...) \
AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
#define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__))
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__))
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, ...) \
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( \
SCALARTYPE, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1( \
SCALARTYPE, __VA_ARGS__))
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
SCALARTYPE1, SCALARTYPE2, ...) \
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( \
SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND5( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
TYPE, \
NAME, \
...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
__VA_ARGS__))
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
...) \
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
TYPE, \
NAME, \
...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
__VA_ARGS__))
#define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)
#define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, ...) \
AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
#define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
#define AT_DISPATCH_CASE_ALL_TYPES(...) \
AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)
#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))
#define AT_DISPATCH_CASE_QINT_TYPES(...) \
AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) \
AT_DISPATCH_CASE_QINT(at::kQInt32, at::qint32, __VA_ARGS__)
#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__))
#define AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, ...) \
AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
#define AT_DISPATCH_QINT_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, __VA_ARGS__))
#define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...) \
AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__)
#define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_BYTE_TYPES(__VA_ARGS__))
#define AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(...) \
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
at::kQInt8, at::qint8, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
at::kQUInt8, at::quint8, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
at::kQInt32, \
at::qint32, \
CHAR_BIT * sizeof(int), \
INT_MIN, \
INT_MAX, \
__VA_ARGS__) \
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
at::kQUInt4x2, at::quint4x2, 4, 0, 15, __VA_ARGS__) \
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
at::kQUInt2x4, at::quint2x4, 2, 0, 3, __VA_ARGS__)
#define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(__VA_ARGS__))
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(...) \
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__))
#define AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, ...) \
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, ...) \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, __VA_ARGS__))
#define AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
#define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
SCALARTYPE1, SCALARTYPE2, ...) \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
#define AT_DISPATCH_CASE_ALL_TYPES_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
#define AT_DISPATCH_ALL_TYPES_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_ALL_TYPES_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
TYPE, \
NAME, \
...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
__VA_ARGS__))
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
...) \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
TYPE, \
NAME, \
...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
__VA_ARGS__))
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
SCALARTYPE7, \
...) \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__)
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND7( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
SCALARTYPE7, \
TYPE, \
NAME, \
...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
SCALARTYPE7, \
__VA_ARGS__))
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
SCALARTYPE7, \
SCALARTYPE8, \
...) \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE8, __VA_ARGS__)
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
SCALARTYPE7, \
SCALARTYPE8, \
TYPE, \
NAME, \
...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
SCALARTYPE7, \
SCALARTYPE8, \
__VA_ARGS__))
#define AT_DISPATCH_CASE_BIT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Bits1x8, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Bits2x4, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Bits4x2, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Bits8, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Bits16, __VA_ARGS__)
#define AT_DISPATCH_BIT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_BIT_TYPES(__VA_ARGS__))
#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Int, index_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Long, index_t, __VA_ARGS__))
// ----------------------------------------------------------------------------
// DEPRECATED MACROS, DON'T USE THESE
// ----------------------------------------------------------------------------
#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_ALL_TYPES_AND(at::ScalarType::Half, __VA_ARGS__))

View File

@ -0,0 +1,186 @@
#include <ATen/Dispatch.h>
// This is a new implementation of the AT_DISPATCH macro family from
// ATen/Dispatch.h
//
// The intended usage is:
//
// ScalarType scalar_type;
//
// AT_DISPATCH_V2(
// scalar_type,
// "debug string",
// AT_WRAP([&] {
// ... code to specialize with scalar_t ...
// }),
// kHalf,
// AT_EXPAND(AT_ALL_TYPES),
// ... as many types arguments as needed ...
// )
//
// For example, given an old style:
//
// AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
// kComplexHalf,
// kHalf,
// self.scalar_type(),
// "_local_scalar_dense_cpu",
// [&] {
// scalar_t value = *self.data_ptr<scalar_t>();
// r = Scalar(value);
// }
// )
//
// You now write:
//
// AT_DISPATCH_V2(
// self.scalar_type(),
// "_local_scalar_dense_cpu",
// AT_WRAP([&] {
// scalar_t value = *self.data_ptr<scalar_t>();
// r = Scalar(value);
// }),
// AT_EXPAND(AT_ALL_TYPES),
// AT_EXPAND(AT_COMPLEX_TYPES),
// kComplexHalf,
// kHalf,
// )
//
// Notably, it sports the following improvements:
//
// - It is not necessary to specify the arity (e.g.,
// AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3,4,...})
// when using the macro
//
// - It is not necessary to specify each dtype individually; if
// there is a set of related dtypes and you want to dispatch
// over all of them, you can simply say, e.g., AT_EXPAND(AT_INTEGRAL_TYPES)
// in your argument list.
//
// However, you must remember to wrap the payload body in AT_WRAP, or commas
// inside your lambda will be improperly handled. Furthermore, if you more
// entries to ScalarType than can be supported by this macro, it will fail
// with an obscure error (due to attempting to concatenate AT_AP with
// something that is not a number).
//
// The implementation strategy is to use the count arguments trick
// (e.g., as described in https://stackoverflow.com/a/2124385/23845)
// to discover how many dtypes have been passed, and then dispatch to a
// hand-written macro for each arity that applies as many DISPATCH_CASE as
// necessary. The hand-written macros can be regenerated for other arities
// with the script below.
//
// There is some delicacy in the implementation in controlling when
// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
// relied on GPT4 to help me get it right.
// Public API macros
// See documentation above
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__))
// This macro lets you pass an arbitrary expression that may contain internal
// commas to another macro without having the commas causing the expression
// to be interpreted as being multiple arguments
#define AT_WRAP(...) __VA_ARGS__
#define AT_FLOAT8_TYPES \
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
c10::kFloat8_e4m3fnuz
#define AT_INTEGRAL_TYPES \
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat
#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64
#define AT_INTEGRAL_TYPES_V2 \
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat
#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32
// NB: not *actually* all types
#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
#define AT_ALL_TYPES_AND_COMPLEX \
AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES)
// Helper macros
#define AT_AP_VAR(N, T, ...) \
AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b)
#define AT_CONCAT_AUX(a, b) a##b
#define AT_EXPAND(X) X
// Ensure we never have too many scalar types for the expansion here to
// support. To bump this, you must regenerate the macros below.
static_assert(static_cast<int>(c10::ScalarType::NumOptions) < 45);
// Python code to regenerate generate code below:
#if 0
num_args = 45
nums = ', '.join(str(i) for i in reversed(range(num_args+1)))
args = ', '.join(f'_{i}' for i in range(1, num_args+1))
print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))')
print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N')
for i in range(1, num_args+1):
args = ', '.join(f'_{i}' for i in range(1, i+1))
cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
print(f'#define AT_AP{i}(N, {args}) {cases}')
#endif
// Begin generated code
// clang-format off
#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, N, ...) N
#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)
#define AT_AP4(N, _1, _2, _3, _4) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N)
#define AT_AP5(N, _1, _2, _3, _4, _5) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N)
#define AT_AP6(N, _1, _2, _3, _4, _5, _6) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N)
#define AT_AP7(N, _1, _2, _3, _4, _5, _6, _7) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N)
#define AT_AP8(N, _1, _2, _3, _4, _5, _6, _7, _8) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N)
#define AT_AP9(N, _1, _2, _3, _4, _5, _6, _7, _8, _9) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N)
#define AT_AP10(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N)
#define AT_AP11(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N)
#define AT_AP12(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N)
#define AT_AP13(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N)
#define AT_AP14(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N)
#define AT_AP15(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N)
#define AT_AP16(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N)
#define AT_AP17(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N)
#define AT_AP18(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N)
#define AT_AP19(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N)
#define AT_AP20(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N)
#define AT_AP21(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N)
#define AT_AP22(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N)
#define AT_AP23(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N)
#define AT_AP24(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N)
#define AT_AP25(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N)
#define AT_AP26(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N)
#define AT_AP27(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N)
#define AT_AP28(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N)
#define AT_AP29(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N)
#define AT_AP30(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N)
#define AT_AP31(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N)
#define AT_AP32(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N)
#define AT_AP33(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N)
#define AT_AP34(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N)
#define AT_AP35(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N)
#define AT_AP36(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N)
#define AT_AP37(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N)
#define AT_AP38(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N)
#define AT_AP39(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N)
#define AT_AP40(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N)
#define AT_AP41(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N)
#define AT_AP42(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N)
#define AT_AP43(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N)
#define AT_AP44(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N)
#define AT_AP45(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N)
// End generated code
// clang-format on

View File

@ -0,0 +1,34 @@
#pragma once
#include <ATen/Utils.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
namespace c10 {
class DynamicLibraryError : public Error {
using Error::Error;
};
} // namespace c10
namespace at {
struct DynamicLibrary {
AT_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary);
TORCH_API DynamicLibrary(
const char* name,
const char* alt_name = nullptr,
bool leak_handle = false);
TORCH_API void* sym(const char* name);
TORCH_API ~DynamicLibrary();
private:
bool leak_handle;
void* handle = nullptr;
};
} // namespace at

View File

@ -0,0 +1,166 @@
#pragma once
#include <ATen/core/TensorBase.h>
namespace at::detail {
inline void check_size_nonnegative(ArrayRef<int64_t> size) {
for (const auto& x : size) {
TORCH_CHECK(
x >= 0,
"Trying to create tensor with negative dimension ",
x,
": ",
size);
}
}
inline void check_size_nonnegative(ArrayRef<c10::SymInt> size) {
for (const auto& x : size) {
TORCH_CHECK(
x.expect_size(__FILE__, __LINE__),
"Trying to create tensor with negative dimension ",
x,
": ",
size);
}
}
TORCH_API size_t computeStorageNbytesContiguous(
IntArrayRef sizes,
size_t itemsize,
size_t storage_offset = 0);
TORCH_API SymInt computeStorageNbytesContiguous(
SymIntArrayRef sizes,
const SymInt& itemsize,
const SymInt& storage_offset = 0);
TORCH_API size_t computeStorageNbytes(
IntArrayRef sizes,
IntArrayRef strides,
size_t itemsize,
size_t storage_offset = 0);
TORCH_API SymInt computeStorageNbytes(
SymIntArrayRef sizes,
SymIntArrayRef strides,
const SymInt& itemsize,
const SymInt& storage_offset = 0);
TORCH_API TensorBase empty_generic(
IntArrayRef size,
c10::Allocator* allocator,
c10::DispatchKeySet ks,
ScalarType scalar_type,
std::optional<c10::MemoryFormat> memory_format_opt);
TORCH_API TensorBase empty_generic_symint(
SymIntArrayRef size,
c10::Allocator* allocator,
c10::DispatchKeySet ks,
ScalarType scalar_type,
std::optional<c10::MemoryFormat> memory_format_opt);
TORCH_API TensorBase empty_strided_generic(
IntArrayRef size,
IntArrayRef stride,
c10::Allocator* allocator,
c10::DispatchKeySet ks,
ScalarType scalar_type);
TORCH_API TensorBase empty_strided_symint_generic(
SymIntArrayRef size,
SymIntArrayRef stride,
c10::Allocator* allocator,
c10::DispatchKeySet ks,
ScalarType scalar_type);
TORCH_API TensorBase empty_cpu(
IntArrayRef size,
ScalarType dtype,
bool pin_memory = false,
std::optional<c10::MemoryFormat> memory_format_opt = std::nullopt);
TORCH_API TensorBase empty_cpu(
IntArrayRef size,
std::optional<ScalarType> dtype_opt,
std::optional<Layout> layout_opt,
std::optional<Device> device_opt,
std::optional<bool> pin_memory_opt,
std::optional<c10::MemoryFormat> memory_format_opt);
TORCH_API TensorBase empty_cpu(IntArrayRef size, const TensorOptions& options);
TORCH_API TensorBase empty_strided_cpu(
IntArrayRef size,
IntArrayRef stride,
ScalarType dtype,
bool pin_memory = false);
TORCH_API TensorBase empty_strided_cpu(
IntArrayRef size,
IntArrayRef stride,
std::optional<ScalarType> dtype_opt,
std::optional<Layout> layout_opt,
std::optional<Device> device_opt,
std::optional<bool> pin_memory_opt);
TORCH_API TensorBase empty_strided_cpu(
IntArrayRef size,
IntArrayRef stride,
const TensorOptions& options);
TORCH_API TensorBase empty_meta(
IntArrayRef size,
ScalarType dtype,
std::optional<c10::MemoryFormat> memory_format_opt = std::nullopt);
TORCH_API TensorBase empty_meta(
IntArrayRef size,
std::optional<ScalarType> dtype_opt,
std::optional<Layout> layout_opt,
std::optional<Device> device_opt,
std::optional<bool> pin_memory_opt,
std::optional<c10::MemoryFormat> memory_format_opt);
TORCH_API TensorBase empty_symint_meta(
SymIntArrayRef size,
std::optional<ScalarType> dtype_opt,
std::optional<Layout> layout_opt,
std::optional<Device> device_opt,
std::optional<bool> pin_memory_opt,
std::optional<c10::MemoryFormat> memory_format_opt);
TORCH_API TensorBase empty_meta(IntArrayRef size, const TensorOptions& options);
TORCH_API TensorBase
empty_strided_meta(IntArrayRef size, IntArrayRef stride, ScalarType dtype);
TORCH_API TensorBase empty_strided_meta(
IntArrayRef size,
IntArrayRef stride,
std::optional<ScalarType> dtype_opt,
std::optional<Layout> layout_opt,
std::optional<Device> device_opt,
std::optional<bool> pin_memory_opt);
TORCH_API TensorBase empty_strided_meta(
IntArrayRef size,
IntArrayRef stride,
const TensorOptions& options);
TORCH_API TensorBase empty_strided_symint_meta(
SymIntArrayRef size,
SymIntArrayRef stride,
ScalarType dtype);
TORCH_API TensorBase empty_strided_symint_meta(
SymIntArrayRef size,
SymIntArrayRef stride,
std::optional<ScalarType> dtype_opt,
std::optional<Layout> layout_opt,
std::optional<Device> device_opt);
TORCH_API TensorBase empty_strided_symint_meta(
SymIntArrayRef size,
SymIntArrayRef stride,
const TensorOptions& options);
} // namespace at::detail

View File

@ -0,0 +1,30 @@
#include <ATen/core/TensorBase.h>
// Broadcasting utilities for working with TensorBase
namespace at {
namespace internal {
TORCH_API TensorBase expand_slow_path(const TensorBase& self, IntArrayRef size);
} // namespace internal
inline c10::MaybeOwned<TensorBase> expand_size(
const TensorBase& self,
IntArrayRef size) {
if (size.equals(self.sizes())) {
return c10::MaybeOwned<TensorBase>::borrowed(self);
}
return c10::MaybeOwned<TensorBase>::owned(
at::internal::expand_slow_path(self, size));
}
c10::MaybeOwned<TensorBase> expand_size(TensorBase&& self, IntArrayRef size) =
delete;
inline c10::MaybeOwned<TensorBase> expand_inplace(
const TensorBase& tensor,
const TensorBase& to_expand) {
return expand_size(to_expand, tensor.sizes());
}
c10::MaybeOwned<TensorBase> expand_inplace(
const TensorBase& tensor,
TensorBase&& to_expand) = delete;
} // namespace at

View File

@ -0,0 +1,527 @@
#pragma once
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/view.h>
#include <ATen/ops/view_copy.h>
#endif
#include <ATen/Tensor.h>
#include <ATen/core/DimVector.h>
#include <c10/util/Exception.h>
#include <c10/util/MaybeOwned.h>
#include <c10/util/irange.h>
#include <functional>
#include <tuple>
#include <utility>
namespace at {
TORCH_API std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b);
TORCH_API std::vector<SymInt> infer_size_symint(
SymIntArrayRef a,
SymIntArrayRef b);
TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b);
TORCH_API SymDimVector
infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b);
// Named type instead of a pair/tuple so that we can be sure to
// construct the vectors in place and get NRVO.
template <typename Container>
struct InferExpandGeometryResult {
Container sizes;
Container strides;
explicit InferExpandGeometryResult(size_t ndim)
: sizes(ndim), strides(ndim) {}
explicit InferExpandGeometryResult(IntArrayRef sizes_, size_t ndim)
: sizes(sizes_.begin(), sizes_.end()), strides(ndim) {}
};
TORCH_API std::tuple<std::vector<int64_t>, std::vector<int64_t>>
inferExpandGeometry(
IntArrayRef tensor_sizes,
IntArrayRef tensor_strides,
IntArrayRef sizes);
TORCH_API InferExpandGeometryResult<DimVector> inferExpandGeometry_dimvector(
IntArrayRef tensor_sizes,
IntArrayRef tensor_strides,
IntArrayRef sizes);
TORCH_API std::vector<int64_t> infer_dense_strides(
IntArrayRef tensor_sizes,
IntArrayRef tensor_strides);
// True if input shapes are expandable
// NOTE: infer_size did a similar check, please keep them sync if change is
// needed
inline bool are_expandable(IntArrayRef shape1, IntArrayRef shape2) {
size_t ndim1 = shape1.size();
size_t ndim2 = shape2.size();
size_t ndim = ndim1 < ndim2 ? ndim1 : ndim2;
for (int64_t i = static_cast<int64_t>(ndim) - 1; i >= 0; --i) {
if (shape1[--ndim1] == shape2[--ndim2] || shape1[ndim1] == 1 ||
shape2[ndim2] == 1) {
continue;
}
return false;
}
return true;
}
// avoid copy-construction of Tensor by using a reference_wrapper.
inline void check_defined(
std::initializer_list<std::reference_wrapper<const Tensor>> tensors,
const char* api_name) {
for (auto& t : tensors) {
if (!t.get().defined()) {
AT_ERROR(api_name, "(...) called with an undefined Tensor");
}
}
}
// NOTE [ ExpandUtils Borrowing ]
//
// Functions in ExpandUtils return `c10::MaybeOwned<Tensor>` because
// expansion may not actually be needed, in which case we can improve
// efficiency by returning
// `c10::MaybeOwned<Tensor>::borrowed(to_expand)`. However, this means
// that you need to be careful: the returned `c10::MaybeOwned<Tensor>`
// must not outlive the original `Tensor` object that `to_expand`
// referred to! The deleted rvalue reference overloads of these
// functions help with this by preventing trivial use of a temporary
// resulting from a function call, but it is still possible to make a
// mistake.
inline c10::MaybeOwned<Tensor> expand_inplace(
const Tensor& tensor,
const Tensor& to_expand) {
if (tensor.sym_sizes().equals(to_expand.sym_sizes())) {
return c10::MaybeOwned<Tensor>::borrowed(to_expand);
}
return c10::MaybeOwned<Tensor>::owned(
to_expand.expand_symint(tensor.sym_sizes()));
}
inline c10::MaybeOwned<Tensor> expand_inplace(
const Tensor& tensor,
Tensor&& to_expand) = delete;
inline c10::MaybeOwned<Tensor> expand_inplace(
const Tensor& tensor,
const Tensor& to_expand,
const char* api_name) {
check_defined({tensor, to_expand}, api_name);
return expand_inplace(tensor, to_expand);
}
inline c10::MaybeOwned<Tensor> expand_inplace(
const Tensor& tensor,
Tensor&& to_expand,
const char* api_name) = delete;
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_inplace(
const Tensor& tensor,
const Tensor& to_expand1,
const Tensor& to_expand2) {
if (tensor.sizes().equals(to_expand1.sizes()) &&
tensor.sizes().equals((to_expand2.sizes()))) {
return std::make_tuple(
c10::MaybeOwned<Tensor>::borrowed(to_expand1),
c10::MaybeOwned<Tensor>::borrowed(to_expand2));
}
return std::make_tuple(
c10::MaybeOwned<Tensor>::owned(to_expand1.expand(tensor.sizes())),
c10::MaybeOwned<Tensor>::owned(to_expand2.expand(tensor.sizes())));
}
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_inplace(
const Tensor& tensor,
Tensor&& to_expand1,
const Tensor& to_expand2) = delete;
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_inplace(
const Tensor& tensor,
const Tensor& to_expand1,
Tensor&& to_expand2) = delete;
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_inplace(const Tensor& tensor, Tensor&& to_expand1, Tensor&& to_expand2) =
delete;
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_inplace(
const Tensor& tensor,
const Tensor& to_expand1,
const Tensor& to_expand2,
const char* api_name) {
check_defined({tensor, to_expand1, to_expand2}, api_name);
return expand_inplace(tensor, to_expand1, to_expand2);
}
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_inplace(
const Tensor& tensor,
Tensor&& to_expand1,
const Tensor& to_expand2,
const char* api_name) = delete;
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_inplace(
const Tensor& tensor,
const Tensor& to_expand1,
Tensor&& to_expand2,
const char* api_name) = delete;
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_inplace(
const Tensor& tensor,
Tensor&& to_expand1,
Tensor&& to_expand2,
const char* api_name) = delete;
// See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation.
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) {
auto s1 = to_expand1.sym_sizes();
auto s2 = to_expand2.sym_sizes();
if (s1.equals(s2)) {
return std::make_tuple(
c10::MaybeOwned<Tensor>::borrowed(to_expand1),
c10::MaybeOwned<Tensor>::borrowed(to_expand2));
}
auto expanded_size = infer_size_symdimvector(s1, s2);
return std::make_tuple(
c10::MaybeOwned<Tensor>::owned(to_expand1.expand_symint(expanded_size)),
c10::MaybeOwned<Tensor>::owned(to_expand2.expand_symint(expanded_size)));
}
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_outplace(Tensor&& to_expand1, const Tensor& to_expand2) = delete;
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_outplace(const Tensor& to_expand1, Tensor&& to_expand2) = delete;
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2) = delete;
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_outplace(
const Tensor& to_expand1,
const Tensor& to_expand2,
const char* api_name) {
check_defined({to_expand1, to_expand2}, api_name);
return expand_outplace(to_expand1, to_expand2);
}
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_outplace(
Tensor&& to_expand1,
const Tensor& to_expand2,
const char* api_name) = delete;
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_outplace(
const Tensor& to_expand1,
Tensor&& to_expand2,
const char* api_name) = delete;
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_outplace(
Tensor&& to_expand1,
Tensor&& to_expand2,
const char* api_name) = delete;
inline std::tuple<
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>>
expand_outplace(
const Tensor& to_expand1,
const Tensor& to_expand2,
const Tensor& to_expand3) {
if (to_expand1.sizes().equals(to_expand2.sizes()) &&
to_expand1.sizes().equals(to_expand3.sizes())) {
return std::make_tuple(
c10::MaybeOwned<Tensor>::borrowed(to_expand1),
c10::MaybeOwned<Tensor>::borrowed(to_expand2),
c10::MaybeOwned<Tensor>::borrowed(to_expand3));
}
auto expanded_size12 =
infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes());
auto expanded_size =
infer_size_dimvector(expanded_size12, to_expand3.sizes());
return std::make_tuple(
c10::MaybeOwned<Tensor>::owned(to_expand1.expand(expanded_size)),
c10::MaybeOwned<Tensor>::owned(to_expand2.expand(expanded_size)),
c10::MaybeOwned<Tensor>::owned(to_expand3.expand(expanded_size)));
}
inline std::tuple<
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>>
expand_outplace(
Tensor&& to_expand1,
const Tensor& to_expand2,
const Tensor& to_expand3) = delete;
inline std::tuple<
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>>
expand_outplace(
const Tensor& to_expand1,
Tensor&& to_expand2,
const Tensor& to_expand3) = delete;
inline std::tuple<
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>>
expand_outplace(
Tensor&& to_expand1,
Tensor&& to_expand2,
const Tensor& to_expand3) = delete;
inline std::tuple<
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>>
expand_outplace(
const Tensor& to_expand1,
const Tensor& to_expand2,
Tensor&& to_expand3) = delete;
inline std::tuple<
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>>
expand_outplace(
Tensor&& to_expand1,
const Tensor& to_expand2,
Tensor&& to_expand3) = delete;
inline std::tuple<
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>>
expand_outplace(
const Tensor& to_expand1,
Tensor&& to_expand2,
Tensor&& to_expand3) = delete;
inline std::tuple<
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>>
expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2, Tensor&& to_expand3) =
delete;
inline std::tuple<
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>>
expand_outplace(
const Tensor& to_expand1,
const Tensor& to_expand2,
const Tensor& to_expand3,
const char* api_name) {
check_defined({to_expand1, to_expand2, to_expand3}, api_name);
return expand_outplace(to_expand1, to_expand2, to_expand3);
}
inline std::tuple<
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>>
expand_outplace(
Tensor&& to_expand1,
const Tensor& to_expand2,
const Tensor& to_expand3,
const char* api_name) = delete;
inline std::tuple<
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>>
expand_outplace(
const Tensor& to_expand1,
Tensor&& to_expand2,
const Tensor& to_expand3,
const char* api_name) = delete;
inline std::tuple<
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>>
expand_outplace(
Tensor&& to_expand1,
Tensor&& to_expand2,
const Tensor& to_expand3,
const char* api_name) = delete;
inline std::tuple<
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>>
expand_outplace(
const Tensor& to_expand1,
const Tensor& to_expand2,
Tensor&& to_expand3,
const char* api_name) = delete;
inline std::tuple<
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>>
expand_outplace(
Tensor&& to_expand1,
const Tensor& to_expand2,
Tensor&& to_expand3,
const char* api_name) = delete;
inline std::tuple<
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>>
expand_outplace(
const Tensor& to_expand1,
Tensor&& to_expand2,
Tensor&& to_expand3,
const char* api_name) = delete;
inline std::tuple<
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>,
c10::MaybeOwned<Tensor>>
expand_outplace(
Tensor&& to_expand1,
Tensor&& to_expand2,
Tensor&& to_expand3,
const char* api_name) = delete;
inline c10::MaybeOwned<Tensor> expand_size(
const Tensor& to_expand,
IntArrayRef sizes) {
if (to_expand.sizes().equals(sizes)) {
return c10::MaybeOwned<Tensor>::borrowed(to_expand);
}
return c10::MaybeOwned<Tensor>::owned(to_expand.expand(sizes));
}
inline c10::MaybeOwned<Tensor> expand_size(
Tensor&& to_expand,
IntArrayRef sizes) = delete;
inline c10::MaybeOwned<Tensor> expand_size(
const Tensor& to_expand,
IntArrayRef sizes,
const char* api_name) {
check_defined({to_expand}, api_name);
return expand_size(to_expand, sizes);
}
inline c10::MaybeOwned<Tensor> expand_size(
Tensor&& to_expand,
IntArrayRef sizes,
const char* api_name) = delete;
inline std::vector<Tensor> expand_outplace(TensorList to_expand) {
// expands a list of Tensors; ignores undefined (null) tensors
bool first = true;
DimVector sizes;
for (const auto i : c10::irange(to_expand.size())) {
if (!to_expand[i].defined()) {
continue;
} else if (first) {
sizes = to_expand[i].sizes();
first = false;
} else {
sizes = infer_size_dimvector(sizes, to_expand[i].sizes());
}
}
std::vector<Tensor> result(to_expand.size());
for (const auto i : c10::irange(to_expand.size())) {
if (!to_expand[i].defined()) {
continue;
} else if (to_expand[i].sizes().equals(sizes)) {
result[i] = to_expand[i];
} else {
result[i] = to_expand[i].expand(sizes);
}
}
return result;
}
template <typename T>
inline Tensor _sum_to(
Tensor tensor,
const c10::ArrayRef<T> shape,
bool always_return_non_view = false) {
if (shape.size() == 0) {
return tensor.sum();
}
auto sizes = at::symint::sizes<T>(tensor);
c10::SmallVector<int64_t, 8> reduce_dims;
const int64_t leading_dims = sizes.size() - shape.size();
for (const auto i : c10::irange(leading_dims)) {
reduce_dims.push_back(i);
}
for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) {
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(shape[i - leading_dims], 1)) &&
TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(sizes[i], 1))) {
reduce_dims.push_back(i);
}
}
if (!reduce_dims.empty()) {
tensor = tensor.sum(reduce_dims, /*keepdim=*/true);
}
if (always_return_non_view) {
// This is only actually used by the functionalization pass.
// We want to be able to guarantee that this function doesn't return a view
// of the input.
return leading_dims > 0 ? at::symint::view_copy<T>(tensor, shape)
: tensor.clone();
} else {
return leading_dims > 0 ? at::symint::view<T>(tensor, shape) : tensor;
}
}
inline Tensor sum_to(
Tensor tensor,
const c10::SymIntArrayRef shape,
bool always_return_non_view = false) {
return _sum_to(std::move(tensor), shape, always_return_non_view);
}
// Sums `tensor` repeatedly to produce a tensor of shape `shape`.
// Precondition: is_expandable_to(shape, tensor.sizes()) must be true
inline Tensor sum_to(
Tensor tensor,
const IntArrayRef shape,
bool always_return_non_view = false) {
return _sum_to(std::move(tensor), shape, always_return_non_view);
}
inline bool is_expandable_to(
SymIntArrayRef shape,
c10::SymIntArrayRef desired) {
size_t ndim = shape.size();
size_t target_dim = desired.size();
if (ndim > target_dim) {
return false;
}
for (const auto i : c10::irange(ndim)) {
const auto& size = shape[ndim - i - 1];
const auto& target = desired[target_dim - i - 1];
if (size != target && size != 1) {
return false;
}
}
return true;
}
inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
auto sym_shape = c10::SymIntArrayRef(
reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size());
auto sym_desired = c10::SymIntArrayRef(
reinterpret_cast<const c10::SymInt*>(desired.data()), desired.size());
return is_expandable_to(sym_shape, sym_desired);
}
} // namespace at

View File

@ -0,0 +1 @@
#include <ATen/core/Formatting.h>

View File

@ -0,0 +1,46 @@
#pragma once
#include <c10/macros/Macros.h>
#include <memory>
namespace at::functorch {
// NOTE [functorch TLS in pytorch/pytorch]
//
// functorch lives out-of-tree. However, it has some TLS that needs to be
// propagated. The solution for that is we store a pointer to the TLS
// inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to
// include whatever functorch needs.
//
// We need to store a pointer due to the indirection:
// inside functorch, we will create a subclass of FunctorchTLSBase called
// FuncTorchTLSImpl that actually contains metadata, like the DynamicLayerStack.
// FuncTorchTLSBase doesn't have any metadata because it hasn't been defined
// yet.
//
// Here in pytorch/pytorch, we will pass around FuncTorchTLSBase*, but inside
// functorch, we will assign a FuncTorchTLSImpl* to the FunctorchTLSBase*.
// We can't directly pass around FunctorchTLSBase (without a pointer) because
// FuncTorchTLSImpl does not fit inside a FuncTorchTLSBase by virtue of having
// more elements.
struct TORCH_API FuncTorchTLSBase {
virtual ~FuncTorchTLSBase() = default;
virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;
virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0;
virtual void checkSupportsCppAutogradFunction() const = 0;
virtual void checkSupportsInplaceRequiresGrad() const = 0;
virtual void checkSupportsRetainGrad() const = 0;
};
// returns deepcopy of the functorch tls
TORCH_API std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS();
// sets the functorch tls. always does a deep copy.
TORCH_API void setFuncTorchTLS(
const std::shared_ptr<const FuncTorchTLSBase>& state);
// get a mutable reference to the functorch tls
TORCH_API std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor();
} // namespace at::functorch

View File

@ -0,0 +1,208 @@
#pragma once
#include <ATen/Tensor.h>
#include <utility>
namespace at::functionalization {
// See Note [Functionalization Pass In Core]
// ViewMeta is a class used by the functionalization pass to navigate between
// a base tensor and a view tensor.
// For example, if I call `b = a.view1(...)`
// the functionalization pass will generate and store a ViewMeta on b that looks
// like:
//
// ViewMeta(
// [<captures>](const Tensor& base, int64_t mutated_view_idx) {
// return base.view1(...);
// },
// [<captures>](const at::Tensor& base, const at::Tensor& mutated_view,
// int64_t mutated_view_idx) -> at::Tensor {
// return at::functionalization::impl::view1_inverse(base, mutated_view,
// ...);
// }
//
// The forward_fn lambda describes how to replay view1 on a tensor.
//
// The reverse_fn lambda describes how, given a tensor that is already a view,
// how to get the corresponding base tensor. See Note [Functionalization Pass:
// View Inverses] for details.
struct ViewMeta {
ViewMeta(
std::function<Tensor(const Tensor&, int64_t)> forward,
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse,
bool has_symbolic_inputs,
bool is_multi_output = false,
bool is_as_strided = false,
int64_t out_idx = 0)
: forward_fn(std::move(forward)),
reverse_fn(std::move(reverse)),
out_index(out_idx),
is_multi_output(is_multi_output),
is_as_strided(is_as_strided),
has_symbolic_inputs(has_symbolic_inputs) {}
std::function<Tensor(const Tensor&, int64_t)> forward_fn;
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn;
// See Note [out_idx in ViewMeta]
int64_t out_index;
// Tells us if this is a multi-output view
bool is_multi_output;
bool is_as_strided;
// Tells us if this view operation has any symbolic inputs
bool has_symbolic_inputs;
// Returns a copy of the current ViewMeta, if out_idx matches the current
// out_index. Otherwise, returns a new ViewMeta with the same forward/reverse
// functions, but a new out index.
ViewMeta to_out_idx(int64_t out_idx);
};
// FunctionalStorageImpl is a subclass of StorageImpl used by the
// functionalization pass. It has no underlying data (similar to meta storage).
// It also knows how to reflect mutations to tensors in the absence of a valid
// data pointer.
//
// A storage represents the state shared by (potentially multiple) views of the
// same tensor. For example, in the following code:
//
// b = a.view1(...)
// c = b.view2(...)
// b.add_(1)
// --> storage.add_update(b, {view1_meta})
//
// The call to add_(1) will result in a call to alias.add_update(b,
// {view1_meta}), queueing up the mutation from b onto the alias. Later, suppose
// c is used in an expression (e.g. you try to print c, or pass it to an
// operator). Doing so will involve "syncing" c. First we apply any pending
// updates to the alias, and then we regenerate c by replaying its views off of
// the updated alias. E.g:
//
// print(str(c))
// --> c.sync_()
// --> alias.apply_updates() // after this, the alias will be updated to
// reflect the mutation to b
struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
public:
struct Update {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const at::Tensor new_val;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::vector<ViewMeta> view_metas;
};
explicit FunctionalStorageImpl(const Tensor& value);
void add_update(
const Tensor& updated_val,
const std::vector<ViewMeta>& view_metas);
bool apply_updates();
const Tensor& base() {
return base_;
}
size_t generation() const {
return generation_;
}
void freeze() {
frozen_ = true;
}
c10::SymInt get_storage_size(bool before) {
if (before) {
return original_storage_size_;
} else {
return curr_storage_size_;
}
}
~FunctionalStorageImpl() override = default;
void mark_mutation() {
mutation_counter_++;
}
void mark_mutation_during_no_grad_or_inference_mode() {
mutation_counter_during_no_grad_or_inference_mode_++;
}
void mark_mutation_hidden_from_autograd() {
mutation_counter_hidden_from_autograd_++;
}
bool are_all_mutations_under_no_grad_or_inference_mode() const {
auto non_autograd_mutations =
mutation_counter_during_no_grad_or_inference_mode_ +
mutation_counter_hidden_from_autograd_;
// The <= is because both counters will technically be incremented, if we
// perform e.g. a triton kernel mutation under no_grad
return mutation_counter_ <= non_autograd_mutations;
}
bool are_all_mutations_hidden_from_autograd() const {
// mutations under no_grad / inference_mode are technically not hidden from
// autograd - they change the version counter
return mutation_counter_ <= mutation_counter_hidden_from_autograd_;
}
void mark_inductor_storage_resize(c10::SymInt new_size) {
inductor_storage_resized_ = true;
curr_storage_size_ = std::move(new_size);
}
bool was_inductor_storage_resized() {
return inductor_storage_resized_;
}
private:
// NB: base_ should always point to a tensor BELOW the current
// functionalization layer. This is mainly to avoid reference cycles. e.g.
// given `b = a.view(...)` Both a.storage_ and b.storage_ are a
// FunctionStorageImpl containing an Walualias, with contains a Tensor
// `base_`. In this case (where a and b are FunctionalTensorWrapper's), base_
// should point not to a, but to a's unwrapped value, a.value_` See Note
// [Functionalization: Walualias Removal] for a diagram that shows this
// visually.
at::Tensor base_;
std::vector<Update> updates_;
// generation_ gets incremented every time a mutation is queued onto the
// alias. It is used to determine if a given tensor is "up to date", or if it
// needs to be regenerated from the alias.
size_t generation_ = 0;
// If frozen, no more mutations are allowed on this storage. Once frozen, a
// storage cannot be unfrozen.
bool frozen_ = false;
// These mutation counters are bumped on the storage
// whenever a FunctionalTensorWrapper experiences a mutation.
// When the mutation is under no_grad, or comes from a triton kernel, we also
// bump the corresponding during_no_grad or hidden_from_autograd counters. Why
// do we need to detect these two situations separately from "normal" input
// mutations? (1) "normal" input mutations can mutate autograd metadata like
// .grad_fn,
// in which case they need to be replayed outside of the compiled graph
// (2) "no_grad" input mutations are generally safe to keep in the graph (and
// compile),
// but they bump the tensor's VC, so we need to mark_dirty() on the inputs
// in torch.compile
// (3) mutations that are fully hidden from autograd (e.g. from a triton
// kernel)
// do not mutate any autograd state, and be fully kept in the graph
// When we detect that an input was mutated, we need to be able to tell if:
// (1) all of the mutations were from triton kernels
// (2) all of the mutations were under no_grad
uint64_t mutation_counter_during_no_grad_or_inference_mode_ = 0;
uint64_t mutation_counter_ = 0;
uint64_t mutation_counter_hidden_from_autograd_ = 0;
// Used to tell if:
// (1) There were any storage resizes on a graph input
// (2) The original/curr storage size tell us if these resizes result in a nop
bool inductor_storage_resized_ = false;
c10::SymInt original_storage_size_;
c10::SymInt curr_storage_size_;
};
} // namespace at::functionalization

View File

@ -0,0 +1,454 @@
#pragma once
#include <ATen/ArrayRef.h>
#include <ATen/FunctionalStorageImpl.h>
#include <ATen/core/IListRef.h>
#include <ATen/core/List.h>
#include <ATen/core/boxing/BoxedKernel.h>
#include <ATen/core/boxing/impl/boxing.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/core/DispatchKey.h>
namespace at {
// Note [Functionalization Pass In Core]
// The Functionalization pass is used to remove aliasing from a pytorch program.
//
// This is useful for backends that don't support aliasing, like XLA and Vulkan.
// It's also necessary in order to remove mutation from a program, which is
// needed in Functorch.
//
// Consider this program:
// a = torch.ones(...)
// b = a.view(...)
// b.add_(1)
//
// In this program, b is meant to alias with a due to the use of view(). At the
// end of the program, both a and b are full of 2's. However, backends that
// don't support aliasing aren't able to correctly implement the view()
// operator. Instead, they can opt into the Functionalization pass, which will
// sit between the user and the backend, and provide the necessary aliasing
// logic.
//
// The functionalization pass will turn the above program into a slightly
// different program that has the same semantics, transparently to the user,
// that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
// a.view_copy(...) # view() replaced with view_copy(). Backends like
// XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization
// pass machinery knows that a and b are aliased - it applies b's mutation to a
// too.
//
// So, how does the functionalization pass keep track of which tensors are
// aliased? The pass works by wrapping EVERY tensor in the program inside of a
// FunctionalTensorWrapper, which knows about its alias'd tensors.
//
// See Note [Functionalization: Alias Removal] for details on the aliasing
// machinery. See Note [Functionalization: Mutation Removal] for details on
// mutation removal.
struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
explicit FunctionalTensorWrapper(const Tensor& value);
// Additional constructor to create a FunctionalTensorWrapper directly from an
// underlying tensor that was created from a view. For example, the code b =
// a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
// view1_meta)
explicit FunctionalTensorWrapper(
const Tensor& view_value,
const FunctionalTensorWrapper* base,
const functionalization::ViewMeta& meta);
// Get the underlying, actual tensor, that doesn't know anything about
// functionalization.
const Tensor& value() const {
return value_;
};
// The concept of "level" is only ever important to functorch; it's exposed
// here as more of a hook for functorch to use.
int64_t level() const {
return level_;
};
void set_level(int64_t level) {
level_ = level;
}
bool has_metadata_mutation() const {
return has_metadata_mutation_;
};
void mark_mutation() {
functional_storage_impl()->mark_mutation();
}
// Denotes a mutation that's hidden from autograd,
// e.g. for the purposes of passing a tensor to a triton kernel
void mark_mutation_hidden_from_autograd() {
functional_storage_impl()->mark_mutation_hidden_from_autograd();
}
void mark_mutation_during_no_grad_or_inference_mode() {
functional_storage_impl()->mark_mutation_during_no_grad_or_inference_mode();
}
// Are all the mutations happening to the tensor hidden from autograd
bool are_all_mutations_hidden_from_autograd() const {
return functional_storage_impl()->are_all_mutations_hidden_from_autograd();
}
// Did all mutations happen under no_grad or inference_mode
// (We also need to ignore mutations fully hidden from autograd here)
bool are_all_mutations_under_no_grad_or_inference_mode() const {
return functional_storage_impl()
->are_all_mutations_under_no_grad_or_inference_mode();
}
void maybe_mark_symbolic(const functionalization::ViewMeta& meta) {
is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs;
}
bool is_symbolic() const {
return is_symbolic_;
}
// Runs the forward_fn of every ViewMeta collected in the current instance
// to some other base.
Tensor apply_view_metas(const Tensor& base);
// Sync's the underlying tensor with its alias, if it's out of date. This
// involves two steps: 1) Apply any pending updates/mutations to the alias 2)
// Replay the views (if any) to regenerate the current tensor off of the
// updated alias.
void sync_();
// Performs step (1) of the sync. This is its own public API because it's
// needed by view_inplace ops like transpose_. See Note [Functionalization
// Pass - Inplace View Ops]
void regenerate_from_base();
// Performs step (2) of the sync. This is its own public API because it's
// needed by functorch. functorch wants to make sure that all input tensors to
// a functionalized program have been properly synced so it can properly
// propagate mutations to inputs. It can't just call sync_(), because the
// FunctionalTensorWrapper will look like it has no aliases and sync_ will be
// a noop. We use the reference count on storage_ to determine if the wrapper
// is aliased, and by the time functorch is ready to propagate updates to
// inputs, any intermediate views of the input created by the program will
// have been deallocated. This function also returns whether or not the base
// actually had any updates to apply.
bool apply_updates();
// Takes the current state of value_ and snapshots it, sending it as a pending
// update to the alias.
void commit_update();
// When any tensor is mutated, the tensor increments its alias's "generation".
// Separately, each tensor maintains its own "generation" counter, which is
// used to determine if it's up-to-date with its alias. The act of syncing a
// tensor will set a tensor's generation equal to its alias's generation.
bool is_up_to_date() const;
// Freezes the storage of this tensor, preventing subsequent mutations
void freeze_storage() const;
// Every FunctionalTensorWrapper contains a vector<ViewMeta> objects
// describing the series of view ops that ran to generate the current tensor
// from the base tensor. This method is used by inplace-view ops like
// transpose_. It appends a ViewMeta to the existing stack, and refreshes the
// tensor by replaying the views off of the alias.
void mutate_view_meta(const at::functionalization::ViewMeta& meta);
// Custom implementation of self.set_(src)
void set__impl(const FunctionalTensorWrapper* other);
// Custom implementation of resize_storage_bytes_(self, new_size)
void storage_resize_(const c10::SymInt& new_size);
// Returns whether the current tensor's data was ever mutated
bool has_data_mutation();
//
// Returns whether the current FunctionalTensorWrapper
// experienced a set_() call.
bool was_storage_changed() {
return was_storage_changed_;
}
void set_storage_changed() {
was_storage_changed_ = true;
}
// A FunctionalTensor is considered a base if its not a view of another
// tensor.
bool isBaseTensor() const {
return view_metas_.empty();
}
c10::SymInt get_storage_size(bool before) {
return functional_storage_impl()->get_storage_size(before);
}
// Returns whether the FunctionalTensor experienced an
// untyped_storage().resize_() call
bool was_inductor_storage_resized() {
return functional_storage_impl()->was_inductor_storage_resized();
}
// The functionalization pass can be used to remove mutations.
// It does so by replacing any mutation op with it's corresponding
// out-of-place op, followed by a call to replace_(). e.g:
//
// a.add_(1)
//
// will turn into:
//
// tmp = a.add(1)
// a.replace_(tmp)
//
// replace_() swaps out the wrapped tensor, value_, with tmp.
void replace_(const Tensor& other, bool from_lazy_regenerate = false);
bool is_multi_output_view() {
return is_multi_output_view_;
}
// See Note[resize_() in functionalization pass]
void maybe_replace_storage(const Tensor& other);
// Replaces the storage with a new functional storage,
// and clears the view_metas_ stack.
// WARNING: Calling this function will sever the aliasing relationship between
// the current FunctionalTensorWrapper and any of its outstanding aliases.
// Please only call if you know what you're doing.
void _unsafe_reset_storage();
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const override;
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const override;
~FunctionalTensorWrapper() override = default;
// FunctionalTensorWrapper overrides all custom size/stride function,
// so that if the inner tensor has a custom implementation
// we make sure to call that implementation.
at::IntArrayRef sizes_custom() const override;
at::IntArrayRef strides_custom() const override;
int64_t dim_custom() const override;
int64_t numel_custom() const override;
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
c10::SymIntArrayRef sym_sizes_custom() const override;
c10::SymInt sym_size_custom(int64_t d) const override;
c10::SymIntArrayRef sym_strides_custom() const override;
c10::SymInt sym_storage_offset_custom() const override;
c10::Device device_custom() const override;
c10::Layout layout_impl() const override;
private:
const char* tensorimpl_type_name() const override;
void set_constructor_metadata();
functionalization::FunctionalStorageImpl* functional_storage_impl() const;
// This is used to re-implement shallow_copy_and_detach for
// FunctionalTensorWrapper. The implementation is identical, but we just need
// to return a subclass instead of a plain TensorImpl.
// TODO: maybe it's possible to arrange for that to happen automatically
// without an override here?
template <typename VariableVersion>
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const;
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
void copy_tensor_metadata_and_refresh(
const FunctionalTensorWrapper* src_impl,
FunctionalTensorWrapper* dest_impl,
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const;
// Note that value is not taken by reference: internally, the wrapper will
// change the value tensor that it points to over time.
Tensor value_;
int64_t level_{};
// These two counters are used for identifying
// whether all the mutations on a given tensor are hidden from autograd or
// not. If we have an input mutation that is hidden from autograd, then once
// we convert the input mutation to a copy_() we know it will be safe to hide
// the copy_() from autograd as well.
bool has_metadata_mutation_ = false;
bool is_multi_output_view_ = false;
// Did the tensor experience a set_() call.
bool was_storage_changed_ = false;
// Did the tensor experience any view operation with symbolic int.
bool is_symbolic_ = false;
size_t generation_ = 0;
std::vector<at::functionalization::ViewMeta> view_metas_;
protected:
static void copy_tensor_metadata(
const FunctionalTensorWrapper* src_impl,
FunctionalTensorWrapper* dest_impl,
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change);
};
// Utility functions for the functionalization pass.
namespace functionalization {
namespace impl {
TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
const Tensor& tensor) {
auto functional_impl =
static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
return functional_impl;
}
TORCH_API bool isBaseTensor(const at::Tensor& tensor);
TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
TORCH_API bool isFunctionalTensor(const std::optional<Tensor>& t);
TORCH_API bool isFunctionalTensor(
const c10::List<std::optional<Tensor>>& t_list);
TORCH_API bool isFunctionalTensor(ITensorListRef list);
TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
TORCH_API std::optional<Tensor> to_functional_tensor(
const std::optional<Tensor>& tensor);
TORCH_API c10::List<std::optional<Tensor>> to_functional_tensor(
const c10::List<std::optional<Tensor>>& t_list);
TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);
TORCH_API void freeze_functional_tensor(const Tensor& tensor);
TORCH_API Tensor
from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
TORCH_API std::optional<Tensor> from_functional_tensor(
const std::optional<Tensor>& t,
bool assert_functional = true);
TORCH_API c10::List<std::optional<Tensor>> from_functional_tensor(
const c10::List<std::optional<Tensor>>& t_list);
TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
TORCH_API void sync(const at::Tensor& t);
TORCH_API void sync(const std::optional<Tensor>& t);
TORCH_API void sync(const c10::List<std::optional<Tensor>>& t_list);
TORCH_API void sync(ITensorListRef t_list);
TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
TORCH_API void replace_(
const ITensorListRef functional_tensor,
ITensorListRef other);
TORCH_API void commit_update(const Tensor& functional_tensor);
TORCH_API void commit_update(ITensorListRef functional_tensor);
TORCH_API void unsafe_reset_storage(const Tensor& functional_tensor);
TORCH_API void mark_mutation_hidden_from_autograd(
const Tensor& functional_tensor);
TORCH_API bool are_all_mutations_hidden_from_autograd(
const Tensor& functional_tensor);
TORCH_API bool are_all_mutations_under_no_grad_or_inference_mode(
const Tensor& functional_tensor);
// These two methods are XLA-specific logic and are no-ops
// for the normal functionalization flow.
TORCH_API void propagate_xla_data(
const Tensor& functional_tensor,
const Tensor& other);
TORCH_API void propagate_xla_data(
const ITensorListRef functional_tensor,
ITensorListRef other);
TORCH_API void propagate_xla_data_direct(
const Tensor& tensor,
const Tensor& other);
TORCH_API void propagate_xla_data_direct(
const ITensorListRef tensor,
ITensorListRef other);
Tensor create_functional_tensor_with_view_meta(
const Tensor& view_to_wrap,
const Tensor& base,
functionalization::ViewMeta meta,
int64_t out_idx = 0);
std::vector<Tensor> create_functional_tensor_with_view_meta(
ITensorListRef view_to_wrap,
const Tensor& base,
const functionalization::ViewMeta& meta);
void mutate_view_meta(
const Tensor& self,
const functionalization::ViewMeta& meta);
void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
void set_sizes_strides_offset(
const std::vector<Tensor>& outs,
const std::vector<Tensor>& meta_outs);
// ~~~~~ TLS used in functionalization ~~~~~
TORCH_API bool getFunctionalizationReapplyViewsTLS();
TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
class TORCH_API FunctionalizationReapplyViewsGuard {
public:
FunctionalizationReapplyViewsGuard(bool reapply_views)
: prev_(getFunctionalizationReapplyViewsTLS()) {
setFunctionalizationReapplyViewsTLS(reapply_views);
}
~FunctionalizationReapplyViewsGuard() {
setFunctionalizationReapplyViewsTLS(prev_);
}
FunctionalizationReapplyViewsGuard(
const FunctionalizationReapplyViewsGuard&) = delete;
FunctionalizationReapplyViewsGuard operator=(
const FunctionalizationReapplyViewsGuard&) = delete;
FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
delete;
FunctionalizationReapplyViewsGuard operator=(
FunctionalizationReapplyViewsGuard&&) = delete;
private:
bool prev_;
};
} // namespace impl
// Helper function to call an out-of-place composite aten kernel that may use
// mutations / views internally, and functionalize them.
TORCH_API void functionalize_op_helper(
const c10::OperatorHandle& op,
torch::jit::Stack* stack);
template <class Op, bool symint, class ReturnType, class... ParameterTypes>
struct _functionalize_aten_op final {};
template <class Op, bool symint, class ReturnType, class... ParameterTypes>
struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final {
static ReturnType call(
typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
using FuncType = ReturnType(
typename c10::maybe_keep_symint<symint, ParameterTypes>::type...);
auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow(
(const char*)Op::name, (const char*)Op::overload_name)
.typed<FuncType>();
return c10::impl::BoxedKernelWrapper<FuncType>::call(
c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(),
op,
// BoxedKernelWrapper knows to ignore this keyset argument,
// because functionalize_op_helper doesn't take in a DispatchKeySet
c10::DispatchKeySet(),
args...);
}
};
template <class Op>
using functionalize_aten_op =
_functionalize_aten_op<Op, false, typename Op::schema>;
template <class Op>
using functionalize_aten_op_symint =
_functionalize_aten_op<Op, true, typename Op::schema>;
} // namespace functionalization
} // namespace at

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,2 @@
#pragma once
#include <ATen/core/Generator.h>

View File

@ -0,0 +1,88 @@
#pragma once
#include <ATen/DimVector.h>
#include <c10/core/ScalarType.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/util/DimVector.h>
#include <optional>
#include <sstream>
#include <vector>
namespace at {
// Infers the size of a dim with size -1, if it exists. Also checks that new
// shape is compatible with the number of elements.
//
// templated to handle std::vector<int64_t> and DimVector use cases, see
// below
//
template <typename InputArrayRef, typename NumelType, typename ResultVec>
inline void infer_size_impl(
InputArrayRef shape,
NumelType numel,
ResultVec& res) {
NumelType newsize = 1;
// N.B. this is an index, not a sym dim!
std::optional<int64_t> infer_dim;
for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
if (shape[dim] == -1) {
if (infer_dim) {
throw std::runtime_error("only one dimension can be inferred");
}
infer_dim = dim;
} else if (shape[dim] >= 0) {
newsize *= shape[dim];
} else {
AT_ERROR("invalid shape dimension ", shape[dim]);
}
}
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, newsize)) ||
(infer_dim && newsize > 0 && numel % newsize == 0)) {
if (infer_dim) {
// We have a degree of freedom here to select the dimension size; follow
// NumPy semantics and just bail. However, a nice error message is needed
// because users often use `view` as a way to flatten & unflatten
// dimensions and will otherwise be confused why
// empty_tensor.view( 0, 0)
// works yet
// empty_tensor.view(-1, 0)
// doesn't.
TORCH_CHECK(
newsize != 0,
"cannot reshape tensor of 0 elements into shape ",
shape,
" because the unspecified dimension size -1 can be any "
"value and is ambiguous");
res[*infer_dim] = numel / newsize;
}
return;
}
std::ostringstream ss;
ss << "shape '" << shape << "' is invalid for input of size " << numel;
throw std::runtime_error(ss.str());
}
inline std::vector<int64_t> infer_size(IntArrayRef shape, int64_t numel) {
auto res = shape.vec();
infer_size_impl(shape, numel, res);
return res;
}
inline at::DimVector infer_size_dv(IntArrayRef shape, int64_t numel) {
auto res = at::DimVector(shape);
infer_size_impl(shape, numel, res);
return res;
}
inline at::SymDimVector infer_size_dv(
c10::SymIntArrayRef shape,
c10::SymInt numel) {
auto res = at::SymDimVector(shape);
infer_size_impl<c10::SymIntArrayRef, c10::SymInt, at::SymDimVector>(
shape, std::move(numel), res);
return res;
}
} // namespace at

View File

@ -0,0 +1,15 @@
#pragma once
#include <c10/core/TensorOptions.h>
namespace at {
// Represents the initial TensorOptions, before the "defaults" are ever changed.
// This is designed to be used in library code, where the explicit devices,
// dtypes, etc. are known. NOTE: this is not a stable API.
inline TensorOptions initialTensorOptions() {
return TensorOptions(kCPU).dtype(kFloat).layout(kStrided).requires_grad(
false);
}
} // namespace at

View File

@ -0,0 +1,2 @@
#pragma once
#include <c10/core/Layout.h>

View File

@ -0,0 +1,25 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/library.h>
namespace at {
// If an operator doesn't have a batching rule implemented then we fallback
// to this implementation. The fallback only works on out-of-place operators
// that return only tensors with new memory. (e.g., no in-place operators, no
// view operations).
//
// The fallback effectively takes all of the BatchedTensors in `stack`, slices
// them, and runs `op` on all of the corresponding slices to produce slices
// of the outputs. The output slices then get `torch.stack`ed to create the
// final returns.
//
// The performance of the fallback is not very good because it introduces an
// extra copy from stacking the sliced outputs. Because of this, we prefer to
// write batching rules for operators whenever possible.
void batchedTensorForLoopFallback(
const c10::OperatorHandle& op,
torch::jit::Stack* stack);
} // namespace at

View File

@ -0,0 +1,160 @@
#pragma once
#include <bitset>
#include <ATen/ArrayRef.h>
#include <ATen/SmallVector.h>
#include <ATen/Tensor.h>
namespace at {
// We assume this in a few other places in the codebase,
// but there isn't a centralized definition.
constexpr int64_t kVmapMaxTensorDims = 64;
// The valid vmap levels range from [0, 64). This effectively means that we
// support a maximum of 64 nested vmaps.
constexpr int64_t kVmapNumLevels = 64;
// Store this number of elements of BatchDims on the stack. Most people will
// probably use <= 5 nested vmaps, but adjust this number as necessary.
constexpr int64_t kBatchDimsStackSize = 5;
// a BatchDim represents a "private" dimension on a Tensor created inside of
// vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
// is being vmap'ed over and the `level` being an identifier for which vmap
// said dimension was created inside. The `dim` corresponds to a "physical
// dim" - it is a dimension index on the underlying physical tensor that is
// being vmapped over.
struct BatchDim {
BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
int64_t dim() const {
return dim_;
}
int64_t level() const {
return level_;
}
private:
int64_t dim_;
int64_t level_;
};
using BatchDims = SmallVector<BatchDim, kBatchDimsStackSize>;
using BatchDimsRef = ArrayRef<BatchDim>;
// A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
// BatchedTensorImpl.
//
// The batch dimensions are treated as being "private"; they are not
// user-visible. For example, in the following Tensor,
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
// dimensions 0 and 1 are batch dimensions.
//
// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7)
// tensor.
struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
// Returns a reference to BatchDims that represent which dimensions of this
// tensor are private.
BatchDimsRef bdims() const {
return bdims_;
}
// BatchedTensorImpl wraps a Tensor
const Tensor& value() const {
return value_;
};
// Given a public dimension index, return the dimension index in the
// underlying value() tensor. For example, if we have
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2,
// dim=2)])
// bt.actualDim(0) -> 1
// bt.actualDim(1) -> 3
// bt.actualDim(2) -> Error
int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
// We have to override this because we opted into CustomStrides
IntArrayRef strides_custom() const override;
// Override a bunch of methods inherited from TensorImpl to return error
// messages.
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
void set_size(int64_t dim, int64_t new_size) override;
void set_stride(int64_t dim, int64_t new_stride) override;
void set_storage_offset(int64_t storage_offset) override;
#ifdef DEBUG
bool has_storage() const override;
#endif
private:
// see NOTE: [BatchedTensorImpl levels invariant]
void checkInvariants() const;
const char* tensorimpl_type_name() const override;
Tensor value_;
// Note: [BatchedTensorImpl levels invariant]
// There is an invariant that the BatchDims must be stored in increasing
// `level` order. That is, for i < j, bdims_[i].level must be less than
// bdims_[j].level.
BatchDims bdims_;
};
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
// BatchedTensorImpl.
inline bool isBatchedTensor(const Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
}
// It is unsafe to call this on a Tensor that is not backed by a
// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
}
inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
if (!isBatchedTensor(tensor)) {
return nullptr;
}
return unsafeGetBatchedImpl(tensor);
}
// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(
BatchDimsRef bdims) {
std::bitset<kVmapMaxTensorDims> is_bdim;
for (const auto& bdim : bdims) {
is_bdim.set(bdim.dim());
}
return is_bdim;
}
// Creates a bitset for all of the levels present in `bdims`
inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
std::bitset<kVmapNumLevels> result;
for (const auto& bdim : bdims) {
result.set(bdim.level());
}
return result;
}
inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
return out;
}
// Use this to construct a BatchedTensor from a regular Tensor
TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
// Adds a batch dim to `tensor`, returning a BatchedTensor
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
// Checks if an inplace operation on self and other is "vmap compatible".
// See NOTE: [vmap-incompatible in-place operations] for the definition of this.
TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
} // namespace at

View File

@ -0,0 +1,26 @@
#pragma once
#include <c10/core/impl/LocalDispatchKeySet.h>
namespace at::impl {
// VmapMode contains a thread local count of how many nested vmaps
// we are currently inside. That number is known as the `vmap level`.
// VmapMode is used in the implementation of the Python `torch.vmap` API.
//
// NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet.
struct TORCH_API VmapMode {
// Returns the vmap level, aka the count of how many nested vmaps we're in.
static int64_t current_vmap_level();
// Increment the count of nested vmaps. If this causes the vmap level to be
// greater than 0, then it enables DispatchKey::VmapMode on all tensors.
static int64_t increment_nesting();
// Decrements the count of nested vmaps. If this causes the vmap level to be
// equal to 0, then it disables DispatchKey::VmapMode on all tensors.
static int64_t decrement_nesting();
};
} // namespace at::impl

View File

@ -0,0 +1,183 @@
#pragma once
#include <ATen/LegacyBatchedTensorImpl.h>
#include <ATen/core/IListRef.h>
namespace at {
// This file contains abstractions used for transforming *logical* vmap
// arguments into *physical* arguments. (Keep reading for definitions of these
// terms).
// NOTE: [Logical vs physical args]
// Consider the following vmap.
// vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
// with batch dims 0 and 2:
// BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
//
// We say the *logical* view of the tensor has size [3] -- tensors inside
// `func` appear to have size [3].
// However, the *physical* underlying tensor (the one passed to vmap) has size
// [2, 3, 4].
//
// This notion of logical vs physical also extends to non-tensor arguments.
// Consider the previous tensor; let's assume the user called
// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
// dimension they are reducing over is dim 0 but the physical dim is dim 1
// (the first non-batch dimension)
// Forward declared; see NOTE: [What is a VmapPhysicalView?]
struct VmapPhysicalView;
// Most PyTorch operators take 4 or fewer inputs.
constexpr int64_t kVmapTransformStaticInputSize = 4;
using VmapPhysicalViewVec =
SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>;
// Pytorch generally advertises good performance for <= 5 dims.
// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
// dimensions to get 8. Adjust this number as necessary
constexpr int64_t kVmapStaticDimVecSize = 8;
using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
using VmapSymDimVector = SmallVector<c10::SymInt, kVmapStaticDimVecSize>;
// NOTE: [What is an VmapTransform?]
// An *VmapTransform* converts logical views of tensors to physical views.
//
// Batching rules use VmapTransforms to convert logical arguments to
// physical arguments, then call one or more at:: operator that handles the
// physical arguments, and then converts the physical result back to a logical
// argument.
// VmapTransform for operators that take tensors with multiple batch dims.
// Given one or more logical views on Tensors, `logicalToPhysical`
// permutes all of the batch dims to the front of the tensor, aligns
// and expands the batch dims to match each other (according to their `level`),
// and returns a VmapPhysicalView on the tensor(s).
struct TORCH_API MultiBatchVmapTransform {
static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
};
// VmapTransform for operators that broadcast all inputs.
// Given some logical views on Tensors, `logicalToPhysical`:
// - permutes all of the batch dims to the front of the tensors
// - aligns all the batch dims to the collective levels of all of the tensors.
// If a tensor does not have a batch dim for a vmap level, then it receives
// a size-one dimension for said level.
// - aligns the non-batch dims to have the same dimensionality, adding extra
// size-1 dimensions in between the batch dimensions and the non-batch
// dimensions so that the batch dimensions are lined up from the right.
//
// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap
// tensors of size (B, 1, 2) and (B, 3, 2).
//
// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
// actually *need* to return a tensor of size (1, 2) for the second tensor
// because the broadcasting operation takes care of that for us, but we do
// it anyways to keep things simple.
struct TORCH_API BroadcastingVmapTransform {
static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
};
// Forward declared, if you're reading this file head to toe, don't worry about
// it yet.
struct VmapPhysicalToLogicalMap;
// NOTE: [What is a VmapPhysicalView?]
// VmapPhysicalView represents a physical view on a Tensor.
//
// One can use it to further convert logical dimension indices, logical shapes,
// and more to their physical variants, or convert a new (physical) tensor into
// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
//
// VmapPhysicalView stores a physical tensor with all of its batch dimensions at
// the front and some levels that correspond to said batch dimensions.
//
// The levels bitset specifies which vmap levels correspond to the batch
// dimensions at the front of the tensor. In particular, the number of set bits
// corresponds to the number of batch dimensions on `tensor` and the rightmost
// bit of `levels` specifies the maximum number of nested vmaps we are in at
// this point in time.
// For example, given:
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
//
// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
// than or equal to 3.
// bitset: 010100
// ^
// |
// levels: 012345
struct TORCH_API VmapPhysicalView {
VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
: levels_(levels), tensor_(std::move(tensor)) {
TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor_));
}
Tensor& tensor() {
return tensor_;
}
const Tensor& tensor() const {
return tensor_;
}
// Maps logical dim indices to physical dim indices. Also does dim wrapping.
//
// For example, given:
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
//
// Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
// This is because the size of levels tell us that the first two dimensions
// of `tensor_` are batch dimensions, so a logical dim of `n` is actually
// a physical dim of `n + 2`.
VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const;
int64_t getPhysicalDim(int64_t logical_dim) const;
// Returns a VmapPhysicalToLogicalMap object. This can be used for
// mapping a physical tensor to a new logical tensor (BatchedTensor)
VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
// Maps a logical shape to a physical shape by pre-pending the batch
// sizes to the logical shape.
VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
int64_t numBatchDims() const;
private:
int64_t numLogicalDims() const;
std::bitset<kVmapNumLevels> levels_;
Tensor tensor_;
};
// Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
// to a logical one (BatchedTensor). It holds some levels that are used to do
// the mapping and assumes that the batch dimensions in the physical tensor all
// occur at the front of the tensor.
struct TORCH_API VmapPhysicalToLogicalMap {
VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels)
: levels_(levels) {}
// Maps a physical tensor to a new logical tensor (BatchedTensor).
// Assumes that all of the "batch dimensions" are at the front
// of the physical tensor. For example, given:
// - x = rank-4 Tensor with size 2, 3, 5, 7
// - levels = (2, 4)
// Returns:
// - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
Tensor apply(const Tensor& physical_tensor) const;
// Given a vector of physical tensors,
// 1. maps each tensor to a new logical tensor. Assumes that all of the
// "batch dimensions" are at the front of the physical tensors.
// 2. stores the new logical tensors back into the passed-in vector. This is
// to avoid additional dynamic allocations.
void applyInplace(std::vector<Tensor>& physical_tensors) const;
std::bitset<kVmapNumLevels> levels_;
};
} // namespace at

View File

@ -0,0 +1,31 @@
#pragma once
#include <c10/util/Exception.h>
#include <ostream>
#include <string>
namespace at {
enum class LinalgBackend : int8_t { Default, Cusolver, Magma };
inline std::string LinalgBackendToString(at::LinalgBackend backend) {
switch (backend) {
case LinalgBackend::Default:
return "at::LinalgBackend::Default";
case LinalgBackend::Cusolver:
return "at::LinalgBackend::Cusolver";
case LinalgBackend::Magma:
return "at::LinalgBackend::Magma";
default:
TORCH_CHECK(false, "Unknown linalg backend");
}
}
inline std::ostream& operator<<(
std::ostream& stream,
at::LinalgBackend backend) {
return stream << LinalgBackendToString(backend);
}
} // namespace at

View File

@ -0,0 +1,143 @@
#pragma once
#include <c10/core/Allocator.h>
#include <c10/util/string_view.h>
namespace at {
enum MappedAllocatorModes {
ALLOCATOR_MAPPED_SHARED = 1,
ALLOCATOR_MAPPED_SHAREDMEM = 2,
ALLOCATOR_MAPPED_EXCLUSIVE = 4,
ALLOCATOR_MAPPED_NOCREATE = 8,
ALLOCATOR_MAPPED_KEEPFD = 16,
ALLOCATOR_MAPPED_FROMFD = 32,
ALLOCATOR_MAPPED_UNLINK = 64
};
// Sentinel value/type to help distinguish the file descriptor constructor from
// the non-file descriptor constructor
enum WithFd { WITH_FD };
TORCH_API std::string NewProcessWideShmHandle();
class TORCH_API MapAllocator {
public:
MapAllocator(c10::string_view filename, int flags, size_t size);
MapAllocator(
WithFd,
c10::string_view filename,
int fd,
int flags,
size_t size);
MapAllocator(const MapAllocator&) = delete;
MapAllocator& operator=(const MapAllocator&) = delete;
MapAllocator(MapAllocator&&) = delete;
MapAllocator& operator=(MapAllocator&&) = delete;
const char* filename() const {
return filename_.c_str();
}
int fd() const {
#ifdef _WIN32
TORCH_CHECK(false, "MapAllocator::fd() is unsupported on Windows");
#else
return fd_;
#endif
}
ptrdiff_t size() const {
return size_;
}
// Return a pointer to the actual data for this allocator
// (in the case of the refcounted allocator, this is offset
// from the base pointer.)
virtual void* data() const {
return base_ptr_;
}
int flags() const {
return flags_;
}
static MapAllocator* fromDataPtr(const at::DataPtr&);
static at::DataPtr makeDataPtr(
c10::string_view filename,
int flags,
size_t size,
size_t* actual_size_out);
static at::DataPtr makeDataPtr(
WithFd,
const char* filename,
int fd,
int flags,
size_t size,
size_t* actual_size_out);
// Closes the data. Helps us avoid destructor shenanigans
virtual void close();
// This is very dangerous. You have to redefine this destructor for each
// subclass
virtual ~MapAllocator();
protected:
bool closed_ = false;
std::string filename_;
int flags_ = 0;
ptrdiff_t size_; /* mapped size */
#ifdef _WIN32
void* handle_;
void* event_;
std::string eventname_;
#else
int fd_ = -1;
#endif
void* base_ptr_ = nullptr;
};
// Base-from-member idiom
struct TORCH_API RefcountedMapAllocatorArgCheck {
RefcountedMapAllocatorArgCheck(int flags);
};
class TORCH_API RefcountedMapAllocator : private RefcountedMapAllocatorArgCheck,
public MapAllocator {
public:
RefcountedMapAllocator(const char* filename, int flags, size_t size);
RefcountedMapAllocator(
WithFd,
const char* filename,
int fd,
int flags,
size_t size);
static RefcountedMapAllocator* fromDataPtr(const at::DataPtr&);
static at::DataPtr makeDataPtr(
const char* filename,
int flags,
size_t size,
size_t* actual_size_out);
static at::DataPtr makeDataPtr(
WithFd,
const char* filename,
int fd,
int flags,
size_t size,
size_t* actual_size_out);
void* data() const override;
void incref();
int decref();
void close() override;
~RefcountedMapAllocator() override {
RefcountedMapAllocator::close();
}
protected:
void checkFlags();
void initializeAlloc();
};
} // namespace at

View File

@ -0,0 +1,107 @@
#pragma once
#include <ATen/Utils.h>
#include <c10/util/ArrayRef.h>
namespace at {
/// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that
/// we can easily view it as a multidimensional array.
///
/// Like ArrayRef, this class does not own the underlying data, it is expected
/// to be used in situations where the data resides in some other buffer.
///
/// This is intended to be trivially copyable, so it should be passed by
/// value.
///
/// For now, 2D only (so the copies are actually cheap, without having
/// to write a SmallVector class) and contiguous only (so we can
/// return non-strided ArrayRef on index).
///
/// P.S. dimension 0 indexes rows, dimension 1 indexes columns
template <typename T>
class MatrixRef {
public:
typedef size_t size_type;
private:
/// Underlying ArrayRef
ArrayRef<T> arr;
/// Stride of dim 0 (outer dimension)
size_type stride0;
// Stride of dim 1 is assumed to be 1
public:
/// Construct an empty Matrixref.
/*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {}
/// Construct an MatrixRef from an ArrayRef and outer stride.
/*implicit*/ MatrixRef(ArrayRef<T> arr, size_type stride0)
: arr(arr), stride0(stride0) {
TORCH_CHECK(
arr.size() % stride0 == 0,
"MatrixRef: ArrayRef size ",
arr.size(),
" not divisible by stride ",
stride0)
}
/// @}
/// @name Simple Operations
/// @{
/// empty - Check if the matrix is empty.
bool empty() const {
return arr.empty();
}
const T* data() const {
return arr.data();
}
/// size - Get size a dimension
size_t size(size_t dim) const {
if (dim == 0) {
return arr.size() / stride0;
} else if (dim == 1) {
return stride0;
} else {
TORCH_CHECK(
0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1");
}
}
size_t numel() const {
return arr.size();
}
/// equals - Check for element-wise equality.
bool equals(MatrixRef RHS) const {
return stride0 == RHS.stride0 && arr.equals(RHS.arr);
}
/// @}
/// @name Operator Overloads
/// @{
ArrayRef<T> operator[](size_t Index) const {
return arr.slice(Index * stride0, stride0);
}
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
U&& Temporary) = delete;
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
std::initializer_list<U>) = delete;
};
} // end namespace at

View File

@ -0,0 +1,42 @@
#pragma once
#include <c10/macros/Export.h>
namespace c10 {
struct TensorImpl;
}
namespace at {
class TensorBase;
// MemOverlap: Whether or not there is memory overlap
//
// No: Absolutely no memory overlap
// Yes: Absolutely yes memory overlap
// TooHard: There might be memory overlap, but it was too expensive to compute.
//
// NB: Please update the python test for these if you renumber them.
enum class MemOverlap { No, Yes, TooHard };
enum class MemOverlapStatus { Full, Partial, No, TooHard };
TORCH_API MemOverlap has_internal_overlap(const TensorBase& t);
TORCH_API MemOverlap has_internal_overlap(c10::TensorImpl* t);
TORCH_API void assert_no_internal_overlap(const TensorBase& t);
TORCH_API void assert_no_internal_overlap(c10::TensorImpl* t);
TORCH_API MemOverlapStatus
get_overlap_status(const TensorBase& a, const TensorBase& b);
TORCH_API MemOverlapStatus
get_overlap_status(const c10::TensorImpl* a, const c10::TensorImpl* b);
TORCH_API void assert_no_partial_overlap(
const TensorBase& a,
const TensorBase& b);
void assert_no_partial_overlap(c10::TensorImpl* a, c10::TensorImpl* b);
TORCH_API void assert_no_overlap(const TensorBase& a, const TensorBase& b);
TORCH_API void assert_no_overlap(c10::TensorImpl* a, c10::TensorImpl* b);
} // namespace at

View File

@ -0,0 +1,29 @@
#include <ATen/core/TensorBody.h>
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
// Code introduced to avoid cyclic dependency in static dispatch is no longer
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
// to Operators.cpp for supporting multiple backends with multiple kernels.
//
// Note [Avoiding Include Cycles In Static Dispatch]
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
//
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
// directly inlined into TensorBody.h.
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
// which include functions that have defaultable std::optional<Tensor> arguments.
// That requires knowing the full Tensor class definition.
//
// We break the cycle by doing the following:
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
// - CPUFunctions_inl.h includes everything else
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
// and then it includes CPUFunctions_inl.h.
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
// - This also means that static dispatch build, CPUFunctions.h only needs to
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
#include <ATen/MetaFunctions_inl.h>

View File

@ -0,0 +1,325 @@
#pragma once
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
// NB: The implementing C++ file is RegisterDispatchKey.cpp
// The only #includes we need are for custom classes that have defaults in the C++ API
#include <c10/core/MemoryFormat.h>
#include <c10/core/Scalar.h>
#include <ATen/core/Reduction.h>
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
#error This change adds a dependency on all pytorch operators, meaning the \
file will need to be re-compiled every time an operator is changed or added. \
Consider including a specific operator from \
<ATen/ops/{my_operator}_meta_dispatch.h>. \
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
#endif
#include <ATen/ops/_add_relu_meta_dispatch.h>
#include <ATen/ops/_addmm_activation_meta_dispatch.h>
#include <ATen/ops/_amp_update_scale_meta_dispatch.h>
#include <ATen/ops/_coalesced_meta_dispatch.h>
#include <ATen/ops/_convert_indices_from_coo_to_csr_meta_dispatch.h>
#include <ATen/ops/_convert_indices_from_csr_to_coo_meta_dispatch.h>
#include <ATen/ops/_ctc_loss_meta_dispatch.h>
#include <ATen/ops/_efficientzerotensor_meta_dispatch.h>
#include <ATen/ops/_fill_mem_eff_dropout_mask_meta_dispatch.h>
#include <ATen/ops/_fused_sdp_choice_meta_dispatch.h>
#include <ATen/ops/_index_put_impl_meta_dispatch.h>
#include <ATen/ops/_linalg_det_meta_dispatch.h>
#include <ATen/ops/_linalg_eigh_meta_dispatch.h>
#include <ATen/ops/_linalg_slogdet_meta_dispatch.h>
#include <ATen/ops/_linalg_solve_ex_meta_dispatch.h>
#include <ATen/ops/_linalg_svd_meta_dispatch.h>
#include <ATen/ops/_log_softmax_meta_dispatch.h>
#include <ATen/ops/_log_softmax_backward_data_meta_dispatch.h>
#include <ATen/ops/_mkldnn_transpose_meta_dispatch.h>
#include <ATen/ops/_reshape_alias_meta_dispatch.h>
#include <ATen/ops/_resize_output_meta_dispatch.h>
#include <ATen/ops/_softmax_meta_dispatch.h>
#include <ATen/ops/_softmax_backward_data_meta_dispatch.h>
#include <ATen/ops/_sparse_coo_tensor_with_dims_meta_dispatch.h>
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_meta_dispatch.h>
#include <ATen/ops/_upsample_bicubic2d_aa_meta_dispatch.h>
#include <ATen/ops/_upsample_bicubic2d_aa_backward_meta_dispatch.h>
#include <ATen/ops/_upsample_bilinear2d_aa_meta_dispatch.h>
#include <ATen/ops/_upsample_bilinear2d_aa_backward_meta_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact1d_meta_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact1d_backward_meta_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact2d_meta_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact2d_backward_meta_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact3d_meta_dispatch.h>
#include <ATen/ops/_upsample_nearest_exact3d_backward_meta_dispatch.h>
#include <ATen/ops/acos_meta_dispatch.h>
#include <ATen/ops/acosh_meta_dispatch.h>
#include <ATen/ops/adaptive_max_pool2d_meta_dispatch.h>
#include <ATen/ops/adaptive_max_pool2d_backward_meta_dispatch.h>
#include <ATen/ops/adaptive_max_pool3d_meta_dispatch.h>
#include <ATen/ops/adaptive_max_pool3d_backward_meta_dispatch.h>
#include <ATen/ops/add_meta_dispatch.h>
#include <ATen/ops/addbmm_meta_dispatch.h>
#include <ATen/ops/addcdiv_meta_dispatch.h>
#include <ATen/ops/addcmul_meta_dispatch.h>
#include <ATen/ops/addmm_meta_dispatch.h>
#include <ATen/ops/addmv_meta_dispatch.h>
#include <ATen/ops/all_meta_dispatch.h>
#include <ATen/ops/amax_meta_dispatch.h>
#include <ATen/ops/amin_meta_dispatch.h>
#include <ATen/ops/aminmax_meta_dispatch.h>
#include <ATen/ops/any_meta_dispatch.h>
#include <ATen/ops/arange_meta_dispatch.h>
#include <ATen/ops/argmax_meta_dispatch.h>
#include <ATen/ops/argmin_meta_dispatch.h>
#include <ATen/ops/as_strided_meta_dispatch.h>
#include <ATen/ops/asin_meta_dispatch.h>
#include <ATen/ops/asinh_meta_dispatch.h>
#include <ATen/ops/atan_meta_dispatch.h>
#include <ATen/ops/atan2_meta_dispatch.h>
#include <ATen/ops/atanh_meta_dispatch.h>
#include <ATen/ops/avg_pool2d_meta_dispatch.h>
#include <ATen/ops/avg_pool2d_backward_meta_dispatch.h>
#include <ATen/ops/avg_pool3d_meta_dispatch.h>
#include <ATen/ops/avg_pool3d_backward_meta_dispatch.h>
#include <ATen/ops/baddbmm_meta_dispatch.h>
#include <ATen/ops/bernoulli_meta_dispatch.h>
#include <ATen/ops/bitwise_and_meta_dispatch.h>
#include <ATen/ops/bitwise_left_shift_meta_dispatch.h>
#include <ATen/ops/bitwise_not_meta_dispatch.h>
#include <ATen/ops/bitwise_or_meta_dispatch.h>
#include <ATen/ops/bitwise_right_shift_meta_dispatch.h>
#include <ATen/ops/bitwise_xor_meta_dispatch.h>
#include <ATen/ops/bmm_meta_dispatch.h>
#include <ATen/ops/cat_meta_dispatch.h>
#include <ATen/ops/cauchy_meta_dispatch.h>
#include <ATen/ops/ceil_meta_dispatch.h>
#include <ATen/ops/clamp_meta_dispatch.h>
#include <ATen/ops/clamp_max_meta_dispatch.h>
#include <ATen/ops/clamp_min_meta_dispatch.h>
#include <ATen/ops/copy_meta_dispatch.h>
#include <ATen/ops/copy_sparse_to_sparse_meta_dispatch.h>
#include <ATen/ops/copysign_meta_dispatch.h>
#include <ATen/ops/cos_meta_dispatch.h>
#include <ATen/ops/cosh_meta_dispatch.h>
#include <ATen/ops/cumprod_meta_dispatch.h>
#include <ATen/ops/cumsum_meta_dispatch.h>
#include <ATen/ops/digamma_meta_dispatch.h>
#include <ATen/ops/div_meta_dispatch.h>
#include <ATen/ops/elu_meta_dispatch.h>
#include <ATen/ops/elu_backward_meta_dispatch.h>
#include <ATen/ops/embedding_renorm_meta_dispatch.h>
#include <ATen/ops/empty_meta_dispatch.h>
#include <ATen/ops/empty_strided_meta_dispatch.h>
#include <ATen/ops/eq_meta_dispatch.h>
#include <ATen/ops/erf_meta_dispatch.h>
#include <ATen/ops/erfc_meta_dispatch.h>
#include <ATen/ops/erfinv_meta_dispatch.h>
#include <ATen/ops/exp_meta_dispatch.h>
#include <ATen/ops/exp2_meta_dispatch.h>
#include <ATen/ops/expm1_meta_dispatch.h>
#include <ATen/ops/exponential_meta_dispatch.h>
#include <ATen/ops/eye_meta_dispatch.h>
#include <ATen/ops/fill_meta_dispatch.h>
#include <ATen/ops/floor_meta_dispatch.h>
#include <ATen/ops/floor_divide_meta_dispatch.h>
#include <ATen/ops/fmax_meta_dispatch.h>
#include <ATen/ops/fmin_meta_dispatch.h>
#include <ATen/ops/fmod_meta_dispatch.h>
#include <ATen/ops/frac_meta_dispatch.h>
#include <ATen/ops/fractional_max_pool2d_meta_dispatch.h>
#include <ATen/ops/fractional_max_pool2d_backward_meta_dispatch.h>
#include <ATen/ops/fractional_max_pool3d_meta_dispatch.h>
#include <ATen/ops/gather_meta_dispatch.h>
#include <ATen/ops/gcd_meta_dispatch.h>
#include <ATen/ops/ge_meta_dispatch.h>
#include <ATen/ops/gelu_meta_dispatch.h>
#include <ATen/ops/gelu_backward_meta_dispatch.h>
#include <ATen/ops/geometric_meta_dispatch.h>
#include <ATen/ops/glu_meta_dispatch.h>
#include <ATen/ops/gt_meta_dispatch.h>
#include <ATen/ops/hardshrink_meta_dispatch.h>
#include <ATen/ops/hardshrink_backward_meta_dispatch.h>
#include <ATen/ops/hardsigmoid_meta_dispatch.h>
#include <ATen/ops/hardsigmoid_backward_meta_dispatch.h>
#include <ATen/ops/hardswish_meta_dispatch.h>
#include <ATen/ops/hardtanh_meta_dispatch.h>
#include <ATen/ops/heaviside_meta_dispatch.h>
#include <ATen/ops/hypot_meta_dispatch.h>
#include <ATen/ops/i0_meta_dispatch.h>
#include <ATen/ops/igamma_meta_dispatch.h>
#include <ATen/ops/igammac_meta_dispatch.h>
#include <ATen/ops/index_meta_dispatch.h>
#include <ATen/ops/index_add_meta_dispatch.h>
#include <ATen/ops/index_copy_meta_dispatch.h>
#include <ATen/ops/index_fill_meta_dispatch.h>
#include <ATen/ops/index_reduce_meta_dispatch.h>
#include <ATen/ops/isin_meta_dispatch.h>
#include <ATen/ops/isneginf_meta_dispatch.h>
#include <ATen/ops/isposinf_meta_dispatch.h>
#include <ATen/ops/lcm_meta_dispatch.h>
#include <ATen/ops/le_meta_dispatch.h>
#include <ATen/ops/leaky_relu_meta_dispatch.h>
#include <ATen/ops/leaky_relu_backward_meta_dispatch.h>
#include <ATen/ops/lerp_meta_dispatch.h>
#include <ATen/ops/lgamma_meta_dispatch.h>
#include <ATen/ops/linalg_cholesky_ex_meta_dispatch.h>
#include <ATen/ops/linalg_cross_meta_dispatch.h>
#include <ATen/ops/linalg_inv_ex_meta_dispatch.h>
#include <ATen/ops/linalg_ldl_factor_ex_meta_dispatch.h>
#include <ATen/ops/linalg_ldl_solve_meta_dispatch.h>
#include <ATen/ops/linalg_lu_meta_dispatch.h>
#include <ATen/ops/linalg_lu_factor_ex_meta_dispatch.h>
#include <ATen/ops/linalg_lu_solve_meta_dispatch.h>
#include <ATen/ops/linalg_qr_meta_dispatch.h>
#include <ATen/ops/linalg_vector_norm_meta_dispatch.h>
#include <ATen/ops/linspace_meta_dispatch.h>
#include <ATen/ops/log_meta_dispatch.h>
#include <ATen/ops/log10_meta_dispatch.h>
#include <ATen/ops/log1p_meta_dispatch.h>
#include <ATen/ops/log2_meta_dispatch.h>
#include <ATen/ops/log_normal_meta_dispatch.h>
#include <ATen/ops/logaddexp_meta_dispatch.h>
#include <ATen/ops/logaddexp2_meta_dispatch.h>
#include <ATen/ops/logit_meta_dispatch.h>
#include <ATen/ops/logit_backward_meta_dispatch.h>
#include <ATen/ops/logspace_meta_dispatch.h>
#include <ATen/ops/lshift_meta_dispatch.h>
#include <ATen/ops/lt_meta_dispatch.h>
#include <ATen/ops/lu_unpack_meta_dispatch.h>
#include <ATen/ops/masked_fill_meta_dispatch.h>
#include <ATen/ops/masked_scatter_meta_dispatch.h>
#include <ATen/ops/max_meta_dispatch.h>
#include <ATen/ops/max_pool2d_with_indices_meta_dispatch.h>
#include <ATen/ops/max_pool2d_with_indices_backward_meta_dispatch.h>
#include <ATen/ops/maximum_meta_dispatch.h>
#include <ATen/ops/mean_meta_dispatch.h>
#include <ATen/ops/min_meta_dispatch.h>
#include <ATen/ops/minimum_meta_dispatch.h>
#include <ATen/ops/mish_meta_dispatch.h>
#include <ATen/ops/mm_meta_dispatch.h>
#include <ATen/ops/mse_loss_meta_dispatch.h>
#include <ATen/ops/mul_meta_dispatch.h>
#include <ATen/ops/ne_meta_dispatch.h>
#include <ATen/ops/neg_meta_dispatch.h>
#include <ATen/ops/nextafter_meta_dispatch.h>
#include <ATen/ops/nll_loss_backward_meta_dispatch.h>
#include <ATen/ops/nll_loss_forward_meta_dispatch.h>
#include <ATen/ops/norm_meta_dispatch.h>
#include <ATen/ops/normal_meta_dispatch.h>
#include <ATen/ops/polygamma_meta_dispatch.h>
#include <ATen/ops/pow_meta_dispatch.h>
#include <ATen/ops/prod_meta_dispatch.h>
#include <ATen/ops/put_meta_dispatch.h>
#include <ATen/ops/random_meta_dispatch.h>
#include <ATen/ops/range_meta_dispatch.h>
#include <ATen/ops/reciprocal_meta_dispatch.h>
#include <ATen/ops/reflection_pad1d_meta_dispatch.h>
#include <ATen/ops/reflection_pad1d_backward_meta_dispatch.h>
#include <ATen/ops/reflection_pad3d_meta_dispatch.h>
#include <ATen/ops/reflection_pad3d_backward_meta_dispatch.h>
#include <ATen/ops/relu_meta_dispatch.h>
#include <ATen/ops/remainder_meta_dispatch.h>
#include <ATen/ops/renorm_meta_dispatch.h>
#include <ATen/ops/replication_pad1d_meta_dispatch.h>
#include <ATen/ops/replication_pad1d_backward_meta_dispatch.h>
#include <ATen/ops/replication_pad2d_meta_dispatch.h>
#include <ATen/ops/replication_pad3d_meta_dispatch.h>
#include <ATen/ops/resize_meta_dispatch.h>
#include <ATen/ops/resize_as_sparse_meta_dispatch.h>
#include <ATen/ops/round_meta_dispatch.h>
#include <ATen/ops/rrelu_with_noise_meta_dispatch.h>
#include <ATen/ops/rshift_meta_dispatch.h>
#include <ATen/ops/rsqrt_meta_dispatch.h>
#include <ATen/ops/scatter_meta_dispatch.h>
#include <ATen/ops/scatter_add_meta_dispatch.h>
#include <ATen/ops/scatter_reduce_meta_dispatch.h>
#include <ATen/ops/set_meta_dispatch.h>
#include <ATen/ops/sgn_meta_dispatch.h>
#include <ATen/ops/sigmoid_meta_dispatch.h>
#include <ATen/ops/sigmoid_backward_meta_dispatch.h>
#include <ATen/ops/sign_meta_dispatch.h>
#include <ATen/ops/signbit_meta_dispatch.h>
#include <ATen/ops/silu_meta_dispatch.h>
#include <ATen/ops/silu_backward_meta_dispatch.h>
#include <ATen/ops/sin_meta_dispatch.h>
#include <ATen/ops/sinc_meta_dispatch.h>
#include <ATen/ops/sinh_meta_dispatch.h>
#include <ATen/ops/slow_conv_transpose2d_meta_dispatch.h>
#include <ATen/ops/smooth_l1_loss_meta_dispatch.h>
#include <ATen/ops/softplus_meta_dispatch.h>
#include <ATen/ops/softplus_backward_meta_dispatch.h>
#include <ATen/ops/softshrink_meta_dispatch.h>
#include <ATen/ops/softshrink_backward_meta_dispatch.h>
#include <ATen/ops/sort_meta_dispatch.h>
#include <ATen/ops/sparse_resize_meta_dispatch.h>
#include <ATen/ops/sparse_resize_and_clear_meta_dispatch.h>
#include <ATen/ops/special_airy_ai_meta_dispatch.h>
#include <ATen/ops/special_bessel_j0_meta_dispatch.h>
#include <ATen/ops/special_bessel_j1_meta_dispatch.h>
#include <ATen/ops/special_bessel_y0_meta_dispatch.h>
#include <ATen/ops/special_bessel_y1_meta_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_t_meta_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_u_meta_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_v_meta_dispatch.h>
#include <ATen/ops/special_chebyshev_polynomial_w_meta_dispatch.h>
#include <ATen/ops/special_entr_meta_dispatch.h>
#include <ATen/ops/special_erfcx_meta_dispatch.h>
#include <ATen/ops/special_hermite_polynomial_h_meta_dispatch.h>
#include <ATen/ops/special_hermite_polynomial_he_meta_dispatch.h>
#include <ATen/ops/special_i0e_meta_dispatch.h>
#include <ATen/ops/special_i1_meta_dispatch.h>
#include <ATen/ops/special_i1e_meta_dispatch.h>
#include <ATen/ops/special_laguerre_polynomial_l_meta_dispatch.h>
#include <ATen/ops/special_legendre_polynomial_p_meta_dispatch.h>
#include <ATen/ops/special_log_ndtr_meta_dispatch.h>
#include <ATen/ops/special_modified_bessel_i0_meta_dispatch.h>
#include <ATen/ops/special_modified_bessel_i1_meta_dispatch.h>
#include <ATen/ops/special_modified_bessel_k0_meta_dispatch.h>
#include <ATen/ops/special_modified_bessel_k1_meta_dispatch.h>
#include <ATen/ops/special_ndtri_meta_dispatch.h>
#include <ATen/ops/special_scaled_modified_bessel_k0_meta_dispatch.h>
#include <ATen/ops/special_scaled_modified_bessel_k1_meta_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_meta_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_meta_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_meta_dispatch.h>
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_meta_dispatch.h>
#include <ATen/ops/special_spherical_bessel_j0_meta_dispatch.h>
#include <ATen/ops/special_xlog1py_meta_dispatch.h>
#include <ATen/ops/special_zeta_meta_dispatch.h>
#include <ATen/ops/sqrt_meta_dispatch.h>
#include <ATen/ops/sub_meta_dispatch.h>
#include <ATen/ops/sum_meta_dispatch.h>
#include <ATen/ops/tan_meta_dispatch.h>
#include <ATen/ops/tanh_meta_dispatch.h>
#include <ATen/ops/tanh_backward_meta_dispatch.h>
#include <ATen/ops/threshold_meta_dispatch.h>
#include <ATen/ops/threshold_backward_meta_dispatch.h>
#include <ATen/ops/topk_meta_dispatch.h>
#include <ATen/ops/triangular_solve_meta_dispatch.h>
#include <ATen/ops/tril_meta_dispatch.h>
#include <ATen/ops/triu_meta_dispatch.h>
#include <ATen/ops/trunc_meta_dispatch.h>
#include <ATen/ops/unfold_meta_dispatch.h>
#include <ATen/ops/uniform_meta_dispatch.h>
#include <ATen/ops/upsample_bicubic2d_meta_dispatch.h>
#include <ATen/ops/upsample_bicubic2d_backward_meta_dispatch.h>
#include <ATen/ops/upsample_bilinear2d_meta_dispatch.h>
#include <ATen/ops/upsample_bilinear2d_backward_meta_dispatch.h>
#include <ATen/ops/upsample_linear1d_meta_dispatch.h>
#include <ATen/ops/upsample_linear1d_backward_meta_dispatch.h>
#include <ATen/ops/upsample_nearest1d_meta_dispatch.h>
#include <ATen/ops/upsample_nearest1d_backward_meta_dispatch.h>
#include <ATen/ops/upsample_nearest2d_meta_dispatch.h>
#include <ATen/ops/upsample_nearest2d_backward_meta_dispatch.h>
#include <ATen/ops/upsample_nearest3d_meta_dispatch.h>
#include <ATen/ops/upsample_nearest3d_backward_meta_dispatch.h>
#include <ATen/ops/upsample_trilinear3d_meta_dispatch.h>
#include <ATen/ops/upsample_trilinear3d_backward_meta_dispatch.h>
#include <ATen/ops/view_meta_dispatch.h>
#include <ATen/ops/view_as_complex_meta_dispatch.h>
#include <ATen/ops/view_as_real_meta_dispatch.h>
#include <ATen/ops/xlogy_meta_dispatch.h>
#include <ATen/ops/zero_meta_dispatch.h>

View File

@ -0,0 +1,443 @@
#pragma once
// @generated by torchgen/gen.py from MethodOperators.h
#ifdef TORCH_ASSERT_NO_OPERATORS
#error This change adds a dependency on native_functions.yaml, \
meaning the file will need to be re-compiled every time an operator \
is changed or added. Consider if your change would be better placed in \
another file, or if a more specific header might achieve the same goal. \
See NOTE: [Tensor vs. TensorBase]
#endif
// Forward declarations of any types needed in the operator signatures.
// We can't directly include these classes because it will cause circular include dependencies.
// This file is included by TensorBody.h, which defines the Tensor class.
#include <ATen/core/ATen_fwd.h>
#include <ATen/ops/_addmm_activation_ops.h>
#include <ATen/ops/_autocast_to_full_precision_ops.h>
#include <ATen/ops/_autocast_to_reduced_precision_ops.h>
#include <ATen/ops/_backward_ops.h>
#include <ATen/ops/_coalesced_ops.h>
#include <ATen/ops/_conj_ops.h>
#include <ATen/ops/_conj_physical_ops.h>
#include <ATen/ops/_dimI_ops.h>
#include <ATen/ops/_dimV_ops.h>
#include <ATen/ops/_fw_primal_ops.h>
#include <ATen/ops/_indices_ops.h>
#include <ATen/ops/_is_all_true_ops.h>
#include <ATen/ops/_is_any_true_ops.h>
#include <ATen/ops/_is_zerotensor_ops.h>
#include <ATen/ops/_lazy_clone_ops.h>
#include <ATen/ops/_neg_view_ops.h>
#include <ATen/ops/_nested_tensor_size_ops.h>
#include <ATen/ops/_nested_tensor_storage_offsets_ops.h>
#include <ATen/ops/_nested_tensor_strides_ops.h>
#include <ATen/ops/_nnz_ops.h>
#include <ATen/ops/_reshape_alias_ops.h>
#include <ATen/ops/_sparse_mask_projection_ops.h>
#include <ATen/ops/_to_dense_ops.h>
#include <ATen/ops/_to_sparse_bsc_ops.h>
#include <ATen/ops/_to_sparse_bsr_ops.h>
#include <ATen/ops/_to_sparse_csc_ops.h>
#include <ATen/ops/_to_sparse_csr_ops.h>
#include <ATen/ops/_to_sparse_ops.h>
#include <ATen/ops/_values_ops.h>
#include <ATen/ops/_version_ops.h>
#include <ATen/ops/abs_ops.h>
#include <ATen/ops/absolute_ops.h>
#include <ATen/ops/acos_ops.h>
#include <ATen/ops/acosh_ops.h>
#include <ATen/ops/add_ops.h>
#include <ATen/ops/addbmm_ops.h>
#include <ATen/ops/addcdiv_ops.h>
#include <ATen/ops/addcmul_ops.h>
#include <ATen/ops/addmm_ops.h>
#include <ATen/ops/addmv_ops.h>
#include <ATen/ops/addr_ops.h>
#include <ATen/ops/adjoint_ops.h>
#include <ATen/ops/alias_ops.h>
#include <ATen/ops/align_as_ops.h>
#include <ATen/ops/align_to_ops.h>
#include <ATen/ops/all_ops.h>
#include <ATen/ops/allclose_ops.h>
#include <ATen/ops/amax_ops.h>
#include <ATen/ops/amin_ops.h>
#include <ATen/ops/aminmax_ops.h>
#include <ATen/ops/and_ops.h>
#include <ATen/ops/angle_ops.h>
#include <ATen/ops/any_ops.h>
#include <ATen/ops/arccos_ops.h>
#include <ATen/ops/arccosh_ops.h>
#include <ATen/ops/arcsin_ops.h>
#include <ATen/ops/arcsinh_ops.h>
#include <ATen/ops/arctan2_ops.h>
#include <ATen/ops/arctan_ops.h>
#include <ATen/ops/arctanh_ops.h>
#include <ATen/ops/argmax_ops.h>
#include <ATen/ops/argmin_ops.h>
#include <ATen/ops/argsort_ops.h>
#include <ATen/ops/argwhere_ops.h>
#include <ATen/ops/as_strided_ops.h>
#include <ATen/ops/as_strided_scatter_ops.h>
#include <ATen/ops/asin_ops.h>
#include <ATen/ops/asinh_ops.h>
#include <ATen/ops/atan2_ops.h>
#include <ATen/ops/atan_ops.h>
#include <ATen/ops/atanh_ops.h>
#include <ATen/ops/baddbmm_ops.h>
#include <ATen/ops/bernoulli_ops.h>
#include <ATen/ops/bincount_ops.h>
#include <ATen/ops/bitwise_and_ops.h>
#include <ATen/ops/bitwise_left_shift_ops.h>
#include <ATen/ops/bitwise_not_ops.h>
#include <ATen/ops/bitwise_or_ops.h>
#include <ATen/ops/bitwise_right_shift_ops.h>
#include <ATen/ops/bitwise_xor_ops.h>
#include <ATen/ops/bmm_ops.h>
#include <ATen/ops/broadcast_to_ops.h>
#include <ATen/ops/cauchy_ops.h>
#include <ATen/ops/ccol_indices_ops.h>
#include <ATen/ops/ceil_ops.h>
#include <ATen/ops/chalf_ops.h>
#include <ATen/ops/cholesky_inverse_ops.h>
#include <ATen/ops/cholesky_ops.h>
#include <ATen/ops/cholesky_solve_ops.h>
#include <ATen/ops/chunk_ops.h>
#include <ATen/ops/clamp_max_ops.h>
#include <ATen/ops/clamp_min_ops.h>
#include <ATen/ops/clamp_ops.h>
#include <ATen/ops/clip_ops.h>
#include <ATen/ops/clone_ops.h>
#include <ATen/ops/coalesce_ops.h>
#include <ATen/ops/col_indices_ops.h>
#include <ATen/ops/conj_ops.h>
#include <ATen/ops/conj_physical_ops.h>
#include <ATen/ops/contiguous_ops.h>
#include <ATen/ops/copy_ops.h>
#include <ATen/ops/copysign_ops.h>
#include <ATen/ops/corrcoef_ops.h>
#include <ATen/ops/cos_ops.h>
#include <ATen/ops/cosh_ops.h>
#include <ATen/ops/count_nonzero_ops.h>
#include <ATen/ops/cov_ops.h>
#include <ATen/ops/cross_ops.h>
#include <ATen/ops/crow_indices_ops.h>
#include <ATen/ops/cummax_ops.h>
#include <ATen/ops/cummin_ops.h>
#include <ATen/ops/cumprod_ops.h>
#include <ATen/ops/cumsum_ops.h>
#include <ATen/ops/data_ops.h>
#include <ATen/ops/deg2rad_ops.h>
#include <ATen/ops/dense_dim_ops.h>
#include <ATen/ops/dequantize_ops.h>
#include <ATen/ops/det_ops.h>
#include <ATen/ops/detach_ops.h>
#include <ATen/ops/diag_embed_ops.h>
#include <ATen/ops/diag_ops.h>
#include <ATen/ops/diagflat_ops.h>
#include <ATen/ops/diagonal_ops.h>
#include <ATen/ops/diagonal_scatter_ops.h>
#include <ATen/ops/diff_ops.h>
#include <ATen/ops/digamma_ops.h>
#include <ATen/ops/dist_ops.h>
#include <ATen/ops/div_ops.h>
#include <ATen/ops/divide_ops.h>
#include <ATen/ops/dot_ops.h>
#include <ATen/ops/dsplit_ops.h>
#include <ATen/ops/eq_ops.h>
#include <ATen/ops/equal_ops.h>
#include <ATen/ops/erf_ops.h>
#include <ATen/ops/erfc_ops.h>
#include <ATen/ops/erfinv_ops.h>
#include <ATen/ops/exp2_ops.h>
#include <ATen/ops/exp_ops.h>
#include <ATen/ops/expand_as_ops.h>
#include <ATen/ops/expand_ops.h>
#include <ATen/ops/expm1_ops.h>
#include <ATen/ops/exponential_ops.h>
#include <ATen/ops/fill_diagonal_ops.h>
#include <ATen/ops/fill_ops.h>
#include <ATen/ops/fix_ops.h>
#include <ATen/ops/flatten_ops.h>
#include <ATen/ops/flip_ops.h>
#include <ATen/ops/fliplr_ops.h>
#include <ATen/ops/flipud_ops.h>
#include <ATen/ops/float_power_ops.h>
#include <ATen/ops/floor_divide_ops.h>
#include <ATen/ops/floor_ops.h>
#include <ATen/ops/fmax_ops.h>
#include <ATen/ops/fmin_ops.h>
#include <ATen/ops/fmod_ops.h>
#include <ATen/ops/frac_ops.h>
#include <ATen/ops/frexp_ops.h>
#include <ATen/ops/gather_ops.h>
#include <ATen/ops/gcd_ops.h>
#include <ATen/ops/ge_ops.h>
#include <ATen/ops/geometric_ops.h>
#include <ATen/ops/geqrf_ops.h>
#include <ATen/ops/ger_ops.h>
#include <ATen/ops/greater_equal_ops.h>
#include <ATen/ops/greater_ops.h>
#include <ATen/ops/gt_ops.h>
#include <ATen/ops/hardshrink_backward_ops.h>
#include <ATen/ops/hardshrink_ops.h>
#include <ATen/ops/heaviside_ops.h>
#include <ATen/ops/histc_ops.h>
#include <ATen/ops/histogram_ops.h>
#include <ATen/ops/hsplit_ops.h>
#include <ATen/ops/hypot_ops.h>
#include <ATen/ops/i0_ops.h>
#include <ATen/ops/igamma_ops.h>
#include <ATen/ops/igammac_ops.h>
#include <ATen/ops/index_add_ops.h>
#include <ATen/ops/index_copy_ops.h>
#include <ATen/ops/index_fill_ops.h>
#include <ATen/ops/index_ops.h>
#include <ATen/ops/index_put_ops.h>
#include <ATen/ops/index_reduce_ops.h>
#include <ATen/ops/index_select_ops.h>
#include <ATen/ops/indices_ops.h>
#include <ATen/ops/inner_ops.h>
#include <ATen/ops/int_repr_ops.h>
#include <ATen/ops/inverse_ops.h>
#include <ATen/ops/is_coalesced_ops.h>
#include <ATen/ops/is_complex_ops.h>
#include <ATen/ops/is_conj_ops.h>
#include <ATen/ops/is_distributed_ops.h>
#include <ATen/ops/is_floating_point_ops.h>
#include <ATen/ops/is_inference_ops.h>
#include <ATen/ops/is_leaf_ops.h>
#include <ATen/ops/is_neg_ops.h>
#include <ATen/ops/is_nonzero_ops.h>
#include <ATen/ops/is_pinned_ops.h>
#include <ATen/ops/is_same_size_ops.h>
#include <ATen/ops/is_set_to_ops.h>
#include <ATen/ops/is_signed_ops.h>
#include <ATen/ops/isclose_ops.h>
#include <ATen/ops/isfinite_ops.h>
#include <ATen/ops/isinf_ops.h>
#include <ATen/ops/isnan_ops.h>
#include <ATen/ops/isneginf_ops.h>
#include <ATen/ops/isposinf_ops.h>
#include <ATen/ops/isreal_ops.h>
#include <ATen/ops/istft_ops.h>
#include <ATen/ops/item_ops.h>
#include <ATen/ops/kron_ops.h>
#include <ATen/ops/kthvalue_ops.h>
#include <ATen/ops/lcm_ops.h>
#include <ATen/ops/ldexp_ops.h>
#include <ATen/ops/le_ops.h>
#include <ATen/ops/lerp_ops.h>
#include <ATen/ops/less_equal_ops.h>
#include <ATen/ops/less_ops.h>
#include <ATen/ops/lgamma_ops.h>
#include <ATen/ops/log10_ops.h>
#include <ATen/ops/log1p_ops.h>
#include <ATen/ops/log2_ops.h>
#include <ATen/ops/log_normal_ops.h>
#include <ATen/ops/log_ops.h>
#include <ATen/ops/log_softmax_ops.h>
#include <ATen/ops/logaddexp2_ops.h>
#include <ATen/ops/logaddexp_ops.h>
#include <ATen/ops/logcumsumexp_ops.h>
#include <ATen/ops/logdet_ops.h>
#include <ATen/ops/logical_and_ops.h>
#include <ATen/ops/logical_not_ops.h>
#include <ATen/ops/logical_or_ops.h>
#include <ATen/ops/logical_xor_ops.h>
#include <ATen/ops/logit_ops.h>
#include <ATen/ops/logsumexp_ops.h>
#include <ATen/ops/lshift_ops.h>
#include <ATen/ops/lt_ops.h>
#include <ATen/ops/lu_solve_ops.h>
#include <ATen/ops/mH_ops.h>
#include <ATen/ops/mT_ops.h>
#include <ATen/ops/masked_fill_ops.h>
#include <ATen/ops/masked_scatter_ops.h>
#include <ATen/ops/masked_select_ops.h>
#include <ATen/ops/matmul_ops.h>
#include <ATen/ops/matrix_H_ops.h>
#include <ATen/ops/matrix_exp_ops.h>
#include <ATen/ops/matrix_power_ops.h>
#include <ATen/ops/max_ops.h>
#include <ATen/ops/maximum_ops.h>
#include <ATen/ops/mean_ops.h>
#include <ATen/ops/median_ops.h>
#include <ATen/ops/min_ops.h>
#include <ATen/ops/minimum_ops.h>
#include <ATen/ops/mm_ops.h>
#include <ATen/ops/mode_ops.h>
#include <ATen/ops/moveaxis_ops.h>
#include <ATen/ops/movedim_ops.h>
#include <ATen/ops/msort_ops.h>
#include <ATen/ops/mul_ops.h>
#include <ATen/ops/multinomial_ops.h>
#include <ATen/ops/multiply_ops.h>
#include <ATen/ops/mv_ops.h>
#include <ATen/ops/mvlgamma_ops.h>
#include <ATen/ops/nan_to_num_ops.h>
#include <ATen/ops/nanmean_ops.h>
#include <ATen/ops/nanmedian_ops.h>
#include <ATen/ops/nanquantile_ops.h>
#include <ATen/ops/nansum_ops.h>
#include <ATen/ops/narrow_copy_ops.h>
#include <ATen/ops/narrow_ops.h>
#include <ATen/ops/ne_ops.h>
#include <ATen/ops/neg_ops.h>
#include <ATen/ops/negative_ops.h>
#include <ATen/ops/new_empty_ops.h>
#include <ATen/ops/new_empty_strided_ops.h>
#include <ATen/ops/new_full_ops.h>
#include <ATen/ops/new_ones_ops.h>
#include <ATen/ops/new_zeros_ops.h>
#include <ATen/ops/nextafter_ops.h>
#include <ATen/ops/nonzero_numpy_ops.h>
#include <ATen/ops/nonzero_ops.h>
#include <ATen/ops/nonzero_static_ops.h>
#include <ATen/ops/norm_ops.h>
#include <ATen/ops/normal_ops.h>
#include <ATen/ops/not_equal_ops.h>
#include <ATen/ops/numpy_T_ops.h>
#include <ATen/ops/or_ops.h>
#include <ATen/ops/orgqr_ops.h>
#include <ATen/ops/ormqr_ops.h>
#include <ATen/ops/outer_ops.h>
#include <ATen/ops/output_nr_ops.h>
#include <ATen/ops/permute_ops.h>
#include <ATen/ops/pin_memory_ops.h>
#include <ATen/ops/pinverse_ops.h>
#include <ATen/ops/polygamma_ops.h>
#include <ATen/ops/positive_ops.h>
#include <ATen/ops/pow_ops.h>
#include <ATen/ops/prelu_ops.h>
#include <ATen/ops/prod_ops.h>
#include <ATen/ops/put_ops.h>
#include <ATen/ops/q_per_channel_axis_ops.h>
#include <ATen/ops/q_per_channel_scales_ops.h>
#include <ATen/ops/q_per_channel_zero_points_ops.h>
#include <ATen/ops/q_scale_ops.h>
#include <ATen/ops/q_zero_point_ops.h>
#include <ATen/ops/qr_ops.h>
#include <ATen/ops/qscheme_ops.h>
#include <ATen/ops/quantile_ops.h>
#include <ATen/ops/rad2deg_ops.h>
#include <ATen/ops/random_ops.h>
#include <ATen/ops/ravel_ops.h>
#include <ATen/ops/reciprocal_ops.h>
#include <ATen/ops/record_stream_ops.h>
#include <ATen/ops/refine_names_ops.h>
#include <ATen/ops/relu_ops.h>
#include <ATen/ops/remainder_ops.h>
#include <ATen/ops/rename_ops.h>
#include <ATen/ops/renorm_ops.h>
#include <ATen/ops/repeat_interleave_ops.h>
#include <ATen/ops/repeat_ops.h>
#include <ATen/ops/requires_grad_ops.h>
#include <ATen/ops/reshape_as_ops.h>
#include <ATen/ops/reshape_ops.h>
#include <ATen/ops/resize_as_ops.h>
#include <ATen/ops/resize_as_sparse_ops.h>
#include <ATen/ops/resize_ops.h>
#include <ATen/ops/resolve_conj_ops.h>
#include <ATen/ops/resolve_neg_ops.h>
#include <ATen/ops/retain_grad_ops.h>
#include <ATen/ops/retains_grad_ops.h>
#include <ATen/ops/roll_ops.h>
#include <ATen/ops/rot90_ops.h>
#include <ATen/ops/round_ops.h>
#include <ATen/ops/row_indices_ops.h>
#include <ATen/ops/rshift_ops.h>
#include <ATen/ops/rsqrt_ops.h>
#include <ATen/ops/scatter_add_ops.h>
#include <ATen/ops/scatter_ops.h>
#include <ATen/ops/scatter_reduce_ops.h>
#include <ATen/ops/select_ops.h>
#include <ATen/ops/select_scatter_ops.h>
#include <ATen/ops/set_data_ops.h>
#include <ATen/ops/set_ops.h>
#include <ATen/ops/sgn_ops.h>
#include <ATen/ops/sigmoid_ops.h>
#include <ATen/ops/sign_ops.h>
#include <ATen/ops/signbit_ops.h>
#include <ATen/ops/sin_ops.h>
#include <ATen/ops/sinc_ops.h>
#include <ATen/ops/sinh_ops.h>
#include <ATen/ops/size_ops.h>
#include <ATen/ops/slice_inverse_ops.h>
#include <ATen/ops/slice_ops.h>
#include <ATen/ops/slice_scatter_ops.h>
#include <ATen/ops/slogdet_ops.h>
#include <ATen/ops/smm_ops.h>
#include <ATen/ops/softmax_ops.h>
#include <ATen/ops/sort_ops.h>
#include <ATen/ops/sparse_dim_ops.h>
#include <ATen/ops/sparse_mask_ops.h>
#include <ATen/ops/sparse_resize_and_clear_ops.h>
#include <ATen/ops/sparse_resize_ops.h>
#include <ATen/ops/split_ops.h>
#include <ATen/ops/split_with_sizes_ops.h>
#include <ATen/ops/sqrt_ops.h>
#include <ATen/ops/square_ops.h>
#include <ATen/ops/squeeze_ops.h>
#include <ATen/ops/sspaddmm_ops.h>
#include <ATen/ops/std_ops.h>
#include <ATen/ops/stft_ops.h>
#include <ATen/ops/stride_ops.h>
#include <ATen/ops/sub_ops.h>
#include <ATen/ops/subtract_ops.h>
#include <ATen/ops/sum_ops.h>
#include <ATen/ops/sum_to_size_ops.h>
#include <ATen/ops/svd_ops.h>
#include <ATen/ops/swapaxes_ops.h>
#include <ATen/ops/swapdims_ops.h>
#include <ATen/ops/t_ops.h>
#include <ATen/ops/take_along_dim_ops.h>
#include <ATen/ops/take_ops.h>
#include <ATen/ops/tan_ops.h>
#include <ATen/ops/tanh_ops.h>
#include <ATen/ops/tensor_split_ops.h>
#include <ATen/ops/tile_ops.h>
#include <ATen/ops/to_dense_ops.h>
#include <ATen/ops/to_mkldnn_ops.h>
#include <ATen/ops/to_ops.h>
#include <ATen/ops/to_padded_tensor_ops.h>
#include <ATen/ops/to_sparse_bsc_ops.h>
#include <ATen/ops/to_sparse_bsr_ops.h>
#include <ATen/ops/to_sparse_csc_ops.h>
#include <ATen/ops/to_sparse_csr_ops.h>
#include <ATen/ops/to_sparse_ops.h>
#include <ATen/ops/topk_ops.h>
#include <ATen/ops/trace_ops.h>
#include <ATen/ops/transpose_ops.h>
#include <ATen/ops/triangular_solve_ops.h>
#include <ATen/ops/tril_ops.h>
#include <ATen/ops/triu_ops.h>
#include <ATen/ops/true_divide_ops.h>
#include <ATen/ops/trunc_ops.h>
#include <ATen/ops/type_as_ops.h>
#include <ATen/ops/unbind_ops.h>
#include <ATen/ops/unflatten_ops.h>
#include <ATen/ops/unfold_ops.h>
#include <ATen/ops/uniform_ops.h>
#include <ATen/ops/unsafe_chunk_ops.h>
#include <ATen/ops/unsafe_split_ops.h>
#include <ATen/ops/unsafe_split_with_sizes_ops.h>
#include <ATen/ops/unsqueeze_ops.h>
#include <ATen/ops/values_ops.h>
#include <ATen/ops/var_ops.h>
#include <ATen/ops/vdot_ops.h>
#include <ATen/ops/view_as_ops.h>
#include <ATen/ops/view_ops.h>
#include <ATen/ops/vsplit_ops.h>
#include <ATen/ops/where_ops.h>
#include <ATen/ops/xlogy_ops.h>
#include <ATen/ops/xor_ops.h>
#include <ATen/ops/zero_ops.h>
namespace at {
namespace _ops {
} // namespace _ops
} // namespace at

View File

@ -0,0 +1 @@
#include <ATen/core/NamedTensor.h>

View File

@ -0,0 +1,214 @@
#pragma once
#include <ATen/NamedTensor.h>
#include <ATen/TensorNames.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/core/DimVector.h>
#include <ATen/core/Tensor.h>
namespace at {
using NameVector = SmallVector<Dimname, kDimVectorStaticSize>;
inline bool has_names(const ITensorListRef& tensors) {
return std::any_of(tensors.begin(), tensors.end(), [](const Tensor& t) {
return t.has_names();
});
}
// Converts dim to an positional index. Errors if `dim` cannot be used to
// refer to any dimension of tensor.
TORCH_API int64_t dimname_to_position(const Tensor& tensor, Dimname dim);
TORCH_API std::vector<int64_t> dimnames_to_positions(
const Tensor& tensor,
DimnameList dims);
// Unifies two DimnameList to produce a third. This is useful for implementing
// the named inference rule for binary broadcasting operations like add.
//
// There are three main constraints:
// 1) Check matching: Names must match positionally from the right.
// 2) Check misaligned: If a name `n` is in `names`, then it must appear at
// the same index from the right in other.
// 3) The output names are obtained by unifying the names individually from the
// right.
TORCH_API std::vector<Dimname> unify_from_right(
DimnameList names,
DimnameList other,
const char* action = "broadcast");
[[noreturn]] inline void reportNYIDimnameOverload(const char* op_name) {
TORCH_CHECK(
false,
op_name,
": You passed a dimname (string) to this op in place of a dimension "
"index but it does not yet support this behavior. Please pass a dimension "
"index to work around this.");
}
// [NOTE] Writing name inference rules
//
// Operators that support named tensors are either composed of operations that
// support named tensors or implement some name inference rule. An op that
// implements its own name inference rule generally looks like the following:
//
// Tensor op(...) {
// perform_shape_checks(...);
// # (1)
// auto maybe_outnames = compute_outnames(...);
// auto result = [&]() {
// NoNamesGuard guard;
// return op_impl(...);
// }();
// # (2)
// propagate_names_if_nonempty(result, maybe_outnames);
//
// Each op has (1) a compute outnames step and (2) a propagate names step.
//
// compute_outnames is responsible for checking that input names match and
// determining what the output names should be. It returns either:
// - {} (if the inputs tensors are all unnamed)
// - non-empty outnames.
//
// propagate_names_if_nonempty propagates the outnames if they exist to the
// result tensors.
//
// The {} case is an optimization; if the user does not use named tensors they
// pay no perf cost for it.
namespace namedinference {
const Tensor& propagate_names_if_present_and_nonempty(
const Tensor& result,
std::optional<DimnameList> maybe_names,
bool validate_names = false);
// Propagates `names` to `result` if `names` is not empty.
// `names` can be empty; see [NOTE] Writing name inference rules
// If `names` is not empty, `names.size()` should equal `result.dim()`.
// When in doubt, use this overload instead of the others.
TORCH_API const Tensor& propagate_names_if_nonempty(
const Tensor& result,
DimnameList maybe_names,
bool validate_names = false);
// Propagates `names` to `result`. Only use this if we are certain that there
// are names to propagate (that names is not empty).
TORCH_API const Tensor& propagate_names(
const Tensor& result,
DimnameList names,
bool validate_names = false);
// Propagates all names from src to result.
TORCH_API void propagate_names(const Tensor& result, const Tensor& src);
// Propagates all names except for those at the excluded_idxs.
TORCH_API void propagate_names_except(
const Tensor& result,
const Tensor& src,
IntArrayRef excluded_idxs);
// Used for reduction ops that have a `keepdim` arg.
TORCH_API void propagate_names_for_reduction(
const Tensor& result,
const Tensor& src,
IntArrayRef excluded_idxs,
bool keepdim);
TORCH_API void propagate_names_for_expand(
const Tensor& result,
const Tensor& self);
TORCH_API std::vector<Dimname> compute_cat_outnames(
const MaterializedITensorListRef& tensors);
TORCH_API std::vector<Dimname> compute_broadcast_outnames(
const Tensor& self,
const Tensor& other);
TORCH_API std::vector<Dimname> broadcast_to_outnames(
const Tensor& tensor,
const Tensor& reference_tensor,
const char* op_name);
TORCH_API std::vector<Dimname> compute_matmul_outnames(
const Tensor& self,
const Tensor& other);
TORCH_API std::vector<Dimname> compute_cdist_outnames(
const Tensor& self,
const Tensor& other);
TORCH_API std::vector<Dimname> compute_bmm_outnames(
const Tensor& result,
const Tensor& self,
const Tensor& other);
TORCH_API std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor);
TORCH_API std::vector<Dimname> compute_squeeze_outnames(
const Tensor& tensor,
std::bitset<dim_bitset_size> dims);
std::vector<Dimname> compute_diagonal_outnames(
const Tensor& tensor,
int64_t dim1,
int64_t dim2);
// TensorImpl* overloads for Legacy TH/THC code. Use these sparingly.
TORCH_API TensorImpl* propagate_names_if_nonempty(
TensorImpl* result,
DimnameList maybe_names,
bool validate_names = false);
TORCH_API TensorImpl* propagate_names(
TensorImpl* result,
DimnameList names,
bool validate_names = false);
TORCH_API void propagate_names(TensorImpl* result, /*const */ TensorImpl* src);
TORCH_API inline void propagate_names(
const TensorBase& result,
DimnameList names,
bool validate_names = false) {
propagate_names(result.unsafeGetTensorImpl(), names, validate_names);
}
TORCH_API inline void propagate_names_if_nonempty(
const TensorBase& result,
DimnameList names,
bool validate_names = false) {
propagate_names_if_nonempty(
result.unsafeGetTensorImpl(), names, validate_names);
}
TORCH_API inline void propagate_names(
const TensorBase& result,
const TensorBase& src) {
propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl());
}
// result = m1 @ m2 + bias
TORCH_API std::vector<Dimname> propagate_names_for_addmm(
const Tensor& m1,
const Tensor& m2,
const Tensor& bias);
TORCH_API std::vector<Dimname> propagate_names_for_addmv(
const Tensor& mat,
const Tensor& vec,
const Tensor& bias);
TORCH_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2);
TORCH_API std::vector<Dimname> compute_baddbmm_outnames(
const Tensor& result,
const Tensor& self,
const Tensor& other,
const Tensor& bias);
TORCH_API bool are_names_equal(TensorImpl* self, TensorImpl* other);
} // namespace namedinference
} // namespace at

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,286 @@
#pragma once
#include <ATen/MemoryOverlap.h>
#include <ATen/Tensor.h>
#include <c10/core/DispatchKey.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/irange.h>
namespace at::native {
struct NestedTensorImpl;
inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt);
int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor);
at::Tensor construct_nested_strides(const at::Tensor& nested_size);
at::Tensor construct_offsets(const at::Tensor& nested_size);
struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
explicit NestedTensorImpl(
Storage storage,
c10::DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
at::Tensor nested_sizes,
at::Tensor nested_strides,
at::Tensor storage_offsets);
explicit NestedTensorImpl(
const at::Tensor& buffer,
at::Tensor nested_sizes,
at::Tensor nested_strides,
at::Tensor storage_offsets);
// assume contiguous, `nested_strides` and `offsets`
// can be infered from `nested_sizes`
explicit NestedTensorImpl(
const at::Tensor& buffer,
const at::Tensor& nested_sizes);
// This constructor is used creating view tensors from nested tensors
explicit NestedTensorImpl(
c10::TensorImpl::ImplType impl_type,
const at::Tensor& base_tensor,
at::Tensor nested_sizes,
at::Tensor nested_strides,
at::Tensor storage_offsets);
// TODO: don't expose private implementation details like this; in
// particular, resizing this tensor will mess up our dim() and
// callers cannot fix it.
const Tensor& get_nested_sizes() const {
return nested_sizes_;
}
// TODO: don't expose private implementation details like this
const Tensor& get_nested_strides() const {
return nested_strides_;
}
const Tensor& get_storage_offsets() const {
return storage_offsets_;
}
// Returns nullopt if the ith dimension is irregular. The ith dimension
// of a NestedTensor is regular if the unbound tensors match in
// size at the (i-1)th dimension.
std::optional<int64_t> opt_size(int64_t d) const;
int64_t size(int64_t d) const {
std::optional<int64_t> optional_size = this->opt_size(d);
TORCH_CHECK(
optional_size.has_value(),
"Given dimension ",
d,
" is irregular and does not have a size.");
return *optional_size;
}
/**
* Return a view of the nested tensor as a 1 dimensional contiguous tensor.
*
* The buffer tensor created by this function shares the same storage_impl as
* the original nested tensor, and therefore can be seen as a view.
*
* @return A newly constructed view tensor
*/
at::Tensor get_buffer() const {
TORCH_CHECK(
nested_tensor_impl_is_contiguous(this),
"NestedTensor must be contiguous to get buffer.");
return get_unsafe_storage_as_tensor();
}
/**
* If possible use get_buffer() instead. This function returns the storage
* as a tensor directly, which is not safe to use in general. If using this
* function, The caller must ensure to account for nested_sizes,
* nested_strides and storage_offsets.
*
* @return A newly constructed view tensor
*/
at::Tensor get_unsafe_storage_as_tensor() const {
auto buffer_key_set_ = generate_buffer_key_set();
const auto buffer_size = get_buffer_size();
auto buffer_tensor_impl = c10::make_intrusive<TensorImpl>(
c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_);
buffer_tensor_impl->set_sizes_contiguous(
c10::makeArrayRef(static_cast<int64_t>(buffer_size)));
return Tensor(buffer_tensor_impl);
}
size_t get_buffer_size() const {
return storage_.nbytes() / data_type_.itemsize();
}
protected:
const char* tensorimpl_type_name() const override;
// TODO: numel_custom and is_contiguous_custom can be profitably overridden
// with real implementations
int64_t numel_custom() const override;
c10::SymInt sym_numel_custom() const override;
bool is_contiguous_custom(MemoryFormat) const override;
int64_t size_custom(int64_t d) const override {
return this->size(d);
}
c10::SymInt sym_size_custom(int64_t d) const override {
return c10::SymInt{this->size(d)};
}
IntArrayRef sizes_custom() const override;
c10::SymIntArrayRef sym_sizes_custom() const override;
IntArrayRef strides_custom() const override;
c10::SymIntArrayRef sym_strides_custom() const override;
// this one is real
int64_t dim_custom() const override;
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const override;
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const override;
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
copy_tensor_metadata(
/*src_impl=*/impl.get(),
/*dest_impl=*/this,
/*version_counter=*/version_counter(),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
}
private:
// Must be called after any changes to our dim() to sync the state
// to TensorImpl.
void refresh_dim();
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const at::Tensor nested_sizes_, nested_strides_;
// The starting positions of the underlying tensors in contiguous buffer
// i.e. the buffer memory offsets to get the underlying tensors
// The reason to keep this metadata is that, without strong enough constraint
// it cannot be derived from `nested_sizes_`
// and `nested_strides_`:
// 1. when buffer has blanks, e.g. [tensor1, blank, tensor2]
// this can happen e.g. after slicing a nested tensor
// 2. when multiple tensors share a same memory
// 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2]
// Some strong enough constraints are:
// 1. every underlying tensor is contiguous in memory
// && nesting in ascending order
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const at::Tensor storage_offsets_;
// NOTE: -1 here means the size is missing
// Optional to allow it to be computed lazily from nested.
// TODO: maybe we can remove this metadata since
// we can compute it from `nested_sizes_`
mutable std::optional<std::vector<int64_t>> opt_sizes_;
template <typename VariableVersion>
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const;
/**
* Generates a non-nested key_set from a nested tensor.
*
* For many nested tensor kernel implementations a buffer tensor
* is generated and redispatched to a non-nested kernel this function
* generates the key set used by that buffer tensor
*
* @return Appropriate key set for non-nested tensor
*/
inline c10::DispatchKeySet generate_buffer_key_set() const {
auto buffer_key_set = this->key_set();
const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset);
// Remove nested tensor specific keys
buffer_key_set = buffer_key_set -
c10::DispatchKeySet{
c10::DispatchKey::NestedTensor,
c10::DispatchKey::AutogradNestedTensor};
// Add dense tensor specific keys
buffer_key_set =
buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense};
buffer_key_set = Autograd
? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set
: buffer_key_set;
return buffer_key_set;
}
};
inline NestedTensorImpl* get_nested_tensor_impl_or_null(
const at::Tensor& tensor) {
if (tensor.is_nested()) {
return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
}
return nullptr;
}
inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) {
TORCH_CHECK(
tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor.");
return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
}
inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) {
int64_t ntensors = nt->size(0);
if (ntensors == 0) {
return true;
}
const Tensor &sizemat = nt->get_nested_sizes(),
&stridemat = nt->get_nested_strides();
const int64_t* offsets_ptr =
nt->get_storage_offsets().const_data_ptr<int64_t>();
int64_t orig_dim = sizemat.size(1);
// nesting scalars
if (orig_dim == 0) {
// each scalar must be contiguous
// if there is blank memory between underlying scalars
for (int64_t i = 0; i < ntensors; i++) {
if (offsets_ptr[i] != i) {
return false;
}
}
}
// nesting tensors
else {
// if any underlying tensor is non-contiguous
const int64_t *sizemat_ptr = sizemat.const_data_ptr<int64_t>(),
*stridemat_ptr = stridemat.const_data_ptr<int64_t>();
for (int64_t i = 0; i < ntensors; i++) {
if (stridemat_ptr[orig_dim - 1] != 1) {
return false;
}
int64_t product = sizemat_ptr[orig_dim - 1];
for (int64_t j = orig_dim - 2; j >= 0; j--) {
if (stridemat_ptr[j] != product) {
return false;
}
product *= sizemat_ptr[j];
}
sizemat_ptr += orig_dim;
stridemat_ptr += orig_dim;
}
// if there is blank memory between underlying tensors
if (offsets_ptr[0] != 0) {
return false;
}
sizemat_ptr = sizemat.const_data_ptr<int64_t>();
stridemat_ptr = stridemat.const_data_ptr<int64_t>();
for (int64_t i = 1; i < ntensors; i++) {
if (offsets_ptr[i] !=
offsets_ptr[i - 1] + *sizemat_ptr * *stridemat_ptr) {
return false;
}
sizemat_ptr += orig_dim;
stridemat_ptr += orig_dim;
}
}
// everything is fine
return true;
}
inline const at::Tensor& get_nested_sizes(const at::Tensor& tensor) {
return get_nested_tensor_impl(tensor)->get_nested_sizes();
}
} // namespace at::native

View File

@ -0,0 +1,203 @@
#pragma once
#ifdef __HIPCC__
#include <hip/hip_runtime.h>
#endif
#include <c10/macros/Macros.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Half.h>
#include <c10/util/complex.h>
#include <cmath>
#include <type_traits>
namespace at {
// std::isnan isn't performant to use on integral types; it will
// (uselessly) convert to floating point and then do the test.
// This function is.
template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
inline C10_HOST_DEVICE bool _isnan(T /*val*/) {
return false;
}
template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
inline C10_HOST_DEVICE bool _isnan(T val) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return ::isnan(val);
#else
return std::isnan(val);
#endif
}
template <typename T, std::enable_if_t<c10::is_complex<T>::value, int> = 0>
inline C10_HOST_DEVICE bool _isnan(T val) {
return std::isnan(val.real()) || std::isnan(val.imag());
}
template <typename T, std::enable_if_t<std::is_same_v<T, at::Half>, int> = 0>
inline C10_HOST_DEVICE bool _isnan(T val) {
return at::_isnan(static_cast<float>(val));
}
template <
typename T,
std::enable_if_t<std::is_same_v<T, at::BFloat16>, int> = 0>
inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
return at::_isnan(static_cast<float>(val));
}
inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
return at::_isnan(static_cast<float>(val));
}
template <
typename T,
std::enable_if_t<std::is_same_v<T, at::Float8_e5m2>, int> = 0>
inline C10_HOST_DEVICE bool _isnan(T val) {
return val.isnan();
}
template <
typename T,
std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fn>, int> = 0>
inline C10_HOST_DEVICE bool _isnan(T val) {
return val.isnan();
}
template <
typename T,
std::enable_if_t<std::is_same_v<T, at::Float8_e5m2fnuz>, int> = 0>
inline C10_HOST_DEVICE bool _isnan(T val) {
return val.isnan();
}
template <
typename T,
std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fnuz>, int> = 0>
inline C10_HOST_DEVICE bool _isnan(T val) {
return val.isnan();
}
// std::isinf isn't performant to use on integral types; it will
// (uselessly) convert to floating point and then do the test.
// This function is.
template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
inline C10_HOST_DEVICE bool _isinf(T /*val*/) {
return false;
}
template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
inline C10_HOST_DEVICE bool _isinf(T val) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return ::isinf(val);
#else
return std::isinf(val);
#endif
}
inline C10_HOST_DEVICE bool _isinf(at::Half val) {
return at::_isinf(static_cast<float>(val));
}
inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) {
return at::_isinf(static_cast<float>(val));
}
inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2 val) {
return val.isinf();
}
inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val [[maybe_unused]]) {
return false;
}
inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2fnuz val [[maybe_unused]]) {
return false;
}
inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fnuz val [[maybe_unused]]) {
return false;
}
template <typename T>
C10_HOST_DEVICE inline T exp(T x) {
static_assert(
!std::is_same_v<T, double>,
"this template must be used with float or less precise type");
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
// use __expf fast approximation for peak bandwidth
return __expf(x);
#else
return ::exp(x);
#endif
}
template <>
C10_HOST_DEVICE inline double exp<double>(double x) {
return ::exp(x);
}
template <typename T>
C10_HOST_DEVICE inline T log(T x) {
static_assert(
!std::is_same_v<T, double>,
"this template must be used with float or less precise type");
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
// use __logf fast approximation for peak bandwidth
return __logf(x);
#else
return ::log(x);
#endif
}
template <>
C10_HOST_DEVICE inline double log<double>(double x) {
return ::log(x);
}
template <typename T>
C10_HOST_DEVICE inline T log1p(T x) {
static_assert(
!std::is_same_v<T, double>,
"this template must be used with float or less precise type");
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
// use __logf fast approximation for peak bandwidth
// NOTE: There is no __log1pf so unfortunately we lose precision.
return __logf(1.0f + x);
#else
return ::log1p(x);
#endif
}
template <>
C10_HOST_DEVICE inline double log1p<double>(double x) {
return ::log1p(x);
}
template <typename T>
C10_HOST_DEVICE inline T tan(T x) {
static_assert(
!std::is_same_v<T, double>,
"this template must be used with float or less precise type");
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
// use __tanf fast approximation for peak bandwidth
return __tanf(x);
#else
return ::tan(x);
#endif
}
template <>
C10_HOST_DEVICE inline double tan<double>(double x) {
return ::tan(x);
}
} // namespace at

View File

@ -0,0 +1,69 @@
#pragma once
#include <c10/core/ScalarType.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Exception.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Half.h>
namespace at {
// For FP16 or BFloat16 inputs, ops should perform internal math in FP32.
template <typename scalar_t>
struct OpMathType {
using type = scalar_t;
};
template <>
struct OpMathType<at::Half> {
using type = float;
};
template <>
struct OpMathType<at::BFloat16> {
using type = float;
};
template <>
struct OpMathType<at::Float8_e5m2> {
using type = float;
};
template <>
struct OpMathType<at::Float8_e4m3fn> {
using type = float;
};
template <>
struct OpMathType<at::Float8_e5m2fnuz> {
using type = float;
};
template <>
struct OpMathType<at::Float8_e4m3fnuz> {
using type = float;
};
template <>
struct OpMathType<c10::complex<Half>> {
using type = c10::complex<float>;
};
template <typename T>
using opmath_type = typename OpMathType<T>::type;
namespace {
inline c10::ScalarType toOpMathType(const c10::ScalarType type) {
switch (type) {
#define DEFINE_CASE(scalar_t, TypeNum) \
case ScalarType::TypeNum: \
return CppTypeToScalarType<at::opmath_type<scalar_t>>::value;
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
#undef DEFINE_CASE
default:
TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
}
}
} // namespace
} // namespace at

View File

@ -0,0 +1,187 @@
#pragma once
#include <c10/core/MemoryFormat.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/Exception.h>
namespace at {
// An "Opaque" TensorImpl -- there are no strides and (for now)
// even data() is not supported (thus no pointer arithmetic).
// NOTE: We could allow data() in the future, but would have to ensure pointer
// arithmetic code is properly guarded.
//
// NOTE: This does not support resize_ (and other metadata-changing ops) because
// of `shallow_copy_and_detach`. We would need to define an interface to
// "shallow copy" in order to add support.
template <typename OpaqueHandle>
struct TORCH_API OpaqueTensorImpl : public TensorImpl {
// public constructor for now...
OpaqueTensorImpl(
at::DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
c10::Device device,
OpaqueHandle opaque_handle,
c10::IntArrayRef sizes,
bool is_non_overlapping_and_dense = true)
: TensorImpl(key_set, data_type, device),
opaque_handle_(std::move(opaque_handle)) {
set_storage_access_should_throw();
set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
sizes_and_strides_.set_sizes(sizes);
refresh_numel();
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
}
// Destructor doesn't call release_resources because it's
// unnecessary; don't forget to change that if needed!
void release_resources() override {
TensorImpl::release_resources();
opaque_handle_ = {};
}
void set_size(int64_t dim, int64_t new_size) override {
AT_ERROR("opaque tensors do not have set_size");
}
void set_stride(int64_t dim, int64_t new_stride) override {
AT_ERROR("opaque tensors do not have set_stride");
}
void set_storage_offset(int64_t storage_offset) override {
AT_ERROR("opaque tensors do not have set_storage_offset");
}
#ifdef DEBUG
bool has_storage() const override {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!storage_, "OpaqueTensorImpl assumes that storage_ is never set");
return false;
}
#endif
/**
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
*
* For usage of `version_counter` and `allow_tensor_metadata_change`,
* see NOTE [ TensorImpl Shallow-Copying ].
*/
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const override {
auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
key_set(),
dtype(),
device(),
opaque_handle_,
sizes_and_strides_.sizes_arrayref());
copy_tensor_metadata(
/*src_opaque_impl=*/this,
/*dest_opaque_impl=*/impl.get(),
/*version_counter=*/version_counter,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
impl->refresh_numel();
return impl;
}
/**
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
*
* For usage of `version_counter` and `allow_tensor_metadata_change`,
* see NOTE [ TensorImpl Shallow-Copying ].
*/
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const override {
auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
key_set(),
dtype(),
device(),
opaque_handle_,
sizes_and_strides_.sizes_arrayref());
copy_tensor_metadata(
/*src_opaque_impl=*/this,
/*dest_opaque_impl=*/impl.get(),
/*version_counter=*/std::move(version_counter),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
impl->refresh_numel();
return impl;
}
/**
* Shallow-copies data from another TensorImpl into this TensorImpl.
*
* For why this function doesn't check this TensorImpl's
* `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
*/
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
auto opaque_impl =
static_cast<const OpaqueTensorImpl<OpaqueHandle>*>(impl.get());
copy_tensor_metadata(
/*src_impl=*/opaque_impl,
/*dest_impl=*/this,
/*version_counter=*/version_counter(),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
refresh_numel();
}
const OpaqueHandle& opaque_handle() const {
return opaque_handle_;
}
OpaqueHandle& unsafe_opaque_handle() {
return opaque_handle_;
}
protected:
/**
* Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
* storage_offset) from one TensorImpl to another TensorImpl.
*
* For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
* [ TensorImpl Shallow-Copying ].
*/
static void copy_tensor_metadata(
const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) {
TensorImpl::copy_tensor_metadata(
src_opaque_impl,
dest_opaque_impl,
version_counter,
allow_tensor_metadata_change);
// OpaqueTensorImpl-specific fields.
dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
}
static void copy_tensor_metadata(
const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) {
TensorImpl::copy_tensor_metadata(
src_opaque_impl,
dest_opaque_impl,
std::move(version_counter),
allow_tensor_metadata_change);
// OpaqueTensorImpl-specific fields.
dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
}
private:
const char* tensorimpl_type_name() const override {
return "OpaqueTensorImpl";
}
OpaqueHandle opaque_handle_;
};
} // namespace at

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,17 @@
#pragma once
#include <ATen/Parallel.h>
#include <c10/core/thread_pool.h>
namespace at {
class TORCH_API PTThreadPool : public c10::ThreadPool {
public:
explicit PTThreadPool(int pool_size, int numa_node_id = -1)
: c10::ThreadPool(pool_size, numa_node_id, []() {
c10::setThreadName("PTThreadPool");
at::init_num_threads();
}) {}
};
} // namespace at

View File

@ -0,0 +1,28 @@
#pragma once
#include <c10/util/Exception.h>
#include <c10/util/string_view.h>
namespace at {
enum class padding_mode {
reflect,
replicate,
circular,
constant,
};
static inline c10::string_view padding_mode_string(padding_mode m) {
switch (m) {
case padding_mode::reflect:
return "reflect";
case padding_mode::replicate:
return "replicate";
case padding_mode::circular:
return "circular";
case padding_mode::constant:
return "constant";
}
TORCH_CHECK(false, "Invalid padding mode (", static_cast<int64_t>(m), ")");
}
} // namespace at

View File

@ -0,0 +1,93 @@
#pragma once
#include <c10/util/Exception.h>
#include <c10/util/ParallelGuard.h>
#include <c10/util/SmallVector.h>
namespace at {
template <class F>
inline void parallel_for(
const int64_t begin,
const int64_t end,
const int64_t grain_size,
const F& f) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0);
if (begin >= end) {
return;
}
#ifdef INTRA_OP_PARALLEL
at::internal::lazy_init_num_threads();
const auto numiter = end - begin;
const bool use_parallel =
(numiter > grain_size && numiter > 1 && !at::in_parallel_region() &&
at::get_num_threads() > 1);
if (!use_parallel) {
internal::ThreadIdGuard tid_guard(0);
c10::ParallelGuard guard(true);
f(begin, end);
return;
}
internal::invoke_parallel(
begin, end, grain_size, [&](int64_t begin, int64_t end) {
c10::ParallelGuard guard(true);
f(begin, end);
});
#else
internal::ThreadIdGuard tid_guard(0);
c10::ParallelGuard guard(true);
f(begin, end);
#endif
}
template <class scalar_t, class F, class SF>
inline scalar_t parallel_reduce(
const int64_t begin,
const int64_t end,
const int64_t grain_size,
const scalar_t ident,
const F& f,
const SF& sf) {
TORCH_CHECK(grain_size >= 0);
if (begin >= end) {
return ident;
}
#ifdef INTRA_OP_PARALLEL
at::internal::lazy_init_num_threads();
const auto max_threads = at::get_num_threads();
const bool use_parallel =
((end - begin) > grain_size && !at::in_parallel_region() &&
max_threads > 1);
if (!use_parallel) {
internal::ThreadIdGuard tid_guard(0);
c10::ParallelGuard guard(true);
return f(begin, end, ident);
}
c10::SmallVector<scalar_t, 64> results(max_threads, ident);
internal::invoke_parallel(
begin,
end,
grain_size,
[&](const int64_t my_begin, const int64_t my_end) {
const auto tid = at::get_thread_num();
c10::ParallelGuard guard(true);
results[tid] = f(my_begin, my_end, ident);
});
scalar_t result = ident;
for (auto partial_result : results) {
result = sf(result, partial_result);
}
return result;
#else
internal::ThreadIdGuard tid_guard(0);
c10::ParallelGuard guard(true);
return f(begin, end, ident);
#endif
}
} // namespace at

View File

@ -0,0 +1,158 @@
#pragma once
#include <ATen/Config.h>
#include <c10/macros/Macros.h>
#include <functional>
#include <string>
namespace at {
inline int64_t divup(int64_t x, int64_t y) {
return (x + y - 1) / y;
}
// Called during new thread initialization
TORCH_API void init_num_threads();
// Sets the number of threads to be used in parallel region
TORCH_API void set_num_threads(int);
// Returns the maximum number of threads that may be used in a parallel region
TORCH_API int get_num_threads();
// Returns the current thread number (starting from 0)
// in the current parallel region, or 0 in the sequential region
TORCH_API int get_thread_num();
// Checks whether the code runs in parallel region
TORCH_API bool in_parallel_region();
namespace internal {
// Initialise num_threads lazily at first parallel call
inline void lazy_init_num_threads() {
thread_local bool init = false;
if (C10_UNLIKELY(!init)) {
at::init_num_threads();
init = true;
}
}
TORCH_API void set_thread_num(int);
class TORCH_API ThreadIdGuard {
public:
ThreadIdGuard(int new_id) : old_id_(at::get_thread_num()) {
set_thread_num(new_id);
}
~ThreadIdGuard() {
set_thread_num(old_id_);
}
private:
int old_id_;
};
} // namespace internal
/*
parallel_for
begin: index at which to start applying user function
end: index at which to stop applying user function
grain_size: number of elements per chunk. impacts the degree of parallelization
f: user function applied in parallel to the chunks, signature:
void f(int64_t begin, int64_t end)
Warning: parallel_for does NOT copy thread local
states from the current thread to the worker threads.
This means for example that Tensor operations CANNOT be used in the
body of your function, only data pointers.
*/
template <class F>
inline void parallel_for(
const int64_t begin,
const int64_t end,
const int64_t grain_size,
const F& f);
/*
parallel_reduce
begin: index at which to start applying reduction
end: index at which to stop applying reduction
grain_size: number of elements per chunk. impacts number of elements in
intermediate results tensor and degree of parallelization.
ident: identity for binary combination function sf. sf(ident, x) needs to return
x.
f: function for reduction over a chunk. f needs to be of signature scalar_t
f(int64_t partial_begin, int64_t partial_end, scalar_t identifiy)
sf: function to combine two partial results. sf needs to be of signature
scalar_t sf(scalar_t x, scalar_t y)
For example, you might have a tensor of 10000 entires and want to sum together
all the elements. Parallel_reduce with a grain_size of 2500 will then allocate
an intermediate result tensor with 4 elements. Then it will execute the function
"f" you provide and pass the beginning and end index of these chunks, so
0-2499, 2500-4999, etc. and the combination identity. It will then write out
the result from each of these chunks into the intermediate result tensor. After
that it'll reduce the partial results from each chunk into a single number using
the combination function sf and the identity ident. For a total summation this
would be "+" and 0 respectively. This is similar to tbb's approach [1], where
you need to provide a function to accumulate a subrange, a function to combine
two partial results and an identity.
Warning: parallel_reduce does NOT copy thread local
states from the current thread to the worker threads.
This means for example that Tensor operations CANNOT be used in the
body of your function, only data pointers.
[1] https://software.intel.com/en-us/node/506154
*/
template <class scalar_t, class F, class SF>
inline scalar_t parallel_reduce(
const int64_t begin,
const int64_t end,
const int64_t grain_size,
const scalar_t ident,
const F& f,
const SF& sf);
// Returns a detailed string describing parallelization settings
TORCH_API std::string get_parallel_info();
// Sets number of threads used for inter-op parallelism
TORCH_API void set_num_interop_threads(int);
// Returns the number of threads used for inter-op parallelism
TORCH_API int get_num_interop_threads();
// Launches inter-op parallel task
TORCH_API void launch(std::function<void()> func);
namespace internal {
void launch_no_thread_state(std::function<void()> fn);
} // namespace internal
// Launches intra-op parallel task
TORCH_API void intraop_launch(std::function<void()> func);
// Returns number of intra-op threads used by default
TORCH_API int intraop_default_num_threads();
} // namespace at
#if AT_PARALLEL_OPENMP
#include <ATen/ParallelOpenMP.h> // IWYU pragma: keep
#elif AT_PARALLEL_NATIVE
#include <ATen/ParallelNative.h> // IWYU pragma: keep
#endif
#include <ATen/Parallel-inl.h> // IWYU pragma: keep

View File

@ -0,0 +1,13 @@
#pragma once
#include <ATen/core/ivalue.h>
#include <c10/macros/Macros.h>
#include <functional>
namespace at {
// Launches intra-op parallel task, returns a future
TORCH_API c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
std::function<void()> func);
} // namespace at

View File

@ -0,0 +1,15 @@
#pragma once
#include <c10/util/Exception.h>
#define INTRA_OP_PARALLEL
namespace at::internal {
TORCH_API void invoke_parallel(
const int64_t begin,
const int64_t end,
const int64_t grain_size,
const std::function<void(int64_t, int64_t)>& f);
} // namespace at::internal

View File

@ -0,0 +1,54 @@
#pragma once
#include <algorithm>
#include <atomic>
#include <cstddef>
#include <exception>
#ifdef _OPENMP
#define INTRA_OP_PARALLEL
#include <omp.h>
#endif
#ifdef _OPENMP
namespace at::internal {
template <typename F>
inline void invoke_parallel(
int64_t begin,
int64_t end,
int64_t grain_size,
const F& f) {
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
std::exception_ptr eptr;
#pragma omp parallel
{
// choose number of tasks based on grain size and number of threads
// can't use num_threads clause due to bugs in GOMP's thread pool (See
// #32008)
int64_t num_threads = omp_get_num_threads();
if (grain_size > 0) {
num_threads = std::min(num_threads, divup((end - begin), grain_size));
}
int64_t tid = omp_get_thread_num();
int64_t chunk_size = divup((end - begin), num_threads);
int64_t begin_tid = begin + tid * chunk_size;
if (begin_tid < end) {
try {
internal::ThreadIdGuard tid_guard(tid);
f(begin_tid, std::min(end, chunk_size + begin_tid));
} catch (...) {
if (!err_flag.test_and_set()) {
eptr = std::current_exception();
}
}
}
}
if (eptr) {
std::rethrow_exception(eptr);
}
}
} // namespace at::internal
#endif // _OPENMP

View File

@ -0,0 +1,36 @@
#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/macros/Macros.h>
namespace at::impl {
enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED };
struct TORCH_API PythonTorchFunctionTLS {
static void set_disabled_state(TorchFunctionDisabledState disabled_state_);
static TorchFunctionDisabledState get_disabled_state();
static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
static const std::shared_ptr<SafePyObject> pop_stack();
static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
static int64_t stack_len();
static const PythonTorchFunctionTLS& get_state();
static void set_state(const PythonTorchFunctionTLS& state);
private:
// The mode TLS is split into
// - disabled_state, which says which part of torch function are disabled
// - stack_, which is a vector of modes representing the stack of user
// defined modes
TorchFunctionDisabledState disabled_state_ =
TorchFunctionDisabledState::ENABLED;
std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
};
TORCH_API bool torch_function_mode_enabled();
TORCH_API bool torch_function_all_disabled();
} // namespace at::impl

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,66 @@
#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/macros/Export.h>
#include <c10/util/python_stub.h>
#include <optional>
#include <stack>
#include <string>
#include <utility>
namespace at {
namespace impl {
struct TORCH_API SavedTensorDefaultHooksTLS {
// PyObject is defined in c10/util/python_stub.h
std::stack<std::pair<c10::SafePyObject, c10::SafePyObject>> stack;
// See NOTE: [Disabling SavedTensorDefaultHooks] for context
// NOTE: [disabled_error_message invariant]
// disabled_error_message is nullopt IFF Saved Tensor hooks is enabled
// We did this for efficiency (so we didn't have to keep a separate bool
// around)
std::optional<std::string> disabled_error_message;
// See NOTE: [Deferring tensor pack/unpack hooks until runtime]
bool is_tracing = false;
};
} // namespace impl
struct TORCH_API SavedTensorDefaultHooks {
static void push_hooks(
c10::SafePyObject pack_hook,
c10::SafePyObject unpack_hook);
static std::pair<c10::SafePyObject, c10::SafePyObject> pop_hooks();
static std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
get_hooks();
static void lazy_initialize();
static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls);
// NOTE: [Disabling SavedTensorDefaultHooks]
// A developer of a PyTorch feature may choose to disable SavedTensorDefault
// hooks, especially if their feature does not work with it. If they are
// disabled, then the following will raise an error:
// - Attempting to push_hooks
// - calling disable(message) with a non-zero stack (hooks) size
static void disable(const std::string& error_message);
static void enable();
static bool is_enabled();
static const std::optional<std::string>& get_disabled_error_message();
// NOTE: [Deferring tensor pack/unpack hooks until runtime]
// To preserve eager semantics of pack/unpack hooks firing only once per saved
// variable, Dynamo/AOTAutograd need to defer hook firing until runtime. Using
// disable() would loud error at trace time, and pushing a no-op hook would
// fail when the traced code is wrapped in a disable_saved_tensors_hooks ctx.
// To do so, we disable these hooks during tracing. See
// https://github.com/pytorch/pytorch/issues/113263.
static bool set_tracing(bool is_tracing);
};
} // namespace at

View File

@ -0,0 +1,3 @@
#pragma once
#include <ATen/core/Scalar.h>

View File

@ -0,0 +1,53 @@
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/Scalar.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/scalar_tensor.h>
#endif
namespace at::detail {
// When filling a number to 1-element CPU tensor, we want to skip
// everything but manipulate data ptr directly.
// Ideally this fast pass should be implemented in TensorIterator,
// but we also want to skip compute_types which in not avoidable
// in TensorIterator for now.
Tensor& scalar_fill(Tensor& self, const Scalar& value);
TORCH_API Tensor scalar_tensor_static(
const Scalar& s,
std::optional<ScalarType> dtype_opt,
std::optional<Device> device_opt);
} // namespace at::detail
// This is in the c10 namespace because we use ADL to find the functions in it.
namespace c10 {
// FIXME: this should be (and was) Scalar::toTensor, but there is currently no
// way to implement this without going through Derived Types (which are not part
// of core).
inline at::Tensor scalar_to_tensor(
const Scalar& s,
const Device device = at::kCPU) {
// This is the fast track we have for CPU scalar tensors.
if (device == at::kCPU) {
return at::detail::scalar_tensor_static(s, s.type(), at::kCPU);
}
return at::scalar_tensor(s, at::device(device).dtype(s.type()));
}
} // namespace c10
namespace at::native {
inline Tensor wrapped_scalar_tensor(
const Scalar& scalar,
const Device device = at::kCPU) {
auto tensor = scalar_to_tensor(scalar, device);
tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
return tensor;
}
} // namespace at::native

View File

@ -0,0 +1,4 @@
#pragma once
#include <ATen/core/ATenGeneral.h> // for BC reasons
#include <c10/core/Backend.h>
#include <c10/core/ScalarType.h>

View File

@ -0,0 +1,13 @@
#pragma once
#include <c10/macros/Export.h>
#include <cstdint>
// A simple thread local enumeration, used to link forward and backward pass
// ops and is used by autograd and observers framework
namespace at::sequence_number {
TORCH_API uint64_t peek();
TORCH_API uint64_t get_and_increment();
} // namespace at::sequence_number

View File

@ -0,0 +1,2 @@
#pragma once
#include <c10/util/SmallVector.h>

View File

@ -0,0 +1,206 @@
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/impl/TorchDispatchModeTLS.h>
#include <c10/util/Exception.h>
namespace at {
// Struct implementing a sparse CSR tensor. It uses three 1-D tensors for
// denoting the data: `crow_indices_`, `col_indices_` and `values_`.
// The `crow_indices_` tensor is a integer tensor of shape `(size(0) + 1)`
// that represents the compressed row indices of the CSR tensor. The
// `col_indices_` tensor is an integer tensor of shape `(nnz())`
// that explicitly stores the column indices of each value of the sparse
// tensor. The `values_` tensor can be of any pytorch-supported data type
// and has shape `(nnz())`.
//
// Since the main advantage of the CSR format over the COO format is speed of
// computation, care must be taken to facilitate smooth interfacing of
// these data structures with optimized libraries such as MKL and MAGMA.
// Since the MKL interface for pytorch currently uses indexing with int32
// type, it is important to make sure that the `crow_indices` and `col_indices`
// are of type int32 when calling MKL routines such as SPMM or SPMV.
//
// If not calling MKL, it should be alright to use 64 bit integer tensors
// for indexing.
struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
Tensor crow_indices_;
Tensor col_indices_;
Tensor values_;
Layout layout_;
public:
explicit SparseCsrTensorImpl(
at::DispatchKeySet,
at::Device device,
Layout layout,
const caffe2::TypeMeta);
void resize_(int64_t nnz, IntArrayRef size);
void resize_and_clear_(
int64_t sparse_dim,
int64_t dense_dim,
IntArrayRef size);
void resize_as_sparse_compressed_tensor_(const Tensor& src);
void set_member_tensors(
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values,
c10::SymIntArrayRef size);
void set_member_tensors(
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values,
IntArrayRef size);
const Tensor& compressed_indices() const {
return crow_indices_;
}
const Tensor& plain_indices() const {
return col_indices_;
}
const Tensor& values() const {
return values_;
}
int64_t nnz() {
return col_indices_.size(-1);
}
inline int64_t batch_dim() const noexcept {
return crow_indices_.dim() - 1;
}
inline int64_t sparse_dim() const noexcept {
return 2;
}
inline int64_t dense_dim() const noexcept {
return values_.dim() - batch_dim() - block_dim() - 1;
}
private:
inline int64_t block_dim() const noexcept {
return (layout_ == kSparseBsr || layout_ == kSparseBsc ? 2 : 0);
}
protected:
IntArrayRef strides_custom() const override;
SymIntArrayRef sym_strides_custom() const override;
bool is_contiguous_custom(MemoryFormat) const override;
public:
void set_size(int64_t dim, int64_t new_size) override;
void set_stride(int64_t dim, int64_t new_stride) override;
void set_storage_offset(int64_t storage_offset) override;
Layout layout_impl() const override {
return layout_;
}
void set_layout(Layout layout) {
switch (layout) {
case kSparseCsr:
case kSparseCsc:
case kSparseBsr:
case kSparseBsc:
layout_ = layout;
break;
default:
TORCH_CHECK(false, "unsupported layout ", layout);
}
}
template <typename VariableVersion>
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const {
const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
c10::impl::PyInterpreter&& interpreter = nullptr;
if (mode_stack_len > 0 &&
!c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
const auto& cur_torch_dispatch_mode_state =
c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
interpreter = cur_torch_dispatch_mode_state->pyinterpreter();
} else if (
key_set_.has(DispatchKey::Python) &&
!c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
interpreter = pyobj_slot_.load_pyobj_interpreter();
} else {
// otherwise just copy the SparseTensorImpl and not the PyObject.
auto impl = c10::make_intrusive<SparseCsrTensorImpl>(
key_set(), device(), layout_impl(), dtype());
copy_tensor_metadata(
/*src_sparse_impl=*/this,
/*dest_sparse_impl=*/impl.get(),
/*version_counter=*/version_counter,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
impl->refresh_numel();
return impl;
}
auto r = interpreter->detach(this);
r->set_version_counter(std::forward<VariableVersion>(version_counter));
r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
return r;
}
/**
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
*
* For usage of `version_counter` and `allow_tensor_metadata_change`,
* see NOTE [ TensorImpl Shallow-Copying ].
*/
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const override {
return shallow_copy_and_detach_core(
version_counter, allow_tensor_metadata_change);
}
/**
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
*
* For usage of `version_counter` and `allow_tensor_metadata_change`,
* see NOTE [ TensorImpl Shallow-Copying ].
*/
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const override {
return shallow_copy_and_detach_core(
std::move(version_counter), allow_tensor_metadata_change);
}
private:
explicit SparseCsrTensorImpl(
at::DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
at::Tensor crow_indices,
at::Tensor col_indices,
at::Tensor values,
at::Layout layout);
const char* tensorimpl_type_name() const override;
/**
* Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
* storage_offset) from one TensorImpl to another TensorImpl.
*
* For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
* [ TensorImpl Shallow-Copying ].
*/
static void copy_tensor_metadata(
const SparseCsrTensorImpl* src_sparse_impl,
SparseCsrTensorImpl* dest_sparse_impl,
c10::VariableVersion version_counter,
bool allow_tensor_metadata_change) {
TensorImpl::copy_tensor_metadata(
src_sparse_impl,
dest_sparse_impl,
std::move(version_counter),
allow_tensor_metadata_change);
// Sparse-specific fields
dest_sparse_impl->crow_indices_ = src_sparse_impl->compressed_indices();
dest_sparse_impl->col_indices_ = src_sparse_impl->plain_indices();
dest_sparse_impl->values_ = src_sparse_impl->values();
dest_sparse_impl->layout_ = src_sparse_impl->layout_impl();
}
};
} // namespace at

View File

@ -0,0 +1,441 @@
#pragma once
#include <ATen/SparseCsrTensorImpl.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/core/Tensor.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Operators.h>
#else
#include <ATen/ops/_sparse_compressed_tensor_unsafe.h>
#include <ATen/ops/resize_as_sparse_native.h>
#endif
#define AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(LAYOUT, NAME, ...) \
[&] { \
const auto& the_layout = LAYOUT; \
switch (the_layout) { \
case kSparseCsr: \
case kSparseCsc: \
case kSparseBsr: \
case kSparseBsc: \
return __VA_ARGS__(); \
default: \
AT_ERROR( \
NAME, \
" expected sparse compressed tensor layout but got ", \
the_layout); \
} \
}()
#define AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( \
LAYOUT, NAME, ROW_DIM_ACTION, COLUMN_DIM_ACTION) \
[&]() { \
const auto& the_layout = LAYOUT; \
switch (the_layout) { \
case kSparseCsr: \
case kSparseBsr: \
return (ROW_DIM_ACTION)(); \
case kSparseCsc: \
case kSparseBsc: \
return (COLUMN_DIM_ACTION)(); \
default: \
AT_ERROR( \
NAME, \
" expected sparse compressed tensor layout but got ", \
the_layout); \
} \
}()
#define AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( \
LAYOUT, NAME, NO_BLOCK_ACTION, BLOCK_ACTION) \
[&]() { \
const auto& the_layout = LAYOUT; \
switch (the_layout) { \
case kSparseCsr: \
case kSparseCsc: \
return (NO_BLOCK_ACTION)(); \
case kSparseBsr: \
case kSparseBsc: \
return (BLOCK_ACTION)(); \
default: \
AT_ERROR( \
NAME, \
" expected sparse compressed tensor layout but got ", \
the_layout); \
} \
}()
#define AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS( \
LAYOUT, NAME, ROW_DIM_ACTION) \
[&]() { \
const auto& the_layout = LAYOUT; \
switch (the_layout) { \
case kSparseCsr: \
case kSparseBsr: \
return (ROW_DIM_ACTION)(); \
default: \
AT_ERROR( \
NAME, \
" expected sparse row compressed tensor layout but got ", \
the_layout); \
} \
}()
#define AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS( \
LAYOUT, NAME, COL_DIM_ACTION) \
[&]() { \
const auto& the_layout = LAYOUT; \
switch (the_layout) { \
case kSparseCsc: \
case kSparseBsc: \
return (COL_DIM_ACTION)(); \
default: \
AT_ERROR( \
NAME, \
" expected sparse column compressed tensor layout but got ", \
the_layout); \
} \
}()
#define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
[&]() { \
const auto& the_layout = LAYOUT; \
switch (the_layout) { \
case kSparseCsr: \
case kSparseCsc: \
return (ACTION)(); \
default: \
AT_ERROR( \
NAME, \
" expected sparse compressed (non-block) tensor layout but got ", \
the_layout); \
} \
}()
#define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
[&]() { \
const auto& the_layout = LAYOUT; \
switch (the_layout) { \
case kSparseBsr: \
case kSparseBsc: \
return (ACTION)(); \
default: \
AT_ERROR( \
NAME, \
" expected sparse compressed block tensor layout but got ", \
the_layout); \
} \
}()
#define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__))
namespace at::sparse_csr {
// Implements RAII object to manage checking sparse tensor invariants:
class CheckSparseTensorInvariants {
bool old_state;
public:
CheckSparseTensorInvariants(bool state) {
old_state = at::globalContext().checkSparseTensorInvariants();
at::globalContext().setCheckSparseTensorInvariants(state);
}
~CheckSparseTensorInvariants() {
at::globalContext().setCheckSparseTensorInvariants(old_state);
}
};
using SparseCsrTensor = Tensor;
inline bool is_sparse_compressed(const Layout& layout) {
switch (layout) {
case kSparseCsr:
case kSparseCsc:
case kSparseBsr:
case kSparseBsc:
return true;
default:;
}
return false;
}
inline bool is_sparse_compressed(const Tensor& self) {
return is_sparse_compressed(self.layout());
}
inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) {
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
self.layout(), "get_sparse_csr_impl", [&] {});
return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl());
}
inline std::string layoutToString(
Layout layout,
bool upper = false,
bool lower = false) {
switch (layout) {
case kSparseCsr:
return (upper ? "CSR" : (lower ? "csr" : "Csr"));
case kSparseCsc:
return (upper ? "CSC" : (lower ? "csc" : "Csc"));
case kSparseBsr:
return (upper ? "BSR" : (lower ? "bsr" : "Bsr"));
case kSparseBsc:
return (upper ? "BSC" : (lower ? "bsc" : "Bsc"));
default:
TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
return "";
}
}
inline bool isCompressedRow(Layout layout) {
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
layout, "isCompressedRow", [&] { return true; }, [&] { return false; });
}
inline bool isCompressedColumn(Layout layout) {
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
layout,
"isCompressedColumn",
[&] { return false; },
[&] { return true; });
}
inline std::string compressedIndicesName(Layout layout) {
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
layout,
"compressedIndicesName",
[&] { return "crow_indices"; },
[&] { return "ccol_indices"; });
}
inline std::string plainIndicesName(Layout layout) {
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
layout,
"plainIndicesName",
[&] { return "col_indices"; },
[&] { return "row_indices"; });
}
inline std::string compressedDimName(Layout layout) {
switch (layout) {
case kSparseCsr:
return "row";
case kSparseCsc:
return "column";
case kSparseBsr:
return "row block";
case kSparseBsc:
return "column block";
default:
TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
return "";
}
}
inline std::string plainDimName(Layout layout) {
switch (layout) {
case kSparseCsr:
return "column";
case kSparseCsc:
return "row";
case kSparseBsr:
return "column block";
case kSparseBsc:
return "row block";
default:
TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
return "";
}
}
inline size_t rowDimension(Layout layout, IntArrayRef size) {
return size.size() - (isCompressedRow(layout) ? 2 : 1);
}
inline size_t columnDimension(Layout layout, IntArrayRef size) {
return size.size() - (isCompressedColumn(layout) ? 2 : 1);
}
inline size_t compressedDimension(
Layout layout,
IntArrayRef size,
size_t dense_ndim = 0) {
return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1);
}
inline size_t plainDimension(
Layout layout,
IntArrayRef size,
size_t dense_ndim = 0) {
return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2);
}
inline int64_t numBatchDimensions(Tensor const& self) {
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
self.layout(),
"numBatchDimensions",
[&self] { return self.crow_indices().dim() - 1; },
[&self] { return self.ccol_indices().dim() - 1; });
}
inline std::pair<Tensor, Tensor> getCompressedPlainIndices(Tensor const& self) {
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
self.layout(),
"getCompressedPlainIndices",
[&self] {
return std::make_pair(self.crow_indices(), self.col_indices());
},
[&self] {
return std::make_pair(self.ccol_indices(), self.row_indices());
});
}
inline ScalarType getIndexDtype(Tensor const& self) {
switch (self.layout()) {
case kSparseCsr:
case kSparseBsr:
return self.crow_indices().scalar_type();
case kSparseCsc:
case kSparseBsc:
return self.ccol_indices().scalar_type();
case kSparse:
return self._indices().scalar_type();
default:
return ScalarType::Long;
}
}
inline Layout flip_compressed_layout(Layout layout) {
switch (layout) {
case kSparseCsr:
return kSparseCsc;
case kSparseCsc:
return kSparseCsr;
case kSparseBsr:
return kSparseBsc;
case kSparseBsc:
return kSparseBsr;
default:
TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
return kSparseCsr;
}
}
inline DimVector getBlockSize(Tensor const& self) {
int64_t n_batch = numBatchDimensions(self);
return at::DimVector(self.values().sizes().slice(n_batch + 1, 2));
}
inline at::OptionalArray<at::SymInt> getSymIntBlockSize(Tensor const& self) {
if (self.layout() == at::kSparseBsr || self.layout() == at::kSparseBsc) {
int64_t n_batch = numBatchDimensions(self);
return self.values().sym_sizes().slice(n_batch + 1, 2).vec();
} else {
return {};
}
}
template <typename binary_op_t, typename binary_op_out_t>
inline bool only_sparse_compressed_binary_op_trivial_cases(
const Tensor& self,
const Tensor& other,
const Scalar& alpha,
Tensor& out,
const binary_op_t& binary_op,
const binary_op_out_t& binary_op_out) {
// Only sparse compressed! Just like the name says :)
TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(self));
TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(other));
TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(out));
// Bypass BLAS if there are matches in (self, other, out)
if (self.is_same(out) && self.is_same(other)) {
binary_op_out(self.values(), other.values(), alpha);
return true;
}
if (self.is_same(other)) {
auto [compressed_indices, plain_indices] =
at::sparse_csr::getCompressedPlainIndices(self);
static_cast<SparseCsrTensorImpl*>(out.unsafeGetTensorImpl())
->set_member_tensors(
compressed_indices,
plain_indices,
binary_op(self.values(), other.values(), alpha),
self.sizes());
return true;
}
return false;
}
inline bool only_sparse_compressed_add_trivial_cases(
const Tensor& self,
const Tensor& other,
const Scalar& alpha,
Tensor& out) {
return only_sparse_compressed_binary_op_trivial_cases(
self,
other,
alpha,
out,
[](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
return v1.add(v2, alpha);
},
[](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
return v1.add_(v2, alpha);
});
}
inline Tensor to_type(const Tensor& input, ScalarType dtype) {
auto [compressed_indices, plain_indices] =
at::sparse_csr::getCompressedPlainIndices(input);
return at::_sparse_compressed_tensor_unsafe(
compressed_indices,
plain_indices,
std::move(input.values()).to(dtype),
input.sizes(),
dtype,
input.layout(),
input.device(),
input.options().pinned_memory_opt());
}
template <typename acc_t, typename scalar_t>
inline std::tuple<Tensor, Tensor> create_acc_buffer(
TensorOptions option,
ScalarType type,
int64_t nnz = -1) {
Tensor new_values, new_values_acc;
constexpr bool need_acc = !std::is_same_v<scalar_t, acc_t>;
bool is_integral = at::isIntegralType(type, /*includeBool=*/true);
if constexpr (need_acc) {
auto acc_dtype = CppTypeToScalarType<acc_t>::value;
new_values_acc = at::empty({}, option.dtype(acc_dtype));
new_values = is_integral ? new_values_acc : at::empty({}, option);
} else {
new_values = new_values_acc = at::empty({}, option);
}
if (nnz != -1) {
return std::make_tuple(
new_values.resize_(nnz), new_values_acc.resize_(nnz));
} else {
return std::make_tuple(new_values, new_values_acc);
}
}
inline void copy_from_acc_buffer(Tensor& new_values, Tensor& new_values_acc) {
if (!new_values_acc.is_same(new_values)) {
new_values.copy_(new_values_acc);
}
}
} // namespace at::sparse_csr

View File

@ -0,0 +1,421 @@
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/impl/TorchDispatchModeTLS.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/resize.h>
#endif
namespace at {
struct TORCH_API SparseTensorImpl : public TensorImpl {
// Stored in COO format, indices + values.
// INVARIANTS:
// sparse_dim: range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
// dense_dim : range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
// _indices.shape: dimensionality: 2, shape: (sparse_dim, nnz)
// _values.shape: dimensionality: 1 + dense_dim. shape: (nnz,
// shape[sparse_dim:])
int64_t sparse_dim_ = 0; // number of sparse dimensions
int64_t dense_dim_ = 0; // number of dense dimensions
Tensor indices_; // always a LongTensor
Tensor values_;
// A sparse tensor is 'coalesced' if every index occurs at most once in
// the indices tensor, and the indices are in sorted order. (This means
// that it is very easy to convert a coalesced tensor to CSR format: you
// need only compute CSR format indices.)
//
// Most math operations can only be performed on coalesced sparse tensors,
// because many algorithms proceed by merging two sorted lists (of indices).
bool coalesced_ = false;
// compute_numel with integer multiplication overflow check, see gh-57542
void refresh_numel() {
TensorImpl::safe_refresh_numel();
}
public:
// Public for now...
explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta);
void release_resources() override;
int64_t nnz() const {
return values_.size(0);
}
c10::SymInt sym_nnz() const {
return values_.sym_size(0);
}
int64_t sparse_dim() const {
return sparse_dim_;
}
int64_t dense_dim() const {
return dense_dim_;
}
bool coalesced() const {
return coalesced_;
}
Tensor indices() const {
return indices_;
}
Tensor values() const {
return values_;
}
void set_size(int64_t dim, int64_t new_size) override;
void set_stride(int64_t dim, int64_t new_stride) override;
void set_storage_offset(int64_t storage_offset) override;
#ifdef DEBUG
bool has_storage() const override;
#endif
// WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim
// with respect to indices and values
void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) {
TORCH_CHECK(
allow_tensor_metadata_change(),
"raw_resize_ ",
err_msg_tensor_metadata_change_not_allowed);
TORCH_CHECK(
!has_symbolic_sizes_strides_,
"raw_resize_ called on tensor with symbolic shape")
set_sizes_and_strides(size, std::vector<int64_t>(size.size()));
sparse_dim_ = sparse_dim;
dense_dim_ = dense_dim;
refresh_numel();
}
// NOTE: This function preserves invariants of sparse_dim/dense_dim with
// respect to indices and values.
//
// NOTE: This function supports the following cases:
// 1. When we keep the number of dense dimensions unchanged, and NOT shrinking
// the size of any of the dense dimensions.
// 2. When we keep the number of sparse dimensions unchanged, and NOT
// shrinking the size of any of the sparse dimensions.
// 3. When the sparse tensor has zero nnz, in which case we are free to change
// the shapes of both its sparse and dense dimensions.
//
// This function DOESN'T support (and will throw an error) the following
// cases:
// 1. When we attempt to change the number of sparse dimensions on a non-empty
// sparse tensor (such an operation will invalidate the indices stored).
// 2. When we attempt to change the number of dense dimensions on a non-empty
// sparse tensor (such an operation will behave differently from an equivalent
// dense tensor's resize method, and for API consistency we don't support it).
// 3. When we attempt to shrink the size of any of the dense dimensions on a
// non-empty sparse tensor (such an operation will behave differently from an
// equivalent dense tensor's resize method, and for API consistency we don't
// support it).
// 4. When we attempt to shrink the size of any of the sparse dimensions on a
// non-empty sparse tensor (this could make some of the stored indices
// out-of-bound and thus unsafe).
template <typename T>
void _resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<T> size) {
TORCH_CHECK(
allow_tensor_metadata_change(),
"resize_ ",
err_msg_tensor_metadata_change_not_allowed);
TORCH_CHECK(
!has_symbolic_sizes_strides_,
"resize_ called on tensor with symbolic shape")
TORCH_CHECK(
sparse_dim + dense_dim == static_cast<int64_t>(size.size()),
"number of dimensions must be sparse_dim (",
sparse_dim,
") + dense_dim (",
dense_dim,
"), but got ",
size.size());
if (nnz() > 0) {
[[maybe_unused]] auto constexpr alt_options_msg =
"You could try the following options:\n\
1. If you need an empty sparse tensor of this size, call `x = torch.sparse_coo_tensor(size)`.\n\
2. If you need to resize this tensor, you have the following options:\n\
1. For both sparse and dense dimensions, keep the number of them constant and the size of them non-shrinking, and then try the same call again.\n\
2. Or, create a new sparse tensor with the correct indices and values from this sparse tensor.";
TORCH_CHECK(
sparse_dim == sparse_dim_,
"changing the number of sparse dimensions (from ",
sparse_dim_,
" to ",
sparse_dim,
") on a non-empty sparse tensor is not supported.\n",
alt_options_msg);
TORCH_CHECK(
dense_dim == dense_dim_,
"changing the number of dense dimensions (from ",
dense_dim_,
" to ",
dense_dim,
") on a non-empty sparse tensor is not supported.\n",
alt_options_msg);
bool shrinking_sparse_dims = false;
bool shrinking_dense_dim = false;
auto sparse_size_original = generic_sizes<T>().slice(0, sparse_dim);
auto sparse_size_new = size.slice(0, sparse_dim);
for (const auto i : c10::irange(sparse_dim)) {
if (sparse_size_new[i] < sparse_size_original[i]) {
shrinking_sparse_dims = true;
break;
}
}
auto dense_size_original = generic_sizes<T>().slice(sparse_dim);
auto dense_size_new = size.slice(sparse_dim);
for (const auto i : c10::irange(dense_dim)) {
if (dense_size_new[i] < dense_size_original[i]) {
shrinking_dense_dim = true;
break;
}
}
TORCH_CHECK(
!shrinking_sparse_dims,
"shrinking the size of sparse dimensions (from ",
sparse_size_original,
" to ",
sparse_size_new,
") on a non-empty sparse tensor is not supported.\n",
alt_options_msg);
TORCH_CHECK(
!shrinking_dense_dim,
"shrinking the size of dense dimensions (from ",
dense_size_original,
" to ",
dense_size_new,
") on a non-empty sparse tensor is not supported.\n",
alt_options_msg);
}
auto sizes_and_strides = generic_sizes<T>();
const bool size_equals_sizes = std::equal(
size.begin(),
size.end(),
sizes_and_strides.begin(),
sizes_and_strides.end());
if ((!size_equals_sizes) || (sparse_dim != sparse_dim_) ||
(dense_dim != dense_dim_)) {
auto nnz = at::symint::sizes<T>(values())[0];
std::vector<T> values_size = {nnz};
auto dense_size = size.slice(sparse_dim);
values_size.insert(
values_size.end(), dense_size.begin(), dense_size.end());
at::symint::resize_<T>(values_, values_size);
at::symint::resize_<T>(indices_, {T(sparse_dim), nnz});
}
if (!size_equals_sizes) {
set_sizes_and_strides(size, std::vector<T>(size.size()));
}
sparse_dim_ = sparse_dim;
dense_dim_ = dense_dim;
refresh_numel();
}
void resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<int64_t> size) {
return _resize_(sparse_dim, dense_dim, size);
}
void resize_(
int64_t sparse_dim,
int64_t dense_dim,
ArrayRef<c10::SymInt> size) {
return _resize_(sparse_dim, dense_dim, size);
}
// NOTE: this function will resize the sparse tensor and also set `indices`
// and `values` to empty.
void resize_and_clear_(
int64_t sparse_dim,
int64_t dense_dim,
IntArrayRef size) {
TORCH_CHECK(
allow_tensor_metadata_change(),
"resize_and_clear_ ",
err_msg_tensor_metadata_change_not_allowed);
TORCH_CHECK(
!has_symbolic_sizes_strides_,
"resize_and_clear_ called on tensor with symbolic shape")
TORCH_CHECK(
sparse_dim + dense_dim == static_cast<int64_t>(size.size()),
"number of dimensions must be sparse_dim (",
sparse_dim,
") + dense_dim (",
dense_dim,
"), but got ",
size.size());
set_sizes_and_strides(size, std::vector<int64_t>(size.size()));
sparse_dim_ = sparse_dim;
dense_dim_ = dense_dim;
auto empty_indices = at::empty({sparse_dim, 0}, indices().options());
std::vector<int64_t> values_size = {0};
auto dense_size = sizes().slice(sparse_dim);
values_size.insert(values_size.end(), dense_size.begin(), dense_size.end());
auto empty_values = at::empty(values_size, values().options());
set_indices_and_values_unsafe(empty_indices, empty_values);
refresh_numel();
}
void set_coalesced(bool coalesced) {
TORCH_CHECK(
allow_tensor_metadata_change(),
"set_coalesced ",
err_msg_tensor_metadata_change_not_allowed);
coalesced_ = coalesced;
}
// NOTE: this function is only used internally and not exposed to Python
// frontend
void set_nnz_and_narrow(int64_t new_nnz) {
TORCH_CHECK(
allow_tensor_metadata_change(),
"set_nnz_and_narrow ",
err_msg_tensor_metadata_change_not_allowed);
AT_ASSERT(new_nnz <= nnz());
indices_ = indices_.narrow(1, 0, new_nnz);
values_ = values_.narrow(0, 0, new_nnz);
if (new_nnz < 2) {
coalesced_ = true;
}
}
// Takes indices and values and directly puts them into the sparse tensor, no
// copy. NOTE: this function is unsafe because it doesn't check whether any
// indices are out of boundaries of `sizes`, so it should ONLY be used where
// we know that the indices are guaranteed to be within bounds. This used to
// be called THSTensor_(_move) NB: This used to be able to avoid a refcount
// bump, but I was too lazy to make it happen
void set_indices_and_values_unsafe(
const Tensor& indices,
const Tensor& values);
template <typename VariableVersion>
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const {
const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
c10::impl::PyInterpreter&& interpreter = nullptr;
if (mode_stack_len > 0 &&
!c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
const auto& cur_torch_dispatch_mode_state =
c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
interpreter = cur_torch_dispatch_mode_state->pyinterpreter();
} else if (
key_set_.has(DispatchKey::Python) &&
!c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
interpreter = pyobj_slot_.load_pyobj_interpreter();
} else {
// otherwise just copy the SparseTensorImpl and not the PyObject.
auto impl = c10::make_intrusive<SparseTensorImpl>(key_set(), dtype());
copy_tensor_metadata(
/*src_sparse_impl=*/this,
/*dest_sparse_impl=*/impl.get(),
/*version_counter=*/version_counter,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
impl->refresh_numel();
return impl;
}
auto r = interpreter->detach(this);
r->set_version_counter(std::forward<VariableVersion>(version_counter));
r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
return r;
}
/**
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
*
* For usage of `version_counter` and `allow_tensor_metadata_change`,
* see NOTE [ TensorImpl Shallow-Copying ].
*/
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const override {
return shallow_copy_and_detach_core(
version_counter, allow_tensor_metadata_change);
}
/**
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
*
* For usage of `version_counter` and `allow_tensor_metadata_change`,
* see NOTE [ TensorImpl Shallow-Copying ].
*/
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const override {
return shallow_copy_and_detach_core(
std::move(version_counter), allow_tensor_metadata_change);
}
/**
* Shallow-copies data from another TensorImpl into this TensorImpl.
*
* For why this function doesn't check this TensorImpl's
* `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
*/
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
auto sparse_impl = static_cast<const SparseTensorImpl*>(impl.get());
copy_tensor_metadata(
/*src_sparse_impl=*/sparse_impl,
/*dest_sparse_impl=*/this,
/*version_counter=*/version_counter(),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
refresh_numel();
}
private:
explicit SparseTensorImpl(
at::DispatchKeySet,
const caffe2::TypeMeta,
at::Tensor indices,
at::Tensor values);
/**
* Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
* storage_offset) from one TensorImpl to another TensorImpl.
*
* For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
* [ TensorImpl Shallow-Copying ].
*/
static void copy_tensor_metadata(
const SparseTensorImpl* src_sparse_impl,
SparseTensorImpl* dest_sparse_impl,
c10::VariableVersion version_counter,
bool allow_tensor_metadata_change) {
TensorImpl::copy_tensor_metadata(
src_sparse_impl,
dest_sparse_impl,
std::move(version_counter),
allow_tensor_metadata_change);
// Sparse-specific fields
dest_sparse_impl->sparse_dim_ = src_sparse_impl->sparse_dim();
dest_sparse_impl->dense_dim_ = src_sparse_impl->dense_dim();
dest_sparse_impl->indices_ = src_sparse_impl->indices();
dest_sparse_impl->values_ = src_sparse_impl->values();
dest_sparse_impl->coalesced_ = src_sparse_impl->coalesced();
}
const char* tensorimpl_type_name() const override;
};
} // namespace at

View File

@ -0,0 +1,2 @@
#pragma once
#include <c10/core/Storage.h>

View File

@ -0,0 +1,49 @@
#pragma once
#include <c10/core/Storage.h>
#include <c10/core/StorageImpl.h>
#include <c10/util/intrusive_ptr.h>
namespace at {
class TensorBase;
// Here we define a series of utils to create/manipulate ATen backed
// c10 storage implementations.
/**
* Create a new shared memory storage impl managed by file descriptor
*
* @param size size in bytes
*/
C10_EXPORT c10::intrusive_ptr<c10::StorageImpl> new_shm_fd_storage(size_t size);
/**
* Copy src to dst
* Caller must guarantee the validness of the storage objects
* during the entire copy process, esp. when it's async.
*
* This can probably live in c10 namespace later if needed,
* but for now keep it in at to keep implementation simple.
*
* @param dst dst tensor
* @param src src tensor
* @param non_blocking (default false) whether this operation blocks caller
*/
C10_EXPORT void storage_copy(
c10::Storage& dst,
const c10::Storage& src,
bool non_blocking = false);
/**
* In place change the storage to shm based.
*
* This is only applicable to CPU tensors not already shared.
* Otherwise, it's a no op to mirror the THP tensor behavior:
* https://pytorch.org/docs/stable/generated/torch.Tensor.share_memory_.html
*
* @param t a tensor
*/
C10_EXPORT void share_memory_(TensorBase& t);
} // namespace at

View File

@ -0,0 +1,3 @@
#pragma once
#include <ATen/core/Tensor.h>

View File

@ -0,0 +1,2 @@
#pragma once
#include <ATen/core/TensorAccessor.h>

View File

@ -0,0 +1,144 @@
#pragma once
#include <ATen/core/TensorBase.h>
#include <c10/core/WrapDimMinimal.h>
namespace at {
// Return if the tensor geometry represented by `sizes` and `strides` is
// contiguous Although we cache is_contiguous in tensor now, this is till useful
// because it allows checking if a particular geometry is contiguous without
// explicitly constructing a tensor, e.g., when you want to choose a kernel
// strategy based on whether a subgeometry is contiguous.
TORCH_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides);
struct TORCH_API TensorGeometry {
TensorGeometry() = default;
explicit TensorGeometry(c10::SymIntArrayRef sizes)
: sizes_(sizes.vec()),
strides_(sizes.size()),
has_symbolic_sizes_strides_(
!c10::asIntArrayRefSlowOpt(sizes).has_value()) {
int64_t dim = static_cast<int64_t>(sizes.size());
c10::SymInt expected_stride = 1;
for (int64_t i = dim - 1; i >= 0; i--) {
strides_[i] = expected_stride;
expected_stride *= sizes_[i];
}
numel_ = expected_stride;
}
explicit TensorGeometry(const TensorBase& t)
: sizes_(t.sym_sizes().vec()),
strides_(t.sym_strides().vec()),
storage_offset_(t.sym_storage_offset()),
numel_(t.sym_numel()),
has_symbolic_sizes_strides_(
t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {}
// true if the tensor is contiguous
bool is_contiguous() const;
int64_t dim() const {
return static_cast<int64_t>(sizes_.size());
}
int64_t size(int64_t dim) const {
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
dim = c10::maybe_wrap_dim(dim, this->dim());
return sizes_.at(static_cast<size_t>(dim)).as_int_unchecked();
}
c10::IntArrayRef sizes() const {
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
return c10::asIntArrayRefUnchecked(sizes_);
}
int64_t stride(int64_t dim) const {
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
dim = c10::maybe_wrap_dim(dim, this->dim());
return strides_.at(static_cast<size_t>(dim)).as_int_unchecked();
}
c10::IntArrayRef strides() const {
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
return c10::asIntArrayRefUnchecked(strides_);
}
int64_t storage_offset() const {
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
return storage_offset_.as_int_unchecked();
}
int64_t numel() const {
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
return numel_.as_int_unchecked();
}
c10::SymInt sym_size(int64_t dim) const {
dim = c10::maybe_wrap_dim(dim, this->dim());
return sizes_.at(static_cast<size_t>(dim));
}
c10::SymIntArrayRef sym_sizes() const {
return sizes_;
}
c10::SymInt sym_stride(int64_t dim) const {
dim = c10::maybe_wrap_dim(dim, this->dim());
return strides_.at(static_cast<size_t>(dim));
}
c10::SymIntArrayRef sym_strides() const {
return strides_;
}
c10::SymInt sym_storage_offset() const {
return storage_offset_;
}
c10::SymInt sym_numel() const {
return numel_;
}
TensorGeometry transpose(int64_t dim0, int64_t dim1) {
TensorGeometry r = *this; // copy
TORCH_CHECK(
dim0 < dim(),
"transpose: dim0=",
dim0,
" out of range (dim=",
dim(),
")")
TORCH_CHECK(
dim1 < dim(),
"transpose: dim1=",
dim1,
" out of range (dim=",
dim(),
")")
std::swap(r.sizes_[dim0], r.sizes_[dim1]);
std::swap(r.strides_[dim0], r.strides_[dim1]);
return r;
}
std::vector<c10::SymInt>& mutable_sizes() {
return sizes_;
}
std::vector<c10::SymInt>& mutable_strides() {
return strides_;
}
c10::SymInt& mutable_storage_offset() {
return storage_offset_;
}
void recompute() {
// recalculate numel after a change
c10::SymInt numel = 1;
for (const auto& i : sizes_) {
numel = numel * i;
}
numel_ = std::move(numel);
has_symbolic_sizes_strides_ =
!c10::asIntArrayRefSlowOpt(sizes_).has_value();
}
private:
std::vector<c10::SymInt> sizes_;
std::vector<c10::SymInt> strides_;
c10::SymInt storage_offset_;
c10::SymInt numel_;
bool has_symbolic_sizes_strides_{false};
};
} // namespace at

View File

@ -0,0 +1,737 @@
#pragma once
#include <ATen/ExpandUtils.h>
#include <ATen/ScalarOps.h>
#include <ATen/core/Tensor.h>
#include <ATen/core/TensorBody.h>
#include <c10/core/SymInt.h>
#include <c10/util/irange.h>
#include <optional>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/alias.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/scalar_tensor.h>
#include <ATen/ops/zeros.h>
#endif
#include <ATen/core/List.h>
#include <utility>
namespace at::indexing {
constexpr int64_t INDEX_MIN = c10::SymInt::min_representable_int();
constexpr int64_t INDEX_MAX = -(INDEX_MIN + 1);
enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor };
constexpr std::nullopt_t None = std::nullopt;
struct TORCH_API EllipsisIndexType final {
EllipsisIndexType() = default;
};
TORCH_API extern const EllipsisIndexType Ellipsis;
struct TORCH_API Slice final {
public:
Slice(
std::optional<c10::SymInt> start_index = std::nullopt,
std::optional<c10::SymInt> stop_index = std::nullopt,
std::optional<c10::SymInt> step_index = std::nullopt) {
if (!step_index.has_value()) {
step_ = c10::SymInt(1);
} else {
step_ = std::move(step_index).value();
}
TORCH_CHECK_VALUE(
step_.sym_ne(0).expect_true(__FILE__, __LINE__),
"slice step cannot be zero");
if (!start_index.has_value()) {
start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
} else {
start_ = std::move(start_index).value();
}
if (!stop_index.has_value()) {
stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
} else {
stop_ = std::move(stop_index).value();
}
}
inline c10::SymInt start() const {
return start_;
}
inline c10::SymInt stop() const {
return stop_;
}
inline c10::SymInt step() const {
return step_;
}
private:
c10::SymInt start_;
c10::SymInt stop_;
c10::SymInt step_;
};
TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
// `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
// `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`
// into its equivalent `std::vector<TensorIndex>`, so that further tensor
// indexing operations can be performed using the supplied indices.
//
// There is one-to-one correspondence between Python and C++ tensor index types:
// Python | C++
// -----------------------------------------------------
// `None` | `at::indexing::None`
// `Ellipsis` | `at::indexing::Ellipsis`
// `...` | `"..."`
// `123` | `123`
// `True` / `False` | `true` / `false`
// `:` | `Slice()` / `Slice(None, None)`
// `::` | `Slice()` / `Slice(None, None, None)`
// `1:` | `Slice(1, None)`
// `1::` | `Slice(1, None, None)`
// `:3` | `Slice(None, 3)`
// `:3:` | `Slice(None, 3, None)`
// `::2` | `Slice(None, None, 2)`
// `1:3` | `Slice(1, 3)`
// `1::2` | `Slice(1, None, 2)`
// `:3:2` | `Slice(None, 3, 2)`
// `1:3:2` | `Slice(1, 3, 2)`
// `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
struct TORCH_API TensorIndex final {
// Case 1: `at::indexing::None`
TensorIndex(std::nullopt_t) : type_(TensorIndexType::None) {}
// Case 2: "..." / `at::indexing::Ellipsis`
TensorIndex(at::indexing::EllipsisIndexType)
: type_(TensorIndexType::Ellipsis) {}
TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
TORCH_CHECK_VALUE(
strcmp(str, "...") == 0,
"Expected \"...\" to represent an ellipsis index, but got \"",
str,
"\"");
}
// Case 3: (Sym) Integer value
TensorIndex(SymInt integer)
: integer_(std::move(integer)), type_(TensorIndexType::SymInt) {}
TensorIndex(int64_t integer) : TensorIndex(SymInt(integer)) {}
TensorIndex(int integer) : TensorIndex(SymInt(integer)) {}
// Case 4: Boolean value
template <class T, class = std::enable_if_t<std::is_same_v<bool, T>>>
TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
// Case 5: Slice represented in `at::indexing::Slice` form
TensorIndex(Slice slice)
: slice_(std::move(slice)), type_(TensorIndexType::Slice) {}
// Case 6: Tensor value
TensorIndex(Tensor tensor)
: tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}
inline bool is_none() const {
return type_ == TensorIndexType::None;
}
inline bool is_ellipsis() const {
return type_ == TensorIndexType::Ellipsis;
}
inline bool is_integer() const {
return type_ == TensorIndexType::SymInt;
}
inline SymInt integer() const {
return integer_;
}
inline bool is_boolean() const {
return type_ == TensorIndexType::Boolean;
}
inline bool boolean() const {
return boolean_;
}
inline bool is_slice() const {
return type_ == TensorIndexType::Slice;
}
inline const Slice& slice() const {
return slice_;
}
inline bool is_tensor() const {
return type_ == TensorIndexType::Tensor;
}
inline const Tensor& tensor() const {
return tensor_;
}
private:
SymInt integer_ = 0;
bool boolean_ = false;
Slice slice_;
Tensor tensor_;
TensorIndexType type_;
};
TORCH_API std::ostream& operator<<(
std::ostream& stream,
const TensorIndex& tensor_index);
TORCH_API std::ostream& operator<<(
std::ostream& stream,
const std::vector<TensorIndex>& tensor_indices);
namespace impl {
inline Tensor applySlice(
const Tensor& self,
int64_t dim,
c10::SymInt start,
c10::SymInt stop,
c10::SymInt step,
bool disable_slice_optimization,
const at::Device& self_device,
const std::optional<SymIntArrayRef>& self_sizes) {
// TODO: implement negative step
TORCH_CHECK_VALUE(
step.sym_gt(0).expect_true(__FILE__, __LINE__),
"step must be greater than zero");
// See NOTE [nested tensor size for indexing]
if (self_sizes.has_value()) {
// Skip this optimization if we are tracing, as the trace may be polymorphic
// over the shape of the `self` tensor, and we still want to record
// the slice.
SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
? (*self_sizes)[dim]
: self.sym_size(dim);
if (!disable_slice_optimization &&
TORCH_GUARD_SIZE_OBLIVIOUS(start.sym_eq(0)) &&
TORCH_GUARD_SIZE_OBLIVIOUS(length.sym_eq(stop)) && step == 1) {
return self;
}
}
return self.slice_symint(
dim, std::move(start), std::move(stop), std::move(step));
}
inline Tensor applySelect(
const Tensor& self,
int64_t dim,
SymInt index,
int64_t real_dim,
const at::Device& /*self_device*/,
const std::optional<SymIntArrayRef>& self_sizes) {
// See NOTE [nested tensor size for indexing]
if (self_sizes.has_value()) {
auto maybe_index = index.maybe_as_int();
if (maybe_index.has_value()) {
TORCH_CHECK_INDEX(
!(maybe_index.value() == 0 && dim == 0 && self_sizes->empty()),
"invalid index of a 0-dim tensor. ",
"Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
}
auto size = (*self_sizes)[dim];
// Note: `size >= -index` is not equivalent to `size > -1 - index` if index
// is INT64_MIN For std::numeric_limits<int64_t>::min() result of unary
// minus is undefined by the standard but in practice is equal to self. On
// the other hand, indexing wraping is valid for all negative int64_t
// values, as x[INT64_MIN] is the same as x[INT64_MAX]
TORCH_CHECK_INDEX(
size > -1 - index && size > index,
"index ",
index,
" is out of bounds for dimension ",
real_dim,
" with size ",
size);
}
// if the index is negative, do not normalize it because that would fix the
// index on the current tensor size in the tracer. aten::select also works on
// negative indices
return self.select_symint(dim, std::move(index));
}
inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) {
// booleans add a dimension of size 1. true indexes this dimension as if 0:,
// false as empty.
if (value) {
return at::empty({1}, self.options().dtype(kLong)).fill_(0.);
} else {
return at::empty({0}, self.options().dtype(kLong));
}
}
inline Tensor boolToIndexingTensorNonNativeDeviceType(
const Tensor& self,
bool value) {
// booleans add a dimension of size 1. true indexes this dimension as if 0:,
// false as empty.
if (value) {
return at::zeros({1}, self.options().dtype(kLong));
} else {
return at::empty({0}, self.options().dtype(kLong));
}
}
inline Tensor boolToIndexingTensor(
const Tensor& self,
bool value,
const at::Device& self_device) {
if (self_device == at::kCPU || self_device == at::kCUDA) {
return boolToIndexingTensorCPUOrCUDA(self, value);
} else {
return boolToIndexingTensorNonNativeDeviceType(self, value);
}
}
inline Tensor scalarToTensorNonNativeDeviceType(
const Scalar& v,
const TensorOptions& options) {
return at::scalar_tensor(v, options);
}
inline void recordTensorIndex(
const Tensor& tensor,
std::vector<Tensor>& outIndices,
int64_t* dim_ptr) {
// TODO: check scalarType
outIndices.resize(*dim_ptr + 1);
outIndices[*dim_ptr] = tensor;
(*dim_ptr)++;
};
inline c10::List<::std::optional<Tensor>> typeConvertIndices(
const Tensor& /*self*/,
std::vector<Tensor>&& indices) {
c10::List<::std::optional<Tensor>> converted_inds;
converted_inds.reserve(indices.size());
for (auto&& i : std::move(indices)) {
converted_inds.push_back(std::move(i));
}
return converted_inds;
}
// NOTE: Why do we mirror instead of replace the `count_specified_dimensions`
// function in torch/csrc/autograd/python_variable_indexing.cpp? It's because
// `count_specified_dimensions` is on the hot path of Python tensor multi-dim
// indexing (i.e. it's called by `applySlicing` which is called by
// `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more
// than one dimension). If we were to merge the Python/C++
// `count_specified_dimensions` function, on the Python side we would have to
// construct a `std::vector` container to be consumed by the C++
// `count_specified_dimensions` function, which adds 100s of nanoseconds
// overhead and is undesirable.
inline int64_t count_specified_dimensions(
const ArrayRef<TensorIndex>& indices) {
// Count the number of indexed dimensions (everything but ellipsis and None)
int64_t count = 0;
for (auto& obj : indices) {
if (obj.is_tensor()) {
auto& tensor = obj.tensor();
if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
count += tensor.dim();
} else {
count++;
}
} else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {
count++;
}
}
return count;
}
} // namespace impl
// NOTE: Many functions below are only for consumption from Python indexing
// implementation, they include:
//
// - `Tensor scalarToTensor(...)`
// - `IntArrayRef slicePrefix1sSize(...)`
// - `void copy_to(...)`
// - `Tensor handleDimInMultiDimIndexing(...)`
// - `Tensor dispatch_index(...)`
// - `Tensor dispatch_index_put_(...)`
// - `Tensor get_item(...)`
// - `void set_item(...)`
//
// The rest of the functions are in `at::indexing::impl` namespace, signifying
// that they shouldn't be used from Python indexing implementation.
inline Tensor scalarToTensor(
const Scalar& v,
const TensorOptions& options,
const at::Device& self_device) {
if (self_device == at::kCPU && !v.isSymbolic()) {
return at::detail::scalar_tensor_static(
v, options.dtype_opt()->toScalarType(), self_device);
} else {
return impl::scalarToTensorNonNativeDeviceType(v, options);
}
}
// To match numpy semantics:
// As a special case for backwards compatibility,
// strip away unit dimensions from the left of 'src'
inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
size_t first_non1_src = sizes.size();
for (const auto i : c10::irange(sizes.size())) {
// Unbacked SymInt has different behavior, but this is sound because
// failing to slice will only ever cause an error, not divergent
// behavior
if (!sizes[i].has_hint() || sizes[i] != 1) {
first_non1_src = i;
break;
}
}
return sizes.slice(first_non1_src);
}
inline void copy_to(const Tensor& dst, const Tensor& src) {
if (dst.sym_sizes().equals(src.sym_sizes())) {
// A shortcut to avoid generating hard-coded constant sizes during tracing.
// This is not a perfect solution: when src & dst have different shapes,
// constants will still appear. Users can workaround that case by
// dst[index..] = src.reshape(..)
dst.copy_(src);
return;
} else if (src.dim() == 0 && src.device().type() == at::kCPU) {
dst.fill_(src);
return;
}
auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes()));
c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");
dst.copy_(*b_src);
}
// See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
// indexing functions from Python ]
inline Tensor handleDimInMultiDimIndexing(
const Tensor& prev_dim_result,
const Tensor& original_tensor,
const TensorIndex& index,
int64_t* dim_ptr,
int64_t* specified_dims_ptr,
int64_t real_dim,
std::vector<Tensor>& outIndices,
bool disable_slice_optimization,
const at::Device& original_tensor_device,
const std::optional<SymIntArrayRef>& prev_dim_result_sizes) {
if (index.is_integer()) {
return impl::applySelect(
prev_dim_result,
*dim_ptr,
index.integer(),
real_dim,
original_tensor_device,
prev_dim_result_sizes);
} else if (index.is_slice()) {
Tensor result = impl::applySlice(
prev_dim_result,
*dim_ptr,
index.slice().start(),
index.slice().stop(),
index.slice().step(),
/*disable_slice_optimization=*/disable_slice_optimization,
original_tensor_device,
prev_dim_result_sizes);
(*dim_ptr)++;
return result;
} else if (index.is_ellipsis()) {
(*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr);
return prev_dim_result;
} else if (index.is_none()) {
Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
(*dim_ptr)++;
return result;
} else if (index.is_boolean()) {
Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
impl::recordTensorIndex(
impl::boolToIndexingTensor(
result, index.boolean(), original_tensor_device),
outIndices,
dim_ptr);
return result;
} else if (index.is_tensor()) {
Tensor result = prev_dim_result;
const Tensor& tensor = index.tensor();
auto scalar_type = tensor.scalar_type();
if (tensor.dim() == 0 &&
at::isIntegralType(scalar_type, /*includeBool=*/true)) {
if (scalar_type != at::kByte && scalar_type != at::kBool) {
result = impl::applySelect(
result,
*dim_ptr,
tensor.item<int64_t>(),
real_dim,
original_tensor_device,
prev_dim_result_sizes);
} else {
result = result.unsqueeze(*dim_ptr);
if (scalar_type == at::kBool) {
impl::recordTensorIndex(
impl::boolToIndexingTensor(
result, tensor.item<bool>() != 0, original_tensor_device),
outIndices,
dim_ptr);
} else {
impl::recordTensorIndex(
impl::boolToIndexingTensor(
result, tensor.item<uint8_t>() != 0, original_tensor_device),
outIndices,
dim_ptr);
}
}
} else {
impl::recordTensorIndex(tensor, outIndices, dim_ptr);
}
return result;
} else {
TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type");
}
}
namespace impl {
// This mirrors `applySlicing` in
// torch/csrc/autograd/python_variable_indexing.cpp
inline Tensor applySlicing(
const Tensor& self,
const ArrayRef<TensorIndex>& indices,
std::vector<Tensor>& outIndices,
bool disable_slice_optimization,
const at::Device& self_device,
const std::optional<SymIntArrayRef>& self_sizes) {
int64_t dim = 0;
int64_t specified_dims = impl::count_specified_dimensions(indices);
// See NOTE [nested tensor size for indexing]
if (self_sizes.has_value()) {
TORCH_CHECK_INDEX(
specified_dims <= (int64_t)self_sizes->size(),
"too many indices for tensor of dimension ",
(int)self_sizes->size());
}
Tensor result = self;
for (const auto i : c10::irange(indices.size())) {
auto& obj = indices[i];
// See NOTE [nested tensor size for indexing]
std::optional<SymIntArrayRef> result_sizes = result.is_nested()
? std::optional<SymIntArrayRef>(std::nullopt)
: std::optional<SymIntArrayRef>(result.sym_sizes());
result = handleDimInMultiDimIndexing(
/*prev_dim_result=*/result,
/*original_tensor=*/self,
/*index=*/obj,
/*dim_ptr=*/&dim,
/*specified_dims_ptr=*/&specified_dims,
/*real_dim=*/static_cast<int64_t>(i),
/*outIndices=*/outIndices,
/*disable_slice_optimization=*/disable_slice_optimization,
/*original_tensor_device=*/self_device,
/*prev_dim_result_sizes=*/result_sizes);
}
return result;
}
} // namespace impl
inline Tensor dispatch_index(
const Tensor& self,
std::vector<Tensor>&& indices) {
return self.index(impl::typeConvertIndices(self, std::move(indices)));
}
inline Tensor dispatch_index_put_(
Tensor& self,
std::vector<Tensor>&& indices,
const Tensor& value) {
return self.index_put_(
impl::typeConvertIndices(self, std::move(indices)), value);
}
// NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing
// functions from Python ]
//
// Question: When should we set `disable_slice_optimization` to `true` when
// calling C++ tensor indexing functions from Python indexing code?
//
// Answer: What "slice optimization" means: when we have a slicing expression
// like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we
// would skip dispatching the actual slice call as an optimization. However,
// here are the cases where we DON'T want this optimization:
//
// 1. When we are doing 1-D slicing (e.g. `tensor[:]`).
// Reason: we always return a shallow copy for expressions such as
// `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:,
// :]`, we return an alias of `tensor` by doing the following:
// ```
// Tensor sliced = impl::applySlicing(self, indices, tensorIndices,
// disable_slice_optimization, self_device, self_sizes); if
// (tensorIndices.empty()) {
// if (sliced.is_same(self)) {
// // ensure we return a shallow copy for things like x[...]
// sliced = at::alias(sliced);
// }
// return sliced;
// }
// ```)
// 2. When we are doing JIT tracing.
// Reason: JIT tracing needs the `self.slice(...)` call to properly trace the
// slice operation.
// This mirrors `THPVariable_getitem` in
// torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
// `disable_slice_optimization` when calling C++ tensor indexing functions from
// Python ]
inline Tensor get_item(
const Tensor& self,
const ArrayRef<TensorIndex>& indices,
bool disable_slice_optimization = false) {
at::Device self_device = self.device();
// NOTE [nested tensor size for indexing]
// nested tensor does not have a size (yet) so for now we represent its size
// as null may need to be changed after we reach a better solution for nested
// tensor size
std::optional<SymIntArrayRef> self_sizes = self.is_nested()
? std::optional<SymIntArrayRef>(std::nullopt)
: std::optional<SymIntArrayRef>(self.sym_sizes());
// handle simple types: integers, slices, none, ellipsis, bool
if (indices.size() == 1) {
const TensorIndex& index = indices[0];
if (index.is_integer()) {
return impl::applySelect(
self, 0, index.integer(), 0, self_device, self_sizes);
} else if (index.is_slice()) {
return impl::applySlice(
self,
0,
index.slice().start(),
index.slice().stop(),
index.slice().step(),
/*disable_slice_optimization=*/true,
self_device,
self_sizes);
} else if (index.is_none()) {
return self.unsqueeze(0);
} else if (index.is_ellipsis()) {
return at::alias(self);
} else if (index.is_boolean()) {
Tensor result = self.unsqueeze(0);
return dispatch_index(
result,
std::vector<Tensor>{impl::boolToIndexingTensor(
result, index.boolean(), self_device)});
}
}
std::vector<Tensor> tensorIndices;
Tensor sliced = impl::applySlicing(
self,
indices,
tensorIndices,
disable_slice_optimization,
self_device,
self_sizes);
if (tensorIndices.empty()) {
if (sliced.is_same(self)) {
// ensure we return a shallow copy for things like x[...]
sliced = at::alias(sliced);
}
return sliced;
}
// indexing by tensors ("advanced" indexing)
return dispatch_index(sliced, std::move(tensorIndices));
}
// This mirrors `THPVariable_setitem` in
// torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
// Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
// tensor indexing functions from Python ]
inline void set_item(
const Tensor& self,
const ArrayRef<TensorIndex>& indices,
const Tensor& value,
bool disable_slice_optimization = false) {
at::Device self_device = self.device();
SymIntArrayRef self_sizes = self.sym_sizes();
// handle simple types: integers, slices, ellipsis, bool
if (indices.size() == 1) {
const TensorIndex& index = indices[0];
if (index.is_boolean() && !index.boolean()) {
// do nothing for false (technically we should check the size, but we
// don't have real 0-sized shapes.
return;
} else if (index.is_ellipsis()) {
copy_to(self, value);
return;
} else if (index.is_none() || (index.is_boolean() && index.boolean())) {
copy_to(self.unsqueeze(0), value);
return;
} else if (index.is_integer()) {
copy_to(
impl::applySelect(
self, 0, index.integer(), 0, self_device, self_sizes),
value);
return;
} else if (index.is_slice()) {
copy_to(
impl::applySlice(
self,
0,
index.slice().start(),
index.slice().stop(),
index.slice().step(),
/*disable_slice_optimization=*/disable_slice_optimization,
self_device,
self_sizes),
value);
return;
}
}
std::vector<Tensor> tensorIndices;
Tensor sliced = impl::applySlicing(
self,
indices,
tensorIndices,
disable_slice_optimization,
self_device,
self_sizes);
if (tensorIndices.empty()) {
copy_to(sliced, value);
return;
}
SymIntArrayRef valueSizes = value.sym_sizes();
SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes);
Tensor valuesSliced;
if (!valueSizes.equals(slicedValueSizes)) {
valuesSliced = value.view_symint(slicedValueSizes);
} else {
valuesSliced = value;
}
dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced);
return;
}
} // namespace at::indexing

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,72 @@
#pragma once
#include <ATen/native/TensorIterator.h>
#include <c10/util/SmallBuffer.h>
#include <c10/util/irange.h>
namespace at {
struct DimCounter {
DimCounter(IntArrayRef shape, Range range);
void increment(const std::array<int64_t, 2>& step);
bool is_done() const;
std::array<int64_t, 2> max_2d_step() const;
IntArrayRef shape;
Range range;
c10::SmallBuffer<int64_t, 4> values;
int64_t offset;
};
namespace internal {
inline void get_data_ptrs(
char** ptrs,
ArrayRef<char*> base,
IntArrayRef strides,
IntArrayRef counter) {
const auto ntensors = base.size();
const auto ndim = counter.size();
std::copy(base.begin(), base.end(), ptrs);
for (const auto dim : c10::irange(ndim)) {
int64_t value = counter[dim];
for (const auto arg : c10::irange(ntensors)) {
ptrs[arg] += value * strides[dim * ntensors + arg];
}
}
}
inline void serial_for_each(
IntArrayRef shape,
IntArrayRef strides,
char** base_ptrs,
size_t ntensors,
typename TensorIteratorBase::loop2d_t loop,
Range range) {
const auto ndim = shape.size();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
strides.size() == ntensors * std::max(size_t{2}, ndim));
if (ndim <= 1) {
if (range.begin == 0) {
loop(base_ptrs, strides.data(), range.size(), 1);
} else {
c10::SmallBuffer<char*, 4> ptrs(ntensors);
get_data_ptrs(ptrs.data(), {base_ptrs, ntensors}, strides, {range.begin});
loop(ptrs.data(), strides.data(), range.size(), 1);
}
} else {
c10::SmallBuffer<char*, 4> ptrs(ntensors);
auto counter = DimCounter(shape, range);
while (!counter.is_done()) {
get_data_ptrs(
ptrs.data(), {base_ptrs, ntensors}, strides, counter.values);
auto step = counter.max_2d_step();
loop(ptrs.data(), strides.data(), step[0], step[1]);
counter.increment(step);
}
}
}
} // namespace internal
} // namespace at

View File

@ -0,0 +1,137 @@
#pragma once
#include <ATen/DimVector.h>
#include <ATen/core/Dimname.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/strides.h>
namespace at {
class Tensor;
namespace impl {
// Use this to define the prototype for a meta function. There are two
// versions; one that takes one argument (just the operator name), or FUNC2
// variant that takes two arguments (operator name and overload name).
//
// Example usage:
//
// TORCH_META_FUNC2(add, Tensor) (
// const Tensor& self, const Tensor& other
// ) {
// ... compute sizes and options ...
// set_output(sizes, options);
// }
//
#define TORCH_META_FUNC(name) void structured_##name::meta
#define TORCH_META_FUNC2(name, overload) \
void structured_##name##_##overload::meta
// These are versions of TORCH_META_FUNC(2) that include a precompute_out struct
// as a return value. They should be used when the kernel in question has
// precomputed values declared in native_functions.yaml and the corresponding
// implementation should return an instance of the aforementioned struct.
#define TORCH_PRECOMPUTE_META_FUNC(name) \
structured_##name::meta_return_ty structured_##name::meta
#define TORCH_PRECOMPUTE_META_FUNC2(name, overload) \
structured_##name##_##overload::meta_return_ty \
structured_##name##_##overload::meta
// Use this to create a precompute struct in a meta function.
#define TORCH_PRECOMPUTE_STRUCT(name) structured_##name::precompute_out<>
#define TORCH_PRECOMPUTE_STRUCT2(name, overload) \
structured_##name##_##overload::precompute_out<>
// Use this to define the prototype for an implementation. This takes only
// one argument, which is the name of the dispatch key entry you're
// implementing.
//
// Example usage:
//
// TORCH_IMPL_FUNC(add_cpu) (
// Tensor& result, const Tensor& self, const Tensor& other
// ) {
// ... do the actual implementation ...
// }
//
#define TORCH_IMPL_FUNC(name) void structured_##name::impl
// Base class for all structured kernel classes. The set_output virtual
// method is varied depending whether or not the operator is
// functional/out/inplace, and could also be specialized for CPU/CUDA/etc
// (although presently it isn't).
//
// A notable subclass of this interface is TensorIteratorBase.
struct TORCH_API MetaBase {
MetaBase() = default;
MetaBase(const MetaBase&) = default;
MetaBase& operator=(const MetaBase&) = default;
MetaBase(MetaBase&&) noexcept = default;
MetaBase& operator=(MetaBase&&) noexcept = default;
virtual const Tensor& maybe_get_output(int64_t output_idx) = 0;
// Note: [set_output_*]
// See: https://github.com/pytorch/pytorch/issues/69813
// Whenever defining the output properties in the META function of a
// structured kernel (what was usually done with `set_output`), use one of
// these 3 variants, instead. In order to decide which variant to use, check
// the following decision tree:
//
// - Can the kernel you are going to implement support output tensors
// with arbitrary strides?
// |
// -- YES: `set_output_raw_strided`
// |
// -- NO: Should the output tensor strides be contiguous?
// |
// -- YES: `set_output_contiguous`
// |
// -- NO: `set_output_strided`
//
// Use this function whenever the kernel requires specific strides for the
// output. If `strides` does not match the given output strides, proxy outputs
// will be created and passed to the IMPL function.
virtual void set_output_strided(
int64_t output_idx [[maybe_unused]],
IntArrayRef sizes [[maybe_unused]],
IntArrayRef strides [[maybe_unused]],
TensorOptions options [[maybe_unused]],
DimnameList names [[maybe_unused]] = {}) {
TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented.");
}
// Use this function whenever the kernel knows how to handle arbitrary strided
// outputs. This function has the same behavior as the old `set_output`: it
// will only re-stride if the given output was resized.
virtual void set_output_raw_strided(
int64_t output_idx [[maybe_unused]],
IntArrayRef sizes [[maybe_unused]],
IntArrayRef strides_hint [[maybe_unused]],
TensorOptions options [[maybe_unused]],
DimnameList names [[maybe_unused]] = {}) {
TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented.");
}
// Use this function if the kernel requires contiguous strides.
// Alias for `set_output_strided`, but with contiguous strides.
void set_output_contiguous(
int64_t output_idx,
IntArrayRef sizes,
TensorOptions options,
DimnameList names = {}) {
auto strides = c10::contiguous_strides(sizes);
set_output_strided(output_idx, sizes, strides, options, names);
}
// Returns a reference to an undefined tensor if there is no presupplied
// output
const Tensor& maybe_get_output() {
return maybe_get_output(0);
}
virtual ~MetaBase() = default;
};
} // namespace impl
} // namespace at

View File

@ -0,0 +1,75 @@
#pragma once
#include <ATen/WrapDimUtils.h>
namespace at::namedinference {
// TensorName and TensorNames are wrappers around Dimname and DimnameList
// that contain helper functions to make writing name inference rules easier.
//
// A TensorName represents a Dimname associated with some DimnameList (from a
// Tensor). This encapsulates all the information that is needed to check if
// names *match* and to *unify* names.
//
// Definition: Two names in two tensors *match* if they are equal, or if at
// least one of them is a wildcard that can be *refined* to the other name.
//
// Definition: unify(name, other) fails if the names do not match. Otherwise,
// it returns the most refined of name and other.
//
// Here is an example of checking if two names match.
// tensor: Tensor[A, None]
// other: Tensor[A]
//
// Let's say we wish to check if tensor.names[-1] matches other.names[-1].
// None (in tensor) cannot match A (in other) because if the None were refined
// to A, `tensor` would have duplicate names [A, A]. Therefore we need to check
// tensor.names [A, None] for the existence of A.
struct TORCH_API TensorName {
explicit TensorName(ArrayRef<Dimname> origin, int origin_idx)
: origin_(origin),
name_(origin[maybe_wrap_dim(
origin_idx,
static_cast<int64_t>(origin.size()))]),
origin_idx_(origin_idx) {}
// op_name is only used for error reporting.
const TensorName& unify(const TensorName& other, const char* op_name) const;
Dimname toDimname() const;
private:
ArrayRef<Dimname> origin_;
Dimname name_;
int origin_idx_; // A named tensor can have at most 64 dims.
TORCH_API friend std::ostream& operator<<(
std::ostream& out,
const TensorName& tensorname);
};
using TensorNameVec = SmallVector<TensorName, 10>;
struct TORCH_API TensorNames {
explicit TensorNames(ArrayRef<Dimname> names);
// Create TensorNames from names[start:end]. Each individual TensorName stores
// `names`, NOT names[start:end], because the original tensor's names are
// `names`.
explicit TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end);
// op_name is only used for error reporting.
TensorNames& unifyFromRightInplace(
const TensorNames& other,
const char* op_name = "unify");
void checkUnique(const char* op_name) const;
void append(TensorName name);
std::vector<Dimname> toDimnameVec() const;
private:
explicit TensorNames(TensorNameVec&& names) : names_(std::move(names)){};
TensorNameVec names_;
};
} // namespace at::namedinference

View File

@ -0,0 +1,51 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <c10/core/Scalar.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty_like.h>
#endif
namespace at {
#define AT_FORALL_BINARY_OPS(_) \
_(+, x.add(y), y.add(x)) \
_(*, x.mul(y), y.mul(x)) \
_(-, \
x.sub(y), \
::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).sub_(y)) \
_(/, \
x.div(y), \
::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).div_(y)) \
_(%, \
x.remainder(y), \
::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).remainder_(y)) \
_(&, x.bitwise_and(y), y.bitwise_and(x)) \
_(|, x.bitwise_or(y), y.bitwise_or(x)) \
_(^, x.bitwise_xor(y), y.bitwise_xor(x)) \
_(<, x.lt(y), y.gt(x)) \
_(<=, x.le(y), y.ge(x)) \
_(>, x.gt(y), y.lt(x)) \
_(>=, x.ge(y), y.le(x)) \
_(==, x.eq(y), y.eq(x)) \
_(!=, x.ne(y), y.ne(x))
#define DEFINE_OPERATOR(op, body, reverse_scalar_body) \
inline Tensor operator op(const Tensor& x, const Tensor& y) { \
return body; \
} \
inline Tensor operator op(const Tensor& x, const Scalar& y) { \
return body; \
} \
inline Tensor operator op(const Scalar& x, const Tensor& y) { \
return reverse_scalar_body; \
}
AT_FORALL_BINARY_OPS(DEFINE_OPERATOR)
#undef DEFINE_OPERATOR
#undef AT_FORALL_BINARY_OPS
} // namespace at

View File

@ -0,0 +1,2 @@
#pragma once
#include <c10/core/TensorOptions.h>

View File

@ -0,0 +1,88 @@
#pragma once
#include <ATen/core/List.h>
#include <ATen/core/Tensor.h>
#include <c10/core/impl/TorchDispatchModeTLS.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/equal.h>
#endif
namespace at {
// Note [Tensor-subclass-like Tensors]
// Tensor-subclass-like is defined as:
// - a Tensor subclass (via __torch_dispatch__ in Python or extending
// TensorImpl in C++)
// - anything else that shares the same perils as Tensor subclasses.
// For example, many Tensor subclasses do not have storage and meta Tensors
// do not have storage either, so meta Tensors belong here.
//
// We should ensure that PyTorch internals supports Tensor-subclass-like
// objects. In particular, Tensor-subclass-like objects struggle with two
// classes of operations that are problematic for Tensor subclasses:
// 1. Because some Tensor subclasses do not have storage, .item() or
// .data_ptr() calls are not good.
// 2. Certain in-place operations can eliminate the typing of the Tensor
// subclass. For example:
// >>> torch.zeros(input.sizes(), grad.options()).diag().copy_(input)
// If input is a Tensor subclass, then the above ends up either erroring out
// or returning a regular non-Tensor-subclass Tensor!
constexpr auto kFunctorchWrappedTensors = DispatchKeySet(
{DispatchKey::FuncTorchGradWrapper,
DispatchKey::FuncTorchBatched,
DispatchKey::Functionalize});
constexpr auto kTensorSubclassLike =
kFunctorchWrappedTensors |
DispatchKeySet(
{// WARNING: DO NOT put combined backend component + functionality keys
// here, you will incorrectly always match on the functionality key
// no matter the backend component
DispatchKey::Batched,
DispatchKey::Sparse,
DispatchKey::SparseCsr,
DispatchKey::Python}) |
DispatchKeySet(BackendComponent::MetaBit);
inline bool isTensorSubclassLike(const Tensor& tensor) {
if (c10::impl::dispatch_mode_enabled())
return true;
auto key_set = tensor.unsafeGetTensorImpl()->key_set();
return !(key_set & kTensorSubclassLike).empty();
}
inline bool areAnyTensorSubclassLike(TensorList tensors) {
if (c10::impl::dispatch_mode_enabled())
return true;
return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike);
}
inline bool areAnyOptionalTensorSubclassLike(
const c10::List<std::optional<Tensor>>& tensors) {
if (c10::impl::dispatch_mode_enabled())
return true;
return std::any_of(
tensors.begin(),
tensors.end(),
[](const std::optional<Tensor>& opt_tensor) {
return (
opt_tensor.has_value() && isTensorSubclassLike(opt_tensor.value()));
});
}
// Helper function to deal testing truthfulness of a scalar tensor
// in a Composite Compliant manner.
// NOTE: This function expects a scalar tensor of boolean dtype.
// Eg.
// Non-Composite Compliant Pattern : (t == 0).all().item<bool>()
// Composite Compliant Patter : is_salar_tensor_true((t == 0).all())
inline bool is_scalar_tensor_true(const Tensor& t) {
TORCH_INTERNAL_ASSERT(t.dim() == 0)
TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool)
return at::equal(t, t.new_ones({}, t.options()));
}
} // namespace at

View File

@ -0,0 +1,190 @@
#pragma once
#include <ATen/DimVector.h>
#include <ATen/EmptyTensor.h>
#include <ATen/Tensor.h>
#include <ATen/TensorGeometry.h>
#include <ATen/Utils.h>
#include <utility>
// These functions are NOT in Utils.h, because this file has a dep on Tensor.h
#define TORCH_CHECK_TENSOR_ALL(cond, ...) \
TORCH_CHECK((cond)._is_all_true().item<bool>(), __VA_ARGS__);
namespace at {
// The following are utility functions for checking that arguments
// make sense. These are particularly useful for native functions,
// which do NO argument checking by default.
struct TORCH_API TensorArg {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const Tensor& tensor;
const char* name;
int pos; // 1-indexed
TensorArg(const Tensor& tensor, const char* name, int pos)
: tensor(tensor), name(name), pos(pos) {}
// Try to mitigate any possibility of dangling reference to temporaries.
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
TensorArg(Tensor&& tensor, const char* name, int pos) = delete;
const Tensor* operator->() const {
return &tensor;
}
const Tensor& operator*() const {
return tensor;
}
};
struct TORCH_API TensorGeometryArg {
TensorGeometry tensor;
const char* name;
int pos; // 1-indexed
/* implicit */ TensorGeometryArg(TensorArg arg)
: tensor(TensorGeometry{arg.tensor}), name(arg.name), pos(arg.pos) {}
TensorGeometryArg(TensorGeometry tensor, const char* name, int pos)
: tensor(std::move(tensor)), name(name), pos(pos) {}
const TensorGeometry* operator->() const {
return &tensor;
}
const TensorGeometry& operator*() const {
return tensor;
}
};
// A string describing which function did checks on its input
// arguments.
// TODO: Consider generalizing this into a call stack.
using CheckedFrom = const char*;
// The undefined convention: singular operators assume their arguments
// are defined, but functions which take multiple tensors will
// implicitly filter out undefined tensors (to make it easier to perform
// tests which should apply if the tensor is defined, and should not
// otherwise.)
//
// NB: This means that the n-ary operators take lists of TensorArg,
// not TensorGeometryArg, because the Tensor to TensorGeometry
// conversion will blow up if you have undefined tensors.
TORCH_API std::ostream& operator<<(
std::ostream& out,
const TensorGeometryArg& t);
TORCH_API void checkDim(
CheckedFrom c,
const Tensor& tensor,
const char* name,
int pos, // 1-indexed
int64_t dim);
TORCH_API void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim);
// NB: this is an inclusive-exclusive range
TORCH_API void checkDimRange(
CheckedFrom c,
const TensorGeometryArg& t,
int64_t dim_start,
int64_t dim_end);
TORCH_API void checkSameDim(
CheckedFrom c,
const TensorGeometryArg& t1,
const TensorGeometryArg& t2);
TORCH_API void checkContiguous(CheckedFrom c, const TensorGeometryArg& t);
TORCH_API void checkAllContiguous(CheckedFrom c, at::ArrayRef<TensorArg> ts);
TORCH_API void checkSize(
CheckedFrom c,
const TensorGeometryArg& t,
IntArrayRef sizes);
TORCH_API void checkSize_symint(
CheckedFrom c,
const TensorGeometryArg& t,
c10::SymIntArrayRef sizes);
TORCH_API void checkSize(
CheckedFrom c,
const TensorGeometryArg& t,
int64_t dim,
int64_t size);
TORCH_API void checkSize_symint(
CheckedFrom c,
const TensorGeometryArg& t,
int64_t dim,
const c10::SymInt& size);
TORCH_API void checkNumel(
CheckedFrom c,
const TensorGeometryArg& t,
int64_t numel);
TORCH_API void checkSameNumel(
CheckedFrom c,
const TensorArg& t1,
const TensorArg& t2);
TORCH_API void checkAllSameNumel(CheckedFrom c, ArrayRef<TensorArg> tensors);
TORCH_API void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType s);
TORCH_API void checkScalarTypes(
CheckedFrom c,
const TensorArg& t,
at::ArrayRef<ScalarType> l);
TORCH_API void checkSameGPU(
CheckedFrom c,
const TensorArg& t1,
const TensorArg& t2);
TORCH_API void checkAllSameGPU(CheckedFrom c, ArrayRef<TensorArg> tensors);
TORCH_API void checkSameType(
CheckedFrom c,
const TensorArg& t1,
const TensorArg& t2);
TORCH_API void checkAllSameType(CheckedFrom c, ArrayRef<TensorArg> tensors);
TORCH_API void checkSameSize(
CheckedFrom c,
const TensorArg& t1,
const TensorArg& t2);
TORCH_API void checkAllSameSize(CheckedFrom c, ArrayRef<TensorArg> tensors);
TORCH_API void checkDefined(CheckedFrom c, const TensorArg& t);
TORCH_API void checkAllDefined(CheckedFrom c, at::ArrayRef<TensorArg> t);
// FixMe: does TensorArg slow things down?
TORCH_API void checkBackend(
CheckedFrom c,
at::ArrayRef<Tensor> t,
at::Backend backend);
TORCH_API void checkDeviceType(
CheckedFrom c,
at::ArrayRef<Tensor> tensors,
at::DeviceType device_type);
TORCH_API void checkLayout(CheckedFrom c, const Tensor& t, Layout layout);
TORCH_API void checkLayout(
CheckedFrom c,
at::ArrayRef<Tensor> tensors,
at::Layout layout);
// Methods for getting data_ptr if tensor is defined
TORCH_API void* maybe_data_ptr(const Tensor& tensor);
TORCH_API void* maybe_data_ptr(const TensorArg& tensor);
TORCH_API void check_dim_size(
const Tensor& tensor,
int64_t dim,
int64_t dim_size,
int64_t size);
namespace detail {
TORCH_API std::vector<int64_t> defaultStrides(IntArrayRef sizes);
TORCH_API std::optional<std::vector<int64_t>> computeStride(
IntArrayRef oldshape,
IntArrayRef oldstride,
IntArrayRef newshape);
TORCH_API std::optional<SymDimVector> computeStride(
c10::SymIntArrayRef oldshape,
c10::SymIntArrayRef oldstride,
c10::SymIntArrayRef newshape);
TORCH_API std::optional<DimVector> computeStride(
IntArrayRef oldshape,
IntArrayRef oldstride,
const DimVector& newshape);
} // namespace detail
} // namespace at

View File

@ -0,0 +1,21 @@
#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/macros/Macros.h>
#include <unordered_map>
namespace at::impl {
struct TORCH_API ThreadLocalPythonObjects {
static void set(const std::string& key, std::shared_ptr<SafePyObject> value);
static const std::shared_ptr<SafePyObject>& get(const std::string& key);
static bool contains(const std::string& key);
static const ThreadLocalPythonObjects& get_state();
static void set_state(ThreadLocalPythonObjects state);
private:
std::unordered_map<std::string, std::shared_ptr<c10::SafePyObject>> obj_dict_;
};
} // namespace at::impl

Some files were not shown because too many files have changed in this diff Show More