I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,98 @@
#pragma once
#include <ATen/native/DispatchStub.h>
#include <c10/util/Exception.h>
#include <c10/util/string_view.h>
namespace c10 {
class Scalar;
}
namespace at {
struct TensorIterator;
struct TensorIteratorBase;
class TensorBase;
}
namespace at::native {
// These constants control the approximation behavior of gelu function.
enum class GeluType {
None, // Baseline Gelu
Tanh, // Tahn Gelu Approximation
END
};
inline GeluType get_gelutype_enum(const c10::string_view approximate) {
if (approximate == "none") {
return GeluType::None;
} else if (approximate == "tanh") {
return GeluType::Tanh;
} else {
TORCH_CHECK(false, "approximate argument must be either none or tanh.");
}
}
inline std::string gelutype_to_string(const GeluType type) {
switch(type) {
case GeluType::None: return "none";
case GeluType::Tanh: return "tanh";
default: TORCH_CHECK(false, "unknown GELU type: ", static_cast<int>(type));
}
}
using structured_activation_fn = void (*)(TensorIteratorBase&);
using structured_activation_backward_fn = void (*)(TensorIteratorBase&);
using activation_fn = void (*)(TensorIterator&);
using activation_backward_fn = void (*)(TensorIterator&);
using softplus_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
using softplus_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
using threshold_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
using hardtanh_backward_fn = void (*)(TensorIterator&, const c10::Scalar&, const c10::Scalar&);
using hardsigmoid_fn = void(*)(TensorIteratorBase&);
using hardsigmoid_backward_fn = void(*)(TensorIteratorBase&);
using hardswish_fn = void(*)(TensorIterator&);
using hardswish_backward_fn = void(*)(TensorIterator&);
using shrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
using softshrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
using shrink_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
using elu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&);
using elu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&, bool);
using leaky_relu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
using log_sigmoid_cpu_fn = void (*)(TensorBase&, TensorBase&, const TensorBase&);
using gelu_fn = void (*)(TensorIteratorBase&, GeluType);
using gelu_backward_fn = void (*)(TensorIteratorBase&, GeluType);
using glu_jvp_fn = void (*)(TensorIteratorBase&);
DECLARE_DISPATCH(elu_fn, elu_stub);
DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub);
DECLARE_DISPATCH(softplus_fn, softplus_stub);
DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub);
DECLARE_DISPATCH(log_sigmoid_cpu_fn, log_sigmoid_cpu_stub);
DECLARE_DISPATCH(activation_backward_fn, log_sigmoid_backward_stub);
DECLARE_DISPATCH(threshold_fn, threshold_stub);
DECLARE_DISPATCH(gelu_fn, GeluKernel);
DECLARE_DISPATCH(gelu_backward_fn, GeluBackwardKernel);
DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub);
DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub);
DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub);
DECLARE_DISPATCH(hardswish_fn, hardswish_stub);
DECLARE_DISPATCH(hardswish_backward_fn, hardswish_backward_stub);
DECLARE_DISPATCH(shrink_fn, hardshrink_stub);
DECLARE_DISPATCH(softshrink_fn, softshrink_stub);
DECLARE_DISPATCH(shrink_backward_fn, shrink_backward_stub);
DECLARE_DISPATCH(leaky_relu_fn, leaky_relu_stub);
DECLARE_DISPATCH(leaky_relu_backward_fn, leaky_relu_backward_stub);
DECLARE_DISPATCH(structured_activation_fn, glu_stub);
DECLARE_DISPATCH(activation_backward_fn, glu_backward_stub);
DECLARE_DISPATCH(glu_jvp_fn, glu_jvp_stub);
DECLARE_DISPATCH(structured_activation_fn, silu_stub);
DECLARE_DISPATCH(structured_activation_backward_fn, silu_backward_stub);
DECLARE_DISPATCH(structured_activation_fn, mish_stub);
DECLARE_DISPATCH(activation_backward_fn, mish_backward_stub);
DECLARE_DISPATCH(activation_fn, prelu_stub);
DECLARE_DISPATCH(activation_backward_fn, prelu_backward_stub);
} // namespace at::native

View File

@ -0,0 +1,49 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/irange.h>
#include <cmath>
namespace at::native {
using adaptive_avg_pooling2d_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size);
using adaptive_avg_pooling2d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output);
DECLARE_DISPATCH(adaptive_avg_pooling2d_fn, adaptive_avg_pool2d_kernel);
DECLARE_DISPATCH(adaptive_avg_pooling2d_backward_fn, adaptive_avg_pool2d_backward_kernel);
using adaptive_max_pooling2d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size);
using adaptive_max_pooling2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
DECLARE_DISPATCH(adaptive_max_pooling2d_fn, adaptive_max_pool2d_kernel);
DECLARE_DISPATCH(adaptive_max_pooling2d_backward_fn, adaptive_max_pool2d_backward_kernel);
using adaptive_avg_pooling3d_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size);
using adaptive_avg_pooling3d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output);
DECLARE_DISPATCH(adaptive_avg_pooling3d_fn, adaptive_avg_pool3d_kernel);
DECLARE_DISPATCH(adaptive_avg_pooling3d_backward_fn, adaptive_avg_pool3d_backward_kernel);
using adaptive_max_pooling3d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size);
using adaptive_max_pooling3d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
DECLARE_DISPATCH(adaptive_max_pooling3d_fn, adaptive_max_pool3d_kernel);
DECLARE_DISPATCH(adaptive_max_pooling3d_backward_fn, adaptive_max_pool3d_backward_kernel);
inline int64_t start_index(int64_t a, int64_t b, int64_t c) {
return (a / b) * c + ((a % b) * c) / b;
}
inline int64_t end_index(int64_t a, int64_t b, int64_t c) {
return 1 + ((a + 1) * c - 1) / b;
}
inline void adaptive_pool_empty_output_check(const Tensor& gradOutput_, const char* arg_name) {
int64_t ndim = gradOutput_.ndimension();
for (const auto i : c10::irange(1, ndim)) {
TORCH_CHECK(gradOutput_.size(i) > 0,
arg_name, "(): Expected grad_output to have non-zero size for non-batch dimensions, "
"but grad_output has sizes ", gradOutput_.sizes(), " with dimension ", i,
" being empty");
}
}
} // namespace at::native

View File

@ -0,0 +1,28 @@
#pragma once
#include <ATen/native/DispatchStub.h>
#include <ATen/core/ATen_fwd.h>
namespace at {
class Tensor;
namespace native {
using _amp_foreach_non_finite_check_and_unscale_cpu__fn = void (*)(
TensorList,
Tensor&,
const Tensor&);
using _amp_update_scale_cpu__fn = Tensor& (*)(
Tensor&,
Tensor&,
const Tensor&,
double,
double,
int64_t);
DECLARE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu__fn, _amp_foreach_non_finite_check_and_unscale_cpu_stub);
DECLARE_DISPATCH(_amp_update_scale_cpu__fn, _amp_update_scale_cpu_stub);
} // namespace native
} // namespace at

View File

@ -0,0 +1,321 @@
#pragma once
#include <optional>
#include <c10/util/string_view.h>
#include <ATen/Config.h>
#include <ATen/native/DispatchStub.h>
// Forward declare TI
namespace at {
class Tensor;
struct TensorIterator;
namespace native {
enum class TransposeType;
}
}
namespace at::native {
enum class LapackLstsqDriverType : int64_t { Gels, Gelsd, Gelsy, Gelss};
#if AT_BUILD_WITH_LAPACK()
// Define per-batch functions to be used in the implementation of batched
// linear algebra operations
template <class scalar_t>
void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info);
template <class scalar_t>
void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info);
template <class scalar_t, class value_t=scalar_t>
void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info);
template <class scalar_t>
void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
template <class scalar_t>
void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
template <class scalar_t>
void lapackOrmqr(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info);
template <class scalar_t, class value_t = scalar_t>
void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info);
template <class scalar_t>
void lapackGels(char trans, int m, int n, int nrhs,
scalar_t *a, int lda, scalar_t *b, int ldb,
scalar_t *work, int lwork, int *info);
template <class scalar_t, class value_t = scalar_t>
void lapackGelsd(int m, int n, int nrhs,
scalar_t *a, int lda, scalar_t *b, int ldb,
value_t *s, value_t rcond, int *rank,
scalar_t* work, int lwork,
value_t *rwork, int* iwork, int *info);
template <class scalar_t, class value_t = scalar_t>
void lapackGelsy(int m, int n, int nrhs,
scalar_t *a, int lda, scalar_t *b, int ldb,
int *jpvt, value_t rcond, int *rank,
scalar_t *work, int lwork, value_t* rwork, int *info);
template <class scalar_t, class value_t = scalar_t>
void lapackGelss(int m, int n, int nrhs,
scalar_t *a, int lda, scalar_t *b, int ldb,
value_t *s, value_t rcond, int *rank,
scalar_t *work, int lwork,
value_t *rwork, int *info);
template <LapackLstsqDriverType, class scalar_t, class value_t = scalar_t>
struct lapackLstsq_impl;
template <class scalar_t, class value_t>
struct lapackLstsq_impl<LapackLstsqDriverType::Gels, scalar_t, value_t> {
static void call(
char trans, int m, int n, int nrhs,
scalar_t *a, int lda, scalar_t *b, int ldb,
scalar_t *work, int lwork, int *info, // Gels flavor
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
value_t *s, // Gelss flavor
int *iwork // Gelsd flavor
) {
lapackGels<scalar_t>(
trans, m, n, nrhs,
a, lda, b, ldb,
work, lwork, info);
}
};
template <class scalar_t, class value_t>
struct lapackLstsq_impl<LapackLstsqDriverType::Gelsy, scalar_t, value_t> {
static void call(
char trans, int m, int n, int nrhs,
scalar_t *a, int lda, scalar_t *b, int ldb,
scalar_t *work, int lwork, int *info, // Gels flavor
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
value_t *s, // Gelss flavor
int *iwork // Gelsd flavor
) {
lapackGelsy<scalar_t, value_t>(
m, n, nrhs,
a, lda, b, ldb,
jpvt, rcond, rank,
work, lwork, rwork, info);
}
};
template <class scalar_t, class value_t>
struct lapackLstsq_impl<LapackLstsqDriverType::Gelsd, scalar_t, value_t> {
static void call(
char trans, int m, int n, int nrhs,
scalar_t *a, int lda, scalar_t *b, int ldb,
scalar_t *work, int lwork, int *info, // Gels flavor
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
value_t *s, // Gelss flavor
int *iwork // Gelsd flavor
) {
lapackGelsd<scalar_t, value_t>(
m, n, nrhs,
a, lda, b, ldb,
s, rcond, rank,
work, lwork,
rwork, iwork, info);
}
};
template <class scalar_t, class value_t>
struct lapackLstsq_impl<LapackLstsqDriverType::Gelss, scalar_t, value_t> {
static void call(
char trans, int m, int n, int nrhs,
scalar_t *a, int lda, scalar_t *b, int ldb,
scalar_t *work, int lwork, int *info, // Gels flavor
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
value_t *s, // Gelss flavor
int *iwork // Gelsd flavor
) {
lapackGelss<scalar_t, value_t>(
m, n, nrhs,
a, lda, b, ldb,
s, rcond, rank,
work, lwork,
rwork, info);
}
};
template <LapackLstsqDriverType driver_type, class scalar_t, class value_t = scalar_t>
void lapackLstsq(
char trans, int m, int n, int nrhs,
scalar_t *a, int lda, scalar_t *b, int ldb,
scalar_t *work, int lwork, int *info, // Gels flavor
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
value_t *s, // Gelss flavor
int *iwork // Gelsd flavor
) {
lapackLstsq_impl<driver_type, scalar_t, value_t>::call(
trans, m, n, nrhs,
a, lda, b, ldb,
work, lwork, info,
jpvt, rcond, rank, rwork,
s,
iwork);
}
template <class scalar_t>
void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info);
template <class scalar_t>
void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info);
template <class scalar_t>
void lapackLdlHermitian(
char uplo,
int n,
scalar_t* a,
int lda,
int* ipiv,
scalar_t* work,
int lwork,
int* info);
template <class scalar_t>
void lapackLdlSymmetric(
char uplo,
int n,
scalar_t* a,
int lda,
int* ipiv,
scalar_t* work,
int lwork,
int* info);
template <class scalar_t>
void lapackLdlSolveHermitian(
char uplo,
int n,
int nrhs,
scalar_t* a,
int lda,
int* ipiv,
scalar_t* b,
int ldb,
int* info);
template <class scalar_t>
void lapackLdlSolveSymmetric(
char uplo,
int n,
int nrhs,
scalar_t* a,
int lda,
int* ipiv,
scalar_t* b,
int ldb,
int* info);
template<class scalar_t, class value_t=scalar_t>
void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info);
#endif
#if AT_BUILD_WITH_BLAS()
template <class scalar_t>
void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb);
#endif
using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/);
DECLARE_DISPATCH(cholesky_fn, cholesky_stub);
using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool /*upper*/);
DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub);
using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/, Tensor& /*infos*/, const Tensor& /*input*/, bool /*compute_eigenvectors*/);
DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub);
using geqrf_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/);
DECLARE_DISPATCH(geqrf_fn, geqrf_stub);
using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/);
DECLARE_DISPATCH(orgqr_fn, orgqr_stub);
using ormqr_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, const Tensor& /*other*/, bool /*left*/, bool /*transpose*/);
DECLARE_DISPATCH(ormqr_fn, ormqr_stub);
using linalg_eigh_fn = void (*)(
const Tensor& /*eigenvalues*/,
const Tensor& /*eigenvectors*/,
const Tensor& /*infos*/,
bool /*upper*/,
bool /*compute_eigenvectors*/);
DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub);
using lstsq_fn = void (*)(
const Tensor& /*a*/,
Tensor& /*b*/,
Tensor& /*rank*/,
Tensor& /*singular_values*/,
Tensor& /*infos*/,
double /*rcond*/,
std::string /*driver_name*/);
DECLARE_DISPATCH(lstsq_fn, lstsq_stub);
using triangular_solve_fn = void (*)(
const Tensor& /*A*/,
const Tensor& /*B*/,
bool /*left*/,
bool /*upper*/,
TransposeType /*transpose*/,
bool /*unitriangular*/);
DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub);
using lu_factor_fn = void (*)(
const Tensor& /*input*/,
const Tensor& /*pivots*/,
const Tensor& /*infos*/,
bool /*compute_pivots*/);
DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub);
using unpack_pivots_fn = void(*)(
TensorIterator& iter,
const int64_t dim_size,
const int64_t max_pivot);
DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub);
using lu_solve_fn = void (*)(
const Tensor& /*LU*/,
const Tensor& /*pivots*/,
const Tensor& /*B*/,
TransposeType /*trans*/);
DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub);
using ldl_factor_fn = void (*)(
const Tensor& /*LD*/,
const Tensor& /*pivots*/,
const Tensor& /*info*/,
bool /*upper*/,
bool /*hermitian*/);
DECLARE_DISPATCH(ldl_factor_fn, ldl_factor_stub);
using svd_fn = void (*)(
const Tensor& /*A*/,
const bool /*full_matrices*/,
const bool /*compute_uv*/,
const std::optional<c10::string_view>& /*driver*/,
const Tensor& /*U*/,
const Tensor& /*S*/,
const Tensor& /*Vh*/,
const Tensor& /*info*/);
DECLARE_DISPATCH(svd_fn, svd_stub);
using ldl_solve_fn = void (*)(
const Tensor& /*LD*/,
const Tensor& /*pivots*/,
const Tensor& /*result*/,
bool /*upper*/,
bool /*hermitian*/);
DECLARE_DISPATCH(ldl_solve_fn, ldl_solve_stub);
} // namespace at::native

View File

@ -0,0 +1,119 @@
#pragma once
#include <ATen/core/TensorBase.h>
#include <ATen/native/DispatchStub.h>
#include <c10/core/Scalar.h>
#include <c10/util/TypeSafeSignMath.h>
namespace at {
struct TensorIterator;
struct TensorIteratorBase;
}
namespace at::native {
inline void alpha_check(const ScalarType dtype, const Scalar& alpha) {
TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool,
"Boolean alpha only supported for Boolean results.");
TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype)
|| alpha.isIntegral(true),
"For integral input tensors, argument alpha must not be a floating point number.");
TORCH_CHECK(isComplexType(dtype) || !alpha.isComplex(),
"For non-complex input tensors, argument alpha must not be a complex number.")
}
// Basic checking for all sub functions.
inline void sub_check(const TensorBase& self, const TensorBase& other) {
TORCH_CHECK(self.scalar_type() != kBool || other.scalar_type() != kBool,
"Subtraction, the `-` operator, with two bool tensors is not supported. "
"Use the `^` or `logical_xor()` operator instead.")
TORCH_CHECK(self.scalar_type() != kBool && other.scalar_type() != kBool,
"Subtraction, the `-` operator, with a bool tensor is not supported. "
"If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
}
inline void sub_check(const TensorBase& self, const Scalar& scalar) {
TORCH_CHECK(self.scalar_type() != kBool || !scalar.isBoolean(),
"Subtraction, the `-` operator, with two bool tensors is not supported. "
"Use the `^` or `logical_xor()` operator instead.")
TORCH_CHECK(self.scalar_type() != kBool && !scalar.isBoolean(),
"Subtraction, the `-` operator, with a bool tensor is not supported. "
"If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
}
using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
using structured_binary_fn_double = void(*)(TensorIteratorBase&, double);
using structured_binary_fn = void(*)(TensorIteratorBase&);
using binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
using binary_fn_double = void(*)(TensorIterator&, double);
using binary_fn = void(*)(TensorIterator&);
using binary_clamp_fn_alpha =
void(*)(TensorIterator&, const Scalar& alpha, const Scalar& min_val, const Scalar& max_val);
// NB: codegenned
DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub);
DECLARE_DISPATCH(structured_binary_fn_alpha, sub_stub);
DECLARE_DISPATCH(structured_binary_fn, mul_stub);
DECLARE_DISPATCH(structured_binary_fn, div_true_stub);
DECLARE_DISPATCH(structured_binary_fn, div_floor_stub);
DECLARE_DISPATCH(structured_binary_fn, div_trunc_stub);
DECLARE_DISPATCH(structured_binary_fn, atan2_stub);
DECLARE_DISPATCH(structured_binary_fn, remainder_stub);
DECLARE_DISPATCH(structured_binary_fn, bitwise_and_stub);
DECLARE_DISPATCH(structured_binary_fn, bitwise_or_stub);
DECLARE_DISPATCH(structured_binary_fn, bitwise_xor_stub);
DECLARE_DISPATCH(structured_binary_fn, lshift_stub);
DECLARE_DISPATCH(structured_binary_fn, rshift_stub);
DECLARE_DISPATCH(binary_fn, logical_xor_stub);
DECLARE_DISPATCH(binary_fn, logical_and_stub);
DECLARE_DISPATCH(binary_fn, logical_or_stub);
DECLARE_DISPATCH(structured_binary_fn, lt_stub);
DECLARE_DISPATCH(structured_binary_fn, le_stub);
DECLARE_DISPATCH(structured_binary_fn, gt_stub);
DECLARE_DISPATCH(structured_binary_fn, ge_stub);
DECLARE_DISPATCH(structured_binary_fn, eq_stub);
DECLARE_DISPATCH(structured_binary_fn, ne_stub);
DECLARE_DISPATCH(binary_fn, max_elementwise_stub);
DECLARE_DISPATCH(binary_fn, min_elementwise_stub);
DECLARE_DISPATCH(structured_binary_fn, maximum_stub);
DECLARE_DISPATCH(structured_binary_fn, minimum_stub);
DECLARE_DISPATCH(structured_binary_fn, fmax_stub);
DECLARE_DISPATCH(structured_binary_fn, fmin_stub);
DECLARE_DISPATCH(structured_binary_fn_double, smooth_l1_stub);
DECLARE_DISPATCH(binary_fn_double, huber_stub);
DECLARE_DISPATCH(structured_binary_fn, sigmoid_backward_stub);
DECLARE_DISPATCH(binary_fn_alpha, logit_backward_stub);
DECLARE_DISPATCH(structured_binary_fn, tanh_backward_stub);
DECLARE_DISPATCH(structured_binary_fn, mse_stub);
DECLARE_DISPATCH(structured_binary_fn, fmod_stub);
DECLARE_DISPATCH(structured_binary_fn, logaddexp_stub);
DECLARE_DISPATCH(structured_binary_fn, logaddexp2_stub);
DECLARE_DISPATCH(structured_binary_fn, gcd_stub);
DECLARE_DISPATCH(structured_binary_fn, lcm_stub);
DECLARE_DISPATCH(structured_binary_fn, hypot_stub);
DECLARE_DISPATCH(structured_binary_fn, igamma_stub);
DECLARE_DISPATCH(structured_binary_fn, igammac_stub);
DECLARE_DISPATCH(structured_binary_fn, nextafter_stub);
DECLARE_DISPATCH(structured_binary_fn, heaviside_stub);
DECLARE_DISPATCH(structured_binary_fn, copysign_stub);
DECLARE_DISPATCH(structured_binary_fn, xlogy_stub);
DECLARE_DISPATCH(structured_binary_fn, xlog1py_stub);
DECLARE_DISPATCH(structured_binary_fn, zeta_stub);
DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_t_stub);
DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_u_stub);
DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_v_stub);
DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_w_stub);
DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_h_stub);
DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_he_stub);
DECLARE_DISPATCH(structured_binary_fn, laguerre_polynomial_l_stub);
DECLARE_DISPATCH(structured_binary_fn, legendre_polynomial_p_stub);
DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_t_stub);
DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_u_stub);
DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_v_stub);
DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_w_stub);
} // namespace at::native

View File

@ -0,0 +1,173 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/ScalarOps.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/result_type.h>
#endif
namespace at::native {
// original values given by raw_*. If an original value is not contiguous, will make a contiguous copy to
// the corresponding trimmed_* value. Additionally, if the dtypes of the boundary and input tensor do not
// match, will change them to be a common super type so comparisons are done between the same types.
// For any trimmed_* tensor, if its outgoing value matches what it was incoming (typically null), then the
// corresponding raw_* version should be used since it was already contiguous of the right type.
inline void searchsorted_maybe_trim_input_tensors(
Tensor& trimmed_input,
Tensor& trimmed_boundaries,
Tensor& trimmed_sorter,
const Tensor& raw_input,
const Tensor& raw_boundaries,
const Tensor& raw_sorter) {
bool in_is_contiguous = raw_input.is_contiguous();
bool bd_is_contiguous = raw_boundaries.is_contiguous();
bool sort_is_contiguous = raw_sorter.is_contiguous();
if (!in_is_contiguous) {
TORCH_WARN_ONCE("torch.searchsorted(): input value tensor is non-contiguous, this will lower the performance due "
"to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value "
"tensor if possible. This message will only appear once per program.");
trimmed_input = raw_input.contiguous();
}
if (!bd_is_contiguous) {
TORCH_WARN_ONCE("torch.searchsorted(): boundary tensor is non-contiguous, this will lower the performance due "
"to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous boundary "
"tensor if possible. This message will only appear once per program.");
trimmed_boundaries = raw_boundaries.contiguous();
}
if (!sort_is_contiguous) {
TORCH_WARN_ONCE("torch.searchsorted(): sorter tensor is non-contiguous, this will lower the performance due "
"to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous sorter "
"tensor if possible. This message will only appear once per program.");
trimmed_sorter = raw_sorter.contiguous();
}
if (raw_input.dtype() != raw_boundaries.dtype()) {
at::native::ResultTypeState state = {};
state = at::native::update_result_type_state(raw_boundaries, state);
state = at::native::update_result_type_state(raw_input, state);
ScalarType common_stype = at::native::result_type(state);
TORCH_INTERNAL_ASSERT(common_stype != ScalarType::Undefined);
if (common_stype != raw_input.scalar_type()) {
trimmed_input = in_is_contiguous ? raw_input.to(common_stype) : trimmed_input.to(common_stype);
}
if (common_stype != raw_boundaries.scalar_type()) {
trimmed_boundaries = bd_is_contiguous ? raw_boundaries.to(common_stype) : trimmed_boundaries.to(common_stype);
}
}
}
/* unused but needed for internal jagged tensor class */
inline void searchsorted_maybe_trim_input_tensors(
Tensor& trimmed_input,
Tensor& trimmed_boundaries,
const Tensor& raw_input,
const Tensor& raw_boundaries) {
Tensor trimmed_sorter;
Tensor raw_sorter;
return searchsorted_maybe_trim_input_tensors(
trimmed_input,
trimmed_boundaries,
trimmed_sorter,
raw_input,
raw_boundaries,
raw_sorter);
}
inline bool searchsorted_dims_matched_before_last_dim(const Tensor& boundaries, const Tensor& input) {
if (boundaries.dim() != input.dim()) {
return false;
}
const auto& dims_bd = boundaries.sizes();
const auto& dims_in = input.sizes();
for (int64_t dim = 0; dim + 1 < boundaries.dim(); ++dim) {
if (dims_bd[dim] != dims_in[dim]) {
return false;
}
}
return true;
}
inline Tensor searchsorted_scalar_tensor(const Scalar& scalar, const c10::Device& device) {
auto tensor = c10::scalar_to_tensor(scalar, device);
// This is to adopt the scalar promotion rules defined in native/TypeProperties.h
// So we have the same type promotion rules as binary operations.
tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
return tensor;
}
inline void searchsorted_pre_check(
const Tensor& boundaries,
const Tensor& input,
const Tensor& output,
const bool out_int32,
const bool right,
const std::optional<c10::string_view> side_opt,
const Tensor& sorter) {
if (side_opt) {
const c10::string_view side = *side_opt;
TORCH_CHECK(side == "left" || side == "right", "torch.searchsorted(): side can only be 'left' or 'right' but ",
"got ", side);
// assume the user has not explicitly set (right=False, side="right")
TORCH_CHECK(!right || side == "right", "torch.searchsorted(): side and right can't be set to opposites, got side "
"of ", side, " while right was True");
}
TORCH_CHECK(boundaries.device() == input.device(), "torch.searchsorted(): boundaries and input value tensors ",
"should have same device type, but got boundaries tensor device type ", boundaries.device(), " and input value ",
"tensor device type ", input.device());
if (sorter.defined()) {
TORCH_CHECK(sorter.device() == boundaries.device(), "torch.searchsorted(): sorter and boundary tensors should ",
"have same device type, but got sorter tensor device type ", sorter.device(), " and input value tensor ",
"device type ", boundaries.device());
TORCH_CHECK(sorter.sizes() == boundaries.sizes(), "torch.searchsorted(): boundary and sorter must have the same "
"size, but got boundary tensor ", boundaries.sizes(), "and got sorter tensor ", sorter.sizes());
TORCH_CHECK(sorter.scalar_type() == ScalarType::Long, "torch.searchsorted(): sorter must be a tensor of long ",
"dtype but got dtype ", sorter.scalar_type());
if (sorter.numel() > 0) {
auto minmax = sorter.aminmax();
int64_t vmin = std::get<0>(minmax).item().toLong();
int64_t vmax = std::get<1>(minmax).item().toLong();
TORCH_CHECK(vmin >= 0 && vmax < sorter.sizes().back(), "torch.searchsorted(): sorter index out of range");
}
}
TORCH_CHECK(input.dim() > 0 || (input.dim() == 0 && input.numel() == 1 && boundaries.dim() == 1),
"torch.searchsorted(): input value can be a scalar only when boundaries tensor dimension is 1, but we got ",
"boundaries tensor dim(", boundaries.dim(), ") and input value's dim(", input.dim(), ") numel(",
input.numel(), ")");
TORCH_CHECK(boundaries.dim() != 0, "torch.searchsorted(): boundaries tensor should have positive dimension, but ",
"got 0 dimension");
TORCH_CHECK(boundaries.dim() == 1 || searchsorted_dims_matched_before_last_dim(boundaries, input),
"torch.searchsorted(): boundaries tensor should be 1 dimension or the first N-1 dimensions of boundaries tensor ",
"and input value tensor must match, but we got boundaries tensor ", boundaries.sizes(), " and input value tensor ",
input.sizes());
ScalarType output_dtype = output.scalar_type();
TORCH_CHECK(
(output_dtype == ScalarType::Long && !out_int32) ||
(output_dtype == ScalarType::Int && out_int32),
"torch.searchsorted(): output tensor's dtype is wrong, it can only be Int(int32) or Long(int64) depending on ",
"whether out_int32 flag is True, but we got output tensor's dtype ", output_dtype,
" and out_int32 flag is ", (out_int32 ? "True" : "False"));
if (out_int32) {
TORCH_CHECK(boundaries.sizes().back() < INT_MAX,
"torch.searchsorted(): the size of boundaries' last dimension should be less than ", INT_MAX, ", but we got ",
boundaries.sizes().back());
}
}
} // namespace at::native

View File

@ -0,0 +1,226 @@
#pragma once
#include <ATen/OpMathType.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TransposeType.h>
#include <c10/util/complex.h>
#include <c10/core/ScalarType.h>
#include <c10/core/Scalar.h>
namespace at::native::cpublas {
namespace internal {
void normalize_last_dims(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
int64_t *lda, int64_t *ldb, int64_t *ldc);
} // namespace internal
using gemm_fn = void(*)(
at::ScalarType type,
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const Scalar& alpha,
const void *a, int64_t lda,
const void *b, int64_t ldb,
const Scalar& beta,
void *c, int64_t ldc);
DECLARE_DISPATCH(gemm_fn, gemm_stub);
template <typename scalar_t>
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
at::opmath_type<scalar_t> alpha,
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
at::opmath_type<scalar_t> beta,
scalar_t *c, int64_t ldc) {
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
gemm_stub(
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
double alpha,
const double *a, int64_t lda,
const double *b, int64_t ldb,
double beta,
double *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
float alpha,
const float *a, int64_t lda,
const float *b, int64_t ldb,
float beta,
float *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
float alpha,
const at::BFloat16 *a, int64_t lda,
const at::BFloat16 *b, int64_t ldb,
float beta,
at::BFloat16 *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const float alpha,
const at::BFloat16 *a, int64_t lda,
const at::BFloat16 *b, int64_t ldb,
const float beta,
float *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
float alpha,
const at::Half *a, int64_t lda,
const at::Half *b, int64_t ldb,
float beta,
at::Half *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const float alpha,
const at::Half *a, int64_t lda,
const at::Half *b, int64_t ldb,
const float beta,
float *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
c10::complex<double> alpha,
const c10::complex<double> *a, int64_t lda,
const c10::complex<double> *b, int64_t ldb,
c10::complex<double> beta,
c10::complex<double> *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
c10::complex<float> alpha,
const c10::complex<float> *a, int64_t lda,
const c10::complex<float> *b, int64_t ldb,
c10::complex<float> beta,
c10::complex<float> *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
int64_t alpha,
const int64_t *a, int64_t lda,
const int64_t *b, int64_t ldb,
int64_t beta,
int64_t *c, int64_t ldc);
template <typename scalar_t>
void gemm_batched(
TransposeType transa, TransposeType transb,
int64_t batch_size, int64_t m, int64_t n, int64_t k,
scalar_t alpha,
const scalar_t * const *a, int64_t lda,
const scalar_t * const *b, int64_t ldb,
const scalar_t beta,
scalar_t * const *c, int64_t ldc);
template <typename scalar_t>
void gemm_batched_with_stride(
TransposeType transa, TransposeType transb,
int64_t batch_size, int64_t m, int64_t n, int64_t k,
scalar_t alpha,
const scalar_t *a, int64_t lda, int64_t batch_stride_a,
const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
scalar_t beta,
scalar_t *c, int64_t ldc, int64_t batch_stride_c);
using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy);
DECLARE_DISPATCH(axpy_fn, axpy_stub);
template<typename scalar_t>
void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){
if(n == 1)
{
incx = 1;
incy = 1;
}
axpy_stub(
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
n, a, x, incx, y, incy);
}
void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy);
void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy);
void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy);
DECLARE_DISPATCH(copy_fn, copy_stub);
template<typename scalar_t>
void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) {
if(n == 1)
{
incx = 1;
incy = 1;
}
copy_stub(
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
n, x, incx, y, incy);
}
void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy);
void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy);
void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
// Batch-reduce GEMM
// Operates by the following formula:
// C = alpha * SUM(A[i] x B[i]) + beta * C, i = 0 to batch size
// A Base pointer to a tensor A.
// B Base pointer to a tensor B.
// C Pointer to a tensor C (accumulation buffer).
TORCH_API void brgemm(
int64_t M,
int64_t N,
int64_t K,
int64_t ld_a,
int64_t ld_b,
int64_t ld_c,
const float alpha,
const float beta,
const at::Half* A,
const at::Half* B,
float* C);
// Release brgemm hardware context
void brgemm_release();
// Pack B matrix to get better performance if needed
void pack(
int64_t K,
int64_t N,
int64_t ld_in,
int64_t ld_out,
ScalarType dt_in,
ScalarType dt_out,
const void* in,
void* out);
// Whether pack is needed in the platform.
bool need_pack(ScalarType dt_in);
} // namespace at::native::cpublas

View File

@ -0,0 +1,46 @@
#pragma once
#include <ATen/core/ivalue.h>
#include <ATen/core/stack.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/util/Metaprogramming.h>
#include <torch/library.h>
namespace at::native {
// This function implements a boxed fallback to CPU.
// External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false,
c10::DispatchKey cpu_dispatch_key = c10::DispatchKey::CPU);
// This is a helper function that backends can use to directly call their boxed CPU fallback
// TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.
template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, bool symint, class ReturnType, class... ParameterTypes>
struct _call_fallback_fn final {};
template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, bool symint, class ReturnType, class... ParameterTypes>
struct _call_fallback_fn<fallback_fn, Op, symint, ReturnType(ParameterTypes...)> final {
static ReturnType call(typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
auto op = c10::Dispatcher::singleton()
// TODO: figure out how to make compiler happy without dynamic casts
.findSchemaOrThrow((const char*) Op::name, (const char*) Op::overload_name)
//.findSchemaOrThrow("a", "b")
.typed<ReturnType (typename c10::maybe_keep_symint<symint, ParameterTypes>::type...)>();
return c10::impl::BoxedKernelWrapper<ReturnType (typename c10::maybe_keep_symint<symint, ParameterTypes>::type...)>::call(
c10::BoxedKernel::makeFromFunction<fallback_fn>(),
op,
c10::DispatchKeySet(), // we know that the cpu_fallback doesn't use the dispatch keyset.
// TODO: get std::forward<> to work
args...
);
}
};
template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op>
using call_fallback_fn_symint = _call_fallback_fn<fallback_fn, Op, true, typename Op::schema>;
template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op>
using call_fallback_fn = _call_fallback_fn<fallback_fn, Op, false, typename Op::schema>;
} // namespace at::native

View File

@ -0,0 +1,13 @@
#pragma once
#include <c10/macros/Export.h>
#include <limits>
namespace at {
class TensorBase;
}
namespace at::native {
TORCH_API bool canUse32BitIndexMath(const at::TensorBase &t, int64_t max_elem=std::numeric_limits<int32_t>::max());
}

View File

@ -0,0 +1,97 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/view_as_real_native.h>
#include <ATen/ops/view_as_complex_native.h>
#include <utility>
#endif
// WARNING: this header contains non-inline functions and should be only
// included from ONE cpp file
namespace at::native {
// View tensor with new dtype, storage offset, sizes and strides
inline Tensor view_tensor(
const Tensor &tensor, ScalarType dtype,
c10::SymInt offset, SymIntArrayRef sizes, SymIntArrayRef strides) {
Storage storage = tensor.storage();
auto key_set = tensor.key_set().remove(DispatchKey::Conjugate);
auto new_tensor = detail::make_tensor<TensorImpl>(
c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype));
auto * impl = new_tensor.unsafeGetTensorImpl();
impl->set_sizes_and_strides(sizes, strides, offset);
return new_tensor;
}
inline SymDimVector computeStrideForViewAsReal(SymIntArrayRef oldstride) {
SymDimVector res(oldstride.size() + 1);
for (const auto i : c10::irange(oldstride.size())) {
res[i] = oldstride[i] * 2;
}
res.back() = 1;
return res;
}
inline Tensor _view_as_real_physical(const Tensor& self) {
TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors");
auto old_sizes = self.sym_sizes();
SymDimVector new_sizes(old_sizes.size() + 1);
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
// last dimension will always have two elements containing the real and imag vals
new_sizes.back() = 2;
auto new_strides = computeStrideForViewAsReal(self.sym_strides());
auto new_storage_offset = self.sym_storage_offset() * 2;
const auto float_type = c10::toRealValueType(self.scalar_type());
auto real_tensor = view_tensor(self, float_type, std::move(new_storage_offset), new_sizes, new_strides);
return real_tensor;
}
// expects as input a complex tensor and returns back a tensor
// with corresponding real dtype containing the complex values
// in the last two dimensions
Tensor view_as_real(const Tensor& self) {
TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
return _view_as_real_physical(self);
}
inline SymDimVector computeStrideForViewAsComplex(SymIntArrayRef oldstride) {
const auto dim = oldstride.size();
TORCH_CHECK(dim > 0 && oldstride[dim - 1] == 1, "Tensor must have a last dimension with stride 1");
SymDimVector res(dim - 1);
for (const auto i : c10::irange(res.size())) {
TORCH_CHECK(oldstride[i] % 2 == 0, "Tensor must have a stride divisible by 2 for all but last dimension");
res[i] = oldstride[i] / 2;
}
return res;
}
// expects as input a float or double tensor with last dimension of size 2
// and returns back a tensor with corresponding complex dtype
Tensor view_as_complex(const Tensor& self) {
TORCH_CHECK(
self.scalar_type() == kFloat || self.scalar_type() == kDouble || self.scalar_type() == kHalf,
"view_as_complex is only supported for half, float and double tensors, but got a tensor of scalar type: ", self.scalar_type());
auto old_sizes = self.sym_sizes();
TORCH_CHECK(!old_sizes.empty(), "Input tensor must have one or more dimensions");
TORCH_CHECK(old_sizes[old_sizes.size()-1] == 2, "Tensor must have a last dimension of size 2");
SymDimVector new_sizes(old_sizes.begin(), old_sizes.end() - 1);
const auto new_strides = computeStrideForViewAsComplex(self.sym_strides());
const auto complex_type = c10::toComplexType(self.scalar_type());
TORCH_CHECK(self.sym_storage_offset() % 2 == 0, "Tensor must have a storage_offset divisible by 2");
const auto new_storage_offset = self.sym_storage_offset() / 2;
return view_tensor(self, complex_type, new_storage_offset, new_sizes, new_strides);
}
} // namespace at::native

View File

@ -0,0 +1,34 @@
#pragma once
#include <ATen/native/CompositeRandomAccessorCommon.h>
namespace at::native {
struct TupleInfoCPU {
template <typename ...Types>
using tuple = std::tuple<Types...>;
template <typename ...Types>
static constexpr auto tie(Types&... args) noexcept {
return std::tie(args...);
}
};
template <typename KeyAccessor, typename ValueAccessor>
using CompositeRandomAccessorCPU =
CompositeRandomAccessor<KeyAccessor, ValueAccessor, TupleInfoCPU>;
template <typename Values, typename References>
void swap(
references_holder<Values, References> rh1,
references_holder<Values, References> rh2
) {
return std::swap(rh1.data(), rh2.data());
}
template <int N, typename Values, typename References>
auto get(references_holder<Values, References> rh) -> decltype(std::get<N>(rh.data())) {
return std::get<N>(rh.data());
}
} // namespace at::native

View File

@ -0,0 +1,263 @@
#include <utility>
#pragma once
namespace at::native {
namespace {
// operator_brackets_proxy is used in
// CompositeRandomAccessor in place of operator[].
// For some iterators, references returned by operator[]
// could become invalid, operator_brackets_proxy tries to
// resolve that by making accessor[n] to be equivalent to
// *(accessor + n).
template <typename Accessor>
class operator_brackets_proxy {
using reference = typename std::iterator_traits<Accessor>::reference;
using value_type = typename std::iterator_traits<Accessor>::value_type;
public:
C10_HOST_DEVICE
operator_brackets_proxy(Accessor const& accessor)
: accessor(accessor)
{}
C10_HOST_DEVICE
operator reference() {
return *accessor;
}
C10_HOST_DEVICE
reference operator*() {
return *accessor;
}
C10_HOST_DEVICE
operator_brackets_proxy& operator=(value_type const& val) {
*accessor = val;
return *this;
}
private:
Accessor accessor;
};
}
// references_holder is used as a surrogate for the
// references type from std::iterator_traits in CompositeRandomAccessor.
// It is assumed in CompositeRandomAccessor that
// References = tuple<Types&...>,
// Values = tuple<Types...> by default,
// but they could be anything as long as References could be
// cast to Values.
// If you plan to use it with STL, for example, you will need to
// define 'swap` and `get`(aka std::get) methods.
template <typename Values, typename References>
class references_holder {
public:
using values = Values;
using references = References;
C10_HOST_DEVICE
references_holder(references refs)
: refs{std::move(refs)}
{}
C10_HOST_DEVICE
operator references() {
return refs;
}
C10_HOST_DEVICE
operator values() {
return refs;
}
C10_HOST_DEVICE
references_holder& operator=(values vals) {
refs = vals;
return *this;
}
C10_HOST_DEVICE
references& data() {
return refs;
}
protected:
references refs;
};
// CompositeRandomAccessor is essentially a simplified version of
// a random access iterator over two random access iterators.
// TupleInfo should contain a variadic type `tuple`, and a method `tie`,
// which constructs a tuple of references from a variadic list of arguments.
template <typename KeyAccessor, typename ValueAccessor, typename TupleInfo>
class CompositeRandomAccessor {
using self_type = CompositeRandomAccessor<KeyAccessor, ValueAccessor, TupleInfo>;
using key_accessor_value_type =
typename std::iterator_traits<KeyAccessor>::value_type;
using value_accessor_value_type =
typename std::iterator_traits<ValueAccessor>::value_type;
using key_accessor_reference_type =
typename std::iterator_traits<KeyAccessor>::reference;
using value_accessor_reference_type =
typename std::iterator_traits<ValueAccessor>::reference;
using composite_value_type = typename TupleInfo::template tuple<
key_accessor_value_type,
value_accessor_value_type>;
using composite_reference = typename TupleInfo::template tuple<
key_accessor_reference_type,
value_accessor_reference_type>;
public:
using value_type = composite_value_type;
using reference = references_holder<composite_value_type, composite_reference>;
// Note that CompositeRandomAccessor does not hold key and values
// in a specific datastructure, which means that a pointer to a (key, value)
// is not defined. Hence we just use a pointer type of the KeyAccessor.
using pointer = typename std::iterator_traits<KeyAccessor>::pointer;
using difference_type = typename std::iterator_traits<KeyAccessor>::difference_type;
using iterator_category = std::random_access_iterator_tag;
C10_HOST_DEVICE
CompositeRandomAccessor() = default;
C10_HOST_DEVICE
CompositeRandomAccessor(KeyAccessor keys, ValueAccessor values)
: keys(keys), values(values)
{}
// Pointer-like operations {
C10_HOST_DEVICE
reference operator*() const {
return TupleInfo::tie(*keys, *values);
}
// operator->() is supposed to return a pointer type.
// Since CompositeRandomAccessor does not hold pointers to pairs,
// we just return a pointer to a key.
C10_HOST_DEVICE
auto* operator->() const {
return keys.operator->();
}
C10_HOST_DEVICE
reference operator[](difference_type idx) {
return operator_brackets_proxy<self_type>(
CompositeRandomAccessor(keys + idx, values + idx)
);
}
// }
// Prefix/postfix increment/decrement {
C10_HOST_DEVICE
CompositeRandomAccessor& operator++() {
++keys;
++values;
return *this;
}
C10_HOST_DEVICE
CompositeRandomAccessor operator++(int) {
CompositeRandomAccessor copy(*this);
++*this;
return copy;
}
C10_HOST_DEVICE
CompositeRandomAccessor& operator--() {
--keys;
--values;
return *this;
}
C10_HOST_DEVICE
CompositeRandomAccessor operator--(int) {
CompositeRandomAccessor copy(*this);
--*this;
return copy;
}
// }
// Arithmetic operations {
C10_HOST_DEVICE
CompositeRandomAccessor& operator+=(difference_type offset) {
keys += offset;
values += offset;
return *this;
}
C10_HOST_DEVICE
CompositeRandomAccessor operator+(difference_type offset) const {
return CompositeRandomAccessor(keys + offset, values + offset);
}
C10_HOST_DEVICE
friend CompositeRandomAccessor operator+(
difference_type offset,
const CompositeRandomAccessor& accessor
) {
return accessor + offset;
}
C10_HOST_DEVICE
CompositeRandomAccessor& operator-=(difference_type offset) {
keys -= offset;
values -= offset;
return *this;
}
C10_HOST_DEVICE
CompositeRandomAccessor operator-(difference_type offset) const {
return CompositeRandomAccessor(keys - offset, values - offset);
}
C10_HOST_DEVICE
difference_type operator-(const CompositeRandomAccessor& other) const {
return keys - other.keys;
}
// }
// Comparison operators {
C10_HOST_DEVICE
bool operator==(const CompositeRandomAccessor& other) const {
return keys == other.keys;
}
C10_HOST_DEVICE
bool operator!=(const CompositeRandomAccessor& other) const {
return keys != other.keys;
}
C10_HOST_DEVICE
bool operator<(const CompositeRandomAccessor& other) const {
return keys < other.keys;
}
C10_HOST_DEVICE
bool operator<=(const CompositeRandomAccessor& other) const {
return keys <= other.keys;
}
C10_HOST_DEVICE
bool operator>(const CompositeRandomAccessor& other) const {
return keys > other.keys;
}
C10_HOST_DEVICE
bool operator>=(const CompositeRandomAccessor& other) const {
return keys >= other.keys;
}
// }
protected:
KeyAccessor keys;
ValueAccessor values;
};
} // namespace at::native

View File

@ -0,0 +1,449 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/TensorUtils.h>
#include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/native/DispatchStub.h>
#include <c10/util/env.h>
#include <c10/util/irange.h>
#include <utility>
namespace at::native {
using conv_depthwise2d_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, at::IntArrayRef, std::array<bool, 2>);
DECLARE_DISPATCH(conv_depthwise2d_backward_fn, conv_depthwise2d_backward_stub);
using conv_depthwise3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
DECLARE_DISPATCH(conv_depthwise3d_backward_fn, conv_depthwise3d_backward_stub);
using cudnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
DECLARE_DISPATCH(cudnn_convolution_backward_fn, cudnn_convolution_backward_stub);
using mps_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, int64_t, std::array<bool,3>);
DECLARE_DISPATCH(mps_convolution_backward_fn, mps_convolution_backward_stub);
using cudnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
DECLARE_DISPATCH(cudnn_convolution_transpose_backward_fn, cudnn_convolution_transpose_backward_stub);
using miopen_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
DECLARE_DISPATCH(miopen_convolution_backward_fn, miopen_convolution_backward_stub);
using miopen_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
DECLARE_DISPATCH(miopen_convolution_transpose_backward_fn, miopen_convolution_transpose_backward_stub);
using miopen_depthwise_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
DECLARE_DISPATCH(miopen_depthwise_convolution_backward_fn, miopen_depthwise_convolution_backward_stub);
using mkldnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, int64_t, std::array<bool,3>);
DECLARE_DISPATCH(mkldnn_convolution_backward_fn, mkldnn_convolution_backward_stub);
using mkldnn_convolution_transpose_fn = Tensor(*)(const Tensor&, const Tensor&, const std::optional<Tensor>&,
IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t);
DECLARE_DISPATCH(mkldnn_convolution_transpose_fn, mkldnn_convolution_transpose_stub);
using mkldnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, at::IntArrayRef, int64_t, std::array<bool,3>);
DECLARE_DISPATCH(mkldnn_convolution_transpose_backward_fn, mkldnn_convolution_transpose_backward_stub);
using slow_conv_dilated2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
DECLARE_DISPATCH(slow_conv_dilated2d_backward_fn, slow_conv_dilated2d_backward_stub);
using slow_conv_dilated3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
DECLARE_DISPATCH(slow_conv_dilated3d_backward_fn, slow_conv_dilated3d_backward_stub);
using slow_conv_transpose2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
DECLARE_DISPATCH(slow_conv_transpose2d_backward_fn, slow_conv_transpose2d_backward_stub);
using slow_conv_transpose3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
DECLARE_DISPATCH(slow_conv_transpose3d_backward_fn, slow_conv_transpose3d_backward_stub);
namespace {
bool is_cudnnv8_heuristic_mode_b() {
static const bool is_cudnnv8_heuristic_mode_b = c10::utils::check_env("TORCH_CUDNN_USE_HEURISTIC_MODE_B") == true;
return is_cudnnv8_heuristic_mode_b;
}
}
inline bool cudnnv8_enabled_check_debug() {
static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_DISABLED") != true;
static bool cudnnv8_debug = c10::utils::check_env("TORCH_CUDNN_V8_API_DEBUG") == true;
static uint8_t cudnnv8_debugcount = 0;
if (cudnnv8_debug == 1 && cudnnv8_debugcount < 10) {
TORCH_WARN("TORCH_CUDNN_V8_DEBUG ON, V8 ON: ", cudnnv8_flag, " TORCH_CUDNN_USE_HEURISTIC_MODE B: ", is_cudnnv8_heuristic_mode_b());
cudnnv8_debugcount++;
}
return cudnnv8_flag == 1;
}
inline bool cudnnv8_use_heur_mode_b() {
return is_cudnnv8_heuristic_mode_b();
}
// Keep in sync with py::enum_ in Module.cpp
enum class ConvBackend {
CudaDepthwise2d,
CudaDepthwise3d,
Cudnn,
CudnnTranspose,
Empty,
Miopen,
MiopenDepthwise,
MiopenTranspose,
Mkldnn,
MkldnnTranspose,
MkldnnEmpty,
NnpackSpatial,
Overrideable,
Slow2d,
Slow3d,
SlowDilated2d,
SlowDilated3d,
SlowTranspose2d,
SlowTranspose3d,
Winograd3x3Depthwise,
Xnnpack2d,
Mps,
MpsTranspose,
};
// Overload for selecting the convolution backend from the full set of convolution inputs.
// This overload is exposed to python for testing, etc.
TORCH_API ConvBackend select_conv_backend(
const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation,
bool transposed, SymIntArrayRef output_padding, c10::SymInt groups, const at::OptionalSymIntArrayRef bias_sizes_opt);
TORCH_API at::MemoryFormat _determine_backend_memory_format(const Tensor& input,
const Tensor& weight,
const ConvBackend backend);
// ---------------------------------------------------------------------
//
// Math
//
// ---------------------------------------------------------------------
constexpr int input_batch_size_dim = 0; // also grad_input
constexpr int input_channels_dim = 1;
constexpr int output_batch_size_dim = 0; // also grad_output
constexpr int output_channels_dim = 1;
constexpr int weight_output_channels_dim = 0;
constexpr int weight_input_channels_dim = 1;
// Often written as 2 + max_dim (extra dims for batch size and channels)
constexpr int max_dim = 3;
// ---------------------------------------------------------------------
//
// Checking
//
// ---------------------------------------------------------------------
// Used on pad, stride and dilation
static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name)
{
TORCH_CHECK(args.size() <= expected_size,
"Too many ", arg_name, " values (", args.size(), ") supplied, expecting ",
expected_size, " (while checking arguments for ", c, ")");
TORCH_CHECK(args.size() >= expected_size,
"Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ",
expected_size, " (while checking arguments for ", c, ")");
auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;});
if (num_negative_values > 0){
std::stringstream ss;
ss << arg_name << " should be greater than zero but got (";
std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,", "));
ss << args.back() << ")" << " (while checking arguments for " << c << ")";
AT_ERROR(ss.str());
}
}
// NOTE [ Convolution checks ]
//
// NB: For many call sites, it is not strictly necessary to check all of
// these relationships (for example, for forward convolution, we compute
// the size of output ourselves, so we don't actually need to check
// output. However, writing a single function that does everything
// means we get to reuse it for both forwards and all backwards
// variants, even when the set of "real" inputs varies. The magic of
// relational computing!
//
// (There is one downside, which is that it is slightly harder to write
// error messages which are able to distinguish between real inputs
// (which the user can change) and computed inputs (which the user can
// only indirectly affect). It would be an interesting exercise to
// come up with a general framework to handle such situations.)
inline void convolution_shape_check(
CheckedFrom c,
const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
{
check_args(c, padding, input->dim() - 2, "padding");
check_args(c, stride, padding.size(), "stride");
check_args(c, dilation, padding.size(), "dilation");
// Input
checkDimRange(c, input, 3, 6 /* exclusive */);
checkSize_symint(c, input, input_channels_dim, weight->size(1) * groups);
// Weight
checkSameDim(c, input, weight);
// TODO: check that output->size() matches output_sizes
// TODO: check that weight matches output->sizes()
checkSameDim(c, input, output);
}
// NB: conv_output_size and conv_input_size are not bijections,
// as conv_output_size loses information; this is why conv_input_size
// takes an extra output_padding argument to resolve the ambiguity.
template <typename T>
inline std::vector<T> _conv_output_size(
ArrayRef<T> input_size, ArrayRef<T> weight_size,
ArrayRef<T> padding, ArrayRef<T> stride, ArrayRef<T> dilation = ArrayRef<T>()
) {
// ASSERT(input_size.size() > 2)
// ASSERT(input_size.size() == weight_size.size())
bool has_dilation = !dilation.empty();
auto dim = input_size.size();
std::vector<T> output_size(dim);
output_size[0] = input_size[input_batch_size_dim];
output_size[1] = weight_size[weight_output_channels_dim];
for (const auto d : c10::irange(2, dim)) {
auto dilation_ = has_dilation ? dilation[d - 2] : 1;
auto kernel = dilation_ * (weight_size[d] - 1) + 1;
output_size[d] = (input_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1;
}
return output_size;
}
inline std::vector<int64_t> conv_output_size(
IntArrayRef input_size, IntArrayRef weight_size,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef()
) {
return _conv_output_size(input_size, weight_size, padding, stride, dilation);
}
inline std::vector<c10::SymInt> conv_output_size(
SymIntArrayRef input_size, SymIntArrayRef weight_size,
SymIntArrayRef padding, SymIntArrayRef stride, SymIntArrayRef dilation = SymIntArrayRef()
) {
return _conv_output_size(input_size, weight_size, padding, stride, dilation);
}
template <typename T>
std::vector<T> _conv_input_size(
ArrayRef<T> output_size, ArrayRef<T> weight_size,
ArrayRef<T> padding, ArrayRef<T> output_padding, ArrayRef<T> stride, ArrayRef<T> dilation, T groups
) {
// ASSERT(output_size.size() > 2)
// ASSERT(output_size.size() == weight_size.size())
auto dim = output_size.size();
std::vector<T> input_size(dim);
input_size[0] = output_size[output_batch_size_dim];
input_size[1] = weight_size[weight_input_channels_dim] * groups;
for (const auto d : c10::irange(2, dim)) {
auto kernel = (weight_size[d] - 1) * dilation[d - 2] + 1;
input_size[d] = (output_size[d] - 1) * stride[d - 2] - (padding[d - 2] * 2) +
kernel + output_padding[d - 2];
}
return input_size;
}
inline std::vector<c10::SymInt> conv_input_size(
SymIntArrayRef output_size, SymIntArrayRef weight_size,
SymIntArrayRef padding, SymIntArrayRef output_padding, SymIntArrayRef stride, SymIntArrayRef dilation, c10::SymInt groups
) {
return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, std::move(groups));
}
inline std::vector<int64_t> conv_input_size(
IntArrayRef output_size, IntArrayRef weight_size,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
) {
return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups);
}
template <typename T>
std::vector<T> _conv_weight_size(
ArrayRef<T> input_size, ArrayRef<T> output_size,
ArrayRef<T> padding, ArrayRef<T> output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
) {
auto dim = input_size.size();
std::vector<T> weight_size(dim);
weight_size[0] = output_size[1];
weight_size[1] = input_size[1] / groups;
for (const auto d : c10::irange(2, dim)) {
auto kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2]
+ padding[d - 2] * 2 - output_padding[d - 2];
weight_size[d] = (kernel - 1) / dilation[d - 2] + 1;
}
return weight_size;
}
inline std::vector<c10::SymInt> conv_weight_size(
SymIntArrayRef input_size, SymIntArrayRef output_size,
SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
) {
return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
}
inline std::vector<int64_t> conv_weight_size(
IntArrayRef input_size, IntArrayRef output_size,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
) {
return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
}
inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {
std::vector<int64_t> shape(dim, 1);
shape[1] = -1;
return bias.reshape(shape);
}
inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
// disable NHWC for float64 input.
if (!at::detail::getCUDAHooks().compiledWithCuDNN() ||
input.scalar_type() == at::kDouble ||
weight.scalar_type() == at::kDouble) {
return at::MemoryFormat::Contiguous;
}
long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
auto input_memory_format = input.suggest_memory_format();
auto weight_memory_format = weight.suggest_memory_format();
auto weight_ndim = weight.ndimension();
bool can_use_cudnn_channels_last_2d = (cudnn_version >= 7603) && (weight_ndim == 4) && (
(input_memory_format == at::MemoryFormat::ChannelsLast) ||
(weight_memory_format == at::MemoryFormat::ChannelsLast)
);
if (can_use_cudnn_channels_last_2d) {
return at::MemoryFormat::ChannelsLast;
}
bool can_use_cudnn_channels_last_3d = (cudnn_version >= 8005) && (weight_ndim == 5) && (
(input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
(weight_memory_format == at::MemoryFormat::ChannelsLast3d)
);
if (can_use_cudnn_channels_last_3d) {
return at::MemoryFormat::ChannelsLast3d;
}
return at::MemoryFormat::Contiguous;
}
// controls whether emptyCache will be called following cudnn conv benchmarking
TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable);
TORCH_API bool _cudnn_get_conv_benchmark_empty_cache();
inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
// disable NHWC for float64 input.
if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
input.scalar_type() == at::kDouble ||
weight.scalar_type() == at::kDouble) {
return false;
}
bool can_use_miopen_channels_last_2d = false;
// TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
// See #64427
static std::optional<bool> PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC");
auto input_memory_format = input.suggest_memory_format();
auto weight_memory_format = weight.suggest_memory_format();
can_use_miopen_channels_last_2d = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC && (
( (input_memory_format == at::MemoryFormat::ChannelsLast) ||
(weight_memory_format == at::MemoryFormat::ChannelsLast) )
);
bool can_use_miopen_channels_last_3d = false;
return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
}
inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
// disable NHWC for float64 input.
if (input.scalar_type() == at::kDouble ||
weight.scalar_type() == at::kDouble) {
return false;
}
// disable NHWC for MkldnnCPU tensor.
if (input.is_mkldnn() || weight.is_mkldnn()) {
return false;
}
auto input_memory_format = input.suggest_memory_format();
auto weight_memory_format = weight.suggest_memory_format();
bool can_use_mkldnn_channels_last_2d =
(input_memory_format == at::MemoryFormat::ChannelsLast) ||
(weight_memory_format == at::MemoryFormat::ChannelsLast);
bool can_use_mkldnn_channels_last_3d =
(input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
(weight_memory_format == at::MemoryFormat::ChannelsLast3d);
return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d;
}
inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
auto input_memory_format = input.suggest_memory_format();
auto weight_memory_format = weight.suggest_memory_format();
bool can_use_thnn_channels_last_2d = input.device().is_cpu() && (
(input_memory_format == at::MemoryFormat::ChannelsLast) || (
weight_memory_format == at::MemoryFormat::ChannelsLast));
return can_use_thnn_channels_last_2d;
}
inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
// check layout only for xpu tensor.
if (!input.is_xpu() || !weight.is_xpu()) {
return false;
}
// disable NHWC for float64 input.
if (input.scalar_type() == at::kDouble ||
weight.scalar_type() == at::kDouble) {
return false;
}
auto input_memory_format = input.suggest_memory_format();
auto weight_memory_format = weight.suggest_memory_format();
bool can_use_xpu_channels_last_2d =
(input_memory_format == at::MemoryFormat::ChannelsLast) ||
(weight_memory_format == at::MemoryFormat::ChannelsLast);
bool can_use_xpu_channels_last_3d =
(input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
(weight_memory_format == at::MemoryFormat::ChannelsLast3d);
return can_use_xpu_channels_last_2d || can_use_xpu_channels_last_3d;
}
} // namespace at::native

View File

@ -0,0 +1,14 @@
#include <ATen/core/Tensor.h>
namespace at::native {
std::tuple<Tensor, Tensor, Tensor> slow_conv3d_backward_cpu(
const Tensor& grad_output,
const Tensor& self,
const Tensor& weight,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
std::array<bool, 3> output_mask);
} // namespace at::native

View File

@ -0,0 +1,20 @@
#pragma once
#include <ATen/native/DispatchStub.h>
namespace at {
class Tensor;
struct TensorIterator;
class TensorBase;
namespace native {
using copy_fn = void (*)(TensorIterator&, bool non_blocking);
DECLARE_DISPATCH(copy_fn, copy_stub);
TORCH_API void copy_ignoring_overlaps(const TensorBase &dst, const TensorBase &src);
} // namespace native
} // namespace at

View File

@ -0,0 +1,14 @@
#pragma once
#include <ATen/native/DispatchStub.h>
namespace at {
class Tensor;
namespace native {
using cross_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const int64_t d);
DECLARE_DISPATCH(cross_fn, cross_stub);
}} // namespace at::native

View File

@ -0,0 +1,229 @@
#pragma once
#include <algorithm>
#include <vector>
#include <ATen/div_rtn.h>
#include <ATen/core/Tensor.h>
#include <c10/util/irange.h>
#define TORCH_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \
TORCH_CHECK( \
T.dim() == DIM && T.size(DIM_SIZE) == SIZE, \
"Need " #T " of dimension ", \
DIM, \
" and " #T ".size[", \
DIM_SIZE, \
"] == ", \
SIZE, \
" but got input to be of shape ", \
T.sizes())
namespace at::native::internal {
namespace {
inline bool all_positive(IntArrayRef& arr) {
return std::all_of(
arr.begin(), arr.end(), [](int64_t item) { return item > 0; });
}
inline bool all_nonnegative(std::vector<int64_t>& arr) {
return std::all_of(
arr.begin(), arr.end(), [](int64_t item) { return item >= 0; });
}
} // namespace
// calculate the rear part of output tensor sizes
template <int64_t dim>
std::vector<int64_t> get_output_size(
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride_size,
IntArrayRef pad_size,
IntArrayRef dilation_size) {
std::vector<int64_t> sizes;
for (const auto index : c10::irange(dim)) {
sizes.push_back(
div_rtn<int64_t>(
input.size(index + input.dim() - dim) + 2 * pad_size[index] -
(dilation_size[index] * (kernel_size[index] - 1) + 1),
stride_size[index]) +
1);
}
return sizes;
}
// calculate the sizes of output tensor
template <int64_t dim>
std::vector<int64_t> get_output_size(
const Tensor& input,
const Tensor& weight,
IntArrayRef kernel_size,
IntArrayRef stride_size,
IntArrayRef pad_size,
IntArrayRef dilation_size) {
auto output_size = get_output_size<dim>(
input, kernel_size, stride_size, pad_size, dilation_size);
output_size.insert(output_size.begin(), weight.size(0));
if (input.dim() == dim + 2) {
output_size.insert(output_size.begin(), input.size(0));
}
return output_size;
}
/*
slow_conv_dilated_shape_check - check user-input to dilated convolution
forward and backward functions.
*/
template <int64_t dim>
void slow_conv_dilated_shape_check(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
const Tensor& grad_output,
IntArrayRef kernel_size,
IntArrayRef stride_size,
IntArrayRef pad_size,
IntArrayRef dilation_size) {
/*
When the following tensors are defined:
bias, grad_weight, grad_output
then these are assumed to be contiguous without checking
because of these tensors are made contiguous by calling
.contiguous() method or by resizing of zero-sized tensors in
forward/backward functions.
When grad_weight is defined then it is assumed without
checking to have the same shape as weight, see backward
functions.
*/
// Check size arguments
TORCH_CHECK(
kernel_size.size() == dim,
"kernel sizes length should be ",
dim,
", but got ",
kernel_size.size());
TORCH_CHECK(
stride_size.size() == dim,
"strides length should be ",
dim,
", but got ",
stride_size.size());
TORCH_CHECK(
dilation_size.size() == dim,
"dilations length should be ",
dim,
", but got ",
dilation_size.size());
TORCH_CHECK(
pad_size.size() == dim,
"pads length should be ",
dim,
", but got ",
pad_size.size());
TORCH_CHECK(
all_positive(kernel_size),
"kernel size should be greater than zero, but got ",
kernel_size);
TORCH_CHECK(
all_positive(stride_size),
"stride should be greater than zero, but got ",
stride_size);
TORCH_CHECK(
all_positive(dilation_size),
"dilation should be greater than zero, but got ",
dilation_size);
// check input
TORCH_CHECK(input.defined(), "input must be defined");
bool is_batch = input.dim() == dim + 2;
int64_t n = (is_batch ? 2 : 1);
int64_t ndim = n + dim;
if (!is_batch) {
// input dim has to be dim + 1 if not batched
TORCH_CHECK(
input.dim() == dim + 1,
"input must be 4D or 5D tensor but got ",
input.dim(),
"D tensor");
}
// check output sizes
auto output_size = get_output_size<dim>(
input, kernel_size, stride_size, pad_size, dilation_size);
TORCH_CHECK(
all_nonnegative(output_size),
"calculated output size ",
output_size,
" is too small (all sizes must be non-negative)");
// check weight
TORCH_CHECK(weight.defined(), "weight must be defined");
TORCH_CHECK(
weight.dim() == dim + 2,
"weight must be ",
dim + 2,
"D tensor but got ",
weight.dim(),
"D tensor dim=",
dim);
TORCH_CHECK(
weight.sizes().slice(2) == kernel_size,
"weight[2:] shape ",
weight.sizes().slice(2),
" must be equal to kernel_size ",
kernel_size);
TORCH_CHECK_DIM_SIZE(input, input.dim(), (is_batch ? 1 : 0), weight.size(1));
// check bias when present
if (bias.defined()) {
TORCH_CHECK(
bias.dim() == 1,
"bias must be 1D tensor but got ",
bias.dim(),
"D tensor");
TORCH_CHECK_DIM_SIZE(bias, 1, 0, weight.size(0));
}
// check grad_output when present
if (grad_output.defined()) {
TORCH_CHECK(
grad_output.dim() == ndim,
"grad_output must be ",
ndim,
"D tensor but got ",
grad_output.dim(),
"D tensor");
if (is_batch) {
TORCH_CHECK(
grad_output.size(0) == input.size(0),
"grad_output.size(0)=",
grad_output.size(0),
" must be input.size(0)=",
input.size(0));
}
TORCH_CHECK(
grad_output.size(n - 1) == weight.size(0),
"grad_output.size(",
n - 1,
")=",
grad_output.size(n - 1),
" must be weight.size(0)=",
weight.size(0));
TORCH_CHECK(
grad_output.sizes().slice(n) == output_size,
"grad_output[",
n,
":] shape",
grad_output.sizes().slice(n),
" must be equal to output size ",
output_size);
}
}
} // namespace at::native::internal

View File

@ -0,0 +1,444 @@
#pragma once
#include <c10/core/DeviceType.h>
#include <c10/macros/Macros.h>
#include <c10/util/Array.h>
#include <atomic>
#include <utility>
#include <variant>
// Implements instruction set specific function dispatch.
//
// Kernels that may make use of specialized instruction sets (e.g. AVX2) are
// compiled multiple times with different compiler flags (e.g. -mavx2). A
// DispatchStub contains a table of function pointers for a kernel. At runtime,
// the fastest available kernel is chosen based on the features reported by
// cpuinfo.
//
// Example:
//
// In native/MyKernel.h:
// using fn_type = void(*)(const Tensor& x);
// DECLARE_DISPATCH(fn_type, stub);
//
// In native/MyKernel.cpp
// DEFINE_DISPATCH(stub);
//
// In native/cpu/MyKernel.cpp:
// namespace {
// // use anonymous namespace so that different cpu versions won't conflict
// void kernel(const Tensor& x) { ... }
// }
// REGISTER_DISPATCH(stub, &kernel);
//
// To call:
// stub(kCPU, tensor);
//
// TODO: CPU instruction set selection should be folded into whatever
// the main dispatch mechanism is.
//
// Supported device types for registration:
// - CPU: Central Processing Unit
// - CUDA: NVIDIA GPUs
// - HIP: AMD GPUs
// - MPS: Apple Silicon GPUs (Metal Performance Shaders)
// - MTIA: Meta Training and Inference Devices
// - XPU: Intel GPUs
// - PrivateUse1: Reserved for private/custom device types
//
// If you want to update the list of supported devices, add a new dispatch_ptr
// member in DispatchStubImpl.h and update the get_call_ptr switch.
// As well you will need to update the inlined list in 'is_device_supported`
//
//
// ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
C10_CLANG_DIAGNOSTIC_PUSH()
C10_CLANG_DIAGNOSTIC_IGNORE("-Wundefined-var-template")
namespace at::native {
enum class CPUCapability {
DEFAULT = 0,
#if defined(HAVE_VSX_CPU_DEFINITION)
VSX = 1,
#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
ZVECTOR = 1,
#else
AVX2 = 1,
AVX512 = 2,
#endif
NUM_OPTIONS
};
// Enum for error types
enum class ErrorType {
MissingDeviceKernel,
DeviceNotSupported
};
// Alias for the return type using std::variant
using DispatchResult = std::variant<void*, ErrorType>;
CPUCapability get_cpu_capability();
template <typename FnPtr, typename T>
struct DispatchStub;
/**
* The sole purpose of this class is to outline methods that don't need to be
* specialized or otherwise inlined and duplicated (by the compiler due to
* template expansion), since it causes size bloat if there are a significant
* number of specialization of the DispatchStub<> class.
*/
struct TORCH_API DispatchStubImpl {
// The DispatchStubImpl::try_get_call_ptr() method is used to get the call
// pointer for a given device type. If the call pointer is not found,
// DispatchStubImpl::try_get_call_ptr() returns an ErrorType.
// The main difference between try_get_call_ptr() and get_call_ptr() is that
// try_get_call_ptr() will return the ErrorType and not raise an exception.
DispatchResult try_get_call_ptr(
c10::DeviceType device_type
, void *DEFAULT
#ifdef HAVE_AVX512_CPU_DEFINITION
, void *AVX512
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, void *AVX2
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, void *VSX
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
);
// Analogous to try_get_call_ptr(), but it will return the ErrorType and not
// raise an exception.
DispatchResult try_choose_cpu_impl(
void *DEFAULT
#ifdef HAVE_AVX512_CPU_DEFINITION
, void *AVX512
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, void *AVX2
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, void *VSX
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
);
void* get_call_ptr(
c10::DeviceType device_type
, void *DEFAULT
#ifdef HAVE_AVX512_CPU_DEFINITION
, void *AVX512
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, void *AVX2
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, void *VSX
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
);
/**
* The CPU Dispatch actual method is chosen in decreasing order of preference by
* DispatchStubImpl::choose_cpu_impl() in case none is found by
* DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr.
*/
void* choose_cpu_impl(
void *DEFAULT
#ifdef HAVE_AVX512_CPU_DEFINITION
, void *AVX512
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, void *AVX2
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, void *VSX
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, void *ZVECTOR
#endif
);
// Fixing dispatch error in Windows debug builds.
// See https://github.com/pytorch/pytorch/issues/22681 for more details.
#if defined(_MSC_VER) && defined(_DEBUG)
std::atomic<void*> cpu_dispatch_ptr;
void* cuda_dispatch_ptr;
void* hip_dispatch_ptr;
void* mps_dispatch_ptr;
void* mtia_dispatch_ptr;
#if defined(USE_XPU)
void* xpu_dispatch_ptr;
#endif
void* privateuse1_dispatch_ptr;
#else
std::atomic<void*> cpu_dispatch_ptr{nullptr};
void* cuda_dispatch_ptr = nullptr;
void* hip_dispatch_ptr = nullptr;
void* mps_dispatch_ptr = nullptr;
void* mtia_dispatch_ptr = nullptr;
#if defined(USE_XPU)
void* xpu_dispatch_ptr = nullptr;
#endif
void* privateuse1_dispatch_ptr = nullptr;
#endif
};
template <typename rT, typename T, typename... Args>
struct DispatchStub<rT (*)(Args...), T> {
using FnPtr = rT (*) (Args...);
DispatchStub() = default;
DispatchStub(const DispatchStub&) = delete;
DispatchStub& operator=(const DispatchStub&) = delete;
private:
FnPtr get_call_ptr(const c10::DeviceType device_type) {
return reinterpret_cast<FnPtr>(
impl.get_call_ptr(device_type
, reinterpret_cast<void*>(DEFAULT)
#ifdef HAVE_AVX512_CPU_DEFINITION
, reinterpret_cast<void*>(AVX512)
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, reinterpret_cast<void*>(AVX2)
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, reinterpret_cast<void*>(VSX)
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, reinterpret_cast<void*>(ZVECTOR)
#endif
)
);
}
public:
template <typename... ArgTypes>
rT operator()(c10::DeviceType device_type, ArgTypes&&... args) {
FnPtr call_ptr = get_call_ptr(device_type);
return (*call_ptr)(std::forward<ArgTypes>(args)...);
}
void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
#if defined(USE_XPU)
void set_xpu_dispatch_ptr(FnPtr fn_ptr){
impl.xpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
#endif
void set_hip_dispatch_ptr(FnPtr fn_ptr) {
impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
void set_mps_dispatch_ptr(FnPtr fn_ptr) {
impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
void set_mtia_dispatch_ptr(FnPtr fn_ptr) {
impl.mtia_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
void set_privateuse1_dispatch_ptr(FnPtr fn_ptr) {
impl.privateuse1_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
// Returns true if the dispatcher has a kernel registered for this device
// type.
bool is_device_supported(const c10::DeviceType device_type) {
auto result = impl.try_get_call_ptr(device_type
, reinterpret_cast<void*>(DEFAULT)
#ifdef HAVE_AVX512_CPU_DEFINITION
, reinterpret_cast<void*>(AVX512)
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, reinterpret_cast<void*>(AVX2)
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, reinterpret_cast<void*>(VSX)
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, reinterpret_cast<void*>(ZVECTOR)
#endif
);
if (std::holds_alternative<ErrorType>(result)){
return false;
}
return true;
};
static TORCH_API FnPtr DEFAULT;
#ifdef HAVE_AVX512_CPU_DEFINITION
static TORCH_API FnPtr AVX512;
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
static TORCH_API FnPtr AVX2;
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
static TORCH_API FnPtr VSX;
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
static TORCH_API FnPtr ZVECTOR;
#endif
private:
DispatchStubImpl impl;
};
namespace {
template <typename DispatchStub>
struct RegisterCUDADispatch {
RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
stub.set_cuda_dispatch_ptr(value);
}
};
template <typename DispatchStub>
struct RegisterXPUDispatch {
RegisterXPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){
stub.set_xpu_dispatch_ptr(value);
}
};
template <typename DispatchStub>
struct RegisterMPSDispatch {
RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
stub.set_mps_dispatch_ptr(value);
}
};
template <typename DispatchStub>
struct RegisterHIPDispatch {
RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
// TODO: make this point at hip_dispatch_ptr
stub.set_cuda_dispatch_ptr(value);
}
};
template <typename DispatchStub>
struct RegisterMTIADispatch {
RegisterMTIADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
stub.set_mtia_dispatch_ptr(value);
}
};
template <typename DispatchStub>
struct RegisterPRIVATEUSE1Dispatch {
RegisterPRIVATEUSE1Dispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
stub.set_privateuse1_dispatch_ptr(value);
}
};
} // anonymous namespace
// Compiler will complain if you put things like std::tuple<Tensor, Tensor> in
// the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g.,
// adding parentheses and using helper struct to get rid of the parentheses, do
// not work with MSVC. So do a `using`-declaration if you need to pass in such
// `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h.
#define DECLARE_DISPATCH(fn, name) \
struct name##_DECLARE_DISPATCH_type : DispatchStub<fn, name##_DECLARE_DISPATCH_type> { \
name##_DECLARE_DISPATCH_type() = default; \
name##_DECLARE_DISPATCH_type(const name##_DECLARE_DISPATCH_type&) = delete; \
name##_DECLARE_DISPATCH_type& operator=(const name##_DECLARE_DISPATCH_type&) = delete; \
}; \
extern TORCH_API struct name##_DECLARE_DISPATCH_type name;
#define DEFINE_DISPATCH(name) struct name##_DECLARE_DISPATCH_type name
#define REGISTER_ARCH_DISPATCH(name, arch, fn) \
template <> name##_DECLARE_DISPATCH_type::FnPtr TORCH_API DispatchStub<name##_DECLARE_DISPATCH_type::FnPtr, struct name##_DECLARE_DISPATCH_type>::arch = fn;
#ifdef HAVE_AVX512_CPU_DEFINITION
#define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn)
#else
#define REGISTER_AVX512_DISPATCH(name, fn)
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
#define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn)
#else
#define REGISTER_AVX2_DISPATCH(name, fn)
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
#define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn)
#else
#define REGISTER_VSX_DISPATCH(name, fn)
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
#define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn)
#else
#define REGISTER_ZVECTOR_DISPATCH(name, fn)
#endif
// Macro to register the same kernel for all CPU arch types. This is useful
// if a kernel does not benefit from being recompiled across different arch types.
#define REGISTER_ALL_CPU_DISPATCH(name, fn) \
REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \
REGISTER_AVX512_DISPATCH(name, fn) \
REGISTER_AVX2_DISPATCH(name, fn) \
REGISTER_VSX_DISPATCH(name, fn) \
REGISTER_ZVECTOR_DISPATCH(name, fn)
#define REGISTER_NO_CPU_DISPATCH(name) \
REGISTER_ALL_CPU_DISPATCH(name, nullptr)
#define REGISTER_CUDA_DISPATCH(name, fn) \
static RegisterCUDADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
#define REGISTER_XPU_DISPATCH(name, fn) \
static RegisterXPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
#define REGISTER_HIP_DISPATCH(name, fn) \
static RegisterHIPDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
#define REGISTER_MPS_DISPATCH(name, fn) \
static RegisterMPSDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
#define REGISTER_MTIA_DISPATCH(name, fn) \
static RegisterMTIADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
#define REGISTER_PRIVATEUSE1_DISPATCH(name, fn) \
static RegisterPRIVATEUSE1Dispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
// NB: This macro must be used in an actual 'cu' file; if you try using
// it from a 'cpp' file it will not work!
#if defined(__CUDACC__)
#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
#elif defined(__HIPCC__)
// TODO: cut this over to HIP dispatch once we stop pretending that CUDA
// is HIP in the PyTorch HIPify build.
#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
// #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
#elif defined(__OBJC__) && defined(USE_MPS)
// NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel
#define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn)
#elif defined(CPU_CAPABILITY)
// REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches.
// ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others.
#ifdef CPU_CAPABILITY_AVX512
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr))
#else
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#endif
#define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#endif
} // namespace at::native
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -0,0 +1,20 @@
#pragma once
#include <ATen/native/DispatchStub.h>
namespace at {
class Tensor;
namespace native {
using pdist_forward_fn = void(*)(Tensor&, const Tensor&, const double p);
using pdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&);
using cdist_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p);
using cdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&);
DECLARE_DISPATCH(pdist_forward_fn, pdist_forward_stub);
DECLARE_DISPATCH(pdist_backward_fn, pdist_backward_stub);
DECLARE_DISPATCH(cdist_fn, cdist_stub);
DECLARE_DISPATCH(cdist_backward_fn, cdist_backward_stub);
}} // namespace at::native

View File

@ -0,0 +1,394 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/Generator.h>
#include <ATen/ExpandUtils.h>
#include <ATen/Tensor.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TensorIterator.h>
#include <cmath>
#include <limits>
#include <optional>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty_like.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/full.h>
#include <ATen/ops/view_as_real.h>
#endif
namespace at::native::templates {
// ==================================================== Random ========================================================
// The purpose of `update_from` and `update_to` is to find the closest valid int64_t number that can be used as actual `from`.
// The current implementation of `random_` uses uint64_t arithmetics and casts the result to the target dtype(scalar_t).
// This casting can result in generating numbers that happen to be greater or equal to `to` value. For instance:
//
// auto actual = torch::empty({3, 3}, torch::half);
// actual.random_(0, 65504);
//
// If random's uint64_t arithmetics produces 65503 as a random value after casting to torch::half it becomes 65504
// and violates the requirement that random value must be less than `to`. To resolve this issue `update_from` and `update_to`
// moves `from` to the right and `to` to the left to the next closest value that won't go outside [from, to) after casting to
// the target dtype. For `to` = 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous
// available number for torch::half dtype.
template<typename scalar_t>
int64_t update_from(int64_t from) {
static_assert(
std::is_floating_point<scalar_t>::value ||
std::is_same<scalar_t, at::Half>::value ||
std::is_same<scalar_t, at::BFloat16>::value, "scalar_t must be floating-point type");
const auto from_plus_1 = static_cast<int64_t>(static_cast<scalar_t>(from + 1));
if (from_plus_1 < from) {
int64_t from_ = std::abs(from + 1);
int n = 0;
while (from_ >>= 1) ++n;
// NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
from = from_plus_1 + (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));
}
return from;
}
template<typename scalar_t>
int64_t update_to(int64_t to) {
static_assert(
std::is_floating_point<scalar_t>::value ||
std::is_same<scalar_t, at::Half>::value ||
std::is_same<scalar_t, at::BFloat16>::value, "scalar_t must be floating-point type");
const auto to_minus_1 = static_cast<int64_t>(static_cast<scalar_t>(to - 1));
if (to_minus_1 >= to) {
int64_t to_ = std::abs(to - 1);
int n = 0;
while (to_ >>= 1) ++n;
// NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
to = to_minus_1 - (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));
}
return to;
}
// Return earlier for not invoking kernel.
// See https://github.com/pytorch/pytorch/issues/103418 for more details
#define CHECK_EMPTY_AND_RETURN(tensor) \
if (tensor.numel() == 0) { \
return tensor; \
}
template<template<typename> class random_kernel, typename RNG>
at::Tensor& random_impl(at::Tensor& self, std::optional<Generator> generator) {
CHECK_EMPTY_AND_RETURN(self);
auto iter = at::TensorIterator::borrowing_nullary_op(self);
random_kernel<RNG>()(iter, generator);
return self;
}
#define CHECK_OUT_OF_BOUNDS(var, name, min, max, dtype) \
TORCH_CHECK(var >= min && var <= max, name , " is out of bounds for ", dtype); \
#define WARN_OUT_OF_BOUNDS(var, name, digits, dtype) \
if (var < -(1LL << digits) || var > (1LL << digits)) { \
TORCH_WARN(name , " is out of bounds [-(2^", digits, "), 2^", digits, "]. ", \
"Due to precision limitations ", dtype, " can support discrete uniform distribution only within this range. ", \
"This warning will become an error in version 1.7 release, please fix the code in advance"); \
}
inline void check_from_to_in_range(int64_t from, int64_t to_inc, caffe2::TypeMeta dtype) {
const auto scalar_type = typeMetaToScalarType(dtype);
if (isFloatingType(scalar_type)) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "check_random_fp_bounds", [&] {
const auto min = static_cast<double>(std::numeric_limits<scalar_t>::lowest());
const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
constexpr auto digits = std::numeric_limits<scalar_t>::digits;
WARN_OUT_OF_BOUNDS(from, "from", digits, dtype);
WARN_OUT_OF_BOUNDS(to_inc, "to - 1", digits, dtype);
});
} else if (scalar_type == kUInt64) {
// When you do a comparison between int64_t and uint64_t, the usual
// arithmetic conversions say that the int64_t value is promoted to
// unsigned. But this conversion wraps around: if I had -1 as my int64_t,
// then it will promote to 0xFFFFFFFFFFFFFFFF in uint64_t. This is never
// the right thing to do.
CHECK_OUT_OF_BOUNDS(from, "from", 0, INT64_MAX, dtype);
CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", 0, INT64_MAX, dtype);
} else if (isIntegralType(scalar_type, /*includeBool=*/true)) {
AT_DISPATCH_V2(scalar_type, "check_random_integral_bounds", AT_WRAP([&]() {
const auto min = static_cast<int64_t>(std::numeric_limits<scalar_t>::lowest());
const auto max = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
}), AT_EXPAND(AT_INTEGRAL_TYPES), kUInt16, kUInt32, kBool);
} else {
TORCH_CHECK(false, "check_random_bounds handles only integral, floating-point and boolean types");
}
}
template<template<typename> class random_from_to_kernel, typename RNG>
at::Tensor& random_from_to_impl(at::Tensor& self, int64_t from, std::optional<int64_t> to_opt, std::optional<Generator> generator) {
uint64_t range = 0;
auto iter = at::TensorIterator::borrowing_nullary_op(self);
if (to_opt.has_value()) {
// [from, to)
int64_t to = *to_opt;
TORCH_CHECK(from < to, "random_ expects 'from' to be less than 'to', but got from=", from, " >= to=", to);
if (isFloatingType(iter.dtype())) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_update_from_to", [&] {
from = update_from<scalar_t>(from);
to = update_to<scalar_t>(to);
TORCH_CHECK(from < to, "random_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=", from, " >= to=", to);
});
}
check_from_to_in_range(from, to - 1, self.dtype());
CHECK_EMPTY_AND_RETURN(self);
range = static_cast<uint64_t>(to) - static_cast<uint64_t>(from);
random_from_to_kernel<RNG>()(iter, range, from, generator);
} else if (from != std::numeric_limits<int64_t>::lowest()) {
// [from, std::numeric_limits<int64_t>::max()]
int64_t to_inc = 0;
if (isFloatingType(iter.dtype())) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_from_to_range_calc", [&] {
constexpr int64_t scalar_t_max = static_cast<int64_t>(1) << std::numeric_limits<scalar_t>::digits;
to_inc = scalar_t_max > std::numeric_limits<int64_t>::max() ? std::numeric_limits<int64_t>::max() : static_cast<int64_t>(scalar_t_max);
from = update_from<scalar_t>(from);
TORCH_CHECK(from < to_inc, "random_ expects 'from' casted to dtype to be less than or equal to 'to_inc' casted to dtype, but got from=", from, " > to_inc=", to_inc);
});
} else if (isIntegralType(iter.dtype(), /*includeBool=*/true)) {
AT_DISPATCH_V2(self.scalar_type(), "random_from_to_range_calc", AT_WRAP([&] {
if constexpr (std::is_same_v<scalar_t, bool>) {
to_inc = static_cast<int64_t>(true);
} else {
to_inc = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
}
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), kBool);
} else {
TORCH_CHECK(false, "random_from_to_impl handles only integral, floating-point and boolean types");
}
check_from_to_in_range(from, to_inc, self.dtype());
CHECK_EMPTY_AND_RETURN(self);
range = static_cast<uint64_t>(to_inc) - static_cast<uint64_t>(from) + 1;
random_from_to_kernel<RNG>()(iter, range, from, generator);
} else {
// [std::numeric_limits<int64_t>::lowest(), std::numeric_limits<int64_t>::max()]
// range = 2^64
CHECK_EMPTY_AND_RETURN(self);
random_from_to_kernel<RNG>()(iter, generator);
}
return self;
}
// ==================================================== Normal ========================================================
#define CHECK_NORMAL_TENSOR_STD(std) \
do { \
TORCH_CHECK( \
!std.is_complex(), \
"normal expects standard deviation to be non-complex"); \
TORCH_CHECK( \
std.numel() == 0 || std.is_meta() || std.min().ge(0).item<bool>(), \
"normal expects all elements of std >= 0.0"); \
} while (0)
#define CHECK_NORMAL_STD(std) \
TORCH_CHECK(std >= 0.0, "normal expects std >= 0.0, but found std ", std);
template<template<typename> class normal_kernel, typename RNG>
Tensor& normal_impl_(Tensor& self, double mean, double std, std::optional<Generator> gen) {
CHECK_NORMAL_STD(std);
CHECK_EMPTY_AND_RETURN(self);
if (self.is_complex()) {
auto float_tensor = at::view_as_real(self);
// variance for normal distribution of the real and imaginary values
// is half of the input variance
normal_kernel<RNG>()(float_tensor, mean, std/(std::sqrt(2)), gen);
} else {
normal_kernel<RNG>()(self, mean, std, gen);
}
return self;
}
template<template<typename> class normal_kernel, typename RNG>
Tensor& normal_out_impl(Tensor& output, const Tensor& mean, double std, std::optional<Generator> gen) {
CHECK_NORMAL_STD(std);
auto std_tensor = at::empty_like(output, MemoryFormat::Contiguous);
auto shape = at::infer_size(mean.sizes(), std_tensor.sizes());
at::native::resize_output(output, shape);
normal_impl_<normal_kernel, RNG>(output, 0, std, gen);
output.add_(mean);
return output;
}
template<template<typename> class normal_kernel, typename RNG>
Tensor& normal_out_impl(Tensor& output, double mean, const Tensor& std, std::optional<Generator> gen) {
CHECK_NORMAL_TENSOR_STD(std);
auto mean_tensor = at::full({}, mean, output.options());
auto shape = at::infer_size(mean_tensor.sizes(), std.sizes());
at::native::resize_output(output, shape);
normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
// CUDA NB: addcmul_out copies the tensor to be added into the output.
// The previous function here was addcmul_out(output, mean_tensor, output, std, 1);
// The third argument is not a constant reference and hence the samples in output are overwritten.
// Consequently, the computation performed is mean_tensor + mean_tensor * std instead of mean_tensor + output * std
output.mul_(std).add_(mean_tensor);
return output;
}
template<template<typename> class normal_kernel, typename RNG>
Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, std::optional<Generator> gen) {
CHECK_NORMAL_TENSOR_STD(std);
auto shape = at::infer_size(mean.sizes(), std.sizes());
at::native::resize_output(output, shape);
normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
// CUDA NB: addcmul_out copies the tensor to be added into the output.
// The previous function here was addcmul_out(output, mean, output, std, 1);
// The third argument is not a constant reference and hence the samples in output are overwritten.
// Consequently, the computation performed is mean + mean * std instead of mean + output * std
output.mul_(std).add_(mean);
return output;
}
template<template<typename> class normal_kernel, typename RNG>
Tensor normal_impl(const Tensor& mean, double std, std::optional<Generator> gen) {
CHECK_NORMAL_STD(std);
Tensor ret = at::empty_like(mean, MemoryFormat::Contiguous);
normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
return ret;
}
template<template<typename> class normal_kernel, typename RNG>
Tensor normal_impl(double mean, const Tensor& std, std::optional<Generator> gen) {
CHECK_NORMAL_TENSOR_STD(std);
Tensor ret = at::empty_like(std, MemoryFormat::Contiguous);
normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
return ret;
}
template<template<typename> class normal_kernel, typename RNG>
Tensor normal_impl(const Tensor& mean, const Tensor& std, std::optional<Generator> gen) {
CHECK_NORMAL_TENSOR_STD(std);
auto shape = at::infer_size(mean.sizes(), std.sizes());
Tensor ret = at::empty(shape, mean.options(), MemoryFormat::Contiguous);
normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
return ret;
}
// ==================================================== Uniform =======================================================
template<template<typename> class uniform_kernel, typename RNG>
at::Tensor& uniform_impl_(at::Tensor& self, double from, double to, std::optional<Generator> generator) {
if (self.is_complex()) {
CHECK_EMPTY_AND_RETURN(self);
auto float_tensor = at::view_as_real(self);
uniform_impl_<uniform_kernel, RNG>(float_tensor, from, to, generator);
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "check_uniform_bounds", [&] {
[[maybe_unused]] const auto dtype = self.dtype();
const auto min = static_cast<double>(std::numeric_limits<scalar_t>::lowest());
const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
CHECK_OUT_OF_BOUNDS(to, "to", min, max, dtype);
TORCH_CHECK(from <= to, "uniform_ expects to return a [from, to) range, but found from=", from, " > to=", to);
TORCH_CHECK((to - from) <= std::numeric_limits<scalar_t>::max(),
"uniform_ expects to-from <= std::numeric_limits<", toString(self.scalar_type()),
">::max(), but found to=", to, " and from=", from,
" which result in to-from to exceed the limit");
from = std::min(std::max(from, min), max);
to = std::max(std::min(to, max), min);
});
CHECK_EMPTY_AND_RETURN(self);
auto iter = at::TensorIterator::borrowing_nullary_op(self);
uniform_kernel<RNG>()(iter, from, to, generator);
}
return self;
}
// ================================================== LogNormal =======================================================
template<template<typename> class log_normal_kernel, typename RNG>
at::Tensor& log_normal_impl_(at::Tensor& self, double mean, double std, std::optional<Generator> gen) {
TORCH_CHECK(std > 0.0, "log_normal_ expects std > 0.0, but found std=", std);
CHECK_EMPTY_AND_RETURN(self);
auto iter = TensorIterator::borrowing_nullary_op(self);
log_normal_kernel<RNG>()(iter, mean, std, gen);
return self;
}
// =================================================== Geometric ======================================================
template<template<typename> class geometric_kernel, typename RNG>
Tensor& geometric_impl_(Tensor& self, double p, std::optional<Generator> gen) {
TORCH_CHECK(0 < p && p < 1, "geometric_ expects p to be in (0, 1), but got p=", p);
CHECK_EMPTY_AND_RETURN(self);
auto iter = TensorIterator::borrowing_nullary_op(self);
geometric_kernel<RNG>()(iter, p, gen);
return self;
}
// ================================================== Exponential =====================================================
template<template<typename> class exponential_kernel, typename RNG>
Tensor& exponential_impl_(Tensor& self, double lambda, std::optional<Generator> gen) {
TORCH_CHECK(lambda > 0.0, "exponential_ expects lambda > 0.0, but found lambda=", lambda);
CHECK_EMPTY_AND_RETURN(self);
auto iter = TensorIterator::borrowing_nullary_op(self);
exponential_kernel<RNG>()(iter, lambda, gen);
return self;
}
// ==================================================== Cauchy ========================================================
template<template<typename> class cauchy_kernel, typename RNG>
Tensor& cauchy_impl_(Tensor& self, double median, double sigma, std::optional<Generator> gen) {
// TODO: instead of variable name 'sigma', use 'gamma' or 'scale'
// the variance, squared sigma, is undefined for cauchy distribution
TORCH_CHECK(sigma > 0.0, "cauchy_ expects sigma > 0.0, but found sigma=", sigma);
TORCH_CHECK(at::isFloatingType(self.scalar_type()), "Cauchy distribution is a continuous probability distribution. dtype must be a floating point but you specified ", self.dtype());
CHECK_EMPTY_AND_RETURN(self);
auto iter = TensorIterator::borrowing_nullary_op(self);
cauchy_kernel<RNG>()(iter, median, sigma, gen);
return self;
}
// ==================================================== Bernoulli =====================================================
template<template<typename> class bernoulli_tensor_kernel, typename RNG>
Tensor& bernoulli_impl_(Tensor& self, const Tensor& p_, std::optional<Generator> gen) {
CHECK_EMPTY_AND_RETURN(self);
NoNamesGuard guard;
at::assert_no_internal_overlap(self);
bernoulli_tensor_kernel<RNG>()(self, p_, gen);
return self;
}
template<template<typename> class bernoulli_scalar_kernel, typename RNG>
Tensor& bernoulli_impl_(Tensor& self, double p, std::optional<Generator> gen) {
TORCH_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
CHECK_EMPTY_AND_RETURN(self);
at::assert_no_internal_overlap(self);
bernoulli_scalar_kernel<RNG>()(self, p, gen);
return self;
}
template<template<typename> class bernoulli_tensor_kernel, typename RNG>
Tensor& bernoulli_out_impl(Tensor& result, const Tensor& self, std::optional<Generator> gen) {
// result.resize_as_(self) requires self to have same dtype as result, so we
// use resize_ instead.
// TODO: Fix resize_as_. See pytorch/pytorch#11665.
result.resize_(self.sizes());
bernoulli_impl_<bernoulli_tensor_kernel, RNG>(result, self, gen);
namedinference::propagate_names(result, self);
return result;
}
#undef CHECK_OUT_OF_BOUNDS
#undef WARN_OUT_OF_BOUNDS
} // namespace at::native::templates

View File

@ -0,0 +1,518 @@
#pragma once
#include <ATen/native/Math.h>
#include <c10/macros/Macros.h>
#include <c10/util/MathConstants.h>
// ROCM hcc doesn't work well with using std:: in kernel functions
#if defined(__CUDA_ARCH__)
#include <c10/cuda/CUDAMathCompat.h>
#define compat_exp c10::cuda::compat::exp
#define compat_ceil c10::cuda::compat::ceil
#define compat_floor c10::cuda::compat::floor
#define compat_log c10::cuda::compat::log
#define compat_pow c10::cuda::compat::pow
#define compat_sqrt c10::cuda::compat::sqrt
#define compat_tan c10::cuda::compat::tan
#define compat_abs c10::cuda::compat::abs
#define compat_log1p c10::cuda::compat::log1p
#elif defined(__HIPCC__)
#include <c10/hip/HIPMathCompat.h>
#define compat_exp c10::hip::compat::exp
#define compat_ceil c10::hip::compat::ceil
#define compat_floor c10::hip::compat::floor
#define compat_log c10::hip::compat::log
#define compat_pow c10::hip::compat::pow
#define compat_sqrt c10::hip::compat::sqrt
#define compat_tan c10::hip::compat::tan
#define compat_abs c10::hip::compat::abs
#define compat_log1p c10::hip::compat::log1p
#else
#define compat_exp std::exp
#define compat_ceil std::ceil
#define compat_floor std::floor
#define compat_log std::log
#define compat_pow std::pow
#define compat_sqrt std::sqrt
#define compat_tan std::tan
#define compat_abs std::abs
#define compat_log1p std::log1p
#endif
namespace {
#if !defined(__CUDA_ARCH__) && !defined(__HIPCC__)
// we cannot use std::isnan directly due to some incompatibility of
// gcc constexpr'ing and nvcc
using std::isnan;
#endif
// Here sampler_t should be function type scalar_t(void). For gpu
// "sampler" is a device function, but since ROCM doesn't have
// equivalent to nvstd::function, we use a template type parameter to
// capture it.
template<typename scalar_t, typename sampler_t>
struct BaseSampler {
sampler_t sampler;
C10_DEVICE BaseSampler(const sampler_t& sampler): sampler(sampler) {}
C10_DEVICE scalar_t sample() {
return sampler();
}
};
// The function `sample_gamma` is
// is adapted from Numpy's distributions.c implementation.
// It is MIT licensed, so here is the copyright:
/* Copyright 2005 Robert Kern (robert.kern@gmail.com)
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t, typename normal_sampler_t>
C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform, BaseSampler<accscalar_t, normal_sampler_t>& standard_normal) {
accscalar_t scale = 1.0f;
// Boost alpha for higher acceptance probability.
if (alpha < 1.0f) {
if (alpha == 0.f) return 0.f;
scale *= compat_pow(1 - standard_uniform.sample(), 1.0f / alpha);
alpha += 1.0f;
}
// This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
// doi:10.1145/358407.358414
const accscalar_t d = alpha - 1.0f / 3.0f;
const accscalar_t c = 1.0f / compat_sqrt(9.0f * d);
for (;;) {
accscalar_t x, y;
do {
x = standard_normal.sample();
y = 1.0f + c * x;
} while (y <= 0);
const accscalar_t v = y * y * y;
const accscalar_t u = 1 - standard_uniform.sample();
const accscalar_t xx = x * x;
if (u < 1.0f - 0.0331f * xx * xx)
return static_cast<scalar_t>(scale * d * v);
if (compat_log(u) < 0.5f * xx + d * (1.0f - v + compat_log(v)))
return static_cast<scalar_t>(scale * d * v);
}
}
/* the functions stirling_approx_tail, binomial_inversion, and btrs are adapted
* from TensorFlow's random_binomial_op.cc implementation. That code is under
* copyright: 2019 The TensorFlow Authors.
*
* It was released under the Apache License, Version 2.0 (the "License"), available at:
* http://www.apache.org/licenses/LICENSE-2.0
*/
template<typename scalar_t>
C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
const static scalar_t kTailValues[] = {
0.0810614667953272,
0.0413406959554092,
0.0276779256849983,
0.02079067210376509,
0.0166446911898211,
0.0138761288230707,
0.0118967099458917,
0.0104112652619720,
0.00925546218271273,
0.00833056343336287
};
if (k <= 9) {
return kTailValues[static_cast<size_t>(k)];
}
scalar_t kp1sq = (k + 1) * (k + 1);
return (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / kp1sq) / kp1sq) / (k + 1);
}
template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
C10_DEVICE scalar_t binomial_inversion(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
accscalar_t U;
accscalar_t geom_sum = 0;
scalar_t num_geom = 0;
accscalar_t logprob = compat_log1p(-prob);
while (1) {
U = standard_uniform.sample();
accscalar_t geom = compat_ceil(compat_log(U) / logprob);
geom_sum += geom;
if (geom_sum > count) {
break;
}
num_geom = num_geom + 1;
}
return num_geom;
}
template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
C10_DEVICE scalar_t btrs(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
scalar_t k;
accscalar_t U, V, us;
// This is spq in the paper.
const accscalar_t stddev = compat_sqrt(count * prob * (1 - prob));
// Other coefficients for Transformed Rejection sampling.
const accscalar_t b = 1.15 + 2.53 * stddev;
const accscalar_t a = -0.0873 + 0.0248 * b + 0.01 * prob;
const accscalar_t c = count * prob + 0.5;
const accscalar_t v_r = 0.92 - 4.2 / b;
const accscalar_t r = prob / (1 - prob);
const accscalar_t alpha = (2.83 + 5.1 / b) * stddev;
const accscalar_t m = compat_floor((count + 1) * prob);
while (1) {
U = standard_uniform.sample() - 0.5;
V = standard_uniform.sample();
us = 0.5 - compat_abs(U);
k = static_cast<scalar_t>(compat_floor((2 * a / us + b) * U + c));
// Reject non-sensical answers.
if (k < 0 || k > count) {
continue;
}
// Region for which the box is tight, and we can return our calculated value.
// This should happen 0.86 * v_r times. In the limit as n * p is large,
// the acceptance rate converges to ~79% (and in the lower regime it is ~24%).
if (us >= 0.07 && V <= v_r) {
return k;
}
// This deviates from Hormann's BTRS algorithm, as there is a log missing.
// For all (u, v) pairs outside of the bounding box, this calculates the
// transformed-reject ratio.
V = compat_log(V * alpha / (a / (us * us) + b));
accscalar_t upperbound =
((m + 0.5) * compat_log((m + 1) / (r * (count - m + 1))) +
(count + 1) * compat_log((count - m + 1) / (count - k + 1)) +
(k + 0.5) * compat_log(r * (count - k + 1) / (k + 1)) +
stirling_approx_tail<accscalar_t>(m) + stirling_approx_tail<accscalar_t>(count - m) -
stirling_approx_tail<accscalar_t>(k) - stirling_approx_tail<accscalar_t>(count - k));
if (V <= upperbound) {
return k;
}
}
}
template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
C10_DEVICE scalar_t sample_binomial(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
if (count <= 0.0 || prob <= 0.0) {
return 0;
} else if (prob >= 1.0) {
return count;
} else if (prob <= 0.5) {
if (count * prob >= 10.0) {
// btrs
return btrs<scalar_t, accscalar_t, uniform_sampler_t>(count, prob, standard_uniform);
} else {
// binomial inversion
return binomial_inversion<scalar_t, accscalar_t, uniform_sampler_t>(count, prob, standard_uniform);
}
} else if (prob > 0.5) {
scalar_t qprob = 1.0 - prob;
if (count * qprob >= 10.0) {
// btrs
return count - btrs<scalar_t, accscalar_t, uniform_sampler_t>(count, qprob, standard_uniform);
} else {
// count - binomial inversion
return count - binomial_inversion<scalar_t, accscalar_t, uniform_sampler_t>(count, qprob, standard_uniform);
}
} else {
// prob is nan?
return static_cast<scalar_t>(NAN);
}
}
/*
* This function is derived from the implementation of the digamma function in the Cephes Math Library.
* See note [3-Clause BSD License for the Cephes Math Library] in ATen/native/Math.h.
*/
template<typename scalar_t, typename accscalar_t>
C10_DEVICE inline scalar_t digamma_one(scalar_t x) {
constexpr accscalar_t PSI_10 = 2.25175258906672110764;
if (x == 0) {
return INFINITY;
}
accscalar_t additional_summand = 0;
int x_is_integer = x == compat_floor(x);
if (x < 0) {
if (x_is_integer) {
return INFINITY;
}
// it is more standard to write this as recursion, but
// nvcc does not like that
additional_summand = -c10::pi<scalar_t> /
compat_tan(c10::pi<scalar_t> * x);
x = 1 - x;
}
// Push x to be >= 10
accscalar_t result = 0;
while (x < 10) {
result -= 1 / x;
x += 1;
}
if (x == 10) {
return result + PSI_10 + additional_summand;
}
// Compute asymptotic digamma
static const accscalar_t A[] = {
8.33333333333333333333E-2,
-2.10927960927960927961E-2,
7.57575757575757575758E-3,
-4.16666666666666666667E-3,
3.96825396825396825397E-3,
-8.33333333333333333333E-3,
8.33333333333333333333E-2,
};
accscalar_t y = 0;
if (x < 1.0e17f) {
accscalar_t z = 1.0 / (x * x);
y = z * polevl<accscalar_t>(z, A, 6);
}
return static_cast<scalar_t>(
result + compat_log(x) - (0.5f / x) - y + additional_summand);
}
// Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
// for random number x drawn from a standard Gamma distribution Gamma(alpha).
template <typename scalar_t, typename accscalar_t>
C10_HOST_DEVICE scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) {
// Use a Taylor series expansion for small x.
accscalar_t x = static_cast<accscalar_t>(x_);
accscalar_t alpha = static_cast<accscalar_t>(alpha_);
if (x < 0.8f) {
accscalar_t numer = 1;
accscalar_t denom = alpha;
auto series1 = numer / denom;
auto series2 = numer / (denom * denom);
for (int i = 1; i <= 5; ++i) {
numer *= -x / static_cast<accscalar_t>(i);
denom += 1;
series1 += numer / denom;
series2 += numer / (denom * denom);
}
const auto pow_x_alpha = compat_pow(x, alpha);
const auto gamma_pdf = compat_pow(x, alpha - 1) * compat_exp(-x);
const auto gamma_cdf = pow_x_alpha * series1;
const auto gamma_cdf_alpha =
(compat_log(x) - digamma_one<accscalar_t, accscalar_t>(alpha)) *
gamma_cdf -
pow_x_alpha * series2;
const auto result = -gamma_cdf_alpha / gamma_pdf;
return isnan(result) ? static_cast<scalar_t>( 0.f ) : static_cast<scalar_t>(result);
}
// Use a Rice saddle point expansion for large alpha.
if (alpha > 8.0f) {
if (0.9f * alpha <= x && x <= 1.1f * alpha) {
const auto numer_1 = 1 + 24 * alpha * (1 + 12 * alpha);
const auto numer_2 = 1440 * (alpha * alpha) + 6 * x * (53 - 120 * x)
- 65 * x * x / alpha + alpha * (107 + 3600 * x);
const auto denom = 1244160 * (alpha * alpha) * (alpha * alpha);
return static_cast<scalar_t>(numer_1 * numer_2 / denom);
}
const auto denom = compat_sqrt(8 * alpha);
const auto term2 = denom / (alpha - x);
const auto term3 = compat_pow(
x - alpha - alpha * compat_log(x / alpha),
static_cast<accscalar_t>(-1.5));
const auto term23 = (x < alpha) ? term2 - term3 : term2 + term3;
const auto term1 = compat_log(x / alpha) * term23 -
compat_sqrt(2 / alpha) * (alpha + x) / ((alpha - x) * (alpha - x));
const auto stirling = 1 + 1 / (12 * alpha) * (1 + 1 / (24 * alpha));
const auto numer = x * term1;
return static_cast<scalar_t>(-stirling * numer / denom);
}
// Use a bivariate rational approximation to the reparameterized gradient.
const auto u = compat_log(x / alpha);
const auto v = compat_log(alpha);
static const accscalar_t coef_uv[3][8] = {
{0.16009398, -0.094634809, 0.025146376, -0.0030648343,
1, 0.32668115, 0.10406089, 0.0014179084},
{0.53487893, 0.1298071, 0.065735949, -0.0015649758,
0.16639465, 0.020070113, -0.0035938915, -0.00058392623},
{0.040121004, -0.0065914022, -0.0026286047, -0.0013441777,
0.017050642, -0.0021309326, 0.00085092367, -1.5247877e-07},
};
accscalar_t coef_v[8];
for (int i = 0; i < 8; ++ i) {
coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
}
const auto p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
const auto q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
return static_cast<scalar_t>(compat_exp(p / q));
}
// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
// Assumes x is close to zero and uses a Taylor expansion.
template <typename scalar_t, typename accscalar_t>
C10_DEVICE inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) {
const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha)
- digamma_one<scalar_t, accscalar_t>(alpha + beta) - compat_log(x);
scalar_t numer = 1;
scalar_t series = numer / alpha * (factor + 1 / alpha);
for (int i = 1; i <= 10; ++i) {
scalar_t casted_i = static_cast<scalar_t>(i);
numer *= (casted_i - beta) * x / casted_i;
const scalar_t denom = alpha + casted_i;
series += numer / denom * (factor + 1 / denom);
}
const scalar_t result = x * compat_pow(1 - x, -beta) * series;
return isnan(result) ? static_cast<scalar_t>( 0.f ) : result;
}
// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt beta.
// Assumes x is close to zero and uses a Taylor expansion.
template <typename scalar_t, typename accscalar_t>
C10_DEVICE inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) {
const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha + beta) - digamma_one<scalar_t, accscalar_t>(beta);
scalar_t numer = 1, betas = 1, dbetas = 0, series = factor / alpha;
for (int i = 1; i <= 8; ++i) {
scalar_t casted_i = static_cast<scalar_t>(i);
numer *= -x / casted_i;
dbetas = dbetas * (beta - casted_i) + betas;
betas = betas * (beta - casted_i);
series += numer / (alpha + casted_i) * (dbetas + factor * betas);
}
const scalar_t result = -compat_pow(1 - x, 1 - beta) * series;
return isnan(result) ? static_cast<scalar_t>( 0.f ) : result;
}
// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
// Assumes alpha and beta are both large and uses a Rice saddle point expansion.
// To ensure numerical stability, this computation is performed at higher precision.
template<typename scalar_t, typename accscalar_t>
C10_DEVICE inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) {
const accscalar_t total = alpha + beta;
const accscalar_t mean = alpha / total;
const accscalar_t std = compat_sqrt(alpha * beta / (total + 1)) / total;
if (mean - 0.1 * std <= x && x <= mean + 0.1 * std) {
// Avoid the singularity at x = mean.
const accscalar_t poly = 47 * x * (beta * beta) * (beta * beta) + alpha * (
(43 + 20 * (16 + 27 * beta) * x) * (beta * beta) * beta + alpha * (
3 * (59 + 180 * beta - 90 * x) * (beta * beta) + alpha * (
(453 + 1620 * beta * (1 - x) - 455 * x) * beta + alpha * (
8 * (1 - x) * (135 * beta - 11)))));
const accscalar_t prefactor_num = (1 + 12 * alpha) * (1 + 12 * beta) / (total * total);
const accscalar_t prefactor_den = 12960 * alpha * alpha * alpha * beta * beta * (1 + 12 * total);
return prefactor_num / (1 - x) * poly / prefactor_den;
}
const accscalar_t prefactor = -x / compat_sqrt(2 * alpha * beta / total);
const accscalar_t stirling = (1 + 1 / (12 * alpha) + 1 / (288 * alpha * alpha))
* (1 + 1 / (12 * beta) + 1 / (288 * beta * beta))
/ (1 + 1 / (12 * total) + 1 / (288 * total * total));
const accscalar_t term1_num = 2 * (alpha * alpha) * (x - 1) + alpha * beta * (x - 1) - x * (beta * beta);
const accscalar_t axbx = alpha * (x - 1) + beta * x;
const accscalar_t term1_den = compat_sqrt(2 * alpha / beta) * compat_pow(total, static_cast<accscalar_t>(1.5f)) * axbx * axbx;
const accscalar_t term1 = term1_num / term1_den;
const accscalar_t term2 = 0.5f * compat_log(alpha / (total * x));
const accscalar_t term3_num = compat_sqrt(8 * alpha * beta / total);
const accscalar_t term3_den = beta * x + alpha * (x - 1);
const accscalar_t term3 = term3_num / term3_den;
const accscalar_t term4_base = beta * compat_log(beta / (total * (1 - x))) +
alpha * compat_log(alpha / (total * x));
const accscalar_t term4 = compat_pow(term4_base, static_cast<accscalar_t>(-1.5f));
const accscalar_t term1234 = term1 + term2 * (term3 + (x < mean ? term4 : -term4));
return static_cast<scalar_t>(stirling * prefactor * term1234);
}
// Computes a scaled reparameterized gradient
// -(d/dalpha cdf(x;alpha,beta)) / pdf(x;alpha,beta) / (1-x)
// for random number x drawn from a Beta distribution Beta(alpha,beta).
// This function inputs total=alpha+beta to make it easy to implement
// Dirichlet reparameterized gradients in terms of Betas.
template<typename scalar_t, typename accscalar_t>
C10_HOST_DEVICE inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) {
accscalar_t x_ = static_cast<accscalar_t>(x);
accscalar_t alpha_ = static_cast<accscalar_t>(alpha);
accscalar_t total_ = static_cast<accscalar_t>(total);
const scalar_t beta = total - alpha;
const accscalar_t beta_ = total_ - alpha_;
const scalar_t boundary = total * x * (1 - x);
// Use an asymptotic approximation for x close to 0.
if (x <= 0.5f && boundary < 2.5f) {
return _beta_grad_alpha_small<scalar_t, accscalar_t>(x, alpha, beta);
}
// Use an asymptotic approximation for x close to 1.
if (x >= 0.5f && boundary < 0.75f) {
return -_beta_grad_beta_small<scalar_t, accscalar_t>(1 - x, beta, alpha);
}
// Use an asymptotic approximation when alpha and (total - alpha) are both large.
if (alpha > 6 && beta > 6) {
return _beta_grad_alpha_mid<scalar_t, accscalar_t>(x_, alpha_, beta_);
}
// Use a rational correction to an analytic approximation.
static const accscalar_t c[2][3][3][4] = {
{{{1.003668233, -0.01061107488, -0.0657888334, 0.01201642863},
{0.6336835991, -0.3557432599, 0.05486251648, -0.001465281033},
{-0.03276231906, 0.004474107445, 0.002429354597, -0.0001557569013}},
{{0.221950385, -0.3187676331, 0.01799915743, 0.01074823814},
{-0.2951249643, 0.06219954479, 0.01535556598, 0.001550077057},
{0.02155310298, 0.004170831599, 0.001292462449, 6.976601077e-05}},
{{-0.05980841433, 0.008441916499, 0.01085618172, 0.002319392565},
{0.02911413504, 0.01400243777, -0.002721828457, 0.000751041181},
{0.005900514878, -0.001936558688, -9.495446725e-06, 5.385558597e-05}}},
{{{1, -0.02924021934, -0.04438342661, 0.007285809825},
{0.6357567472, -0.3473456711, 0.05454656494, -0.002407477521},
{-0.03301322327, 0.004845219414, 0.00231480583, -0.0002307248149}},
{{0.5925320577, -0.1757678135, 0.01505928619, 0.000564515273},
{0.1014815858, -0.06589186703, 0.01272886114, -0.0007316646956},
{-0.007258481865, 0.001096195486, 0.0003934994223, -4.12701925e-05}},
{{0.06469649321, -0.0236701437, 0.002902096474, -5.896963079e-05},
{0.001925008108, -0.002869809258, 0.0008000589141, -6.063713228e-05},
{-0.0003477407336, 6.959756487e-05, 1.097287507e-05, -1.650964693e-06}}},
};
const accscalar_t u = compat_log(x_);
const accscalar_t a = compat_log(alpha_) - u;
const accscalar_t b = compat_log(total_) - a;
const accscalar_t pow_u[3] = {1, u, u * u};
const accscalar_t pow_a[3] = {1, a, a * a};
accscalar_t p = 0.0;
accscalar_t q = 0.0;
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 3; ++j) {
const accscalar_t ua = pow_u[i] * pow_a[j];
p += ua * (c[0][i][j][0] + b * (c[0][i][j][1] + b * (c[0][i][j][2] + b * c[0][i][j][3])));
q += ua * (c[1][i][j][0] + b * (c[1][i][j][1] + b * (c[1][i][j][2] + b * c[1][i][j][3])));
}
}
const accscalar_t approx = x_ * (digamma_one<scalar_t, accscalar_t>(total_) - digamma_one<scalar_t, accscalar_t>(alpha_)) / beta_;
return static_cast<scalar_t>(p / q * approx);
}
} // namespace

View File

@ -0,0 +1,153 @@
#include <ATen/core/Tensor.h>
#include <ATen/Config.h>
#include <cstdint>
#ifdef USE_FBGEMM
#include <fbgemm/FbgemmEmbedding.h>
#endif
namespace at::native {
enum class EmbeddingBagMode {
SUM = 0,
MEAN = 1,
MAX = 2,
};
[[maybe_unused]] static bool operator==(int64_t op1, EmbeddingBagMode op2) {
return op1 == static_cast<int64_t>(op2);
}
[[maybe_unused]] static bool operator!=(int64_t op1, EmbeddingBagMode op2) {
return !(op1 == op2);
}
void check_arguments(
const Tensor& weight,
const Tensor& indices,
const Tensor& offsets,
const int64_t mode,
const std::optional<Tensor>& per_sample_weights,
bool include_last_offset);
void make_bag_size_out(
Tensor& bag_size_out,
const Tensor& offsets,
const Tensor& indices,
const int64_t mode,
const bool include_last_offset,
const bool requires_grad);
void make_max_indices_out(
Tensor& max_indices_out,
const Tensor& weight,
const Tensor& indices,
const Tensor& offsets,
const Tensor& bag_size,
const int64_t mode,
bool include_last_offset);
void make_offset2bag_out(
Tensor& offset2bag,
Tensor& output,
const Tensor& weight,
const Tensor& indices,
const Tensor& offsets,
const int64_t mode,
const std::optional<Tensor>& per_sample_weights,
const int64_t padding_idx = -1);
#ifdef USE_FBGEMM
template<bool has_weight, typename TIndex, typename TData>
struct _CallbackAndBlockSize {
using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature<TData, TIndex, TIndex, TData>::Type;
int64_t blockSize = -1;
TCallback callback = nullptr;
static TCallback generateCallback(int64_t block_size) {
return fbgemm::GenerateEmbeddingSpMDM<TData, TIndex, TIndex, TData>(
block_size,
has_weight,
/* normalize_by_lengths */false,
/* prefetch */16,
/* is_weight_positional */false,
/* use_offsets */true);
}
_CallbackAndBlockSize() = default;
explicit _CallbackAndBlockSize(std::optional<int64_t> maybe_block_size)
: blockSize(maybe_block_size.value_or(-1))
, callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr)
{}
};
template<typename... StorageMixins>
struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {
_EmbeddingBagKernelCacheImpl() = default;
// use each of the mixins to store corresponding kernel and block size
explicit _EmbeddingBagKernelCacheImpl(std::optional<int64_t> maybe_block_size)
: StorageMixins(maybe_block_size)...
{}
// this method is thread safe (call sites may call from different threads)
template<bool has_weight, typename TIndex, typename TData>
typename _CallbackAndBlockSize<has_weight, TIndex, TData>::TCallback
getCallback(int64_t block_size) const {
// if the cache doesn't store the kernel for the incoming block size
// (so it is different from the one stored in corresponding mixin)
// regenerate the kernel (not writing it into the cache so we avoid locks)
if (block_size != _CallbackAndBlockSize<has_weight, TIndex, TData>::blockSize) {
return _CallbackAndBlockSize<has_weight, TIndex, TData>::generateCallback(block_size);
}
// else retrieve the cached kernel from the corresponding mixin
return _CallbackAndBlockSize<has_weight, TIndex, TData>::callback;
}
};
// instantiate the cache with the list of storage mixins
// for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file
using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl<
_CallbackAndBlockSize<true, int32_t, float>,
_CallbackAndBlockSize<false, int32_t, float>,
_CallbackAndBlockSize<true, int64_t, float>,
_CallbackAndBlockSize<false, int64_t, float>,
_CallbackAndBlockSize<true, int32_t, unsigned short>,
_CallbackAndBlockSize<false, int32_t, unsigned short>,
_CallbackAndBlockSize<true, int64_t, unsigned short>,
_CallbackAndBlockSize<false, int64_t, unsigned short>>;
#else
struct _EmbeddingBagKernelCache {
explicit _EmbeddingBagKernelCache(std::optional<int64_t> /* maybe_block_size */) {}
};
#endif
void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
Tensor& bag_size, Tensor* max_indices,
const Tensor &weight, const Tensor &indices,
const Tensor &offsets, const int64_t mode = 0,
const std::optional<Tensor>& per_sample_weights = std::nullopt,
bool include_last_offset = false,
int64_t padding_idx = -1,
_EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
void _embedding_bag_cpu_out(
at::Tensor& output,
at::Tensor& offset2bag,
at::Tensor& bag_size,
at::Tensor* p_max_indices,
const at::Tensor& weight,
const at::Tensor& indices,
const at::Tensor& offsets,
const bool scale_grad_by_freq,
const int64_t mode,
const bool sparse,
const std::optional<at::Tensor>& per_sample_weights,
const bool include_last_offset,
const std::optional<int64_t>& padding_idx,
_EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
} // namespace at::native

View File

@ -0,0 +1,21 @@
// Functions that fill Tensors with constants. Implementations are in Fill.cpp.
#pragma once
#include <ATen/native/DispatchStub.h>
namespace c10 {
class Scalar;
}
namespace at {
class Tensor;
struct TensorIterator;
namespace native {
DECLARE_DISPATCH(void(*)(TensorIterator&, const c10::Scalar&), fill_stub);
Tensor& fill_out(Tensor& self, const Scalar& value);
}} // namespace at::native

View File

@ -0,0 +1,396 @@
#pragma once
#include <ATen/Device.h>
#include <ATen/Dispatch.h>
#include <ATen/ScalarType.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/utils/ParamsHash.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/result_type_native.h>
#endif
#include <unordered_map>
#include <vector>
namespace at::native {
namespace {
// Check if tensor list has either a boolean tensor or a integer tensor
inline bool has_integral_tensor(TensorList tensors, const bool includeBool) {
return std::any_of(
tensors.begin(), tensors.end(), [&includeBool](const auto& t) {
return at::isIntegralType(t.scalar_type(), includeBool);
});
}
// check if tensor list has bool tensors
inline bool has_bool_tensor(TensorList tensors) {
return std::any_of(tensors.begin(), tensors.end(), [](const auto& t) -> bool {
return t.scalar_type() == ScalarType::Bool;
});
}
// Check foreach API restrictions
// - Tensor lists must be non-empty.
// - All TensorLists and ScalarLists must have the same number of elements.
// - Corresponding tensors must have the same size.
inline void check_foreach_api_restrictions(TensorList tensors) {
TORCH_CHECK(!tensors.empty(), "Tensor list must have at least one tensor.");
}
inline void check_foreach_api_restrictions(
TensorList tensors,
ArrayRef<Scalar> scalars) {
check_foreach_api_restrictions(tensors);
TORCH_CHECK(
tensors.size() == scalars.size(),
"Tensor list must have same number of elements as scalar list.");
}
inline void check_foreach_api_restrictions(
TensorList tensors1,
TensorList tensors2) {
TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
TORCH_CHECK(
tensors1.size() == tensors2.size(),
"Tensor lists must have the same number of tensors, got ",
tensors1.size(),
" and ",
tensors2.size());
}
inline void check_foreach_api_restrictions(
TensorList tensors1,
TensorList tensors2,
TensorList tensors3) {
TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
TORCH_CHECK(!tensors3.empty(), "Tensor list must have at least one tensor.");
TORCH_CHECK(
tensors1.size() == tensors2.size(),
"Tensor lists must have the same number of tensors, got ",
tensors1.size(),
" and ",
tensors2.size());
TORCH_CHECK(
tensors1.size() == tensors3.size(),
"Tensor lists must have the same number of tensors, got ",
tensors1.size(),
" and ",
tensors3.size());
}
inline void check_foreach_api_restrictions(
TensorList tensors1,
TensorList tensors2,
TensorList tensors3,
ArrayRef<Scalar> scalars) {
check_foreach_api_restrictions(tensors1, tensors2, tensors3);
TORCH_CHECK(
tensors1.size() == scalars.size(),
"Tensor list must have same number of elements as scalar list, got ",
tensors1.size(),
" and ",
scalars.size());
}
// Helper function called in check_fast_path_restrictions to check whether all
// corresponding tensors (aligning in index across the tensorLists) share the
// same device and dtype.
inline bool _check_tensors_share_device_and_dtype(
ArrayRef<TensorList> tensorLists,
const bool skip_dtype_check = false) {
const auto expected_dtype = tensorLists[0][0].dtype();
const auto expected_device = tensorLists[0][0].device();
auto is_tensor_okay = [&](const Tensor& tensor) {
return (skip_dtype_check || tensor.dtype() == expected_dtype) &&
tensor.device() == expected_device && tensor.layout() == at::kStrided &&
tensor.is_non_overlapping_and_dense();
};
for (const auto& tensorList : tensorLists) {
for (const auto& tensor : tensorList) {
if (!is_tensor_okay(tensor)) {
return false;
}
}
}
return true;
}
// Helper function called in check_fast_path_restrictions to check if
// corresponding tensors in tensor lists have the same sizes and strides.
inline bool _check_tensors_share_sizes_and_strides(
ArrayRef<TensorList> tensorLists) {
auto is_diff_stride = [](const IntArrayRef& size,
const IntArrayRef& left_stride,
const IntArrayRef& right_stride) -> bool {
const size_t size_size = size.size();
for (const auto dim : c10::irange(size_size)) {
if (size[dim] == 1)
continue;
if (left_stride[dim] != right_stride[dim]) {
return true;
}
}
return false;
};
for (const auto i : c10::irange(1, tensorLists.size())) {
for (const auto j : c10::irange(tensorLists[0].size())) {
if (tensorLists[0][j].sizes() != tensorLists[i][j].sizes() ||
is_diff_stride(
tensorLists[0][j].sizes(),
tensorLists[0][j].strides(),
tensorLists[i][j].strides())) {
return false;
}
}
}
return true;
}
// Helper function called in check_fast_path_restrictions to check whether
// all tensors type promote properly with the scalars in scalarList. This
// function assumes that _check_tensors_share_device_and_dtype has already been
// called so that all corresponding tensors in tensorLists have the same dtype.
// Then, it is sufficient to check the type promotion with just one tensorList.
inline bool _check_tensors_do_type_promotion_with_scalars(
TensorList tensorList,
ArrayRef<Scalar> scalarList = {},
bool does_op_promote_integer_inputs_to_float = false) {
for (const auto i : c10::irange(tensorList.size())) {
// For division, integer inputs will result in float.
if (does_op_promote_integer_inputs_to_float) {
if (at::isIntegralType(
tensorList[i].scalar_type(), /*includeBool*/ true)) {
return false;
}
}
if (!scalarList.empty()) {
const auto& scalar =
scalarList.size() == 1 ? scalarList[0] : scalarList[i];
const auto& tensor = tensorList[i];
// note(mkozuki): This check might be responsible for
// `_foreach_add(bool_tensors, bool_tensors)` being pushed to slow path.
if (tensor.scalar_type() != at::native::result_type(scalar, tensor)) {
return false;
}
}
}
return true;
}
// To go via 'fast' path, several conditions must be satisfied
// - All tensors in all lists must have the same dtype.
// - All tensors must be on the same device
// - All tensors must have strided layout
// - All tensors must be non-overlapping and dense
// - Resulting tensor must have the same dtype as the input one
// [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
// ``does_op_promote_integer_inputs_to_float=true`` means that the result of
// the op will be float even if inputs are integer or boolean, which
// currently fast path does not support. In short, this flag, when
// turned on, gatekeeps the op from going down the fastpath.
// Please, make sure to call check_foreach_api_restrictions before calling this
// method. There is a set of preconditions that have to be satisfied.
inline bool check_fast_path_restrictions(
ArrayRef<TensorList> tensorLists,
ArrayRef<Scalar> scalarList = {},
bool does_op_promote_integer_inputs_to_float = false) {
return _check_tensors_share_device_and_dtype(tensorLists) &&
_check_tensors_share_sizes_and_strides(tensorLists) &&
_check_tensors_do_type_promotion_with_scalars(
tensorLists[0],
scalarList,
does_op_promote_integer_inputs_to_float);
}
inline std::vector<c10::Scalar> convert_tensor_to_scalar_list(
const Tensor& scalarList_,
int64_t expect_length) {
std::vector<c10::Scalar> scalarList;
TORCH_CHECK(
scalarList_.device() == c10::kCPU,
"Expected scalars to be on CPU, got ",
scalarList_.device(),
" instead.");
TORCH_CHECK(
scalarList_.is_contiguous(), "Expected scalars to be contiguous.");
TORCH_CHECK(
scalarList_.dim() == 1,
"Expected packed scalar Tensor to be of dimension 1. Got ",
scalarList_.dim(),
" instead.");
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf,
kHalf,
kBool,
kBFloat16,
scalarList_.scalar_type(),
"convert_tensor_to_scalar_list",
[&]() {
const scalar_t* scalar_data = scalarList_.const_data_ptr<scalar_t>();
TORCH_CHECK(
(expect_length == scalarList_.size(0)),
"Expected length of scalars to match input of length ",
expect_length,
" but got ",
scalarList_.size(0),
" instead.");
for (int64_t i = 0; i < scalarList_.size(0); i++) {
scalarList.emplace_back(scalar_data[i]);
}
});
return scalarList;
}
// see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
inline bool can_use_fast_route(
ArrayRef<TensorList> tensorLists,
ArrayRef<Scalar> scalarList = {},
bool does_op_promote_integer_inputs_to_float = false) {
return check_fast_path_restrictions(
tensorLists, scalarList, does_op_promote_integer_inputs_to_float);
}
// see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
inline bool can_use_fast_route(
TensorList tensors1,
TensorList tensors2,
bool does_op_promote_integer_inputs_to_float = false) {
return can_use_fast_route(
{tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float);
}
using DeviceDtypeKey = std::pair<at::Device, at::ScalarType>;
using IndicesT = std::vector<size_t>;
using nested_optional_tensorvec_t =
std::vector<std::vector<std::optional<at::Tensor>>>;
using TensorsAndIndicesT = std::pair<nested_optional_tensorvec_t, IndicesT>;
using FlatMap = std::unordered_map<
DeviceDtypeKey,
TensorsAndIndicesT,
ParamsHash<DeviceDtypeKey>>;
inline FlatMap _group_tensors_by_first_tensors_device_and_dtype(
const nested_optional_tensorvec_t& nested_tensorlist,
const bool with_indices) {
FlatMap grouped_tensors_with_indices;
TORCH_CHECK(!nested_tensorlist.empty());
TORCH_CHECK(!nested_tensorlist[0].empty());
const auto num_lists = nested_tensorlist.size();
const auto num_tensors = nested_tensorlist[0].size();
TORCH_CHECK(std::all_of(
nested_tensorlist.cbegin(),
nested_tensorlist.cend(),
[&](const auto& tensorlist) -> bool {
// note(crcrpar): Allow empty tensorlists following
// ref:
// https://github.com/pytorch/pytorch/blob/85885301fd3c6adb8b9dc3cf7afadf6945566684/torch/utils/_foreach_utils.py#L21-L24
return tensorlist.size() == num_tensors || tensorlist.size() == 0;
}));
for (const auto& tensor_index : c10::irange(num_tensors)) {
const auto key = [&]() -> DeviceDtypeKey {
const auto t = nested_tensorlist[0][tensor_index];
TORCH_CHECK(
t.has_value(),
"Tensors of the first list of nested Tensor lists are supposed to be defined but ",
"the ",
tensor_index,
"-th Tensor is not.");
return {t->device(), t->scalar_type()};
}();
TORCH_CHECK(
std::all_of(
nested_tensorlist.cbegin(),
nested_tensorlist.cend(),
[&](const auto& tensorlist) -> bool {
if (tensorlist.size() == 0) {
return true;
}
const auto& tensor = tensorlist[tensor_index];
// note(crcrpar): Currently the scope of this function is
// optimizers so there could be `state_steps` and other scalars
// whose elements are float tensors no matter what the parameter's
// dtype is.
if (!tensor.has_value()) {
return true;
} else {
const auto s = tensor->scalar_type();
const auto d = tensor->device();
// Note: `step` or `state_step` is float32 by default.
if (key.first == d) {
return key.second == s || s == at::ScalarType::Float ||
s == at::ScalarType::Double;
} else if (d.is_cpu()) {
// note(crcrpar): There are some test cases (e.g.
// TestOptim::test_adam) where state_steps are on CPU and the
// others are on CUDA. Currently a state_step Tensor has the
// dtype of float.
return s == at::ScalarType::Float ||
s == at::ScalarType::Double;
} else {
return false;
}
}
}),
"Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding");
if (!grouped_tensors_with_indices.count(key)) {
grouped_tensors_with_indices.insert(
{key,
TensorsAndIndicesT{
[&]() -> nested_optional_tensorvec_t {
nested_optional_tensorvec_t nested_tensorvec;
nested_tensorvec.reserve(num_lists);
for (const auto& i : c10::irange(num_lists)) {
std::vector<std::optional<at::Tensor>> tensors;
if (!nested_tensorlist[i].empty()) {
// NB: num_tensors is the max possible length for any of
// the inner lists of tensor references. Reserving the max
// trades memory for perf. This should not have significant
// impact.
tensors.reserve(num_tensors);
}
nested_tensorvec.emplace_back(tensors);
}
return nested_tensorvec;
}(),
[&]() -> IndicesT {
if (!with_indices) {
return {};
} else {
IndicesT indices;
indices.reserve(num_tensors);
return indices;
}
}()}});
}
for (const auto& list_index : c10::irange(num_lists)) {
if (!nested_tensorlist[list_index].empty()) {
grouped_tensors_with_indices[key].first[list_index].emplace_back(
nested_tensorlist[list_index][tensor_index]);
}
}
if (with_indices) {
grouped_tensors_with_indices[key].second.emplace_back(tensor_index);
}
}
return grouped_tensors_with_indices;
}
} // namespace
} // namespace at::native

View File

@ -0,0 +1,80 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/TensorUtils.h>
#include <c10/util/irange.h>
namespace at::native {
template<typename scalar_t>
inline std::vector<int64_t> generate_intervals(
scalar_t sample,
int64_t inputSize,
int64_t outputSize,
int64_t poolSize) {
std::vector<int64_t> sequence(outputSize);
if (outputSize > 1) {
scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
static_cast<scalar_t>(outputSize - 1);
for (const auto i : c10::irange(outputSize - 1)) {
sequence[i] =
static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
}
}
if (outputSize > 0) {
sequence[outputSize - 1] = inputSize - poolSize;
}
return sequence;
}
template <int64_t ndim>
inline void fractional_max_pool_check_shape(
const Tensor& input,
const Tensor& randomSamples) {
TORCH_CHECK(
input.scalar_type() == randomSamples.scalar_type(),
"Expect _random_samples to have the same dtype as input");
int64_t ndimension = randomSamples.ndimension();
TORCH_CHECK(
ndimension == 3,
"Expect _random_samples to have 3 dimensions, got ", ndimension);
int64_t N = randomSamples.size(0);
int64_t C = randomSamples.size(1);
int64_t D = randomSamples.size(2);
int64_t input_batch = 0, input_channel = 0;
if (ndim == 2) {
// fractional_max_pool2d
if (input.ndimension() == 3) {
input_batch = 1;
input_channel = input.size(0);
} else {
input_batch = input.size(0);
input_channel = input.size(1);
}
} else {
// factional_max_pool3d
if (input.ndimension() == 4) {
input_batch = 1;
input_channel = input.size(0);
} else {
input_batch = input.size(0);
input_channel = input.size(1);
}
}
TORCH_CHECK(
N >= input_batch,
"Expect _random_samples.size(0) no less then input batch size.");
TORCH_CHECK(
C == input_channel,
"Expect _random_samples.size(1) equals to input channel size.");
TORCH_CHECK(
D == ndim,
"Expect _random_samples.size(2) equals to ", ndim, "; got ", D, ".");
}
} // namespace at::native

View File

@ -0,0 +1,20 @@
#pragma once
#include <ATen/native/DispatchStub.h>
#include <cstdint>
namespace at {
struct TensorIterator;
namespace native {
using _compute_linear_combination_fn = void(*)(
TensorIterator& iter,
int64_t in_stride,
int64_t coeff_stride,
int64_t num_summations
);
DECLARE_DISPATCH(_compute_linear_combination_fn, _compute_linear_combination_stub);
}} // namespace at::native

View File

@ -0,0 +1,20 @@
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
namespace at::native {
using fused_adagrad_fn = void (*)(
const at::Tensor& param,
const at::Tensor& grad,
const at::Tensor& state_sum,
const at::Tensor& state_step,
const double lr,
const double lr_decay,
const double weight_decay,
const double eps,
const bool maximize,
const float* grad_scale_ptr);
DECLARE_DISPATCH(fused_adagrad_fn, fused_adagrad_stub);
} // namespace at::native

View File

@ -0,0 +1,27 @@
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
namespace at::native {
enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 };
using fused_adam_fn = void (*)(
const at::Tensor& param,
const at::Tensor& grad,
const at::Tensor& exp_avg,
const at::Tensor& exp_avg_sq,
const at::Tensor& max_exp_avg_sq,
const at::Tensor& state_step,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const float* grad_scale_ptr,
const ADAM_MODE);
DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub);
} // namespace at::native

View File

@ -0,0 +1,21 @@
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
namespace at::native {
using fused_sgd_fn = void (*)(
const at::Tensor& param,
const at::Tensor& grad,
const at::Tensor& momentum_buffer,
const double weight_decay,
const double momentum,
const double lr,
const double dampening,
const bool nesterov,
const bool maximize,
const bool is_first_step,
const float* grad_scale_ptr);
DECLARE_DISPATCH(fused_sgd_fn, fused_sgd_stub);
} // namespace at::native

View File

@ -0,0 +1,298 @@
#pragma once
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <utility>
#include <ATen/native/GridSamplerUtils.h>
namespace at::native {
using detail::GridSamplerInterpolation;
using detail::GridSamplerPadding;
// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
// if align_corners: -1 and +1 get sent to the centers of the corner pixels
// -1 --> 0
// +1 --> (size - 1)
// scale_factor = (size - 1) / 2
// if not align_corners: -1 and +1 get sent to the image edges
// -1 --> -0.5
// +1 --> (size - 1) + 0.5 == size - 0.5
// scale_factor = size / 2
template <typename scalar_t>
static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size,
bool align_corners) {
if (align_corners) {
// unnormalize coord from [-1, 1] to [0, size - 1]
return ((coord + 1) / 2) * (size - 1);
} else {
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
return ((coord + 1) * size - 1) / 2;
}
}
// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
// except that it also returns the `d output / d input` via pointer argument
// `grad_in`.
// This is useful in the backward pass of grid_sampler.
template <typename scalar_t>
static inline scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int64_t size,
bool align_corners, scalar_t *grad_in) {
if (align_corners) {
// unnormalize coord from [-1, 1] to [0, size - 1]
*grad_in = static_cast<scalar_t>(size - 1) / 2;
return ((coord + 1) / 2) * (size - 1);
} else {
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
*grad_in = static_cast<scalar_t>(size) / 2;
return ((coord + 1) * size - 1) / 2;
}
}
// Clips coordinates to between 0 and clip_limit - 1
template<typename scalar_t>
static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) {
return std::min(static_cast<scalar_t>(clip_limit - 1), std::max(in, static_cast<scalar_t>(0)));
}
// clip_coordinates_set_grad works similarly to clip_coordinates except that
// it also returns the `d output / d input` via pointer argument `grad_in`.
// This is useful in the backward pass of grid_sampler.
template<typename scalar_t>
static inline scalar_t clip_coordinates_set_grad(scalar_t in, int64_t clip_limit,
scalar_t *grad_in) {
// Note that it is important for the gradient calculation that borders
// are considered out of bounds.
if (in <= static_cast<scalar_t>(0)) {
*grad_in = static_cast<scalar_t>(0);
return static_cast<scalar_t>(0);
} else {
scalar_t max = static_cast<scalar_t>(clip_limit - 1);
if (in >= max) {
*grad_in = static_cast<scalar_t>(0);
return max;
} else {
*grad_in = static_cast<scalar_t>(1);
return in;
}
}
}
// Reflects coordinates until they fall between low and high (inclusive).
// The bounds are passed as twice their value so that half-integer values
// can be represented as ints.
template<typename scalar_t>
static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low,
int64_t twice_high) {
if (twice_low == twice_high) {
return static_cast<scalar_t>(0);
}
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
in = std::fabs(in - min);
// `fmod` returns same sign as `in`, which is positive after the `fabs` above.
scalar_t extra = std::fmod(in, span);
int flips = static_cast<int>(std::floor(in / span));
if (flips % 2 == 0) {
return extra + min;
} else {
return span - extra + min;
}
}
// reflect_coordinates_set_grad works similarly to reflect_coordinates except
// that it also returns the `d output / d input` via pointer argument
// `grad_in`.
// This is useful in the backward pass of grid_sampler.
template<typename scalar_t>
static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_low,
int64_t twice_high, scalar_t *grad_in) {
if (twice_low == twice_high) {
*grad_in = static_cast<scalar_t>(0);
return static_cast<scalar_t>(0);
}
int grad_in_mult_;
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
in = in - min;
if (in < static_cast<scalar_t>(0)) {
grad_in_mult_ = -1;
in = -in;
} else {
grad_in_mult_ = 1;
}
// `fmod` returns same sign as `in`, which is positive after the `if` above.
scalar_t extra = std::fmod(in, span);
int flips = static_cast<int>(std::floor(in / span));
if (flips % 2 == 0) {
*grad_in = static_cast<scalar_t>(grad_in_mult_);
return extra + min;
} else {
*grad_in = static_cast<scalar_t>(-grad_in_mult_);
return span - extra + min;
}
}
// Mapping the out-of-boundary points back into boundary
// This would only affect padding_mode=border or reflection
template<typename scalar_t>
static inline scalar_t compute_coordinates(scalar_t coord, int64_t size,
GridSamplerPadding padding_mode,
bool align_corners) {
if (padding_mode == GridSamplerPadding::Border) {
// clip coordinates to image borders
coord = clip_coordinates(coord, size);
} else if (padding_mode == GridSamplerPadding::Reflection) {
// reflect coordinates by image borders
if (align_corners) {
coord = reflect_coordinates(coord, 0, 2*(size - 1));
} else {
coord = reflect_coordinates(coord, -1, 2*size - 1);
}
// clip coordinates to image borders
coord = clip_coordinates(coord, size);
}
return coord;
}
// Computes the pixel source index value for a grid coordinate
template <typename scalar_t>
static inline scalar_t grid_sampler_compute_source_index(
scalar_t coord,
int64_t size,
GridSamplerPadding padding_mode,
bool align_corners) {
coord = grid_sampler_unnormalize(coord, size, align_corners);
coord = compute_coordinates(coord, size, padding_mode, align_corners);
return coord;
}
// grid_sampler_compute_source_index_set_grad works similarly to
// grid_sampler_compute_source_index except that it also returns the
// `d output / d input` via pointer argument `grad_in`.
// This is useful in the backward pass of grid_sampler.
template <typename scalar_t>
static inline scalar_t grid_sampler_compute_source_index_set_grad(
scalar_t coord,
int64_t size,
GridSamplerPadding padding_mode,
bool align_corners,
scalar_t *grad_in) {
scalar_t grad_clip, grad_refl;
coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
if (padding_mode == GridSamplerPadding::Border) {
// clip coordinates to image borders
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_clip;
} else if (padding_mode == GridSamplerPadding::Reflection) {
// reflect coordinates by image borders
if (align_corners) {
coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
} else {
coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
}
// clip coordinates to image borders
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_refl * grad_clip;
}
return coord;
}
static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) {
return h >= 0 && h < H && w >= 0 && w < W;
}
static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W) {
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
}
template<typename scalar_t>
static inline scalar_t get_value_bounded(
const scalar_t* data,
scalar_t x,
scalar_t y,
int64_t W,
int64_t H,
int64_t sW,
int64_t sH,
GridSamplerPadding padding_mode,
bool align_corners) {
x = compute_coordinates(x, W, padding_mode, align_corners);
y = compute_coordinates(y, H, padding_mode, align_corners);
int64_t ix = static_cast<int64_t>(x);
int64_t iy = static_cast<int64_t>(y);
if (within_bounds_2d(iy, ix, H, W)) {
return data[iy * sH + ix * sW];
}
return static_cast<scalar_t>(0);
}
template<typename scalar_t>
static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w,
int64_t sH, int64_t sW, int64_t H, int64_t W,
scalar_t delta) {
if (within_bounds_2d(h, w, H, W)) {
data[h * sH + w * sW] += delta;
}
}
template<typename scalar_t>
static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w,
int64_t sD, int64_t sH, int64_t sW,
int64_t D, int64_t H, int64_t W,
scalar_t delta) {
if (within_bounds_3d(d, h, w, D, H, W)) {
data[d * sD + h * sH + w * sW] += delta;
}
}
template<typename scalar_t>
static inline void add_value_bounded(
scalar_t* data,
scalar_t x,
scalar_t y,
int64_t W,
int64_t H,
int64_t sW,
int64_t sH,
scalar_t delta,
GridSamplerPadding padding_mode,
bool align_corners) {
x = compute_coordinates(x, W, padding_mode, align_corners);
y = compute_coordinates(y, H, padding_mode, align_corners);
int64_t ix = static_cast<int64_t>(x);
int64_t iy = static_cast<int64_t>(y);
safe_add_2d(data, iy, ix, sH, sW, H, W, delta);
}
// Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
template<typename scalar_t>
static inline void get_cubic_coefficients_grad(
scalar_t coeffs[4],
scalar_t t) {
// Must be the same as forward calculation in
// aten/src/ATen/native/UpSample.h:get_cubic_upsample_coefficients
scalar_t A = -0.75;
scalar_t x;
x = -1 - t; // 1 < x = |-1 - tx| < 2
coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A;
x = -t; // x = |0 - tx| <= 1
coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;
x = 1 - t; // x = |1 - tx| <= 1
coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;
x = 2 - t; // 1 < x = |2 - tx| < 2
coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;
}
} // namespace at::native

View File

@ -0,0 +1,105 @@
#pragma once
// See NOTE: [Tensor vs. TensorBase]
// https://github.com/pytorch/pytorch/pull/66979
#include <ATen/core/TensorBase.h>
#include <ATen/native/TensorProperties.h>
#include <ATen/native/CanUse32BitIndexMath.h>
namespace at::native {
namespace detail {
enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic};
enum class GridSamplerPadding {Zeros, Border, Reflection};
} // namespace detail
using detail::GridSamplerInterpolation;
using detail::GridSamplerPadding;
// See NOTE [ grid_sampler Native Functions ].
inline void check_grid_sampler_common(
const TensorBase& input,
const TensorBase& grid
) {
auto input_opt = input.options();
auto grid_opt = grid.options();
TORCH_CHECK(
input.defined(),
"grid_sampler(): expected input to not be undefined");
TORCH_CHECK(
grid.defined(),
"grid_sampler(): expected grid to not be undefined");
TORCH_CHECK(
input_opt.device() == grid_opt.device(),
"grid_sampler(): expected input and grid to be on same device, but input "
"is on ", input_opt.device(), " and grid is on ", grid_opt.device());
TORCH_CHECK(
input_opt.layout() == kStrided && grid_opt.layout() == kStrided,
"grid_sampler(): expected input and grid to have torch.strided layout, but "
"input has ", input_opt.layout(), " and grid has ", grid_opt.layout());
TORCH_CHECK(
input.size(0) == grid.size(0),
"grid_sampler(): expected grid and input to have same batch size, but got "
"input with sizes ", input.sizes(), " and grid with sizes ", grid.sizes());
TORCH_CHECK(
grid.size(-1) == input.dim() - 2,
"grid_sampler(): expected grid to have size ", input.dim() - 2, " in last "
"dimension, but got grid with sizes ", grid.sizes());
for (const auto i : c10::irange(2, input.dim())) {
TORCH_CHECK(input.size(i) > 0,
"grid_sampler(): expected input to have non-empty spatial dimensions, "
"but input has sizes ", input.sizes(), " with dimension ", i, " being "
"empty");
}
}
// See NOTE [ grid_sampler Native Functions ].
inline void check_grid_sampler_2d(
const TensorBase& input,
const TensorBase& grid
) {
TORCH_CHECK(
input.dim() == 4 && input.dim() == grid.dim(),
"grid_sampler(): expected 4D input and grid with same number of "
"dimensions, but got input with sizes ", input.sizes(),
" and grid with sizes ", grid.sizes());
}
// See NOTE [ grid_sampler Native Functions ].
inline void check_grid_sampler_3d(
const TensorBase& input,
const TensorBase& grid,
int64_t interpolation_mode
) {
TORCH_CHECK(
input.dim() == 5 && input.dim() == grid.dim(),
"grid_sampler(): expected 5D input and grid with same number of "
"dimensions, but got input with sizes ", input.sizes(),
" and grid with sizes ", grid.sizes());
TORCH_CHECK(
!(input.dim() == 5 &&
static_cast<GridSamplerInterpolation>(interpolation_mode) ==
GridSamplerInterpolation::Bicubic),
"grid_sampler(): bicubic interpolation only supports 4D input");
}
// See NOTE [ grid_sampler Native Functions ].
// cudnn does not support inputs larger than 1024.
inline bool cond_cudnn_grid_sampler(
const TensorBase& input,
const TensorBase& grid
) {
return (
at::native::cudnn_is_acceptable(input) &&
at::native::cudnn_is_acceptable(grid) &&
at::native::canUse32BitIndexMath(input) &&
at::native::canUse32BitIndexMath(grid) &&
input.dim() == 4 &&
input.sym_size(1) <= 1024);
}
} // namespace at::native

View File

@ -0,0 +1,16 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
namespace at::native {
using histogramdd_fn = void(*)(const Tensor&, const std::optional<Tensor>&, bool, Tensor&, const TensorList&);
using histogramdd_linear_fn = void(*)(const Tensor&, const std::optional<Tensor>&, bool, Tensor&, const TensorList&, bool);
using histogram_select_outer_bin_edges_fn = void(*)(const Tensor& input, const int64_t N, std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges);
DECLARE_DISPATCH(histogramdd_fn, histogramdd_stub);
DECLARE_DISPATCH(histogramdd_linear_fn, histogramdd_linear_stub);
DECLARE_DISPATCH(histogram_select_outer_bin_edges_fn, histogram_select_outer_bin_edges_stub);
} // namespace at::native

View File

@ -0,0 +1,41 @@
#pragma once
#include <ATen/native/DispatchStub.h>
#include <c10/util/ArrayRef.h>
namespace at {
class Tensor;
class TensorBase;
struct TensorIterator;
struct TensorIteratorBase;
}
namespace c10 {
class Scalar;
}
namespace at::native {
using index_fn = void(*)(TensorIteratorBase &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides);
using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source);
using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride);
using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
using put_fn = void(*)(TensorIterator & iter, const TensorBase& self, const bool accumulate);
using take_fn = void(*)(TensorIterator & iter, const TensorBase& input);
using flip_fn = void(*)(TensorIterator &, const bool);
using masked_fill_fn = void(*)(TensorIterator &, const Scalar& scalar);
using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride);
using masked_scatter_fn = void(*)(TensorIterator &, const TensorBase &);
DECLARE_DISPATCH(index_fn, index_stub);
DECLARE_DISPATCH(index_fill_fn, index_fill_stub);
DECLARE_DISPATCH(index_copy_fn, index_copy_stub);
DECLARE_DISPATCH(index_put_fn, index_put_stub);
DECLARE_DISPATCH(put_fn, put_stub);
DECLARE_DISPATCH(take_fn, take_stub);
DECLARE_DISPATCH(flip_fn, flip_stub);
DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub);
DECLARE_DISPATCH(masked_select_fn, masked_select_serial_stub);
DECLARE_DISPATCH(masked_select_fn, masked_select_stub);
DECLARE_DISPATCH(masked_scatter_fn, masked_scatter_stub);
} // namespace at::native

View File

@ -0,0 +1,160 @@
#pragma once
#include <ATen/ExpandUtils.h>
#include <ATen/native/CanUse32BitIndexMath.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/core/IListRef.h>
#include <c10/util/irange.h>
namespace at::native {
[[noreturn]]
static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) {
TORCH_CHECK_INDEX(false, "The shape of the mask ", mask.sizes(), " at index ", maskIdx,
" does not match the shape of the indexed tensor ", self.sizes(), " at index ", idx);
}
static C10_UNUSED std::vector<Tensor> expandTensors(const Tensor & self, IOptTensorListRef indices) {
// If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors
std::vector<Tensor> result;
for (const auto& index_opt : indices) {
if (!index_opt.has_value()) {
result.emplace_back();
} else {
const auto& index = *index_opt;
if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
if (index.scalar_type() == kByte) {
TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \
" please use a dtype torch.bool instead.");
}
// The sizes of the ByteTensor mask or bool tensor must match the sizes of the
// corresponding dimensions in self
for (const auto j : c10::irange(index.dim())) {
int64_t srcIdx = static_cast<int64_t>(result.size() + j);
if (index.size(j) != self.size(srcIdx)) {
invalid_mask(self, srcIdx, index, j);
}
}
// Replace with nonzeros
auto nonzero = index.nonzero();
for (const auto j : c10::irange(index.dim())) {
result.emplace_back(nonzero.select(1, j));
}
} else {
result.emplace_back(index);
}
}
}
return result;
}
static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices, bool allow_int=false) {
for (const auto& tensor : indices) {
if (tensor.has_value() && tensor->defined()) {
auto scalarType = tensor->scalar_type();
if (allow_int) {
if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) {
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, int, byte or bool tensors");
}
} else {
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
}
}
}
}
}
inline torch::List<std::optional<Tensor>> toListOfOptionalTensors(ArrayRef<Tensor> list) {
torch::List<std::optional<Tensor>> result;
result.reserve(list.size());
for (const Tensor& a : list) {
result.push_back(a);
}
return result;
}
inline torch::List<std::optional<Tensor>> toListOfOptionalTensors(ArrayRef<IValue> list) {
torch::List<std::optional<Tensor>> result;
result.reserve(list.size());
for (const IValue& a : list) {
result.push_back(a.isTensor() ? std::optional<Tensor>(a.toTensor()) : std::optional<Tensor>());
}
return result;
}
static C10_UNUSED bool hasContiguousSubspace(TensorList tl) {
// true if all the non-null tensors are adjacent
auto isDefined = [](const Tensor & tensor){ return tensor.defined(); };
auto isNull = [](const Tensor & tensor){ return !tensor.defined(); };
auto start = std::find_if(tl.begin(), tl.end(), isDefined);
auto stop = std::find_if(tl.rbegin(), tl.rend(), isDefined);
auto it = std::find_if(start, stop.base(), isNull);
return it == stop.base();
}
// Transposes the tensor and indices together so that all the non-null indices
// index the first k dimensions of the tensor. Returns the transposed tensor
// and the reordered indices. For example:
// transposeToFront(tensor, {nullptr, a, nullptr, b})
// returns
// tensor.permute([1, 3, 0, 2]), {a, b, nullptr, nullptr}
static C10_UNUSED std::tuple<Tensor, std::vector<Tensor>>
transposeToFront(const Tensor& self, TensorList indices) {
std::vector<int64_t> dims;
std::vector<Tensor> transposedIndices;
dims.reserve(self.dim());
for (const auto i : c10::irange(self.dim())) {
if (indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back(indices[i]);
}
}
for (const auto i : c10::irange(self.dim())) {
if (!indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back();
}
}
return std::make_tuple(self.permute(dims), std::move(transposedIndices));
}
inline std::tuple<Tensor, std::vector<Tensor>, std::vector<int64_t>>
transposeToFrontAndInvPerm(const Tensor& self, TensorList indices) {
std::vector<int64_t> dims;
std::vector<int64_t> invPerm;
std::vector<Tensor> transposedIndices;
dims.reserve(self.dim());
invPerm.resize(self.dim());
for (const auto i : c10::irange(self.dim())) {
if (indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back(indices[i]);
}
}
for (const auto i : c10::irange(self.dim())) {
if (!indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back();
}
}
for (const auto i : c10::irange(self.dim())) {
invPerm[dims[i]] = i;
}
return std::make_tuple(self.permute(dims), std::move(transposedIndices), std::move(invPerm));
}
struct AdvancedIndex {
AdvancedIndex(const Tensor& src, TensorList indices);
Tensor src;
std::vector<Tensor> indices;
DimVector indexed_sizes;
DimVector indexed_strides;
int64_t dims_before;
int64_t dims_after;
};
} //namespace at::native

View File

@ -0,0 +1,46 @@
#pragma once
#include <ATen/native/DispatchStub.h>
#include <ATen/OpMathType.h>
#include <ATen/TensorIterator.h>
#include <c10/core/Scalar.h>
namespace at::native {
template <typename scalar_t>
C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(scalar_t weight) {
return std::abs(weight) < scalar_t(0.5);
}
template <typename scalar_t>
C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(c10::complex<scalar_t> weight) {
// Avoid the sqrt in abs(weight)
return (weight.real() * weight.real() + weight.imag() * weight.imag()) < scalar_t(0.25);
}
template <typename scalar_t, typename weight_t>
C10_HOST_DEVICE C10_ALWAYS_INLINE scalar_t lerp(scalar_t self_, scalar_t end_, weight_t weight_) {
using opmath_t = at::opmath_type<scalar_t>;
using opmath_weight_t = at::opmath_type<weight_t>;
opmath_t self = self_;
opmath_t end = end_;
opmath_weight_t weight = weight_;
// Conditional for better numeric. This has been discussed in
// https://github.com/pytorch/pytorch/pull/18871
return is_lerp_weight_small(weight)
? self + weight * (end - self)
: end - (end - self) * (opmath_t(1) - weight);
}
using lerp_fn_scalar = void (*)(
at::TensorIteratorBase& iter,
const Scalar& weight);
using lerp_fn_tensor = void (*)(
at::TensorIteratorBase& iter);
DECLARE_DISPATCH(lerp_fn_scalar, lerp_kernel_scalar_weight);
DECLARE_DISPATCH(lerp_fn_tensor, lerp_kernel_tensor_weight);
} // namespace at::native

View File

@ -0,0 +1,17 @@
#pragma once
#include <ATen/native/DispatchStub.h>
namespace c10 {
class Scalar;
}
namespace at {
struct TensorIterator;
}
namespace at::native {
using addr_fn = void (*)(TensorIterator &, const Scalar& beta, const Scalar& alpha);
DECLARE_DISPATCH(addr_fn, addr_stub);
} // namespace at::native

View File

@ -0,0 +1,623 @@
#pragma once
#include <c10/core/ScalarType.h>
#include <c10/util/irange.h>
#include <c10/util/Exception.h>
#include <c10/util/strides.h>
#include <ATen/core/Tensor.h>
#include <ATen/ExpandUtils.h>
#include <ATen/TensorUtils.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TransposeType.h>
#include <limits>
#include <type_traits>
#include <sstream>
#include <cstring>
#include <cctype>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/arange.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/zeros.h>
#endif
namespace at::native {
inline c10::MaybeOwned<Tensor> expect_resolved_conj(const Tensor& tensor) {
if (tensor.is_conj()) {
return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj());
} else {
return c10::MaybeOwned<Tensor>::borrowed(tensor);
}
}
inline DimVector batched_matrix_contiguous_strides(
const IntArrayRef sizes,
const bool f_contig = false) {
// f_contig chooses between the strides of a batch of Fortran (F-contiguous)
// and C-contiguous matrices
auto strides = c10::contiguous_strides(sizes);
auto dim = strides.size();
if (f_contig && dim >= 2) {
// Fix the strides of the last two dimensions, so that we return
// C-contiguous batches of F-contiguous matrices.
strides[dim - 1] = std::max(sizes[dim - 2], static_cast<int64_t>(1));
strides[dim - 2] = 1;
}
return strides;
}
/*
* Clones a Tensor so that the following conditions hold:
* If we think of a Tensor of having size (B, M, N), where B is any number
* of batch dimensions, then:
* - Each (M, N) matrix is in column major form
* - Let Tensor P have size (B, M, N) and Q have size (B, M', N').
* Then when laid out in memory, the M by N matrix starting at
* P.data_ptr()[B * M * N] is of the same corresponding batch as the M' by N'
* matrix starting at Q.data_ptr()[B * M' * N'].
*/
inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
// If src is already in batched column major format, then
// this will be efficient (no reordering of the data will occur)
// because the first transpose will make the tensor contiguous,
// and cloning a contiguous tensor is fast.
auto result = src.mT().clone(at::MemoryFormat::Contiguous);
result.transpose_(-2, -1);
return result;
}
/*
* contig chooses between C-contig (true) and F-contig (false)
*/
inline c10::MaybeOwned<Tensor> borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) {
return cond ? c10::MaybeOwned<Tensor>::borrowed(borrow)
: c10::MaybeOwned<Tensor>::owned(contig ? clone.clone(MemoryFormat::Contiguous)
: cloneBatchedColumnMajor(clone));
}
/*
* This method is designed to be a faster alternative to
* `cloneBatchedColumnMajor` with some additional features,
* namely:
* 1. It uses `copy` instead of `clone` which could be much faster.
* 2. `nrows` parameter used to create inputs with the number of rows larger
* than the original input, which is required for some LAPACK/MAGMA methods.
* 3. `desired_batch_size` is used to create copies with the batch size
* which is either the original batch size of the input, or its larger
* broadcasted shape.
*/
inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1,
at::OptionalIntArrayRef desired_batch_sizes = std::nullopt) {
nrows = (nrows == -1) ? src.size(-2) : nrows;
auto copy_sizes = desired_batch_sizes.has_value()
? desired_batch_sizes.value().vec()
: IntArrayRef(src.sizes().data(), src.dim() - 2).vec();
copy_sizes.insert(copy_sizes.end(), {nrows, src.size(-1)});
const auto copy_strides = batched_matrix_contiguous_strides(copy_sizes, /*f-contig*/true);
auto copy = at::empty_strided(copy_sizes, copy_strides, src.options());
copy.narrow(-2, 0, src.size(-2)).copy_(src);
return copy;
}
/*
* Given batches of matrices with arbitrary batch dim,
* computes the number of batches.
*/
inline int64_t batchCount(const Tensor& batched_matrices) {
int64_t result = 1;
for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
result *= batched_matrices.size(i);
}
return result;
}
// Computes the number of elements of a matrix in a batched matrix tensor
inline int64_t matrixStride(const Tensor& batched_matrices) {
return batched_matrices.size(-1) * batched_matrices.size(-2);
}
// Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig)
inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") {
TORCH_CHECK(A.dim() >= 2, f_name, ": The input tensor ", arg_name, " must have at least 2 dimensions.");
}
inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") {
checkIsMatrix(self, f_name, arg_name);
TORCH_CHECK(self.sym_size(-1) == self.sym_size(-2),
f_name,
": ", arg_name, " must be batches of square matrices, "
"but they are ", self.sym_size(-2), " by ", self.sym_size(-1), " matrices");
}
inline void checkInputsSolver(const Tensor& A,
const Tensor& B,
const bool left,
const char* const f_name) {
squareCheckInputs(A, f_name, "A");
checkIsMatrix(B, f_name, "B");
TORCH_CHECK(left ? A.size(-2) == B.size(-2) : A.size(-1) == B.size(-1),
f_name, ": Incompatible shapes of A and B for the equation ",
left ? "AX = B" : "XA = B",
" (", A.size(-2), "x", A.size(-1), " and ", B.size(-2), "x", B.size(-1), ")");
}
inline bool is_row_or_column_contiguous(const Tensor& t) {
// This could be made more general, similar to how it's checked in matmul, which would allow to
// ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
// We choose to be conservative for simplicity
return t.is_contiguous() || t.transpose(-2, -1).is_contiguous();
}
inline TransposeType to_transpose_type(const bool contig, const bool conj) {
if (conj) {
if (contig) { TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); }
else { return TransposeType::ConjTranspose; }
} else {
if (contig) { return TransposeType::NoTranspose; }
else { return TransposeType::Transpose; }
}
}
// This function is designed to be used with linear algebra methods that minimize
// L(ax - b) = 0, where L is generally the identity map (`solve`, for example)
// or the L2 norm (`lstsq`).
// It is expected that `a` and `b` are contiguous tensors of column-major matrices
// (so that a.view({-1, a.size(-2), a.size(-1)}) succeeds, same for `b`),
// with the following additional properties:
//
// 1. a.dim() == b.dim()
// 2. a.shape[:-2] broadcasts over b.shape[:-2]
// 3. a.size(i) <= b.size(i) for i=0,..., a.dim() - 3 (only for batch dimensions)
//
// MAGMA/LAPACK modify tensor `a` in-place, and the main goal of this method
// is to be memory efficient, which means that if there exists an index i such that
// a.shape[i] < b.shape[i], 0 <= i <= a.dim() - 3,
// then instead of materializing copies of `a` in the broadcasted shape, we keep
// a buffer copy of `a` along with flags that check whether specific batch dimension
// indices for `a` were already accessed. If they were, we copy the data from the buffer
// into `a`. The number of copies does not exceed
// prod(max(a.shape[:-2], b.shape[:-2]) - a.shape[:-2] + 1)
// and this value is attained by tensors with non-empty batch dimensions.
//
// func_t `f` is a callable that is being supplied with
// scalar_t* a_working_ptr, scalar_t* b_working_ptr, int64_t a_linear_batch_idx.
// a_working_ptr and b_working_ptr can directly be passed to LAPACK/MAGMA routines,
// and a_linear_batch_idx is an index in the 3d representation which corresponds to
// the memory a_working_ptr points to, in other words:
// a_working_ptr == a.view({-1, a.size(-2), a.size(-1)}.select(0, a_linear_batch_idx).data_ptr<scalar_t>();
// a_linear_batch_idx is useful to store metadata related to `a`, such as, for example,
// its rank or singular values (see linalg_lstsq).
template<typename scalar_t, typename func_t>
void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const func_t& f) {
IntArrayRef a_batch_sizes(a.sizes().data(), a.dim() - 2);
IntArrayRef b_batch_sizes(b.sizes().data(), b.dim() - 2);
auto a_linear_batch_idx = at::arange(batchCount(a)).view(a_batch_sizes);
auto b_linear_batch_idx = at::arange(batchCount(b)).view(b_batch_sizes);
TensorIterator iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.add_output(b_linear_batch_idx)
.add_input(a_linear_batch_idx)
.build();
auto m = a.size(-2);
auto n = a.size(-1);
auto a_3d = a.view({batchCount(a), m, n});
auto b_3d = b.view({batchCount(b), b.size(-2), b.size(-1)});
auto a_broadcasts_over_b = (a_batch_sizes != b_batch_sizes);
Tensor a_buffer, a_was_accessed, a_buffer_3d;
std::function<void(int64_t)> check_if_copy_needed_for_a
= [](int64_t /*a_curr_linear_batch_idx*/){};
if (a_broadcasts_over_b) {
a_buffer = at::empty_strided(a.sizes(), a.strides(), a.options())
.copy_(a);
a_was_accessed = at::zeros(batchCount(a), at::kBool);
a_buffer_3d = a_buffer.view({batchCount(a), m, n});
check_if_copy_needed_for_a = [&](int64_t a_curr_linear_batch_idx) {
auto* a_was_accessed_flag = a_was_accessed
.select(0, a_curr_linear_batch_idx)
.data_ptr<bool>();
if (!(*a_was_accessed_flag)) {
*a_was_accessed_flag = true;
}
else {
a_3d.select(0, a_curr_linear_batch_idx)
.copy_(a_buffer_3d.select(0, a_curr_linear_batch_idx));
}
};
}
auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
auto* b_batch_idx_ptr = data[0];
auto* a_batch_idx_ptr = data[1];
for (const auto elem C10_UNUSED : c10::irange(nelems)) {
auto b_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(b_batch_idx_ptr);
auto a_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(a_batch_idx_ptr);
check_if_copy_needed_for_a(a_curr_linear_batch_idx);
auto* a_working_ptr = a_3d.select(0, a_curr_linear_batch_idx)
.data_ptr<scalar_t>();
auto* b_working_ptr = b_3d.select(0, b_curr_linear_batch_idx)
.data_ptr<scalar_t>();
f(a_working_ptr, b_working_ptr, a_curr_linear_batch_idx);
b_batch_idx_ptr += strides[0];
a_batch_idx_ptr += strides[1];
}
};
iter.serial_for_each(loop, {0, batchCount(b)});
}
// Returns the epsilon value for floating types except half
inline double _get_epsilon(const ScalarType& sc_type) {
switch (sc_type) {
case at::ScalarType::Float:
return static_cast<double>(std::numeric_limits<float>::epsilon());
case at::ScalarType::Double:
return std::numeric_limits<double>::epsilon();
default:
AT_ERROR("This function doesn't handle types other than float and double");
}
}
// Validates input shapes and devices
// for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) {
TORCH_CHECK(self.device() == A.device(),
"Expected b and A to be on the same device, but found b on ",
self.device(), " and A on ", A.device(), " instead.");
TORCH_CHECK(self.scalar_type() == A.scalar_type(),
"Expected b and A to have the same dtype, but found b of type ",
self.scalar_type(), " and A of type ", A.scalar_type(), " instead.");
TORCH_CHECK(A.size(-1) == A.size(-2),
"A must be batches of square matrices, "
"but they are ", A.size(-2), " by ", A.size(-1), " matrices");
TORCH_CHECK(A.size(-1) == self.size(-2),
"Incompatible matrix sizes for ", name, ": each A "
"matrix is ", A.size(-1), " by ", A.size(-1),
" but each b matrix is ", self.size(-2), " by ", self.size(-1));
}
inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) {
auto dtype = t.scalar_type();
TORCH_CHECK((at::isFloatingType(dtype) || at::isComplexType(dtype)),
f_name, ": Expected a floating point or complex tensor as input. Got ", dtype);
if (!allow_low_precision_dtypes) {
TORCH_CHECK(dtype == kFloat || dtype == kDouble || dtype == kComplexFloat || dtype == kComplexDouble,
f_name, ": Low precision dtypes not supported. Got ", dtype);
}
}
// Checks if all the Tensors in a TensorList are of the same dimensions
inline void checkAllSameDim(TensorList tensors, int64_t dim) {
for (auto &t : tensors) {
TORCH_CHECK(t.dim() == dim, "Tensor dimension is ", t.dim(), ", expected ", dim, " instead.");
}
}
inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) {
// broadcast the batch dimensions of arg1 and arg2.
IntArrayRef arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2);
IntArrayRef arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2);
std::vector<int64_t> expand_batch_portion = infer_size(arg1_batch_sizes, arg2_batch_sizes);
std::vector<int64_t> arg1_expand_size({expand_batch_portion});
arg1_expand_size.insert(arg1_expand_size.end(), { arg1.size(-2), arg1.size(-1) });
std::vector<int64_t> arg2_expand_size({expand_batch_portion});
arg2_expand_size.insert(arg2_expand_size.end(), { arg2.size(-2), arg2.size(-1) });
return std::make_tuple(std::move(arg1_expand_size), std::move(arg2_expand_size));
}
inline std::tuple<Tensor,Tensor> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) {
// If there's no name we assume we don't want to check the errors
if (name != nullptr) {
linearSolveCheckInputs(arg1, arg2, name);
}
auto [arg1_expand_size, arg2_expand_size] = at::native::_linalg_broadcast_batch_dims(arg1, arg2);
auto arg1_broadcasted = arg1_expand_size == arg1.sizes() ? arg1 : arg1.expand(arg1_expand_size);
auto arg2_broadcasted = arg2_expand_size == arg2.sizes() ? arg2 : arg2.expand(arg2_expand_size);
return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
}
inline std::vector<int64_t> broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) {
IntArrayRef t1_batch_sizes(t1.sizes().data(), n_batch_dims);
IntArrayRef t2_batch_sizes(t2.sizes().data(), n_batch_dims);
auto broadcasted_batch_sizes = infer_size(t1_batch_sizes, t2_batch_sizes);
return broadcasted_batch_sizes;
}
// Return a permutation with the given axes moved to the end.
inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
const std::vector<int64_t> a = axes.vec();
const int64_t ndim = self.ndimension();
std::vector<int64_t> perm;
for (const auto i : c10::irange(ndim)) {
auto it = std::find(a.begin(), a.end(), i);
if (it == a.end()) {
perm.push_back(i);
}
}
for (auto i : a) {
perm.push_back(i);
}
TORCH_CHECK((int64_t)perm.size() == ndim,
"duplicate or invalid axis in 'dim' argument for tensor with ndim==", ndim);
return self.permute(perm);
}
// parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
inline std::tuple<bool, bool> _parse_qr_mode(c10::string_view mode) {
bool compute_q;
bool reduced;
if (mode == "reduced") {
compute_q = true;
reduced = true;
} else if (mode == "complete") {
compute_q = true;
reduced = false;
} else if (mode == "r") {
compute_q = false;
reduced = true; // this is actually irrelevant in this mode
} else {
TORCH_CHECK(false, "qr received unrecognized mode '", mode,
"' but expected one of 'reduced' (default), 'r', or 'complete'");
}
return std::make_tuple(compute_q, reduced);
}
// Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition
inline std::tuple<DimVector, DimVector, int64_t> _compute_geometry_for_Q(
const Tensor& input,
bool reduced) {
int64_t m = input.size(-2), n = input.size(-1);
int64_t n_columns_q;
// We need to compute the required size of Q based on the `reduced` option
DimVector q_sizes(input.sizes());
if (!reduced && m > n) {
q_sizes[input.dim() - 1] = m;
n_columns_q = m;
} else {
q_sizes[input.dim() - 1] = n;
n_columns_q = std::min(m, n);
}
auto q_strides = batched_matrix_contiguous_strides(q_sizes, /*f-contig*/true);
return std::make_tuple(q_sizes, q_strides, n_columns_q);
}
inline bool svd_uses_cusolver(const Tensor& A) {
// if cusolver is available, it is used unconditionally
return A.is_cuda()
&& at::globalContext().hasCuSOLVER()
&& at::globalContext().linalgPreferredBackend() != at::LinalgBackend::Magma;
}
// Function used instead of .to so that the original strides are retained
// .to doesn't retain strides and make the output tensor contiguous
inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) {
auto strided_to = at::empty_strided(original_tensor.sizes(),
original_tensor.strides(),
options);
strided_to.copy_(original_tensor);
return strided_to;
}
// Creates a dimension permutation array that can be given to `at::permute()`, which will shift
// the two specified dimensions to the end of a tensor, without changing the order of
// the other dimensions. `dim1` will be placed at the very end, and `dim0` will be
// placed just to the left of it.
//
// For instance, given a 4-D tensor, dimensions 1 and 3 can be shifted to the end by
// calling `create_dim_backshift_permutation(1, 3, 4)`. The resulting vector will
// be `vec(0, 2, 1, 3)`.
inline std::vector<int64_t> create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) {
TORCH_CHECK(
(dim0 != dim1) && (dim0 < ndim) && (dim0 >= 0) && (dim1 < ndim) && (dim1 >= 0),
"duplicate or invalid dimensions");
std::vector<int64_t> permutation(ndim);
int64_t cur_permuted_dim = 0;
for (const auto dim_ind : c10::irange(ndim)) {
if ((dim_ind != dim0) && (dim_ind != dim1)) {
permutation[cur_permuted_dim++] = dim_ind;
}
}
permutation[cur_permuted_dim++] = dim0;
permutation[cur_permuted_dim] = dim1;
return permutation;
}
// Creates a dimension permutation array that can be given to `at::permute()`, which
// will reverse a given permutation.
// The reverse permutation array is created by swapping the indices and their
// associated values from the given permutation array.
inline std::vector<int64_t> create_reverse_permutation(std::vector<int64_t> permutation) {
int64_t ndim = permutation.size();
std::vector<int64_t> reverse_permutation(ndim);
for (const auto dim_ind : c10::irange(ndim)) {
reverse_permutation[permutation[dim_ind]] = dim_ind;
}
return reverse_permutation;
}
// Compute R-work array size for MAGMA/LAPACK cgesdd/zgesdd
// See https://github.com/Reference-LAPACK/lapack/blob/122506cd8b6ce050a200920c3d4c0b153b150fd8/SRC/cgesdd.f#L186
inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
auto mn = std::min(m, n);
auto mx = std::max(m, n);
if (jobz == 'N') {
#ifdef __APPLE__
// According to `vecLib.framework/Headers/clapack.h` Accelerate.framework is based on LAPACK 3.2.1
return 7 * mn;
#else
// These setting is valid for on LAPACK 3.6+
return 5 * mn;
#endif
}
if (mx > 10 * mn) {
return 5 * mn * mn + 5 * mn;
}
return std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn);
}
// This function checks whether the uplo argument input is valid
// Allowed strings are "u", "U", "l", "L"
inline void checkUplo(const c10::string_view uplo) {
// To use std::toupper safely with plain chars (or signed chars), the argument should first be converted to unsigned char
char uplo_uppercase = static_cast<char>(std::toupper(static_cast<unsigned char>(uplo[0])));
TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'),
"Expected UPLO argument to be 'L' or 'U', but got ", uplo);
}
inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
TORCH_CHECK(
result.device() == input.device(),
fn_name,
": Expected ", result_name, " and input tensors to be on the same device, but got ",
result_name, " on ", result.device(), " and input on ", input.device());
}
// Check the dtype of result and input tensors (for _out variants).
// Most linear algebra functions have the same dtype for input and output
// (either floating or complex type input), so we can check whether input's dtype can be casted to result's dtype.
// According to https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
// c10::canCast is used for checking the "safe copy" dtype requirements.
inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
bool can_cast = c10::canCast(input.scalar_type(), result.scalar_type());
TORCH_CHECK(
can_cast,
fn_name,
": Expected ", result_name, " to be safely castable from ", input.scalar_type(), " dtype, but got ",
result_name, " with dtype ", result.scalar_type());
}
// Alternatively, we can check whether the specific expected output type (result_type) can be safely casted to out tensor dtype (out_type)
inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") {
bool can_cast = c10::canCast(result_type, out_type);
TORCH_CHECK(
can_cast,
fn_name,
": Expected ", out_name, " to be safely castable from ", result_type, " dtype, but got ",
out_name, " with dtype ", out_type);
}
inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f_name, const c10::string_view tol_name) {
TORCH_CHECK(!at::isComplexType(tol.scalar_type()),
f_name, ": ", tol_name, " tensor of complex type is not supported. Got ", tol.scalar_type());
}
/*
Two types of 'other' tensors are supported when solving
a system of linear equations matmul(input, x) = other:
* 1-dimensional (1D) tensor or batch of 1D tensors (vector case)
* 2-dimensional (2D) tensor or batch of 2D tensors (matrix case).
The original torch.solve supported only the matrix case, while NumPy works for both cases.
For the batched input we need to be able to distinguish them.
Let input.shape = (batch_dimensions, m, n), then 'other' is of vector type if other.shape == (batch_dimensions, m).
This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389
*/
inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) {
auto expected_batched_rhs_shape = SymIntArrayRef(input.sym_sizes().data(), input.dim() - 1); // input.shape[:-1]
bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sym_sizes().equals(expected_batched_rhs_shape));
return vector_case;
}
/*
Computes linear indices for a tensor with original_shape to access its elements like it was a materialized broadcast tensor.
*/
inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) {
TensorOptions options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
return at::arange(numel, options).view(original_shape).broadcast_to(broadcast_shape).contiguous();
}
class BroadcastLinearIndices {
private:
Tensor linear_indices_;
bool is_broadcasting_;
public:
BroadcastLinearIndices(
int64_t numel,
IntArrayRef original_shape,
IntArrayRef broadcast_shape) : is_broadcasting_(!original_shape.equals(broadcast_shape)) {
// The assumption is that the broadcast_shape is a materialized broadcast
// shape of the original_shape. We need to compute the linear indices
// compatible with the original_shape to access the elements in the original
// tensor corresponding to the broadcast tensor.
if (is_broadcasting_) {
linear_indices_ =
get_linear_indices(numel, original_shape, broadcast_shape);
}
}
int64_t operator()(int64_t broadcast_linear_index) {
return is_broadcasting_
? linear_indices_.data_ptr<int64_t>()[broadcast_linear_index]
: broadcast_linear_index;
}
};
inline bool is_blas_compatible_column_major_order(const Tensor& input) {
IntArrayRef input_strides = input.strides();
IntArrayRef input_sizes = input.sizes();
auto ndim = input.dim();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
if (ndim > 3) {
return input.transpose(-2, -1).is_contiguous();
}
auto leading_dimension = input_strides[ndim - 1];
auto rows = input_sizes[ndim - 2];
bool batch_stride_compatible = true;
if (ndim == 3) {
auto cols = input_sizes[ndim - 1];
batch_stride_compatible =
input_strides[ndim - 3] >= leading_dimension * cols;
}
return (input_strides[ndim - 2] == 1) &&
(leading_dimension >= std::max<int64_t>(1, rows)) &&
batch_stride_compatible;
}
inline bool is_blas_compatible_row_major_order(const Tensor& input) {
IntArrayRef input_strides = input.strides();
IntArrayRef input_sizes = input.sizes();
auto ndim = input.dim();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
if (ndim > 3) {
return input.is_contiguous();
}
auto leading_dimension = input_strides[ndim - 2];
auto cols = input_sizes[ndim - 1];
bool batch_stride_compatible = true;
if (ndim == 3) {
auto rows = input_sizes[ndim - 2];
batch_stride_compatible =
input_strides[ndim - 3] >= leading_dimension * rows;
}
return (input_strides[ndim - 1] == 1) &&
(leading_dimension >= std::max<int64_t>(1, cols)) &&
batch_stride_compatible;
}
} // namespace at::native

View File

@ -0,0 +1,69 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/TensorUtils.h>
namespace at::native {
inline void multilabel_margin_loss_shape_check(
int64_t& nframe,
int64_t& dim,
const int64_t& ndims,
const Tensor& input,
const Tensor& target) {
TORCH_CHECK(
(ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
"Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
input.sizes());
if (ndims <= 1) {
nframe = 1;
dim = ndims == 0 ? 1 : input.size(0);
TORCH_CHECK(
target.dim() <= 1 && target.numel() == dim,
"inconsistent target size: ", target.sizes(), " for input of size: ",
input.sizes());
} else {
nframe = input.size(0);
dim = input.size(1);
TORCH_CHECK(
target.dim() == 2 && target.size(0) == nframe &&
target.size(1) == dim,
"inconsistent target size: ", target.sizes(), " for input of size: ",
input.sizes());
}
}
inline void multi_margin_loss_shape_check(
int64_t& nframe,
int64_t& dim,
const int64_t& ndims,
const Tensor& input,
const Tensor& target,
const std::optional<Tensor>& weight) {
TORCH_CHECK(
(ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
"Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
input.sizes());
if (ndims <= 1) {
nframe = 1;
dim = ndims == 0 ? 1 : input.size(0);
} else {
nframe = input.size(0);
dim = input.size(1);
}
TORCH_CHECK(
target.dim() <= 1 && target.numel() == nframe,
"inconsistent target size, expected ", nframe, " but got ",
target.sizes());
if (weight && weight->defined()) {
TORCH_CHECK(
weight->dim() <= 1 && weight->numel() == dim,
"inconsistent weight size, expected ", dim, " but got ",
weight->sizes());
}
}
} // namespace at::native

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,71 @@
#pragma once
namespace at {
// views and their in-place version ops
#define TORCH_VIEW_FNS(m) \
m.impl("as_strided_", torch::CppFunction::makeFallthrough()); \
m.impl("detach", torch::CppFunction::makeFallthrough()); \
m.impl("detach_", torch::CppFunction::makeFallthrough()); \
m.impl("diagonal", torch::CppFunction::makeFallthrough()); \
m.impl("expand", torch::CppFunction::makeFallthrough()); \
m.impl("expand_as", torch::CppFunction::makeFallthrough()); \
m.impl("movedim.int", torch::CppFunction::makeFallthrough()); \
m.impl("movedim.intlist", torch::CppFunction::makeFallthrough()); \
m.impl("narrow", torch::CppFunction::makeFallthrough()); \
m.impl("permute", torch::CppFunction::makeFallthrough()); \
m.impl("select.Dimname", torch::CppFunction::makeFallthrough()); \
m.impl("select.int", torch::CppFunction::makeFallthrough()); \
m.impl("squeeze", torch::CppFunction::makeFallthrough()); \
m.impl("squeeze_", torch::CppFunction::makeFallthrough()); \
m.impl("transpose.int", torch::CppFunction::makeFallthrough()); \
m.impl("transpose.Dimname", torch::CppFunction::makeFallthrough()); \
m.impl("transpose_", torch::CppFunction::makeFallthrough()); \
m.impl("t", torch::CppFunction::makeFallthrough()); \
m.impl("t_", torch::CppFunction::makeFallthrough()); \
m.impl("real", torch::CppFunction::makeFallthrough()); \
m.impl("imag", torch::CppFunction::makeFallthrough()); \
m.impl("view_as_real", torch::CppFunction::makeFallthrough()); \
m.impl("unflatten.int", torch::CppFunction::makeFallthrough()); \
m.impl("unflatten.Dimname", torch::CppFunction::makeFallthrough()); \
m.impl("unfold", torch::CppFunction::makeFallthrough()); \
m.impl("unsqueeze", torch::CppFunction::makeFallthrough()); \
m.impl("unsqueeze_", torch::CppFunction::makeFallthrough()); \
m.impl("view_as", torch::CppFunction::makeFallthrough()); \
m.impl("unbind.int", torch::CppFunction::makeFallthrough()); \
m.impl("unbind.Dimname", torch::CppFunction::makeFallthrough()); \
m.impl("split.Tensor", torch::CppFunction::makeFallthrough()); \
m.impl("split_with_sizes", torch::CppFunction::makeFallthrough()); \
m.impl("swapaxes", torch::CppFunction::makeFallthrough()); \
m.impl("swapdims", torch::CppFunction::makeFallthrough()); \
m.impl("chunk", torch::CppFunction::makeFallthrough()); \
m.impl("reshape", torch::CppFunction::makeFallthrough()); \
m.impl("alias", torch::CppFunction::makeFallthrough()); \
m.impl("hsplit.int", torch::CppFunction::makeFallthrough()); \
m.impl("hsplit.array", torch::CppFunction::makeFallthrough()); \
m.impl("dsplit.int", torch::CppFunction::makeFallthrough()); \
m.impl("dsplit.array", torch::CppFunction::makeFallthrough()); \
m.impl("vsplit.int", torch::CppFunction::makeFallthrough()); \
m.impl("vsplit.array", torch::CppFunction::makeFallthrough()); \
m.impl("conj", torch::CppFunction::makeFallthrough()); \
m.impl("_conj", torch::CppFunction::makeFallthrough()); \
m.impl("_unsafe_view", torch::CppFunction::makeFallthrough()); \
m.impl("resize_", torch::CppFunction::makeFallthrough());
#define TENSOR_UTILITIES_AND_CONSTRUCTORS(m) \
m.impl("empty_like", torch::CppFunction::makeFallthrough()); \
m.impl("empty.memory_format", torch::CppFunction::makeFallthrough()); \
m.impl("empty.out", torch::CppFunction::makeFallthrough()); \
m.impl("empty_strided", torch::CppFunction::makeFallthrough()); \
m.impl("full_like", torch::CppFunction::makeFallthrough()); \
m.impl("stride.int", torch::CppFunction::makeFallthrough()); \
m.impl("stride.Dimname", torch::CppFunction::makeFallthrough()); \
m.impl("size.int", torch::CppFunction::makeFallthrough()); \
m.impl("size.Dimname", torch::CppFunction::makeFallthrough()); \
m.impl("is_complex", torch::CppFunction::makeFallthrough()); \
m.impl("is_floating_point", torch::CppFunction::makeFallthrough()); \
m.impl("requires_grad_", torch::CppFunction::makeFallthrough());
}
#define TORCH_VIEW_FNS_NATIVE_FN_REGISTRATION(m) \
m.impl("as_strided", torch::CppFunction::makeFallthrough()); \
m.impl("view", torch::CppFunction::makeFallthrough());

View File

@ -0,0 +1,157 @@
#include <ATen/core/Tensor.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/Resize.h>
#include <c10/util/irange.h>
#include <torch/library.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/clone.h>
#include <utility>
#endif
namespace at::native {
// This fallback should only be used for operations that are self inverse and have a corresponding tensor
// bit (internally implemented using DispatchKey) to maintain the state on tensor using tensor bit.
// Currently there are two tensor bits that trigger this fallback: conjugate bit and negative bit.
// Conjugate bit is set on a tensor when `.conj()` is called and neg bit is set on a tensor when `.conj().imag` is called.
// NOTE: To use this fallback, `clone` and `copy_` should fully understand and be able to correctly handle the semantic of your math bit.
struct MathOpFallback {
MathOpFallback(DispatchKey key_, string op_name_) : key(key_), op_name(std::move(op_name_)) {}
virtual bool is_bit_set(const Tensor&) = 0;
void fallback_impl(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
/*
Situations to handle:
1. Out-of-place operation. Easy: materialize all inputs and
call it a day.
2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_().
Materialize other inputs as in (1).
3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
Materialize other inputs as in (1).
It is important to be able to tell if we READ from an argument and if we
WRITE to an argument. Conservative approach is to assume that we always
READ from an argument, but in out= operations you can skip
conjugating inputs on entry that never get used. In the current schema we
can't easily tell if the operation is in in-place or out= operation.
Note:
1. Mutable tensorlists containing tensors whose math bit set to true are disallowed.
2. Mutable tensors with math bit set to true are unconditionally cloned to ensure
correct behavior in the case when the mutable tensor shares memory with non mutable arguments.
If we were to in-place resolve the math bit for mutable inputs, then the non-mutable inputs sharing partial or full memory
with these mutable inputs would read into wrong values in the following cases:
1. Non mutable inputs have their math bit set to false.
2. Math bit for mutable input(s) is resolved before the non mutable inputs (with bit set to true and sharing memory
with one or more mutable arg(s)) are cloned.
At the end, the final value of the mutable arguments from the stack are copied into the original input mutable tensor inputs.
*/
const auto& arguments = op.schema().arguments();
const auto num_arguments = arguments.size();
const auto stack_start = stack->size() - num_arguments;
std::optional<bool> is_write;
for (const auto i : c10::irange(num_arguments)) {
// Three possible states:
// 1. alias_info has no value --> out-of-place operation
// 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation
// 3. alias_info does have a value, alias_info->is_write=False --> view operation
const AliasInfo* alias_info = arguments[i].alias_info();
if (alias_info != nullptr) {
if (is_write.has_value()) {
TORCH_CHECK(*is_write == alias_info->isWrite(),
"Unsupported operator for ", op_name, " fallback: ", op.schema().name(),
op_name, " fallback doesn't work for operators with a mix "
"mutable and non-mutable inputs that alias with outputs, "
"this must be implemented manually. "
"If you got this error on a core op, please report a bug to PyTorch.");
} else {
is_write = alias_info->isWrite();
}
}
}
if (is_write.has_value() && !*is_write) {
// We assume that view operators automatically handle the math bit
// correctly by propagating the dispatch key in key_set.
// This is not necessarily always right, so you should test these cases.
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
return;
}
// Mutable inputs with math bit set to True and their clones
std::vector<std::pair<Tensor, Tensor>> mutable_inputs_with_their_clones;
for (const auto i : c10::irange(num_arguments)) {
auto& ivalue = (*stack)[stack_start + i];
if (!(ivalue.isTensor() || ivalue.isTensorList())) {
continue;
}
const auto& argument = arguments[i];
bool mut_arg = false;
if (argument.alias_info()) {
// Was already tested by is_write loop above
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
mut_arg = true;
}
if (ivalue.isTensor()) {
if (!is_bit_set(ivalue.toTensor())) {
continue;
}
auto tensor = std::move(ivalue).toTensor();
auto resolved_tensor = at::clone(tensor);
if (mut_arg) {
TORCH_CHECK(mutable_inputs_with_their_clones.empty(), op_name, " fallback does not support operators with more than one mutable tensors with ",
op_name, "bit set to true.");
mutable_inputs_with_their_clones.emplace_back(std::move(tensor), resolved_tensor);
}
(*stack)[stack_start + i] = std::move(resolved_tensor);
} else if (ivalue.isTensorList()) {
auto tensors = std::move(ivalue).toTensorList();
for(const auto j : c10::irange(tensors.size())) {
const auto& tensor = tensors[j];
if (!is_bit_set(tensor)) {
continue;
}
TORCH_CHECK(!mut_arg, " fallback doesn't currently support mutable TensorLists with ",
op_name, " inputs. Please materialize all the ", op_name, " input tensor(s) in the mutable TensorList inputs before calling ",
op.schema().name());
tensors[j] = at::clone(tensor);
}
(*stack)[stack_start + i] = std::move(tensors);
}
}
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
TORCH_INTERNAL_ASSERT(mutable_inputs_with_their_clones.size() <= 1);
for (std::pair<Tensor, Tensor> mut_tensors: mutable_inputs_with_their_clones) {
auto& mutable_input = mut_tensors.first;
auto& cloned_mutable_input = mut_tensors.second;
auto& ivalue = (*stack)[stack_start];
auto returned_output = std::move(ivalue).toTensor();
// sanity check to ensure that the tensor in stack aliases the cloned_mutable_input
TORCH_INTERNAL_ASSERT(cloned_mutable_input.is_same(returned_output));
// necessary for out= arg
at::native::resize_output(mutable_input, returned_output.sizes());
mutable_input.copy_(returned_output);
(*stack)[stack_start] = std::move(mutable_input);
}
}
virtual ~MathOpFallback() = default;
DispatchKey key;
string op_name;
};
} // namespace at::native

View File

@ -0,0 +1,97 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/Parallel.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Pool.h>
namespace at::native {
inline void check_max_pool1d(
const Tensor& self,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
TORCH_CHECK(
self.dim() == 2 || self.dim() == 3,
"max_pool1d() Expected 2D or 3D input tensor, but got ", self.sym_sizes());
TORCH_CHECK(
kernel_size.size() == 1,
"max_pool1d() kernel_size must be an int, list of ints or tuple of ints of size 1 but got size ",
kernel_size.size());
TORCH_CHECK(
stride.empty() || stride.size() == 1,
"max_pool1d() stride must be None, an int, list of ints, or tuple of ints of size 1 but got size ",
stride.size());
TORCH_CHECK(
padding.size() == 1,
"max_pool1d() padding must be an int, list of ints, or tuple of ints of size 1 but got size ",
padding.size());
TORCH_CHECK(
dilation.size() == 1,
"max_pool1d() dilation must be an int, list of ints or tuple of ints of size 1 but got size ",
dilation.size());
// If stride=None then set it to kernel_size
if (stride.empty()) {
stride = kernel_size;
}
TORCH_CHECK(
kernel_size[0] > 0,
"max_pool1d() kernel_size must be greater than zero, but got ",
kernel_size[0]);
TORCH_CHECK(
stride[0] > 0, "max_pool1d() stride must be greater than zero, but got ", stride[0]);
TORCH_CHECK(
padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]);
TORCH_CHECK(
padding[0] <= kernel_size[0] / 2,
"max_pool1d() padding should be at most half of kernel size, but got padding=",
padding[0],
" and kernel_size=",
kernel_size[0]);
TORCH_CHECK(
dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);
const int64_t OW = pooling_output_shape(self.sym_size(-1).guard_int(__FILE__, __LINE__), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
TORCH_CHECK(OW > 0, "max_pool1d() Invalid computed output size: ", OW);
}
// TODO(Heitor) Template by dimension
struct PoolingParams1D {
int64_t NB; // Number of batches
int64_t NC; // Number of channels
int64_t IW; // Input width
int64_t OW; // Output width
int64_t KW; // Kernel width
int64_t SJ; // Column stride
int64_t PJ; // Column padding
int64_t DJ; // Column dilation
// Return index of input element for the given kernel and output index
inline int64_t index(int64_t kj, int64_t oj) const {
return oj * SJ + kj * DJ - PJ;
}
// Return index of first output within bounds for this kernel index
inline int64_t valid_output_start(int64_t kj) const {
int64_t ij = index(kj, 0);;
return ij < 0 ? at::divup(-ij, SJ) : 0;
}
// Return index one past last output within bounds for this kernel index
inline int64_t valid_output_end(int64_t kj) const {
int64_t ij = index(kj, OW - 1);
return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW;
}
};
using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&);
DECLARE_DISPATCH(pooling_fn, max_pool1d_stub);
} // namespace at::native

View File

@ -0,0 +1,27 @@
#include <ATen/core/TensorBase.h>
#include <algorithm>
#include <vector>
namespace at::native {
inline int64_t ensure_nonempty_dim(int64_t dim) {
return std::max<int64_t>(dim, 1);
}
inline int64_t ensure_nonempty_size(const TensorBase &t, int64_t dim) {
return t.dim() == 0 ? 1 : t.size(dim);
}
inline int64_t ensure_nonempty_stride(const TensorBase &t, int64_t dim) {
return t.dim() == 0 ? 1 : t.stride(dim);
}
using IdxVec = std::vector<int64_t>;
inline IdxVec ensure_nonempty_vec(IdxVec vec) {
if (vec.empty()) {
vec.push_back(1);
}
return vec;
}
} // namespace at::native

View File

@ -0,0 +1,26 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <c10/util/irange.h>
#include <ATen/core/IListRef.h>
namespace at::native {
// This file contains non-symbolic signatures for ops that we have sym-intified the signature of.
// However, in certain cases (such as static runtime), we call the native versions of the ops directly.
// In those cases, we will duplicate the signature here with non-symbolic ints, and also duplicate the C++ implementation.
TORCH_API at::Tensor reshape(const at::Tensor& self, at::IntArrayRef proposed_shape);
TORCH_API at::Tensor narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length);
TORCH_API at::Tensor _sparse_coo_tensor_unsafe(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, std::optional<at::ScalarType> dtype=std::nullopt, std::optional<at::Layout> layout=std::nullopt, std::optional<at::Device> device=std::nullopt, std::optional<bool> pin_memory=std::nullopt, std::optional<bool> is_coalesced=std::nullopt);
TORCH_API at::Tensor nll_loss(const at::Tensor & self, const at::Tensor & target, const std::optional<at::Tensor>& weight_opt, int64_t reduction, int64_t ignore_index);
TORCH_API at::Tensor nll_loss2d(const at::Tensor & self, const at::Tensor & target, const std::optional<at::Tensor>& weight_opt, int64_t reduction, int64_t ignore_index);
// The below ops don't get a duplicated C++ implementation.
// They are backward ops, which make them very unlikely to be called directly
// by external code (at::native::trace_backward).
// They get their own declaration for BC purposes however.
TORCH_API at::Tensor _embedding_bag_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const std::optional<at::Tensor> & per_sample_weights, int64_t padding_idx=-1);
TORCH_API at::Tensor _embedding_bag_sparse_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const std::optional<at::Tensor> & per_sample_weights, int64_t padding_idx=-1);
TORCH_API at::Tensor value_selecting_reduction_backward(const at::Tensor & grad, int64_t dim, const at::Tensor & indices, at::IntArrayRef sizes, bool keepdim);
TORCH_API at::Tensor trace_backward(const at::Tensor & grad, at::IntArrayRef sizes);
TORCH_API at::Tensor index_select_backward(const at::Tensor & grad, at::IntArrayRef self_sizes, int64_t dim, const at::Tensor & index);
TORCH_API at::Tensor select(const at::Tensor& self, int64_t dim, int64_t index);
TORCH_API std::vector<Tensor> tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim);
} // namespace at::native

View File

@ -0,0 +1,19 @@
#pragma once
#include <ATen/TensorIterator.h>
#include <ATen/native/DispatchStub.h>
namespace at::native {
using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);
enum class BatchNormBackend {
Native,
Cudnn,
Miopen,
};
TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps);
} // namespace at::native

View File

@ -0,0 +1,62 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
namespace at::native {
using padding_fn = void (*)(const Tensor&, const Tensor&, IntArrayRef);
// reflection padding
DECLARE_DISPATCH(padding_fn, reflection_pad1d_kernel);
DECLARE_DISPATCH(padding_fn, reflection_pad1d_backward_kernel);
DECLARE_DISPATCH(padding_fn, reflection_pad2d_kernel);
DECLARE_DISPATCH(padding_fn, reflection_pad2d_backward_kernel);
DECLARE_DISPATCH(padding_fn, reflection_pad3d_kernel);
DECLARE_DISPATCH(padding_fn, reflection_pad3d_backward_kernel);
// replication padding
DECLARE_DISPATCH(padding_fn, replication_pad1d_kernel);
DECLARE_DISPATCH(padding_fn, replication_pad1d_backward_kernel);
DECLARE_DISPATCH(padding_fn, replication_pad2d_kernel);
DECLARE_DISPATCH(padding_fn, replication_pad2d_backward_kernel);
DECLARE_DISPATCH(padding_fn, replication_pad3d_kernel);
DECLARE_DISPATCH(padding_fn, replication_pad3d_backward_kernel);
namespace padding {
template <int dim>
inline void check_valid_input(const Tensor& input, IntArrayRef padding) {
TORCH_CHECK(padding.size() == 2 * dim,
"padding size is expected to be ", 2 * dim,
", but got: ", padding.size());
int input_dim = input.dim();
bool is_batch_mode = input_dim == (dim + 2);
bool valid_batch_mode = is_batch_mode;
bool valid_non_batch_mode = !is_batch_mode;
if (is_batch_mode) {
// allow batch size of 0-dim.
for (const auto d : c10::irange(1, input_dim)) {
valid_batch_mode = valid_batch_mode && input.size(d) != 0;
}
} else {
for (const auto d : c10::irange(0, input_dim)) {
valid_non_batch_mode = valid_non_batch_mode && input.size(d) != 0;
}
}
// allow empty batch size but not other dimensions.
TORCH_CHECK(valid_batch_mode || valid_non_batch_mode,
"Expected ", dim + 1, "D or ", dim + 2,
"D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
input.sizes());
}
} // namespace padding
} // at::native

View File

@ -0,0 +1,47 @@
#include <ATen/core/Tensor.h>
#include <c10/util/Exception.h>
namespace at {
namespace native {
inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_factor) {
TORCH_CHECK(self.dim() >= 3,
"pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
self.dim(), " dimension(s)");
TORCH_CHECK(upscale_factor > 0,
"pixel_shuffle expects a positive upscale_factor, but got ",
upscale_factor);
int64_t c = self.size(-3);
int64_t upscale_factor_squared = upscale_factor * upscale_factor;
TORCH_CHECK(c % upscale_factor_squared == 0,
"pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
"upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared);
}
inline void check_pixel_unshuffle_shapes(const Tensor& self, int64_t downscale_factor) {
TORCH_CHECK(
self.dim() >= 3,
"pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
self.dim(),
" dimension(s)");
TORCH_CHECK(
downscale_factor > 0,
"pixel_unshuffle expects a positive downscale_factor, but got ",
downscale_factor);
int64_t h = self.size(-2);
int64_t w = self.size(-1);
TORCH_CHECK(
h % downscale_factor == 0,
"pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=",
h,
" is not divisible by ",
downscale_factor);
TORCH_CHECK(
w % downscale_factor == 0,
"pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=",
w,
" is not divisible by ",
downscale_factor);
}
}} // namespace at::native

View File

@ -0,0 +1,28 @@
// Ternary and higher-order pointwise operations
#pragma once
#include <ATen/native/DispatchStub.h>
namespace c10 {
class Scalar;
}
namespace at {
struct TensorIterator;
struct TensorIteratorBase;
namespace native {
using pointwise_fn = void (*)(TensorIterator&, const Scalar& scalar);
using structured_pointwise_fn = void (*)(TensorIteratorBase&, const Scalar& scalar);
using pointwise_fn_double = void (*)(TensorIterator&, const Scalar&, double);
DECLARE_DISPATCH(structured_pointwise_fn, addcmul_stub);
DECLARE_DISPATCH(structured_pointwise_fn, addcdiv_stub);
DECLARE_DISPATCH(pointwise_fn_double, smooth_l1_backward_stub);
DECLARE_DISPATCH(pointwise_fn_double, huber_backward_stub);
DECLARE_DISPATCH(pointwise_fn, mse_backward_stub);
} // namespace native
} // namespace at

View File

@ -0,0 +1,355 @@
#include <ATen/core/Tensor.h>
#include <ATen/div_rtn.h>
#include <ATen/TensorUtils.h>
#include <ATen/native/DispatchStub.h>
#include <c10/util/irange.h>
#include <utility>
#pragma once
namespace at::native {
using max_pool2d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input,
int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH);
using max_pool2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel);
DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel);
// averge pooling has same signature for forward and backward
using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH,
int64_t dW, int64_t dH, int64_t padW, int64_t padH, bool count_include_pad, std::optional<int64_t> divisor_override);
using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH,
int dW, int dH, int padW, int padH, bool count_include_pad, std::optional<int64_t> divisor_override);
DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_kernel);
DECLARE_DISPATCH(avg_pool2d_backward_fn, avg_pool2d_backward_kernel);
// averge pooling has same signature for forward and backward
using avg_pool3d_fn = void(*)(const Tensor& output, const Tensor& input,
int64_t kW, int64_t kH, int64_t kD, int64_t dW, int64_t dH, int64_t dD,
int64_t padW, int64_t padH, int64_t padD, bool count_include_pad,
std::optional<int64_t> divisor_override);
using avg_pool3d_backward_fn = void(*)(const Tensor& output, const Tensor& input,
int kW, int kH, int kD, int dW, int dH, int dD,
int padW, int padH, int padD, bool count_include_pad,
std::optional<int64_t> divisor_override);
DECLARE_DISPATCH(avg_pool3d_fn, avg_pool3d_kernel);
DECLARE_DISPATCH(avg_pool3d_backward_fn, avg_pool3d_backward_kernel);
using max_pool3d_fn = void(*)(Tensor& output, Tensor& indices, const Tensor& input,
int kW, int kH, int kD, int dW, int dH, int dD, int pW, int pH, int pD, int dilationW, int dilationH, int dilationD);
using max_pool3d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
DECLARE_DISPATCH(max_pool3d_fn, max_pool3d_kernel);
DECLARE_DISPATCH(max_pool3d_backward_fn, max_pool3d_backward_kernel);
namespace {
template <typename dest_t, typename src_t>
inline dest_t
safe_downcast(src_t v)
{
TORCH_CHECK(std::numeric_limits<dest_t>::min() <= v && v <= std::numeric_limits<dest_t>::max(),
"integer out of range");
return static_cast<dest_t>(v);
}
template<typename T>
inline T pooling_output_shape_pad_lr(
T inputSize, T kernelSize, T pad_l, T pad_r, T stride, T dilation,
bool ceil_mode) {
T outputSize = div_rtn<T>(
inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 +
(ceil_mode ? stride - 1 : 0), stride) + 1;
if (ceil_mode) {
// ensure that the last pooling starts inside the image
// needed to avoid problems in ceil mode
if ((outputSize - 1) * stride >= inputSize + pad_l) {
--outputSize;
}
}
return outputSize;
}
template<typename T>
inline T pooling_output_shape(
T inputSize, T kernelSize, T pad, T stride, T dilation, bool ceil_mode) {
TORCH_CHECK(stride != 0, "stride should not be zero");
TORCH_CHECK(pad >= 0,
"pad must be non-negative, but got pad: ", pad);
TORCH_CHECK(pad <= ((kernelSize - 1) * dilation + 1) / 2,
"pad should be at most half of effective kernel size, but got pad=",
pad, ", kernel_size=", kernelSize, " and dilation=", dilation)
return pooling_output_shape_pad_lr(
inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode);
}
template <typename T>
std::pair<T, T> _pooling_same_mode_padding_lr(
T inputSize, T kernelSize, T stride, T dilation) {
// NOTE: with strides, the output shape is ceil(inputSize/stride)
auto total_padding = T(dilation) * (kernelSize - 1);
// Prefer symmetric padding if possible
if (stride > 2 && (total_padding % 2 == 1)) {
// The floor in the output size calculation gives us a little wiggle room
auto wiggle_room = inputSize % stride - 1;
if (wiggle_room > 0) {
total_padding = total_padding - 1;
}
}
auto left = total_padding / 2;
return {left, total_padding - left};
}
inline std::pair<int64_t, int64_t> pooling_same_mode_padding_lr(
int64_t inputSize, int64_t kernelSize, int64_t stride, int64_t dilation) {
return _pooling_same_mode_padding_lr(inputSize, kernelSize, stride, dilation);
}
inline std::pair<c10::SymInt, c10::SymInt> pooling_same_mode_padding_lr(
c10::SymInt inputSize, c10::SymInt kernelSize, c10::SymInt stride, c10::SymInt dilation) {
return _pooling_same_mode_padding_lr(std::move(inputSize), std::move(kernelSize), std::move(stride), std::move(dilation));
}
// AveragePool2d/DilatedMaxPool2d (forward)
inline void
pool2d_shape_check(
const Tensor& input,
int64_t kH, int64_t kW, int64_t dH, int64_t dW, int64_t padH, int64_t padW, int64_t dilationH, int64_t dilationW,
int64_t nInputPlane,
int64_t inputHeight, int64_t inputWidth,
int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
{
const int64_t ndim = input.ndimension();
#ifndef STRIP_ERROR_MESSAGES
const int64_t nOutputPlane = nInputPlane;
#endif
TORCH_CHECK(kW > 0 && kH > 0,
"kernel size should be greater than zero, but got ",
"kH: ", kH, " kW: ", kW);
TORCH_CHECK(dW > 0 && dH > 0,
"stride should be greater than zero, but got "
"dH: ", dH, " dW: ", dW);
TORCH_CHECK(dilationH > 0 && dilationW > 0,
"dilation should be greater than zero, but got ",
"dilationH: ", dilationH, " dilationW: ", dilationW);
bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
if (memory_format == at::MemoryFormat::ChannelsLast){
// Expect tensor in NHWC format and allow 0-dim only for N.
TORCH_CHECK((ndim == 4 && valid_dims && input.size(3) != 0),
"Expected 4D (batch mode) tensor expected for input with channels_last layout"
" with optional 0 dim batch size for input, but got: ", input.sizes());
} else {
TORCH_CHECK((ndim == 3 && input.size(0) != 0 && valid_dims) ||
(ndim == 4 && valid_dims && input.size(3) != 0),
"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:",
input.sizes());
}
TORCH_CHECK(kW/2 >= padW && kH/2 >= padH,
"pad should be smaller than or equal to half of kernel size, but got ",
"padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH);
TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1,
"Given input size: (",
nInputPlane, "x", inputHeight, "x", inputWidth, "). ",
"Calculated output size: (",
nOutputPlane, "x", outputHeight, "x", outputWidth, "). ",
"Output size is too small");
}
// DilatedMaxPool2d (backward)
inline void
max_pool2d_backward_shape_check(
const Tensor& input,
const Tensor& gradOutput,
const Tensor& indices,
int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
int64_t nInputPlane,
int64_t inputHeight, int64_t inputWidth,
int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
{
pool2d_shape_check(
input,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format);
const int64_t ndim = input.ndimension();
const int64_t nOutputPlane = nInputPlane;
check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane);
check_dim_size(gradOutput, ndim, ndim-2, outputHeight);
check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
check_dim_size(indices, ndim, ndim-3, nOutputPlane);
check_dim_size(indices, ndim, ndim-2, outputHeight);
check_dim_size(indices, ndim, ndim-1, outputWidth);
}
// AveragePool2d (backward)
inline void
avg_pool2d_backward_shape_check(
const Tensor& input,
const Tensor& gradOutput,
int64_t /*nbatch*/,
int kH, int kW, int dH, int dW, int padH, int padW,
int64_t nInputPlane,
int64_t inputHeight, int64_t inputWidth,
int64_t outputHeight, int64_t outputWidth,
MemoryFormat memory_format)
{
pool2d_shape_check(
input,
kH, kW, dH, dW, padH, padW, 1, 1,
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
memory_format);
const int64_t ndim = input.ndimension();
const int64_t nOutputPlane = nInputPlane;
check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane);
check_dim_size(gradOutput, ndim, ndim-2, outputHeight);
check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
}
// AveragePool3d/DilatedMaxPool3d (forward)
inline void
pool3d_shape_check(
const Tensor& input,
int64_t nslices,
int kT, int kH, int kW,
int dT, int dH, int dW,
int pT, int pH, int pW,
int dilationT, int dilationH, int dilationW,
int64_t itime, int64_t iheight, int64_t iwidth,
int64_t otime, int64_t oheight, int64_t owidth,
const char *fn_name,
bool check_input_size=false)
{
const int64_t ndim = input.ndimension();
TORCH_CHECK(kT > 0 && kW > 0 && kH > 0,
"kernel size should be greater than zero, but got ",
"kT: ", kT, " kH: ", kH, " kW: ", kW);
TORCH_CHECK(dT > 0 && dW > 0 && dH > 0,
"stride should be greater than zero, but got ",
"dT: ", dT, " dH: ", dH, " dW: ", dW);
TORCH_CHECK(dilationT > 0 && dilationW > 0 && dilationH > 0,
"dilation should be greater than zero, but got ",
"dilationT: ", dilationT, " dilationH: ", dilationH, " dilationW: ", dilationW);
TORCH_CHECK(ndim == 4 || ndim == 5,
fn_name, ": Expected 4D or 5D tensor for input, but got: ", input.sizes());
for (const auto i : c10::irange(ndim)) {
if (ndim == 5 && i == 0) {
// size of batch-dim can be 0.
continue;
}
TORCH_CHECK(
input.size(i) > 0,
fn_name,
": Expected input's non-batch dimensions to have positive length,"
" but input has a shape of ",
input.sizes(),
" and non-batch dimension ",
input.size(i),
" has length zero!")
}
if (check_input_size) { // AveragePool3d
TORCH_CHECK(itime >= kT && iheight >= kH && iwidth >= kW,
"input image ", "(T: ", itime, " H: ", iheight, " W: ", iwidth, ") smaller than ",
"kernel size ", "(kT: ", kT, " kH: ", kH, " kW: ", kW, ")");
}
TORCH_CHECK(kT/2 >= pT && kW/2 >= pW && kH/2 >= pH,
"pad should be smaller than or equal to half of kernel size, but got "
"kT: ", kT, " kW: ", kW, " kH: ", kH, " padT: ", pT, " padW: ", pW, " padH: ", pH);
TORCH_CHECK(otime >= 1 && owidth >= 1 && oheight >= 1,
"Given input size: (",
nslices,"x", itime, "x", iheight, "x", iwidth, "). ",
"Calculated output size: (",
nslices, "x", otime, "x", oheight, "x", owidth, "). ",
"Output size is too small");
}
inline void
max_pool3d_backward_shape_check(
const Tensor& input,
const Tensor& gradOutput,
const Tensor& indices,
int64_t nslices,
int kT, int kH, int kW,
int dT, int dH, int dW,
int pT, int pH, int pW,
int dilationT, int dilationH, int dilationW,
int64_t itime, int64_t iheight, int64_t iwidth,
int64_t otime, int64_t oheight, int64_t owidth,
const char* fn_name)
{
const int64_t ndim = input.ndimension();
pool3d_shape_check(
input,
nslices,
kT, kH, kW,
dT, dH, dW,
pT, pH, pW,
dilationT, dilationH, dilationW,
itime, iheight, iwidth,
otime, oheight, owidth, fn_name);
check_dim_size(gradOutput, ndim, ndim-4, nslices);
check_dim_size(gradOutput, ndim, ndim-3, otime);
check_dim_size(gradOutput, ndim, ndim-2, oheight);
check_dim_size(gradOutput, ndim, ndim-1, owidth);
check_dim_size(indices, ndim, ndim-4, nslices);
check_dim_size(indices, ndim, ndim-3, otime);
check_dim_size(indices, ndim, ndim-2, oheight);
check_dim_size(indices, ndim, ndim-1, owidth);
}
inline void
avg_pool3d_backward_shape_check(
const Tensor& input,
const Tensor& gradOutput,
int64_t nslices,
int kT, int kH, int kW,
int dT, int dH, int dW,
int pT, int pH, int pW,
int64_t itime, int64_t iheight, int64_t iwidth,
int64_t otime, int64_t oheight, int64_t owidth,
const char *fn_name)
{
const int64_t ndim = input.ndimension();
pool3d_shape_check(
input,
nslices,
kT, kH, kW,
dT, dH, dW,
pT, pH, pW,
1, 1, 1,
itime, iheight, iwidth,
otime, oheight, owidth,
fn_name, true);
check_dim_size(gradOutput, ndim, ndim-4, nslices);
check_dim_size(gradOutput, ndim, ndim-3, otime);
check_dim_size(gradOutput, ndim, ndim-2, oheight);
check_dim_size(gradOutput, ndim, ndim-1, owidth);
}
} // anonymous namespace
} // namespace at::native

View File

@ -0,0 +1,69 @@
#pragma once
#include <ATen/native/DispatchStub.h>
namespace c10 {
class Scalar;
}
namespace at {
struct TensorIterator;
struct TensorIteratorBase;
namespace native {
#if defined(__CUDACC__) || defined(__HIPCC__)
#define HOST_DEVICE __host__ __device__
#else
#define HOST_DEVICE
#endif
// integral power in pytorch allows for negative exponents, giving truncated integral results.
// e.g. since 2**-1==0.5, the truncated integral result is zero. 1**negative_exponent is the
// only non-zero result.
template <class T,
typename std::enable_if<std::is_integral<T>::value, T>::type* = nullptr>
inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) {
T result = 1;
while (b) {
if (b & 1) {
result *= a;
}
b /= 2;
a *= a;
}
return result;
}
template <class T,
typename std::enable_if<std::is_integral<T>::value && !std::is_signed<T>::value, T>::type* = nullptr>
inline HOST_DEVICE T powi(T a, T b) {
return powi_impl(a, b);
}
template <class T,
typename std::enable_if<std::is_integral<T>::value && std::is_signed<T>::value, T>::type* = nullptr>
inline HOST_DEVICE T powi(T a, T b) {
if ( b < 0 ) {
if ( a == 1 ) {
return 1;
} else if ( a == -1 ) {
auto negative = (-b) % static_cast<T>(2);
return negative ? -1 : 1;
} else {
return 0;
}
}
return powi_impl(a, b);
}
using pow_tensor_tensor_fn = void (*)(TensorIteratorBase&);
using pow_tensor_scalar_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
DECLARE_DISPATCH(pow_tensor_tensor_fn, pow_tensor_tensor_stub);
DECLARE_DISPATCH(pow_tensor_scalar_fn, pow_tensor_scalar_stub);
} // namespace native
} // namespace at

View File

@ -0,0 +1,53 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
namespace at::native {
using lstm_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool, bool);
using rnn_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool, bool);
using lstm_packed_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool);
using rnn_packed_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool);
DECLARE_DISPATCH(lstm_fn, lstm_cudnn_stub);
DECLARE_DISPATCH(lstm_fn, lstm_miopen_stub);
DECLARE_DISPATCH(lstm_fn, lstm_mkldnn_stub);
DECLARE_DISPATCH(rnn_fn, gru_cudnn_stub);
DECLARE_DISPATCH(rnn_fn, gru_miopen_stub);
DECLARE_DISPATCH(rnn_fn, rnn_tanh_cudnn_stub);
DECLARE_DISPATCH(rnn_fn, rnn_tanh_miopen_stub);
DECLARE_DISPATCH(rnn_fn, rnn_relu_cudnn_stub);
DECLARE_DISPATCH(rnn_fn, rnn_relu_miopen_stub);
DECLARE_DISPATCH(lstm_packed_fn, lstm_packed_cudnn_stub);
DECLARE_DISPATCH(lstm_packed_fn, lstm_packed_miopen_stub);
DECLARE_DISPATCH(rnn_packed_fn, gru_packed_cudnn_stub);
DECLARE_DISPATCH(rnn_packed_fn, gru_packed_miopen_stub);
DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_cudnn_stub);
DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_miopen_stub);
DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_cudnn_stub);
DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_miopen_stub);
inline void check_attributes(const Tensor& input, const TensorList& params, const TensorList& hiddens, bool check_dtype=false) {
auto input_device = input.device();
auto input_dtype = input.scalar_type();
auto check_tensors = [&](const std::string& name, const Tensor& t) {
if (!t.defined()) return;
auto t_device = t.device();
TORCH_CHECK(input_device == t_device,
"Input and ", name, " tensors are not at the same device, found input tensor at ",
input_device, " and ", name, " tensor at ", t_device);
if (check_dtype) {
auto t_dtype = t.scalar_type();
TORCH_CHECK(input_dtype == t_dtype,
"Input and ", name, " tensors are not the same dtype, found input tensor with ",
input_dtype, " and ", name, " tensor with ", t_dtype);
}
};
for (const auto& h : hiddens) check_tensors("hidden", h);
for (const auto& p : params) check_tensors("parameter", p);
}
} // namespace at::native

View File

@ -0,0 +1,12 @@
#include <ATen/native/DispatchStub.h>
#include <c10/core/Scalar.h>
namespace at {
struct TensorIterator;
namespace native {
DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, const Scalar&), arange_stub);
DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, int64_t), linspace_stub);
}} // namespace at::native

View File

@ -0,0 +1,16 @@
#pragma once
#include <ATen/native/DispatchStub.h>
namespace at {
class Tensor;
}
namespace at::native {
using reduce_all_fn = void (*)(Tensor & result, const Tensor & self);
using reduce_min_max_fn = void (*)(Tensor & max_result, Tensor & min_result, const Tensor & self);
DECLARE_DISPATCH(reduce_all_fn, min_all_stub);
DECLARE_DISPATCH(reduce_all_fn, max_all_stub);
} // namespace at::native

View File

@ -0,0 +1,56 @@
#pragma once
#include <ATen/native/DispatchStub.h>
#include <c10/util/ArrayRef.h>
#include <optional>
namespace c10 {
class Scalar;
}
namespace at {
struct TensorIterator;
class Tensor;
}
namespace at::native {
using reduce_fn = void(*)(TensorIterator &);
DECLARE_DISPATCH(reduce_fn, sum_stub);
DECLARE_DISPATCH(reduce_fn, nansum_stub);
DECLARE_DISPATCH(reduce_fn, prod_stub);
DECLARE_DISPATCH(reduce_fn, mean_stub);
DECLARE_DISPATCH(reduce_fn, and_stub);
DECLARE_DISPATCH(reduce_fn, or_stub);
DECLARE_DISPATCH(reduce_fn, min_values_stub);
DECLARE_DISPATCH(reduce_fn, max_values_stub);
DECLARE_DISPATCH(reduce_fn, argmax_stub);
DECLARE_DISPATCH(reduce_fn, argmin_stub);
using reduce_std_var_function =
void (*)(TensorIterator&, double correction, bool take_sqrt);
DECLARE_DISPATCH(reduce_std_var_function, std_var_stub);
using reduce_norm_fn =
void (*)(Tensor&, const Tensor&, const c10::Scalar&, std::optional<int64_t>);
DECLARE_DISPATCH(reduce_norm_fn, norm_kernel);
using reduce_fn_flag = void(*)(TensorIterator &, const c10::Scalar&);
DECLARE_DISPATCH(reduce_fn_flag, norm_stub);
using structured_cum_fn = void (*)(const Tensor&, const Tensor&, int64_t);
using cum_fn = void (*)(Tensor&, const Tensor&, int64_t);
DECLARE_DISPATCH(structured_cum_fn, cumsum_stub);
DECLARE_DISPATCH(structured_cum_fn, cumprod_stub);
DECLARE_DISPATCH(cum_fn, logcumsumexp_stub);
DECLARE_DISPATCH(void (*)(const Tensor&, int64_t, bool, Tensor&, Tensor&), aminmax_stub);
DECLARE_DISPATCH(void (*)(const Tensor&, Tensor&, Tensor&), aminmax_allreduce_stub);
// Used in cuda/Normalization.cu
TORCH_API std::tuple<Tensor&,Tensor&> var_mean_out(
Tensor &result1, Tensor &result2, const Tensor &self, IntArrayRef dim,
int64_t correction, bool keepdim);
} // namespace at::native

View File

@ -0,0 +1,455 @@
#pragma once
#include <limits>
#include <ATen/core/Tensor.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/NonEmptyUtils.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <c10/core/ScalarType.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/scalar_tensor.h>
#endif
namespace at::native {
// Maximum and minimum possible scalar values, including infinities
template <typename scalar_t>
constexpr scalar_t upper_bound() {
using lim = std::numeric_limits<scalar_t>;
return lim::has_infinity ? lim::infinity() : lim::max();
}
template <typename scalar_t>
constexpr scalar_t lower_bound() {
using lim = std::numeric_limits<scalar_t>;
return lim::has_infinity ? -lim::infinity() : lim::lowest();
}
inline Tensor restride_dim(
const Tensor& src, int64_t dim,
IntArrayRef replacement_shape
) {
auto strides = ensure_nonempty_vec(src.strides().vec());
strides[dim] = 0;
return src.as_strided(replacement_shape, strides);
}
inline void _dimreduce_setup(const Tensor &result, const Tensor &self,
int64_t dim) {
IntArrayRef self_sizes = self.sizes();
std::vector<int64_t> result_sizes;
result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end());
result_sizes[dim] = 1;
result.resize_(result_sizes);
}
inline bool _dimreduce_return_trivial(const Tensor &result, const Tensor &self,
const Scalar& ident, int64_t dim, bool keepdim) {
if (self.numel() == 1 && self.ndimension() == 0) {
result.resize_({});
result.fill_(self);
return true;
}
// Return identity
if (self.numel() == 0) {
_dimreduce_setup(result, self, dim);
result.fill_(ident);
if (!keepdim) result.squeeze_(dim);
return true;
}
return false;
}
inline bool _dimreduce_return_trivial_no_ident(Tensor &result, const Tensor &self,
int64_t /*dim*/, bool /*keepdim*/, const char* /*fn_name*/) {
if (self.numel() == 1 && self.ndimension() == 0) {
result.resize_({});
result.fill_(self);
return true;
}
return false;
}
inline std::optional<Tensor> _allreduce_return_trivial(
const Tensor& self,
const Scalar& ident) {
// Return identity
if (self.numel() == 0) {
return at::scalar_tensor(ident, self.options());
}
return std::nullopt;
}
#define OPTION_TYPE_EQUALITY_CHECK(option, out, self) \
{ \
TORCH_CHECK(\
out.option() == self.option(),\
"expected ", #option, " ",\
self.option(),\
" but found ", out.option())\
}
inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) {
OPTION_TYPE_EQUALITY_CHECK(scalar_type, out, self);
OPTION_TYPE_EQUALITY_CHECK(device, out.options(), self.options());
OPTION_TYPE_EQUALITY_CHECK(layout, out.options(), self.options());
}
inline Tensor integer_upcast(const Tensor& self, std::optional<ScalarType> dtype) {
ScalarType scalarType = self.scalar_type();
TORCH_CHECK(!isBarebonesUnsignedType(scalarType), "integer upcasting for uint16, uint32 and uint64 is not currently implemented");
ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType);
return self.toType(upcast_scalarType);
}
using DimMask = TensorIterator::DimMask;
inline DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
if (opt_dims.has_value()) {
return DimVector(opt_dims.value());
} else {
std::vector<int64_t> all_dims(ndim);
std::iota(all_dims.begin(), all_dims.end(), 0);
return DimVector(all_dims);
}
}
inline DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim, bool allow_empty_dims=false) {
DimMask mask;
if (opt_dims.has_value()) {
auto dims = opt_dims.value();
if (dims.empty() && !allow_empty_dims) {
mask = DimMask().flip();
} else {
mask = at::dim_list_to_bitset(dims, ndim);
}
} else {
mask = DimMask().flip();
}
return mask;
}
inline DimVector shape_from_dim_mask(const Tensor& self, DimMask mask, bool keepdim) {
auto shape = DimVector(self.sizes());
for (int dim = shape.size() - 1; dim >= 0; dim--) {
if (mask[dim]) {
if (keepdim) {
shape[dim] = 1;
} else {
shape.erase(shape.begin() + dim);
}
}
}
return shape;
}
inline void resize_reduction_result(
Tensor& result, const Tensor& self, DimMask mask, bool keepdim,
ScalarType /*dtype*/)
{
auto shape = shape_from_dim_mask(self, mask, keepdim);
TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
at::native::resize_output(result, shape);
}
inline Tensor create_reduction_result(
const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype
) {
DimMask mask = make_dim_mask(dim, self.dim());
auto shape = shape_from_dim_mask(self, mask, keepdim);
return at::empty(shape, self.options().dtype(dtype));
}
inline Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask, bool keepdim) {
if (keepdim) {
return result;
}
auto shape = DimVector(result.sizes());
auto stride = DimVector(result.strides());
for (const auto dim : c10::irange(ndim)) {
if (mask[dim]) {
shape.insert(shape.begin() + dim, 1);
stride.insert(stride.begin() + dim, 0);
}
}
return result.as_strided(shape, stride);
}
inline TensorIterator make_reduction(
const char* name, Tensor& result, const Tensor& self,
at::OptionalIntArrayRef dim_opt,
bool keepdim, ScalarType in_dtype, ScalarType out_dtype) {
// check that result type and dtype match if provided
TORCH_CHECK(
!result.defined() || result.scalar_type() == out_dtype,
name, ": provided dtype must match dtype of result. Got ",
toString(result.scalar_type()),
" and ",
toString(out_dtype),
".");
// dim={} performs an all-reduce, same as dim=None
IntArrayRef dim = dim_opt.value_or(IntArrayRef{});
int64_t ndim = self.dim();
auto mask = make_dim_mask(dim, ndim);
resize_reduction_result(result, self, mask, keepdim, out_dtype);
auto viewed_result = review_reduce_result(result, ndim, mask, keepdim);
namedinference::propagate_names_for_reduction(result, self, dim, keepdim);
if (self.scalar_type() == in_dtype) {
return TensorIterator::reduce_op(viewed_result, self);
}
return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
}
inline C10_UNUSED TensorIterator make_reduction(
const char* name, Tensor& result, const Tensor& self,
at::OptionalIntArrayRef dim, bool keepdim, ScalarType out_dtype) {
// special case for type promotion in mixed precision, improves computational
// efficiency.
// not generalize this to common mismatched input/output types to avoid cross
// product of templated kernel launches.
const bool gpu_lowp_to_f32 = (
self.is_cuda() && (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) && out_dtype == kFloat);
auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type()
: self.is_complex() ? c10::toComplexType(out_dtype)
: out_dtype;
return make_reduction(name, result, self, dim, keepdim, in_dtype, out_dtype);
}
inline TensorIterator make_reduction(
const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
at::OptionalIntArrayRef dim_opt, bool keepdim, ScalarType dtype1,
ScalarType dtype2) {
// check that result type and dtype match if provided
TORCH_CHECK(
(!result1.defined() || result1.scalar_type() == dtype1) && (!result2.defined() || result2.scalar_type() == dtype2),
name, ": provided dtype must match dtype of result. Got ",
toString(result1.scalar_type()), toString(result2.scalar_type()),
" and ",
toString(dtype1), toString(dtype2),
".");
// dim={} performs an all-reduce, same as dim=None
auto dim = dim_opt.value_or(IntArrayRef{});
int64_t ndim = self.dim();
DimMask mask = make_dim_mask(dim, ndim);
resize_reduction_result(result1, self, mask, keepdim, dtype1);
auto viewed_result1 = review_reduce_result(result1, ndim, mask, keepdim);
resize_reduction_result(result2, self, mask, keepdim, dtype2);
auto viewed_result2 = review_reduce_result(result2, ndim, mask, keepdim);
namedinference::propagate_names_for_reduction(result1, self, dim, keepdim);
namedinference::propagate_names_for_reduction(result2, self, dim, keepdim);
// special case for type promotion in mixed precision, improves computational
// efficiency.
// We don't generalize this to common mismatched input/output types to avoid cross
// product of templated kernel launches.
if (self.scalar_type() == dtype1 ||
(self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) {
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
}
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
}
inline C10_UNUSED TensorIterator make_reduction(
const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype) {
return make_reduction(name, result1, result2, self, dim, keepdim, dtype, dtype);
}
inline void zero_numel_check_dims(const Tensor& self, const int64_t dim, const char *fn_name) {
if (self.ndimension() == 0) {
TORCH_CHECK_INDEX(dim == 0 || dim == -1, fn_name,
": Expected reduction dim -1 or 0 for scalar but got ", dim);
}
else {
TORCH_CHECK_INDEX(self.size(dim) != 0, fn_name,
": Expected reduction dim ", dim, " to have non-zero size.");
}
}
inline void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, const char *fn_name) {
TORCH_CHECK(
!dim.empty(),
fn_name, ": Expected reduction dim to be specified for input.numel() == 0. ",
"Specify the reduction dim with the 'dim' argument.");
for (const int64_t d : dim) {
zero_numel_check_dims(self, d, fn_name);
}
}
inline std::vector<int64_t> get_zero_numel_tensor_size(
const Tensor& self,
const int64_t dim,
const bool keepdim,
const char* fn_name) {
TORCH_INTERNAL_ASSERT(self.numel() == 0, fn_name, ": Expected self.numel() == 0.");
zero_numel_check_dims(self, dim, fn_name);
std::vector<int64_t> sizes;
if (keepdim) {
sizes = self.sizes().vec();
sizes[dim] = 1;
}
else {
for (const auto d : c10::irange(self.dim())) {
if (d != dim) {
sizes.push_back(self.sizes()[d]);
}
}
}
return sizes;
}
// Resize the result tensor and indices when result.numel() == 0 depending on values of
// dim and keepdim for returning tensors containing reduction results.
// This function should be called when you are reducing a zero-numel tensor and want to
// resize the output and return it. This function exists for resizing zero-numel
// tensors when the size of the reduction dimension is non-zero.
inline C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_indices,
const Tensor& self, const int64_t dim,
const bool keepdim, const char *fn_name) {
auto sizes = get_zero_numel_tensor_size(self, dim, keepdim, fn_name);
at::native::resize_output(result, sizes);
at::native::resize_output(result_indices, sizes);
}
inline ScalarType get_dtype_from_self(
const Tensor& self,
const std::optional<ScalarType>& dtype,
bool promote_integers) {
if (dtype.has_value()) {
return dtype.value();
}
ScalarType src_type = self.scalar_type();
if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) {
return kLong;
}
return src_type;
}
inline ScalarType get_dtype_from_result(Tensor& result, std::optional<ScalarType> dtype) {
TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
if (dtype.has_value()) {
return dtype.value();
} else {
return result.scalar_type();
}
}
} // namespace at::native
namespace at::meta {
inline C10_UNUSED DimVector get_reduction_shape(
const Tensor& self,
IntArrayRef dims,
bool keepdim,
bool allow_empty_dims=false) {
auto mask = native::make_dim_mask(dims, self.dim(), allow_empty_dims);
return native::shape_from_dim_mask(self, mask, keepdim);
}
inline void resize_reduction(
impl::MetaBase& meta,
const Tensor& self,
OptionalIntArrayRef opt_dims,
bool keepdim,
ScalarType out_dtype,
bool allow_empty_dims=false) {
DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim());
maybe_wrap_dims(dims_, self.dim());
auto shape = get_reduction_shape(self, dims_, keepdim, allow_empty_dims);
if (self.layout() == kStrided) {
meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
} else if (shape.empty()) {
meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype).layout(kStrided));
} else {
TORCH_CHECK(false, "resize_reduction: support for output with ", self.layout(), " layout is not implemented yet");
}
namedinference::propagate_names_for_reduction(
meta.maybe_get_output(), self, dims_, keepdim);
}
inline void resize_reduction_with_indices(
impl::MetaBase& meta,
const Tensor& self,
IntArrayRef dims,
bool keepdim,
ScalarType out_dtype) {
DimVector dims_(dims);
maybe_wrap_dims(dims_, self.dim());
auto shape = get_reduction_shape(self, dims_, keepdim);
meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
meta.set_output_raw_strided(1, shape, {}, self.options().dtype(kLong));
namedinference::propagate_names_for_reduction(
meta.maybe_get_output(0), self, dims_, keepdim);
namedinference::propagate_names_for_reduction(
meta.maybe_get_output(1), self, dims_, keepdim);
}
inline TensorIterator make_reduction(
const Tensor& self,
const Tensor& result,
OptionalIntArrayRef opt_dims,
bool keepdim,
ScalarType in_dtype) {
int64_t ndim = self.dim();
auto mask = at::native::make_dim_mask(opt_dims, ndim);
auto viewed_result =
at::native::review_reduce_result(result, ndim, mask, keepdim);
if (self.scalar_type() == in_dtype) {
return TensorIterator::reduce_op(viewed_result, self);
}
return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
}
inline TensorIterator make_reduction(
const Tensor& self,
const Tensor& result1,
const Tensor& result2,
IntArrayRef dims,
bool keepdim,
ScalarType dtype1,
ScalarType /*dtype2*/) {
int64_t ndim = self.dim();
auto mask = at::native::make_dim_mask(dims, ndim);
auto viewed_result1 = at::native::review_reduce_result(result1, ndim, mask, keepdim);
auto viewed_result2 = at::native::review_reduce_result(result2, ndim, mask, keepdim);
// special case for type promotion in mixed precision, improves computational efficiency.
// We don't generalize this to common mismatched input/output types to avoid cross product
// of templated kernel launches.
if (self.scalar_type() == dtype1 ||
(self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) {
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
}
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
}
inline C10_UNUSED TensorIterator make_reduction_from_out_ty(
const Tensor& self,
const Tensor& result,
OptionalIntArrayRef opt_dims,
bool keepdim,
ScalarType out_dtype) {
// special case for type promotion in mixed precision, improves computational
// efficiency.
// not generalize this to common mismatched input/output types to avoid cross
// product of templated kernel launches.
const bool gpu_lowp_to_f32 =
(self.is_cuda() &&
(self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) &&
out_dtype == kFloat);
auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype;
return make_reduction(self, result, opt_dims, keepdim, in_dtype);
}
} // namespace at::meta

View File

@ -0,0 +1,40 @@
#pragma once
#include <c10/core/Scalar.h>
namespace at::native {
enum class ReductionType {MAX, MEAN, MIN, SUM, PROD};
inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
if (reduce == "max" || reduce == "amax") {
return ReductionType::MAX;
} else if (reduce == "mean") {
return ReductionType::MEAN;
} else if (reduce == "min" || reduce == "amin") {
return ReductionType::MIN;
} else if (reduce == "sum") {
return ReductionType::SUM;
} else if (reduce == "prod") {
return ReductionType::PROD;
} else {
TORCH_CHECK(false, "reduce argument must be either sum, prod, mean, amax or amin, got ", reduce);
}
}
// used for `scatter_reduce`, old options for BC.
inline ReductionType get_operator_enum(const c10::string_view reduce, bool use_new_options) {
if (use_new_options) {
return get_reduction_enum(reduce);
} else {
if (reduce == "add") {
return ReductionType::SUM;
} else if (reduce == "multiply") {
return ReductionType::PROD;
} else {
TORCH_CHECK(false, "reduce argument must be either add or multiply.")
}
}
}
} // at::native

View File

@ -0,0 +1,48 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/TensorOperators.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#endif
namespace at::native {
template <
typename index_t,
void compute(const index_t*, const int64_t*, index_t*, int64_t, int64_t)>
static inline Tensor repeat_interleave_common(
const Tensor& repeats,
std::optional<int64_t> output_size) {
TORCH_CHECK(
repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
TORCH_CHECK(
repeats.scalar_type() == at::kLong || repeats.scalar_type() == at::kInt,
"repeats has to be Long or Int tensor");
if (repeats.size(0) == 0) {
return at::empty_like(repeats, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
Tensor repeats_ = repeats.contiguous();
Tensor cumsum = repeats.cumsum(0);
int64_t total = 0;
if (output_size.has_value()) {
total = output_size.value();
} else {
total = cumsum[-1].item<int64_t>();
TORCH_CHECK(
(repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
}
Tensor result = at::empty({total}, repeats.options());
const index_t* repeat_ptr = repeats_.const_data_ptr<index_t>();
const int64_t* cumsum_ptr = cumsum.const_data_ptr<int64_t>();
index_t* result_ptr = result.data_ptr<index_t>();
compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0), total);
return result;
}
} // namespace at::native

View File

@ -0,0 +1,173 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/native/ResizeCommon.h>
#include <ATen/EmptyTensor.h>
#include <ATen/TensorUtils.h>
#include <c10/core/CPUAllocator.h>
#include <utility>
namespace at::native {
// TODO: make all operations that resize given outputs use this function
// for consistency and maintainability.
// Some operations like `cat` might not be able to make the use of
// resize_output directly. For more details to understand how it works in `cat`,
// see https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
// Resizes outputs
// Functions accepting output tensors, like with the "out" kwarg, should
// call this function to handle resizing their output tensor.
// Issues a warning if the output tensor has one or more elements and
// needs resizing
// NOTE: In the future the warning will become an error
// Returns a bool saying whether or not the resize actually happened or not
TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);
// WARNING: Do NOT call this directly. If you are resizing an output and want
// to support dynamic shapes call at::resize__symint and resize_output_check_symint.
// For more details, see: https://github.com/pytorch/pytorch/pull/111530/files#r1365845272
TORCH_API bool resize_output_symint(const Tensor& output, SymIntArrayRef shape);
// Utility for resize_output
// Returns a bool saying resize should happen or not and
// raises a warning if resizing for one or more elements
TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape);
TORCH_API bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape);
TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
TORCH_API void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes);
TORCH_API void resize_bytes_nocuda(const Storage& storage, const c10::SymInt& size_bytes);
inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) {
// It does not make sense to try to resize a storage
// to hold 0 elements, and this can break
// if storage_offset is positive but
// new_size is 0, so just bail in that case
// (same comment is in cuda/Resize.h)
if (self->numel() == 0) {
return;
}
const Storage& storage = self->unsafe_storage();
if (!storage) {
auto new_storage = c10::make_intrusive<StorageImpl>(
StorageImpl::use_byte_size_t(),
new_size_bytes,
c10::GetCPUAllocator(),
true);
self->set_storage_keep_dtype(std::move(new_storage));
} else if (new_size_bytes > storage.nbytes()) {
resize_bytes_cpu(storage.unsafeGetStorageImpl(), new_size_bytes);
}
}
TORCH_API TensorImpl* resize_impl_cpu_(
TensorImpl* self,
IntArrayRef size,
at::OptionalIntArrayRef stride,
bool resize_storage = true);
template <typename T>
T maybe_convert_symint(c10::SymInt) = delete;
template <>
inline c10::SymInt maybe_convert_symint(c10::SymInt x) { return x; }
template <>
inline int64_t maybe_convert_symint(c10::SymInt x) { return x.guard_int(__FILE__, __LINE__); }
template <typename T>
inline void checkInBoundsForStorage(
ArrayRef<T> size,
ArrayRef<T> stride,
T storage_offset,
const caffe2::TypeMeta& data_type,
const Storage& new_storage) {
T storage_size_bytes =
at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
T storage_offset_bytes = storage_offset * data_type.itemsize();
if (storage_size_bytes == 0) {
// NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
return;
}
T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes());
TORCH_CHECK(
storage_size_bytes + storage_offset_bytes <= new_storage_size_bytes,
"setStorage: sizes ",
size,
", strides ",
stride,
","
" storage offset ",
storage_offset,
", and itemsize ",
data_type.itemsize(),
" requiring a storage size of ",
storage_size_bytes + storage_offset_bytes,
" are out of bounds for storage of size ",
new_storage_size_bytes);
}
template <typename T>
inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
ArrayRef<T> size, ArrayRef<T> stride) {
// FIXME: stride should be optional
if (stride.data()) {
TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(),
") and stride length (", stride.size(), ")");
}
#ifdef DEBUG
TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX");
#endif
// storage: note this can't be replaced with result.set_(storage) as the semantics of that
// function is to set the tensor size to be equal to the size of the storage.
if (!result.storage().is_alias_of(storage)) {
// Caffe2 might have tensors whose storages are null, but we
// don't allow it in PyTorch.
TORCH_INTERNAL_ASSERT(storage);
TORCH_INTERNAL_ASSERT(result.storage());
// We used to allow this, but this breaks device caching.
// Let's put an actual error message for this one.
TORCH_CHECK(result.storage().device() == storage.device(),
"Attempted to set the storage of a tensor on device \"", result.storage().device(),
"\" to a storage on different device \"", storage.device(),
"\". This is no longer allowed; the devices must match.");
result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage));
}
// storageOffset
TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
}
/**
* Set self's sizes, strides, and storage_offset.
* (size, stride, storage_offset) must be in bounds for self's storage.
*/
template <typename T>
inline void setStrided(
const Tensor& self,
ArrayRef<T> size,
ArrayRef<T> stride,
T storage_offset) {
TORCH_CHECK(size.size() == stride.size(), "mismatch in length of strides and shape");
for (const auto& val : stride) {
TORCH_CHECK(val >= 0,
"as_strided: Negative strides are not supported at the moment, "
"got strides: ", stride);
}
auto* self_ = self.unsafeGetTensorImpl();
checkInBoundsForStorage(
size, stride, storage_offset, self_->dtype(), self_->storage());
/* storage offset */
TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
self_->set_sizes_and_strides(size, stride, std::make_optional(storage_offset));
}
} // namespace at::native

View File

@ -0,0 +1,75 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/native/TensorFactories.h>
#include <ATen/NamedTensorUtils.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#endif
namespace at::native {
template <typename T>
inline T storage_size_for(ArrayRef<T> size, ArrayRef<T> stride) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(size.size() == stride.size(),
"storage_size_for(size, stride) requires that size and stride ",
"have the same size as a precondition.");
T storage_size = 1;
for (const auto dim : c10::irange(size.size())) {
if (size[dim] == 0) {
storage_size = 0;
break;
}
storage_size += (size[dim] - 1) * stride[dim];
}
return storage_size;
}
inline const Tensor& resize_named_tensor_(
const Tensor& self,
IntArrayRef size,
std::optional<MemoryFormat> optional_memory_format) {
TORCH_INTERNAL_ASSERT(self.has_names());
TORCH_CHECK(
self.sizes() == size,
"Cannot resize named tensor with resize_ or resize_as_ (tried to resize "
"Tensor",
self.names(),
" with size ",
self.sizes(),
" to ",
size,
"). This may be caused by passing a named tensor ",
"as an `out=` argument; please ensure that the sizes are the same. ");
TORCH_CHECK(
!optional_memory_format.has_value(),
"Unsupported memory format for named tensor resize ",
optional_memory_format.value());
return self;
}
// For deterministic output, fill new elements that were added after a storage
// resize with NaN or MAX_INT. `old_storage_nbytes` is the size of the storage
// before the resize happened.
inline const Tensor& fill_resize_deterministic_(const Tensor& tensor, int64_t old_storage_nbytes) {
const at::Storage& storage = tensor.unsafeGetTensorImpl()->unsafe_storage();
int64_t new_storage_nbytes = storage.nbytes();
int64_t old_storage_numel = old_storage_nbytes / tensor.itemsize();
int64_t new_storage_numel = new_storage_nbytes / tensor.itemsize();
if (new_storage_numel > old_storage_numel) {
at::Tensor tensor_view = at::empty({}, at::TensorOptions().dtype(tensor.scalar_type()).device(tensor.device()));
tensor_view.set_(
storage,
/*storage_offset=*/old_storage_numel,
/*size=*/{new_storage_numel - old_storage_numel},
/*stride=*/{1});
at::native::fill_empty_deterministic_(tensor_view);
}
return tensor;
}
} // namespace at::native

View File

@ -0,0 +1,128 @@
#pragma once
#include <vector>
#include <ATen/core/Tensor.h>
#include <ATen/native/ReduceOpsUtils.h>
#include <c10/util/irange.h>
namespace at::native {
namespace {
// checks whether index.dtype == int64
// and self.dtype == src.dtype if src is a Tensor
inline void scatter_gather_dtype_check(
const std::string& method_name,
const Tensor& self,
const Tensor& index,
const std::optional<Tensor>& src_opt = std::nullopt
) {
if (index.numel() != 0) {
TORCH_CHECK(
index.scalar_type() == at::ScalarType::Long,
method_name, "(): Expected dtype int64 for index"
);
}
if (src_opt.has_value()) {
const auto& src = src_opt.value();
TORCH_CHECK(
self.scalar_type() == src.scalar_type(),
method_name, "(): Expected self.dtype to be equal to src.dtype"
);
}
}
// Used for `gather`-like methods
// Note: self means the input tensor here
// Test:
// 1. index.size(d) <= self.size(d) for all d != dim
// 2. index.dim() == self.dim()
inline void gather_shape_check(const Tensor& self, int64_t dim,
const Tensor& index
) {
auto self_dims = ensure_nonempty_dim(self.dim());
TORCH_CHECK(self_dims == ensure_nonempty_dim(index.dim()),
"Index tensor must have the same number of dimensions as input tensor"
);
for (const auto i : c10::irange(self_dims)) {
if (i != dim) {
TORCH_CHECK(
ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
"Size does not match at dimension ", i,
" expected index ", index.sizes(),
" to be smaller than self ", self.sizes(),
" apart from dimension ", dim
);
}
}
}
// Used for `scatter` and `scatter_add`
// Tests:
// 1. index.size(d) <= self.size(d) for all d != dim
// 2. index.size(d) <= src.size(d) for all d if src is a Tensor
// 3. index.dim() == self.dim() == src.dim()
inline void scatter_shape_check(
const Tensor& self, int64_t dim, const Tensor& index,
const std::optional<Tensor>& src_opt = std::nullopt
) {
if (index.numel() == 0) return;
TORCH_CHECK(
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
"Index tensor must have the same number of dimensions as self tensor"
);
bool is_wrong_shape = false;
int64_t self_dims = ensure_nonempty_dim(self.dim());
// Check: index.size(d) <= self.size(d) for all d != dim
for (const auto d : c10::irange(self_dims)) {
int64_t index_d_size = ensure_nonempty_size(index, d);
if (d == dim) continue;
if (index_d_size > ensure_nonempty_size(self, d)) {
is_wrong_shape = true;
break;
}
}
// Check: index.size(d) <= src.size(d) for all d if src is Tensor
if (!is_wrong_shape && src_opt.has_value()) {
const auto& src = src_opt.value();
for (const auto d : c10::irange(self_dims)) {
int64_t index_d_size = ensure_nonempty_size(index, d);
if (index_d_size > ensure_nonempty_size(src, d)) {
is_wrong_shape = true;
break;
}
}
}
if (src_opt.has_value()) {
const auto& src = src_opt.value();
TORCH_CHECK(
ensure_nonempty_dim(src.dim()) == ensure_nonempty_dim(index.dim()),
"Index tensor must have the same number of dimensions as src tensor"
);
TORCH_CHECK(!is_wrong_shape,
"Expected index ", index.sizes(),
" to be smaller than self ", self.sizes(),
" apart from dimension ", dim,
" and to be smaller size than src ", src.sizes()
);
}
else {
TORCH_CHECK(!is_wrong_shape,
"Expected index ", index.sizes(),
" to be smaller than self ", self.sizes(),
" apart from dimension ", dim
);
}
}
} // anonymous namespace
} // namespace at::native

View File

@ -0,0 +1,50 @@
#pragma once
#include <ATen/native/DispatchStub.h>
#include <ATen/native/ReductionType.h>
#include <c10/core/Scalar.h>
#include <optional>
namespace at {
class Tensor;
namespace native {
using segment_reduce_lengths_fn = Tensor (*)(
ReductionType,
const Tensor&,
const Tensor&,
int64_t,
const std::optional<Scalar>&);
DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub);
using segment_reduce_offsets_fn = Tensor (*)(
ReductionType,
const Tensor&,
const Tensor&,
int64_t,
const std::optional<Scalar>&);
DECLARE_DISPATCH(segment_reduce_offsets_fn, _segment_reduce_offsets_stub);
using segment_reduce_lengths_backward_fn = Tensor (*)(
const Tensor&,
const Tensor&,
const Tensor&,
ReductionType,
const Tensor&,
int64_t,
const std::optional<Scalar>&);
DECLARE_DISPATCH(segment_reduce_lengths_backward_fn, _segment_reduce_lengths_backward_stub);
using segment_reduce_offsets_backward_fn = Tensor (*)(
const Tensor&,
const Tensor&,
const Tensor&,
ReductionType,
const Tensor&,
int64_t,
const std::optional<Scalar>&);
DECLARE_DISPATCH(segment_reduce_offsets_backward_fn, _segment_reduce_offsets_backward_stub);
} // namespace native
} // namespace at

View File

@ -0,0 +1,544 @@
#pragma once
// Please note that this file is
// used across both CPU and GPU.
#include <type_traits>
#include <complex>
#include <c10/macros/Macros.h>
#include <ATen/detail/FunctionTraits.h>
#include <ATen/NumericUtils.h>
#if defined(__CUDACC__)
#include <ATen/cuda/DeviceUtils.cuh>
#include <ATen/native/cuda/DeviceSqrt.cuh>
#elif defined(__HIPCC__)
#include <ATen/hip/DeviceUtils.cuh>
#include <ATen/native/hip/DeviceSqrt.cuh>
#endif
#if defined(__CUDACC__) || defined(__HIPCC__)
#include <thrust/pair.h>
#else
#include <cmath>
#define device_sqrt std::sqrt
#endif
#if defined(__CUDACC__) || defined(__HIPCC__)
template <typename scalar_t>
inline C10_DEVICE scalar_t max_propagate_nan(scalar_t a, scalar_t b) {
#if defined(__HIPCC__)
// TODO: remove this special case for HIP when issue is fixed:
// https://github.com/ROCm-Developer-Tools/HIP/issues/2209
scalar_t max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max(a, b));
#else
scalar_t max = at::_isnan(b) ? b : std::max(a, b);
#endif
return max;
}
template <typename scalar_t>
inline C10_DEVICE scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
#if defined(__HIPCC__)
// TODO: remove this special case for HIP when issue is fixed:
// https://github.com/ROCm-Developer-Tools/HIP/issues/2209
scalar_t min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min(a, b));
#else
scalar_t min = at::_isnan(b) ? b : std::min(a, b);
#endif
return min;
}
#define MAX(X, Y) max_propagate_nan(X,Y)
#define MIN(X, Y) min_propagate_nan(X,Y)
#else
#include <ATen/native/cpu/zmath.h>
#define MAX(X, Y) max_impl(X,Y)
#define MIN(X, Y) min_impl(X,Y)
#endif
// ROCM hcc doesn't work well with using std:: in kernel functions
#if defined(__CUDA_ARCH__)
#include <c10/cuda/CUDAMathCompat.h>
#define compat_pow c10::cuda::compat::pow
#elif defined(__HIPCC__)
#include <c10/hip/HIPMathCompat.h>
#define compat_pow c10::hip::compat::pow
#else
#define compat_pow std::pow
#endif
namespace at { namespace native {
namespace detail {
#if defined(__CUDACC__) || defined(__HIPCC__)
template <typename T1, typename T2> using pair = thrust::pair<T1, T2>;
#else
template <typename T1, typename T2> using pair = std::pair<T1, T2>;
#endif
} // namespace detail
template <typename scalar_t, typename index_t>
struct WelfordData {
scalar_t mean;
scalar_t m2;
index_t n;
scalar_t nf;
C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {}
C10_HOST_DEVICE WelfordData(
scalar_t mean,
scalar_t m2,
index_t n,
scalar_t nf)
: mean(mean), m2(m2), n(n), nf(nf) {}
};
template <typename scalar_t, typename acc_scalar_t, typename index_t, typename res_t>
struct WelfordOps {
acc_scalar_t correction;
bool take_sqrt;
public:
using acc_t = WelfordData<acc_scalar_t, index_t>;
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
// We accumulate n in index_t to avoid cumulative rounding error, but still
// need nf for use in combine where int32 may overflow.
index_t new_n = acc.n + 1;
acc_scalar_t new_nf = static_cast<acc_scalar_t>(new_n);
acc_scalar_t delta = data - acc.mean;
acc_scalar_t new_mean = acc.mean + delta / new_nf;
acc_scalar_t new_delta = data - new_mean;
return {
new_mean,
acc.m2 + delta * new_delta,
new_n,
new_nf,
};
}
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
if (a.nf == 0) {
return b;
}
if (b.nf == 0) {
return a;
}
acc_scalar_t delta = b.mean - a.mean;
acc_scalar_t new_count = a.nf + b.nf;
acc_scalar_t nb_over_n = b.nf / new_count;
return {
a.mean + delta * nb_over_n,
a.m2 + b.m2 + delta * delta * a.nf * nb_over_n,
// setting acc.n as -1 since acc.n might not be able to represent the count
// correctly within its range, setting it to -1 to avoid confusion
-1,
new_count
};
}
inline C10_DEVICE res_t project(acc_t acc) const __ubsan_ignore_float_divide_by_zero__ {
const auto mean = static_cast<scalar_t>(acc.mean);
const auto divisor = acc.nf > correction ? acc.nf - correction : 0;
const auto var = acc.m2 / divisor;
res_t results(take_sqrt ? device_sqrt(var) : var, mean);
return results;
}
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
return acc;
}
#if defined(__CUDACC__) || defined(__HIPCC__)
inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const {
return {
WARP_SHFL_DOWN(acc.mean, offset)
, WARP_SHFL_DOWN(acc.m2, offset)
, WARP_SHFL_DOWN(acc.n, offset)
, WARP_SHFL_DOWN(acc.nf, offset)
};
}
#endif
C10_HOST_DEVICE WelfordOps(acc_scalar_t correction, bool take_sqrt)
: correction(correction), take_sqrt(take_sqrt) {}
};
template <typename scalar_t, typename acc_t=scalar_t, typename factor_t=acc_t, typename out_t = acc_t>
struct MeanOps {
factor_t factor;
inline C10_DEVICE acc_t reduce(acc_t a, scalar_t b, int64_t /*idx*/) const {
return combine(a, static_cast<acc_t>(b));
}
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
return a + b;
}
inline C10_DEVICE out_t project(acc_t a) const {
return a * factor;
}
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
return acc;
}
#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
return WARP_SHFL_DOWN(data, offset);
}
#endif
MeanOps(factor_t factor): factor(factor) {
}
};
// This accumulator template is used to calculate the minimum absolute value of
// a set of numbers.
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
// value. These types differ for complex number input support.
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
struct AbsMinOps {
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
return MIN(acc, static_cast<acc_t>(std::abs(data)));
}
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
return MIN(a, b);
}
inline C10_DEVICE out_t project(acc_t a) const {
return a;
}
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
return acc;
}
#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
return WARP_SHFL_DOWN(acc, offset);
}
#endif
};
// This accumulator template is used to calculate the maximum absolute value of
// a set of numbers.
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
// value. These types differ for complex number input support.
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
struct AbsMaxOps {
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
return MAX(acc, static_cast<acc_t>(std::abs(data)));
}
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
return MAX(a, b);
}
inline C10_DEVICE out_t project(acc_t a) const {
return a;
}
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
return acc;
}
#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
return WARP_SHFL_DOWN(acc, offset);
}
#endif
};
// This accumulator template is used to calculate the norm of the absolute value
// of a set of numbers.
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
// value. These types differ for complex number input support.
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
struct NormOps {
acc_t norm_;
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
return acc + compat_pow(static_cast<acc_t>(std::abs(data)), norm_);
}
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
return a + b;
}
inline C10_DEVICE out_t project(acc_t a) const {
return compat_pow(a, static_cast<acc_t>(1.0) / norm_);
}
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
return acc;
}
#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
return WARP_SHFL_DOWN(acc, offset);
}
#endif
NormOps(acc_t norm_): norm_(norm_) {
}
};
// This accumulator template is used to calculate the order zero norm of the
// absolute value of a set of numbers.
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
// value. These types differ for complex number input support.
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
struct NormZeroOps {
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
return acc + (data == static_cast<scalar_t>(0) ? static_cast<acc_t>(0) : static_cast<acc_t>(1));
}
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
return a + b;
}
inline C10_DEVICE out_t project(acc_t a) const {
return a;
}
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
return acc;
}
#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
return WARP_SHFL_DOWN(acc, offset);
}
#endif
};
// This accumulator template is used to calculate the order one norm of the
// absolute value of a set of numbers.
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
// value. These types differ for complex number input support.
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
struct NormOneOps {
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
return acc + static_cast<acc_t>(std::abs(data));
}
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
return a + b;
}
inline C10_DEVICE out_t project(acc_t a) const {
return a;
}
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
return acc;
}
#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
return WARP_SHFL_DOWN(acc, offset);
}
#endif
};
template<typename acc_t>
struct AbsSwitch {};
template<typename scalar_t, typename acc_t>
inline C10_DEVICE acc_t abs_if_complex(scalar_t data, AbsSwitch<acc_t>) {
return static_cast<acc_t>(data);
}
template<typename scalar_t, typename acc_t>
inline C10_DEVICE acc_t abs_if_complex(std::complex<scalar_t> data, AbsSwitch<acc_t>) {
return static_cast<acc_t>(std::abs(data));
}
template<typename scalar_t, typename acc_t>
inline C10_DEVICE acc_t abs_if_complex(c10::complex<scalar_t> data, AbsSwitch<acc_t>) {
return static_cast<acc_t>(std::abs(data));
}
// This accumulator template is used to calculate the order two norm of the
// absolute value of a set of numbers.
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
// value. These types differ for complex number input support.
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
struct NormTwoOps {
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
acc_t data_ = abs_if_complex(data, AbsSwitch<acc_t>());
return acc + data_ * data_;
}
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
return a + b;
}
inline C10_DEVICE out_t project(acc_t a) const {
return device_sqrt(a);
}
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
return acc;
}
#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
return WARP_SHFL_DOWN(acc, offset);
}
#endif
};
template <typename acc_t, typename data_t>
struct NanSumOps {
inline C10_DEVICE acc_t reduce(acc_t a, data_t b, int64_t /*idx*/) const {
return a + (at::_isnan(b) ? acc_t{0.} : acc_t{b});
}
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
return a + b;
}
inline C10_DEVICE data_t project(acc_t a) const {
return data_t{a};
}
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
return acc;
}
#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
return WARP_SHFL_DOWN(data, offset);
}
#endif
};
namespace detail {
template <typename scalar_t>
struct LessOrNan {
C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
// If (a == b), then choose the one with lower idx, else min(a, b)
if (at::_isnan(a)) {
if (at::_isnan(b)) {
return idx_a < idx_b;
}
return true;
}
return (a == b) ? idx_a < idx_b : (a < b);
}
};
template <typename scalar_t>
struct GreaterOrNan {
C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
// If (a == b), then choose the one with lower idx, else max(a, b)
if (at::_isnan(a)) {
if (at::_isnan(b)) {
return idx_a < idx_b;
}
return true;
}
return (a == b) ? idx_a < idx_b : (a > b);
}
};
template <typename comp_t>
struct MinMaxReductionOps {
using scalar_t = typename binary_function_traits<comp_t>::arg1_t;
using index_t = int64_t;
using arg_t = detail::pair<scalar_t, index_t>;
static C10_DEVICE arg_t project(arg_t arg) {
return arg;
}
static C10_DEVICE arg_t reduce(arg_t arg, scalar_t val, int64_t idx) {
return comp_t{}(arg.first, val, arg.second, idx) ? arg : arg_t(val, idx);
}
static C10_DEVICE arg_t combine(arg_t a, arg_t b) {
return comp_t{}(a.first, b.first, a.second, b.second) ? a : b;
}
static C10_DEVICE arg_t translate_idx(arg_t a, int64_t base_idx) {
return {a.first, a.second + base_idx};
}
#if defined(__CUDACC__) || defined(__HIPCC__)
static C10_DEVICE arg_t warp_shfl_down(arg_t arg, int offset) {
return arg_t(WARP_SHFL_DOWN(arg.first, offset),
WARP_SHFL_DOWN(arg.second, offset));
}
#endif
};
template <typename comp_t>
struct ArgReductionOps : public MinMaxReductionOps<comp_t> {
using typename MinMaxReductionOps<comp_t>::scalar_t;
using typename MinMaxReductionOps<comp_t>::index_t;
using typename MinMaxReductionOps<comp_t>::arg_t;
static C10_DEVICE index_t project(arg_t arg) {
return arg.second;
}
};
} // namespace detail
template <typename scalar_t>
struct ArgMaxOps :
public detail::ArgReductionOps<detail::GreaterOrNan<scalar_t>> {
};
template <typename scalar_t>
struct ArgMinOps :
public detail::ArgReductionOps<detail::LessOrNan<scalar_t>> {
};
template <typename scalar_t>
struct MinOps :
public detail::MinMaxReductionOps<detail::LessOrNan<scalar_t>> {
};
template <typename scalar_t>
struct MaxOps :
public detail::MinMaxReductionOps<detail::GreaterOrNan<scalar_t>> {
};
template <typename scalar_t, typename acc_scalar_t, typename index_t>
struct MinMaxOps {
using acc_t = detail::pair<acc_scalar_t, acc_scalar_t>;
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
return combine(acc, {data, data});
}
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
auto min_val = (at::_isnan(a.first) || a.first < b.first) ? a.first : b.first;
auto max_val = (at::_isnan(a.second) || a.second > b.second) ? a.second : b.second;
return {min_val, max_val};
}
inline C10_DEVICE acc_t project(acc_t acc) const {
return acc;
}
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
return acc;
}
#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
return {
WARP_SHFL_DOWN(acc.first, offset), WARP_SHFL_DOWN(acc.second, offset)
};
}
#endif
};
}} // namespace at::native
#undef MAX
#undef MIN

View File

@ -0,0 +1,55 @@
/// This file contains some tensor-agnostic operations to be used in the
/// core functions of the `SobolEngine`
#include <ATen/core/Tensor.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/arange.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/pow.h>
#endif
namespace at::native::sobol_utils {
/// Function to return the minimum of number of bits to represent the integer `n`
inline int64_t bit_length(const int64_t n) {
int64_t nbits, nloc;
for (nloc = n, nbits = 0; nloc > 0; nloc /= 2, nbits++);
return nbits;
}
/// Function to get the position of the rightmost zero in the bit representation of an integer
/// This value is the zero-indexed position
inline int64_t rightmost_zero(const int64_t n) {
int64_t z, i;
for (z = n, i = 0; z % 2 == 1; z /= 2, i++);
return i;
}
/// Function to get a subsequence of bits in the representation of an integer starting from
/// `pos` and of length `length`
inline int64_t bitsubseq(const int64_t n, const int64_t pos, const int64_t length) {
return (n >> pos) & ((1 << length) - 1);
}
/// Function to perform the inner product between a batched square matrix and a power of 2 vector
inline at::Tensor cdot_pow2(const at::Tensor& bmat) {
at::Tensor inter = at::arange(bmat.size(-1) - 1, -1, -1, bmat.options());
inter = at::pow(2, inter).expand_as(bmat);
return at::mul(inter, bmat).sum(-1);
}
/// All definitions below this point are data. These are constant, and should not be modified
/// without notice
constexpr int64_t MAXDIM = 21201;
constexpr int64_t MAXDEG = 18;
constexpr int64_t MAXBIT = 30;
constexpr int64_t LARGEST_NUMBER = 1 << MAXBIT;
constexpr float RECIPD = 1.0 / LARGEST_NUMBER;
extern const int64_t poly[MAXDIM];
extern const int64_t initsobolstate[MAXDIM][MAXDEG];
} // namespace at::native::sobol_utils

View File

@ -0,0 +1,28 @@
#pragma once
#include <ATen/native/DispatchStub.h>
#include <cstdint>
namespace at {
class TensorBase;
}
namespace at::native {
enum class QUANTILE_INTERPOLATION_MODE : uint8_t {
LINEAR,
LOWER,
HIGHER,
MIDPOINT,
NEAREST
};
using sort_fn = void(*)(const TensorBase&, const TensorBase&, const TensorBase&, int64_t, bool, bool);
using topk_fn = void(*)(const TensorBase&, const TensorBase&, const TensorBase&, int64_t, int64_t, bool, bool);
DECLARE_DISPATCH(sort_fn, sort_stub);
DECLARE_DISPATCH(topk_fn, topk_stub);
void _fill_indices(const TensorBase &indices, int64_t dim);
} // namespace at::native

View File

@ -0,0 +1,88 @@
#pragma once
#include <ATen/NumericUtils.h>
#include <ATen/native/Resize.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#endif
namespace at::native {
// ensure we get good values and indices for kthvalue, mode
// this will always be with the reducing dim as 1-d
inline void _reduction_with_indices_allocate_or_resize_output(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t dim_,
bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
auto result_sizes = self.sizes().vec();
if (!result_sizes.empty()) {
result_sizes[dim] = 1;
}
if (values.defined()) {
TORCH_CHECK(
self.options().type_equal(values.options()),
"output values must be of same type as input");
if (!keepdim && values.dim() == self.dim() - 1) {
// unsqueeze to preserve passed in noncontiguous tensor in resize
values.unsqueeze_(dim);
}
resize_output(values, result_sizes);
} else {
values = at::empty(result_sizes, self.options());
}
if (indices.defined()) {
TORCH_CHECK(
indices.dtype() == kLong, "output indices must be of scalar type Long");
TORCH_CHECK(
indices.device() == self.device(),
"output indices must be on same device as input");
if (!keepdim && indices.dim() == self.dim() - 1) {
// unsqueeze to preserve passed in noncontiguous tensor in resize
indices.unsqueeze_(dim);
}
resize_output(indices, result_sizes);
} else {
indices = at::empty(result_sizes, self.options().dtype(kLong));
}
}
// ensure we get good values and indices for topk
inline void _allocate_or_resize_output_with_indices(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t dim_,
int64_t k) {
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
auto result_sizes = self.sizes().vec();
if (!result_sizes.empty()) {
result_sizes[dim] = k;
}
if (values.defined()) {
TORCH_CHECK(
self.options().type_equal(values.options()),
"output values must be of same type as input");
values.resize_(result_sizes);
} else {
values = at::empty(result_sizes, self.options());
}
if (indices.defined()) {
TORCH_CHECK(
indices.dtype() == kLong, "output indices must be of scalar type Long");
TORCH_CHECK(
indices.device() == self.device(),
"output indices must be on same device as input");
indices.resize_(result_sizes);
} else {
indices = at::empty(result_sizes, self.options().dtype(kLong));
}
}
} // namespace at::native

View File

@ -0,0 +1,190 @@
#pragma once
#include <ATen/Parallel.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/core/Tensor.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/tensor.h>
#endif
namespace at::sparse {
// Just for documentary purposes
using SparseTensor = Tensor;
using SparseType = Type;
// This is an internal utility function for getting at the SparseTensorImpl,
// so that we can write sparse tensor specific accessors for special fields
// in SparseTensor. You should only use this for writing low level
// setters/getters for SparseTensorImpl fields; otherwise, you should use
// the low level setters/getters that were implemented using this.
//
// This may be called repeatedly, so make sure it's pretty cheap.
inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) {
TORCH_INTERNAL_ASSERT(
self.is_sparse(), "_internal_get_SparseTensorImpl: not a sparse tensor");
return static_cast<SparseTensorImpl*>(self.unsafeGetTensorImpl());
}
// Takes indices and values and directly puts them into the sparse tensor, no
// copy. This used to be called THSTensor_(_move)
inline void alias_into_sparse(
const SparseTensor& self,
const Tensor& indices,
const Tensor& values) {
get_sparse_impl(self)->set_indices_and_values_unsafe(indices, values);
}
// Take indices and values and makes a (data) copy of them to put into the
// sparse indices/values. This used to be called THSTensor_(_set)
inline void copy_into_sparse(
const SparseTensor& self,
const Tensor& indices,
const Tensor& values,
bool non_blocking) {
alias_into_sparse(
self,
indices.to(self._indices().options(), non_blocking, /*copy=*/true),
values.to(self._values().options(), non_blocking, /*copy=*/true));
}
// TODO: put this into the public API
inline bool is_same_tensor(const Tensor& lhs, const Tensor& rhs) {
return lhs.unsafeGetTensorImpl() == rhs.unsafeGetTensorImpl();
}
inline bool is_same_density(const SparseTensor& self, const SparseTensor& src) {
return self.sparse_dim() == src.sparse_dim() &&
self.dense_dim() == src.dense_dim();
}
// Give us a new values tensor, with the same dimensionality
// as 'values' but with a new number of non-zero elements.
// TODO: Expose this for real in ATen, some day?
// NB: Doesn't preserve data.
inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) {
std::vector<int64_t> size = values.sizes().vec();
size[0] = nnz;
return at::empty(size, values.options());
}
// NOTE [ Flatten Sparse Indices ]
// This helper function flattens a sparse indices tensor (a Tensor) into a 1D
// indices tensor. E.g.,
// input = [[2, 4, 0],
// [3, 1, 10]]
// full_size = [2, 12]
// output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10]
//
// In other words, assuming that each `indices[i, :]` is a valid index to a
// tensor `t` of shape `full_size`. This returns the corresponding indices to
// the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
// if forceClone is true, the result will forced to be a clone of self.
// if force_clone is true, the result will forced to be a clone of self.
TORCH_API Tensor flatten_indices(
const Tensor& indices,
IntArrayRef full_size,
bool force_clone = false);
// Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten
// Sparse Indices ], except this one allows partial flatten: only flatten on
// specified dims. Note that the flatten indices might be uncoalesced if
// dims_to_flatten.size() < sparse_dim. Also if input indices is already
// coalesced, the flattened indices will also be sorted.
//
// args:
// indices: sparse tensor indices
// sizes: sparse tensor sizes
// dims_to_flatten: a list of dim index to flatten
//
// Ex1:
// indices = [[2, 4, 0],
// [3, 1, 3]]
// sizes = [2, 12]
// dims_to_flatten = [0, 1]
// new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3]
//
// Ex2:
// dims_to_flatten = [1]
// new_indices = [ 3, 1, 3 ] # uncoalesced
TORCH_API Tensor flatten_indices_by_dims(
const Tensor& indices,
const IntArrayRef& sizes,
const IntArrayRef& dims_to_flatten);
// Find the CSR representation for a row `indices` from the COO format
TORCH_API Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz);
TORCH_API Tensor zeros_like_with_indices(const Tensor& t);
template <size_t static_shape_max_len>
class TensorGeometryHolder {
using geometry_holder_t = std::array<int64_t, static_shape_max_len>;
public:
explicit TensorGeometryHolder(
IntArrayRef sizes,
IntArrayRef strides,
TensorOptions options = {}) {
std::copy(sizes.begin(), sizes.end(), t_sizes.begin());
std::copy(strides.begin(), strides.end(), t_strides.begin());
}
explicit TensorGeometryHolder(const Tensor& t)
: TensorGeometryHolder(t.sizes(), t.strides()) {}
auto operator*() const {
return std::make_tuple(t_sizes, t_strides);
}
private:
geometry_holder_t t_sizes;
geometry_holder_t t_strides;
};
template <>
class TensorGeometryHolder<0> {
using geometry_holder_t = Tensor;
public:
explicit TensorGeometryHolder(
IntArrayRef sizes,
IntArrayRef strides,
TensorOptions options) {
const int64_t t_ndims = sizes.size();
const auto cpu_options = TensorOptions(options).dtype(kLong).device(kCPU);
Tensor t_sizes_and_strides_cpu = at::empty({2, t_ndims}, cpu_options);
t_sizes_and_strides_cpu.select(0, 0).copy_(at::tensor(sizes, cpu_options));
t_sizes_and_strides_cpu.select(0, 1).copy_(
at::tensor(strides, cpu_options));
const Tensor t_sizes_and_strides =
t_sizes_and_strides_cpu.to(options.device());
t_sizes = t_sizes_and_strides.select(0, 0);
t_strides = t_sizes_and_strides.select(0, 1);
}
explicit TensorGeometryHolder(const Tensor& t)
: TensorGeometryHolder(t.sizes(), t.strides(), t.options()) {}
auto operator*() const {
return std::make_tuple(
t_sizes.template data_ptr<int64_t>(),
t_strides.template data_ptr<int64_t>());
}
private:
geometry_holder_t t_sizes;
geometry_holder_t t_strides;
};
// Return all indices of a tensor with the given shape.
//
// full_coo_indices(shape) is equivalent to
// torch.ones(shape).nonzero().transpose(-2, -1) but much faster.
TORCH_API Tensor full_coo_indices(IntArrayRef sizes, TensorOptions options);
} // namespace at::sparse

View File

@ -0,0 +1,84 @@
#pragma once
#include <string>
#include <stdexcept>
#include <sstream>
#include <c10/core/ScalarType.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/core/TensorBase.h>
namespace at::native {
// Normalization types used in _fft_with_size
enum class fft_norm_mode {
none, // No normalization
by_root_n, // Divide by sqrt(signal_size)
by_n, // Divide by signal_size
};
// NOTE [ Fourier Transform Conjugate Symmetry ]
//
// Real-to-complex Fourier transform satisfies the conjugate symmetry. That is,
// assuming X is the transformed K-dimensionsal signal, we have
//
// X[i_1, ..., i_K] = X[j_i, ..., j_K]*,
//
// where j_k = (N_k - i_k) mod N_k, N_k being the signal size at dim k,
// * is the conjugate operator.
//
// Therefore, in such cases, FFT libraries return only roughly half of the
// values to avoid redundancy:
//
// X[:, :, ..., :floor(N / 2) + 1]
//
// This is also the assumption in cuFFT and MKL. In ATen SpectralOps, such
// halved signal will also be returned by default (flag onesided=True).
// The following infer_ft_real_to_complex_onesided_size function calculates the
// onesided size from the twosided size.
//
// Note that this loses some information about the size of signal at last
// dimension. E.g., both 11 and 10 maps to 6. Hence, the following
// infer_ft_complex_to_real_onesided_size function takes in optional parameter
// to infer the twosided size from given onesided size.
//
// cuFFT doc: http://docs.nvidia.com/cuda/cufft/index.html#multi-dimensional
// MKL doc: https://software.intel.com/en-us/mkl-developer-reference-c-dfti-complex-storage-dfti-real-storage-dfti-conjugate-even-storage#CONJUGATE_EVEN_STORAGE
inline int64_t infer_ft_real_to_complex_onesided_size(int64_t real_size) {
return (real_size / 2) + 1;
}
inline int64_t infer_ft_complex_to_real_onesided_size(int64_t complex_size,
int64_t expected_size=-1) {
int64_t base = (complex_size - 1) * 2;
if (expected_size < 0) {
return base + 1;
} else if (base == expected_size) {
return base;
} else if (base + 1 == expected_size) {
return base + 1;
} else {
std::ostringstream ss;
ss << "expected real signal size " << expected_size << " is incompatible "
<< "with onesided complex frequency size " << complex_size;
AT_ERROR(ss.str());
}
}
using fft_fill_with_conjugate_symmetry_fn =
void (*)(ScalarType dtype, IntArrayRef mirror_dims, IntArrayRef half_sizes,
IntArrayRef in_strides, const void* in_data,
IntArrayRef out_strides, void* out_data);
DECLARE_DISPATCH(fft_fill_with_conjugate_symmetry_fn, fft_fill_with_conjugate_symmetry_stub);
// In real-to-complex transform, cuFFT and MKL only fill half of the values
// due to conjugate symmetry. This function fills in the other half of the full
// fft by using the Hermitian symmetry in the signal.
// self should be the shape of the full signal and dims.back() should be the
// one-sided dimension.
// See NOTE [ Fourier Transform Conjugate Symmetry ]
TORCH_API void _fft_fill_with_conjugate_symmetry_(const Tensor& self, IntArrayRef dims);
} // namespace at::native

View File

@ -0,0 +1,301 @@
#pragma once
namespace at::native {
// (Const)StridedRandomAccessor is a
// (const) random access iterator defined over
// a strided array.
// The traits below are to introduce __restrict__
// modifier on different platforms.
template <typename T>
struct DefaultPtrTraits {
using PtrType = T*;
};
#if (defined(_WIN32) || defined(_WIN64))
#define RESTRICT __restrict
#else
#define RESTRICT __restrict__
#endif
template <typename T>
struct RestrictPtrTraits {
using PtrType = T* RESTRICT;
};
template <
typename T,
typename index_t = int64_t,
template <typename U> class PtrTraits = DefaultPtrTraits
>
class ConstStridedRandomAccessor {
public:
using difference_type = index_t;
using value_type = const T;
using pointer = const typename PtrTraits<T>::PtrType;
using reference = const value_type&;
using iterator_category = std::random_access_iterator_tag;
using PtrType = typename PtrTraits<T>::PtrType;
using index_type = index_t;
// Constructors {
C10_HOST_DEVICE
ConstStridedRandomAccessor(PtrType ptr, index_t stride)
: ptr{ptr}, stride{stride}
{}
C10_HOST_DEVICE
explicit ConstStridedRandomAccessor(PtrType ptr)
: ptr{ptr}, stride{static_cast<index_t>(1)}
{}
C10_HOST_DEVICE
ConstStridedRandomAccessor()
: ptr{nullptr}, stride{static_cast<index_t>(1)}
{}
// }
// Pointer-like operations {
C10_HOST_DEVICE
reference operator*() const {
return *ptr;
}
C10_HOST_DEVICE
const value_type* operator->() const {
return reinterpret_cast<const value_type*>(ptr);
}
C10_HOST_DEVICE
reference operator[](index_t idx) const {
return ptr[idx * stride];
}
// }
// Prefix/postfix increment/decrement {
C10_HOST_DEVICE
ConstStridedRandomAccessor& operator++() {
ptr += stride;
return *this;
}
C10_HOST_DEVICE
ConstStridedRandomAccessor operator++(int) {
ConstStridedRandomAccessor copy(*this);
++*this;
return copy;
}
C10_HOST_DEVICE
ConstStridedRandomAccessor& operator--() {
ptr -= stride;
return *this;
}
C10_HOST_DEVICE
ConstStridedRandomAccessor operator--(int) {
ConstStridedRandomAccessor copy(*this);
--*this;
return copy;
}
// }
// Arithmetic operations {
C10_HOST_DEVICE
ConstStridedRandomAccessor& operator+=(index_t offset) {
ptr += offset * stride;
return *this;
}
C10_HOST_DEVICE
ConstStridedRandomAccessor operator+(index_t offset) const {
return ConstStridedRandomAccessor(ptr + offset * stride, stride);
}
C10_HOST_DEVICE
friend ConstStridedRandomAccessor operator+(
index_t offset,
const ConstStridedRandomAccessor& accessor
) {
return accessor + offset;
}
C10_HOST_DEVICE
ConstStridedRandomAccessor& operator-=(index_t offset) {
ptr -= offset * stride;
return *this;
}
C10_HOST_DEVICE
ConstStridedRandomAccessor operator-(index_t offset) const {
return ConstStridedRandomAccessor(ptr - offset * stride, stride);
}
// Note that this operator is well-defined when `this` and `other`
// represent the same sequences, i.e. when
// 1. this.stride == other.stride,
// 2. |other - this| / this.stride is an Integer.
C10_HOST_DEVICE
difference_type operator-(const ConstStridedRandomAccessor& other) const {
return (ptr - other.ptr) / stride;
}
// }
// Comparison operators {
C10_HOST_DEVICE
bool operator==(const ConstStridedRandomAccessor& other) const {
return (ptr == other.ptr) && (stride == other.stride);
}
C10_HOST_DEVICE
bool operator!=(const ConstStridedRandomAccessor& other) const {
return !(*this == other);
}
C10_HOST_DEVICE
bool operator<(const ConstStridedRandomAccessor& other) const {
return ptr < other.ptr;
}
C10_HOST_DEVICE
bool operator<=(const ConstStridedRandomAccessor& other) const {
return (*this < other) || (*this == other);
}
C10_HOST_DEVICE
bool operator>(const ConstStridedRandomAccessor& other) const {
return !(*this <= other);
}
C10_HOST_DEVICE
bool operator>=(const ConstStridedRandomAccessor& other) const {
return !(*this < other);
}
// }
protected:
PtrType ptr;
index_t stride;
};
template <
typename T,
typename index_t = int64_t,
template <typename U> class PtrTraits = DefaultPtrTraits
>
class StridedRandomAccessor
: public ConstStridedRandomAccessor<T, index_t, PtrTraits> {
public:
using difference_type = index_t;
using value_type = T;
using pointer = typename PtrTraits<T>::PtrType;
using reference = value_type&;
using BaseType = ConstStridedRandomAccessor<T, index_t, PtrTraits>;
using PtrType = typename PtrTraits<T>::PtrType;
// Constructors {
C10_HOST_DEVICE
StridedRandomAccessor(PtrType ptr, index_t stride)
: BaseType(ptr, stride)
{}
C10_HOST_DEVICE
explicit StridedRandomAccessor(PtrType ptr)
: BaseType(ptr)
{}
C10_HOST_DEVICE
StridedRandomAccessor()
: BaseType()
{}
// }
// Pointer-like operations {
C10_HOST_DEVICE
reference operator*() const {
return *this->ptr;
}
C10_HOST_DEVICE
value_type* operator->() const {
return reinterpret_cast<value_type*>(this->ptr);
}
C10_HOST_DEVICE
reference operator[](index_t idx) const {
return this->ptr[idx * this->stride];
}
// }
// Prefix/postfix increment/decrement {
C10_HOST_DEVICE
StridedRandomAccessor& operator++() {
this->ptr += this->stride;
return *this;
}
C10_HOST_DEVICE
StridedRandomAccessor operator++(int) {
StridedRandomAccessor copy(*this);
++*this;
return copy;
}
C10_HOST_DEVICE
StridedRandomAccessor& operator--() {
this->ptr -= this->stride;
return *this;
}
C10_HOST_DEVICE
StridedRandomAccessor operator--(int) {
StridedRandomAccessor copy(*this);
--*this;
return copy;
}
// }
// Arithmetic operations {
C10_HOST_DEVICE
StridedRandomAccessor& operator+=(index_t offset) {
this->ptr += offset * this->stride;
return *this;
}
C10_HOST_DEVICE
StridedRandomAccessor operator+(index_t offset) const {
return StridedRandomAccessor(this->ptr + offset * this->stride, this->stride);
}
C10_HOST_DEVICE
friend StridedRandomAccessor operator+(
index_t offset,
const StridedRandomAccessor& accessor
) {
return accessor + offset;
}
C10_HOST_DEVICE
StridedRandomAccessor& operator-=(index_t offset) {
this->ptr -= offset * this->stride;
return *this;
}
C10_HOST_DEVICE
StridedRandomAccessor operator-(index_t offset) const {
return StridedRandomAccessor(this->ptr - offset * this->stride, this->stride);
}
// Note that here we call BaseType::operator- version
C10_HOST_DEVICE
difference_type operator-(const BaseType& other) const {
return (static_cast<const BaseType&>(*this) - other);
}
// }
};
} // namespace at::native

View File

@ -0,0 +1,49 @@
#pragma once
// Indexing tensors by tensors
#include <ATen/core/List.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/ReductionType.h>
namespace at {
struct TensorIterator;
}
namespace at::native {
using index_put_with_sort_fn = void(*)(Tensor &, const c10::List<std::optional<Tensor>> &, const Tensor &, bool accumulate, bool unsafe);
using index_put_with_sort_quantized_fn = void(*)(Tensor& self, const c10::List<std::optional<Tensor>>& indices, const Tensor& value, double scale, int zero_point, bool unsafe);
using gather_fn = void (*)(const Tensor & result, const Tensor & self, int64_t dim, const Tensor & index);
using scatter_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
using scatter_fill_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src);
using scatter_add_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
using scatter_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
const Tensor& src, const ReductionType& reduce);
using scatter_scalar_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
const Scalar& value, const ReductionType& reduce);
using scatter_reduce_two_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
const Tensor& src, const ReductionType& reduce);
DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub);
DECLARE_DISPATCH(index_put_with_sort_quantized_fn, index_put_with_sort_quantized_stub);
DECLARE_DISPATCH(gather_fn, gather_stub);
DECLARE_DISPATCH(scatter_fn, scatter_stub);
DECLARE_DISPATCH(scatter_fill_fn, scatter_fill_stub);
DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub);
DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub);
DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub);
DECLARE_DISPATCH(scatter_reduce_two_fn, scatter_reduce_two_stub);
TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, const c10::List<std::optional<at::Tensor>>& indices);
using scatter_add_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&);
using scatter_reduce_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const ReductionType& reduce, bool);
using gather_expanded_index_fn = void (*)(const Tensor&, const Tensor&, const Tensor&);
DECLARE_DISPATCH(scatter_add_expanded_index_fn, scatter_add_expanded_index_stub);
DECLARE_DISPATCH(scatter_reduce_expanded_index_fn, scatter_reduce_expanded_index_stub);
DECLARE_DISPATCH(gather_expanded_index_fn, gather_expanded_index_stub);
} // namespace at::native

View File

@ -0,0 +1,94 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/TensorIterator.h>
namespace at::native {
namespace {
#ifndef STRIP_ERROR_MESSAGES
inline std::string shapes_as_str(TensorList tensors) {
std::ostringstream os;
bool first = true;
for (auto& tensor : tensors) {
if (tensor.defined()) {
if (!first) {
os << ", ";
}
os << tensor.sizes();
first = false;
}
}
return os.str();
}
#endif
} // anonymous namespace
inline std::tuple<bool, Tensor> canDispatchToMaskedFill(const Tensor& self, const torch::List<std::optional<at::Tensor>>& indices,
const Tensor& value){
if (!(value.numel() ==1 && value.device().is_cpu())){
return std::make_tuple(false,Tensor());
}
int64_t num_ind = 0;
Tensor mask;
auto self_device = self.device();
for (const std::optional<Tensor>& i: indices) {
if (!i.has_value() || !(*i).defined()){
num_ind++;
} else {
const Tensor &index = *i;
if ((index.scalar_type() != kByte && index.scalar_type() != kBool) ||
index.device() != self_device || mask.defined()){
return std::make_tuple(false, Tensor());
} else {
mask = index;
for (const auto j : c10::irange(index.dim())) {
int64_t srcIdx = num_ind + j;
TORCH_CHECK_INDEX(index.size(j) == self.size(srcIdx), "The shape of the mask ", index.sizes(), " at index ", j,
" does not match the shape of the indexed tensor ", self.sizes(), " at index ", srcIdx);
}
num_ind += mask.ndimension();
}
}
}
for (C10_UNUSED const auto i : c10::irange(num_ind, self.ndimension())) {
mask = mask.unsqueeze(-1);
}
return std::make_tuple(true, mask);
}
inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
checkIndexTensorTypes(orig, /*allow_int*/ true);
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
auto indices = expandTensors(self, orig);
// next broadcast all index tensors together
try {
indices = expand_outplace(indices);
} catch (std::exception& e) {
TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together"
" with shapes ", shapes_as_str(indices));
}
// add missing null Tensors so that it matches self.dim()
while (indices.size() < (size_t)self.dim()) {
indices.emplace_back();
}
// if the non-null indices are not all adjacent, transpose self and indices
// together so that they're adjacent at the front
if (!hasContiguousSubspace(indices)) {
std::tie(self, indices) = transposeToFront(self, indices);
}
// Ensure indices are on the same device as self
for (auto & indice : indices) {
if (indice.defined() && indice.device() != self.device()) {
indice = indice.to(self.device());
}
}
for (auto & indice : indices) {
if (indice.defined() && indice.dtype() == at::kInt) {
indice = indice.to(at::kLong);
}
}
return AdvancedIndex(self, indices);
}
} // namespace at::native

View File

@ -0,0 +1,49 @@
#pragma once
#include <ATen/native/DispatchStub.h>
namespace c10 {
class Scalar;
}
namespace at {
class Tensor;
struct TensorIterator;
struct TensorIteratorBase;
}
namespace at::native {
using reduce_minmax_fn =
void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool);
using structured_reduce_minmax_fn =
void (*)(const Tensor&, const Tensor&, const Tensor&, int64_t, bool);
DECLARE_DISPATCH(structured_reduce_minmax_fn, max_stub);
DECLARE_DISPATCH(structured_reduce_minmax_fn, min_stub);
using where_fn = void (*)(TensorIterator &);
DECLARE_DISPATCH(where_fn, where_kernel);
using is_infinity_op_fn = void (*)(TensorIteratorBase &);
DECLARE_DISPATCH(is_infinity_op_fn, isposinf_stub);
DECLARE_DISPATCH(is_infinity_op_fn, isneginf_stub);
using mode_fn = void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool);
DECLARE_DISPATCH(mode_fn, mode_stub);
using clamp_tensor_fn = void (*)(TensorIteratorBase &);
DECLARE_DISPATCH(clamp_tensor_fn, clamp_stub);
namespace detail {
enum class ClampLimits {Min, Max, MinMax};
}
DECLARE_DISPATCH(void (*)(TensorIteratorBase &, const c10::Scalar&, const c10::Scalar&), clamp_scalar_stub);
DECLARE_DISPATCH(void (*)(TensorIteratorBase &, c10::Scalar), clamp_min_scalar_stub);
DECLARE_DISPATCH(void (*)(TensorIteratorBase &, c10::Scalar), clamp_max_scalar_stub);
using isin_default_fn = void (*)(const Tensor&, const Tensor&, bool, const Tensor&);
DECLARE_DISPATCH(isin_default_fn, isin_default_stub);
} // namespace at::native

View File

@ -0,0 +1,26 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/core/Layout.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/ScalarType.h>
#include <optional>
namespace at {
class Tensor;
namespace native {
bool to_will_alias(
const Tensor& self,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
bool copy,
std::optional<c10::MemoryFormat> optional_memory_format);
Tensor to_meta(const Tensor& tensor);
std::optional<Tensor> to_meta(const std::optional<Tensor>& tensor);
std::vector<Tensor> to_meta(at::ITensorListRef t_list);
Tensor dense_to_sparse_with_mask(const Tensor& self, const Tensor& mask, std::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, std::optional<int64_t> dense_dim_opt);
} // namespace native
} // namespace at

View File

@ -0,0 +1,55 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <c10/util/irange.h>
namespace at::native {
//input tensors are non-zero dim and non-empty
template<typename T1, typename T2, typename Function>
void tensor_dim_apply3(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim, Function func) {
int ndims = self.dim();
int tensor_dim_apply_has_finished = 0;
std::vector<int64_t> counter(ndims, 0);
const T1* self_data = self.const_data_ptr<T1>();
T1* values_data = values.data_ptr<T1>();
T2* indices_data = indices.data_ptr<T2>();
int64_t self_stride = self.stride(dim);
int64_t values_stride = values.stride(dim);
int64_t indices_stride = indices.stride(dim);
int self_dim_size = self.size(dim);
while (!tensor_dim_apply_has_finished) {
func(self_data, values_data, indices_data, self_dim_size, self_stride, values_stride, indices_stride);
if (ndims == 1) {
break;
}
for (const auto dim_i : c10::irange(ndims)) {
if (dim_i == dim) {
if (dim_i == (ndims - 1)) {
tensor_dim_apply_has_finished = 1;
break;
}
continue;
}
counter[dim_i]++;
self_data += self.stride(dim_i);
values_data += values.stride(dim_i);
indices_data += indices.stride(dim_i);
if (counter[dim_i] == self.size(dim_i)) {
if (dim_i == ndims-1) {
tensor_dim_apply_has_finished = 1;
break;
} else {
self_data -= counter[dim_i]*self.stride(dim_i);
values_data -= counter[dim_i]*values.stride(dim_i);
indices_data -= counter[dim_i]*indices.stride(dim_i);
counter[dim_i] = 0;
}
} else {
break;
}
}
}
}
} // namespace at::native

View File

@ -0,0 +1,142 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/EmptyTensor.h>
#include <ATen/TensorIterator.h>
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/native/DispatchStub.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/scalar_tensor.h>
#endif
namespace at::native {
// Different combinations of row, col, and offset can lead to two cases:
//
// Case 1 - Trapezoid (Triangle as a special case): row + offset <= col
// Example A: offset > 0
// 1 1 0 0 0
// 1 1 1 0 0
// 1 1 1 1 0
// Example B: offset <= 0
// 0 0 0
// 1 0 0
// 1 1 0
// In this case, we calculate the number of elements in the first row and
// last row of the tril respectively, and then compute the tril size.
//
// Case 2 - Trapezoid + Rectangle: row + offset > col
// Example:
// 1 1 0
// 1 1 1
// 1 1 1
// In this case, we first calculate the size of top trapezoid, and then
// calculate the size of the bottom rectangle.
inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) {
// If either dimension is 0 then the there is no tril
if (row == 0 || col == 0) {
return 0;
}
// number of elements in the first row of the tril
auto m_first_row = offset > 0 ?
std::min<int64_t>(col, 1 + offset) : // upper bounded by col
row + offset > 0; // either 0 or 1
// number of elements in the last row of the tril, bounded by [0, col]
auto m_last_row = std::max<int64_t>(0, std::min<int64_t>(col, row + offset));
// number of rows, bounded by [0, row]
auto n_row_all = std::max<int64_t>(0, std::min<int64_t>(row, row + offset));
auto n_row_trapezoid = (m_last_row - m_first_row + 1);
// calculate # of elements in the top trapezoid
auto tril_size = (m_first_row + m_last_row) * n_row_trapezoid >> 1;
// calculate # of elements in the bottom rectangle if there is any
auto diff_row = n_row_all - n_row_trapezoid;
if (diff_row > 0) {
tril_size += diff_row * col;
}
return tril_size;
}
inline void check_args(
int64_t row, int64_t col, std::optional<Layout> layout_opt) {
TORCH_CHECK(row >= 0, "row must be non-negative, got", row);
TORCH_CHECK(col >= 0, "col must be non-negative, got", col);
if (layout_opt.has_value()) {
TORCH_CHECK(
*layout_opt == at::kStrided,
"only support layout=torch.strided, got",
*layout_opt)
}
}
using at::check_size_nonnegative;
// assumes maximum value in created tensor is n-1 (e.g., torch.randperm(n))
inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tensor) {
// match defined() to behavior of checks below
TORCH_CHECK(at::scalar_tensor(n>0?n-1:n, tensor.options()).defined(),
"n is too large for result tensor type: '", tensor.toString(), "'");
// Ensure sufficient precision for floating point representation.
switch (tensor.scalar_type()) {
case at::ScalarType::Half:
TORCH_CHECK(n <= (int64_t(1) << 11) + 1, "n cannot be greater than 2049 for Half type.");
break;
case at::ScalarType::Float:
TORCH_CHECK(n <= (int64_t(1) << 24) + 1, "n cannot be greater than 2^24+1 for Float type.");
break;
case at::ScalarType::Double: // Unlikely to happen, but doesn't hurt to check
TORCH_CHECK(n <= (int64_t(1) << 53) + 1, "n cannot be greater than 2^53+1 for Double type.");
break;
default:
break;
}
}
// Called by `empty*` functions when deterministic algorithms are enabled to
// fill the tensor with NaN if it is floating point or complex type, or fill
// with max value if it is integer type
inline Tensor& fill_empty_deterministic_(Tensor& tensor) {
if (tensor.is_floating_point() || tensor.is_complex()) {
AT_DISPATCH_V2(
tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
tensor.fill_(std::numeric_limits<scalar_t>::quiet_NaN());
}), AT_EXPAND(AT_FLOATING_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), AT_EXPAND(AT_FLOAT8_TYPES), kBFloat16, kHalf);
} else {
AT_DISPATCH_V2(
tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
tensor.fill_(std::numeric_limits<scalar_t>::max());
}), kBool, AT_EXPAND(AT_INTEGRAL_TYPES_V2));
}
return tensor;
}
// The ZeroTensor allocator ignores whatever allocation is requested and always
// gives you nullptr
struct ZeroTensorAllocator final : public at::Allocator {
ZeroTensorAllocator(at::Device device) : device_(device) {};
~ZeroTensorAllocator() override = default;
static void deleter(void* const pointer) {
TORCH_INTERNAL_ASSERT(!pointer);
}
DataPtr allocate(const size_t /*nbytes*/) override {
return {nullptr, nullptr, &deleter, device_};
}
DeleterFnPtr raw_deleter() const override {
return deleter;
}
void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]]) const final {}
at::Device device_;
};
using binary_fn = void (*)(TensorIterator&);
DECLARE_DISPATCH(binary_fn, complex_stub);
DECLARE_DISPATCH(binary_fn, polar_stub);
} // namespace at::native

View File

@ -0,0 +1,2 @@
#pragma once
#include <ATen/TensorIterator.h>

View File

@ -0,0 +1,52 @@
#pragma once
#include <complex>
#include <type_traits>
#include <c10/core/ScalarType.h>
#include <ATen/detail/FunctionTraits.h>
#include <ATen/native/TensorIterator.h>
// This file includes utilities for dynamic_casting done by TensorIterator, see CUDALoops.cuh and Loops.h.
// dynamic_casting handles when the types expected by the iterator do not match the types of the arguments
// to the function that is being called.
// On CUDA, the cast is currently pushed down into the kernel (for performance reasons).
// On CPU, there is currently an internal assert that a dynamic_cast is not needed.
namespace at::native {
// `needs_dynamic_casting` compares the types expected by iterator
// (i.e. dtypes of the operands) with the actual type of the arguments
// (and returns) of func_t
template<typename func_t, int nargs=function_traits<func_t>::arity>
struct needs_dynamic_casting {
static bool check(TensorIteratorBase& iter) {
using traits = function_traits<func_t>;
using cpp_type = typename traits::template arg<nargs - 1>::type;
using cpp_map = c10::CppTypeToScalarType<cpp_type>;
if (iter.input_dtype(nargs-1) != cpp_map::value) {
return true;
}
return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
}
};
template<typename func_t>
struct needs_dynamic_casting<func_t, 0> {
static bool check(TensorIteratorBase& iter) {
using traits = function_traits<func_t>;
using cpp_type = typename traits::result_type;
// we could assert output numbers are correct here, but checks
// (including arity) are currently pushed outside of this struct.
if constexpr (std::is_void_v<cpp_type>) {
return false;
} else {
return iter.dtype(0) != c10::CppTypeToScalarType<cpp_type>::value;
}
}
};
} //namespace at::native

View File

@ -0,0 +1,12 @@
#pragma once
// See NOTE: [Tensor vs. TensorBase]
namespace at {
class TensorBase;
}
namespace at::native {
TORCH_API bool cudnn_is_acceptable(const TensorBase& self);
} // namespace at::native

View File

@ -0,0 +1,105 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <c10/util/irange.h>
#include <ATen/core/IListRef.h>
namespace at::native {
TORCH_API at::Tensor clone_preserve_strides(const at::Tensor& self);
inline bool cat_should_skip_tensor(const Tensor& t) {
return t.sym_numel() == 0 && t.dim() == 1;
}
// Check to see if the shape of tensors is compatible
// for being concatenated along a given dimension.
inline void check_cat_shape_except_dim(const Tensor & first, const Tensor & second, int64_t dimension, int64_t index) {
int64_t first_dims = first.dim();
int64_t second_dims = second.dim();
TORCH_CHECK(first_dims == second_dims, "Tensors must have same number of dimensions: got ",
first_dims, " and ", second_dims);
for (const auto dim : c10::irange(first_dims)) {
if (dim == dimension) {
continue;
}
int64_t first_dim_size = first.sizes()[dim];
int64_t second_dim_size = second.sizes()[dim];
TORCH_CHECK(first_dim_size == second_dim_size, "Sizes of tensors must match except in dimension ",
dimension, ". Expected size ", static_cast<long long>(first_dim_size), " but got size ", static_cast<long long>(second_dim_size), " for tensor number ", index, " in the list.");
}
}
inline void check_cat_no_zero_dim(const MaterializedITensorListRef& tensors) {
int64_t i = 0;
for(const Tensor& t : tensors) {
TORCH_CHECK(t.dim() > 0,
"zero-dimensional tensor (at position ", i, ") cannot be concatenated");
i++;
}
}
inline int64_t get_num_splits(const Tensor& self, int64_t split_size, int64_t dim) {
TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
TORCH_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size);
int64_t dim_size = self.size(dim);
TORCH_CHECK(split_size > 0 || dim_size == 0,
"split_size can only be 0 if dimension size is 0, "
"but got dimension size of ", dim_size);
// if split_size is 0 and dimension size is 0, there is 1 split.
int64_t num_splits = 1;
if (split_size != 0) {
// ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size
// (returns a single split). We might want to error here, but keep it for BC.
num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1);
}
return num_splits;
}
inline bool have_same_ndims(TensorList tensors) {
auto ndim = tensors[0].dim();
for (const auto tensor_idx : c10::irange(tensors.size())) {
if(tensors[tensor_idx].dim() != ndim) {
return false;
}
}
return true;
}
inline void leading_dimension_matches(TensorList tensors, int64_t dim) {
auto tensor_zero_size = tensors[0].sizes();
std::vector<c10::SymInt> leading_dim_sizes(tensor_zero_size.begin(), tensor_zero_size.begin() + dim);
for (const auto i : c10::irange(tensors.size())) {
at::Tensor tensor = tensors[i];
for(const auto j : c10::irange(dim)) {
TORCH_CHECK(
tensor.size(j) == leading_dim_sizes[j],
"_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors"
);
}
}
}
inline int64_t preprocess_chunk_cat_inputs(TensorList tensors, int64_t dim, int64_t num_chunks) {
TORCH_CHECK(num_chunks >= 1, "_chunk_cat expects positive num_chunks");
TORCH_CHECK(!tensors.empty(),
"_chunk_cat expects a non-empty input tensor list");
auto expected_dtype = tensors[0].dtype();
auto expected_device = tensors[0].device();
for(const auto i : c10::irange(tensors.size())) {
TORCH_CHECK(tensors[i].numel() > 0, "_chunk_cat expects non-empty tensor");
TORCH_CHECK(tensors[i].dtype() == expected_dtype, "_chunk_cat expects all input tensors with the same dtype");
TORCH_CHECK(tensors[i].device() == expected_device, "_chunk_cat expects all inputs tensors on the same device");
}
if (have_same_ndims(tensors)) {
dim = maybe_wrap_dim(dim, tensors[0].dim());
} else {
TORCH_CHECK(dim >= 0, "_chunk_cat expects non-negative dim when input tensors have different ndims")
for(const auto i : c10::irange(tensors.size())) {
TORCH_CHECK(dim < tensors[i].ndimension(), "_chunk_cat expects dim < ndim for all input tensors");
}
}
leading_dimension_matches(tensors, dim);
return dim;
}
} // namespace at::native

View File

@ -0,0 +1,30 @@
#include <ATen/core/Tensor.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/roll.h>
#endif
#include <c10/util/Exception.h>
namespace at::native {
static inline Tensor roll_common(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) {
TORCH_CHECK(!shifts.empty(), "`shifts` required");
if (dims.empty() && shifts.size() == 1) {
auto flattened = self.contiguous().view(self.numel());
return roll(flattened, shifts[0], 0).view(self.sizes());
}
TORCH_CHECK(
shifts.size() == dims.size(),
"shifts and dimensions must align. shifts: ", shifts.size(), ", dims:", dims.size()
);
AT_ASSERT(dims.size() > 1);
auto tail_shifts = shifts.slice(1);
auto tail_dims = dims.slice(1);
auto first_dim_rolled = roll(self, shifts[0], dims[0]);
return at::roll(first_dim_rolled, tail_shifts, tail_dims);
}
} // namespace at::native

View File

@ -0,0 +1,98 @@
#pragma once
#include <ATen/core/TensorAccessor.h>
#include <ATen/NumericUtils.h>
namespace at::native {
#ifdef CPU_CAPABILITY
inline namespace CPU_CAPABILITY {
#else
inline namespace DEFAULT {
#endif
// Core topk loop, shared between CPU and QuantizedCPU
template <typename scalar_t, typename accscalar_t>
void topk_impl_loop(
const int64_t mode_values_stride,
const int64_t mode_indices_stride,
const int64_t tmp_values_stride,
const int64_t k,
const int64_t dim_size,
const bool largest,
const bool sorted,
char** data, const int64_t* strides, const int64_t n) {
// If k is zero, then output values and indices are empty tensors
// So iterating over other dims is pointless
if (k == 0) {
return;
}
using elem_t = std::pair<accscalar_t, int64_t>;
std::vector<elem_t> queue(dim_size);
for (const auto i : c10::irange(n)) {
TensorAccessor<scalar_t, 1> mode_values(
reinterpret_cast<scalar_t*>(data[0] + i * strides[0]),
&k, &mode_values_stride);
TensorAccessor<int64_t, 1> mode_indices(
reinterpret_cast<int64_t*>(data[1] + i * strides[1]),
&k, &mode_indices_stride);
TensorAccessor<const scalar_t, 1> tmp_values(
reinterpret_cast<scalar_t*>(data[2] + i * strides[2]),
&dim_size, &tmp_values_stride);
auto n_2 = dim_size;
auto use_partial_sort = k * 64 <= n_2;
for (const auto j : c10::irange(n_2)) {
queue[j].first = tmp_values[j];
queue[j].second = j;
}
// we want nan to be sorted as top for numpy compatibility
if (use_partial_sort) {
if (largest) {
std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
[](const elem_t& x, const elem_t& y) -> bool {
return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
});
} else {
std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
[](const elem_t& x, const elem_t& y) -> bool {
return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
});
}
} else {
if (largest) {
std::nth_element(queue.begin(), queue.begin() + k - 1, queue.end(),
[](const elem_t& x, const elem_t& y) -> bool {
return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
});
if (sorted) {
std::sort(queue.begin(), queue.begin() + k - 1,
[](const elem_t& x, const elem_t& y) -> bool {
return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
});
}
} else {
std::nth_element(queue.begin(), queue.begin() + k -1, queue.end(),
[](const elem_t& x, const elem_t& y) -> bool {
return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
});
if (sorted) {
std::sort(queue.begin(), queue.begin() + k -1,
[](const elem_t& x, const elem_t& y) -> bool {
return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
});
}
}
}
for (const auto j : c10::irange(k)) {
mode_values[j] = queue[j].first;
mode_indices[j] = queue[j].second;
}
}
}
} // namespace CPU_CAPABILITY
} // namespace at::native

View File

@ -0,0 +1,23 @@
#pragma once
#include <c10/util/Exception.h>
namespace at::native {
// Used as an interface between the different BLAS-like libraries
enum class TransposeType {
NoTranspose,
Transpose,
ConjTranspose,
};
// Transforms TransposeType into the BLAS / LAPACK format
static inline char to_blas(TransposeType trans) {
switch (trans) {
case TransposeType::Transpose: return 'T';
case TransposeType::NoTranspose: return 'N';
case TransposeType::ConjTranspose: return 'C';
}
TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
}
} // namespace at::native

View File

@ -0,0 +1,57 @@
#include <ATen/core/Tensor.h>
#include <ATen/native/LinearAlgebraUtils.h>
namespace at::native {
/*
* Given batches of matrices with arbitrary batch dim,
* computes the number of batches for Triu and Tril. This ignores stride 0 dimension
*/
static inline int64_t batchCountTrilTriu(const Tensor& batched_matrices) {
int64_t result = 1;
for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
if (batched_matrices.stride(i) != 0) {
result *= batched_matrices.size(i);
}
}
return result;
}
/* Checks a necessary property for the triu and tril implementations, hence the name.
* Here batch contiguity is checked for tensors with greater than 4 dimensions.
* Contiguous tensors and tensors with less than 3 dimensions pass this check
*/
static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor& tensor, bool allow_zero_stride) {
// Complete contiguity is the most desired property, which is why
// we return true if the tensor is contiguous
if (tensor.is_contiguous()) {
auto default_strides_for_size = batched_matrix_contiguous_strides(tensor.sizes());
if (tensor.strides() == default_strides_for_size) {
return std::make_tuple(true, tensor);
} else {
return std::make_tuple(false, tensor.as_strided(tensor.sizes(), default_strides_for_size));
}
}
int64_t dims = tensor.dim();
// Tensors with dimension less than 4 are handled by default
if (allow_zero_stride && dims <= 3) {
return std::make_tuple(true, tensor);
}
int64_t expected_stride = tensor.size(-1) * tensor.size(-2);
for (int64_t i = dims - 3; i >= 0; i--) {
// Skip trivial dimension;
if (allow_zero_stride && i == 0 && (tensor.stride(i) == 0 || tensor.size(i) == 1)) {
continue;
}
if (expected_stride != tensor.stride(i)) {
return std::make_tuple(false, tensor.contiguous());
}
expected_stride *= tensor.size(i);
}
return std::make_tuple(true, tensor);
}
} // namespace at::native

View File

@ -0,0 +1,20 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/core/IListRef.h>
namespace at::native {
struct ResultTypeState {
c10::ScalarType dimResult = ScalarType::Undefined;
c10::ScalarType wrappedResult = ScalarType::Undefined;
c10::ScalarType zeroResult = ScalarType::Undefined;
};
TORCH_API ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeState& in_state);
TORCH_API ResultTypeState update_result_type_state(const Scalar& scalar, const ResultTypeState& in_state);
TORCH_API ScalarType result_type(const ResultTypeState& state);
TORCH_API ScalarType result_type(ITensorListRef tensors);
} // namespace at::native

View File

@ -0,0 +1,130 @@
#pragma once
#include <ATen/native/DispatchStub.h>
#include <ATen/Generator.h>
#include <c10/core/Scalar.h>
#include <stdexcept>
namespace at {
class Tensor;
class TensorBase;
struct TensorIteratorBase;
}
namespace at::native {
using unary_fn = void(*)(TensorIteratorBase&);
using unary_fn_with_scalar = void(*)(TensorIteratorBase&, const Scalar& a);
inline namespace CPU_CAPABILITY {
void conj_kernel(TensorIteratorBase &iter);
void neg_kernel(TensorIteratorBase &iter);
void reciprocal_kernel(TensorIteratorBase &iter);
void rsqrt_kernel(TensorIteratorBase& iter);
void sqrt_kernel(TensorIteratorBase& iter);
} // namespace CPU_CAPABILITY
DECLARE_DISPATCH(unary_fn, abs_stub);
DECLARE_DISPATCH(unary_fn, angle_stub);
DECLARE_DISPATCH(unary_fn, conj_physical_stub);
DECLARE_DISPATCH(unary_fn, acos_stub);
DECLARE_DISPATCH(unary_fn, acosh_stub);
DECLARE_DISPATCH(unary_fn, asinh_stub);
DECLARE_DISPATCH(unary_fn, atanh_stub);
DECLARE_DISPATCH(unary_fn, asin_stub);
DECLARE_DISPATCH(unary_fn, atan_stub);
DECLARE_DISPATCH(unary_fn, bitwise_not_stub);
DECLARE_DISPATCH(unary_fn, logical_not_stub);
DECLARE_DISPATCH(unary_fn, ceil_stub);
DECLARE_DISPATCH(unary_fn, cos_stub);
DECLARE_DISPATCH(unary_fn, cosh_stub);
DECLARE_DISPATCH(unary_fn, digamma_stub);
DECLARE_DISPATCH(unary_fn, special_entr_stub);
DECLARE_DISPATCH(unary_fn, special_erfcx_stub);
DECLARE_DISPATCH(unary_fn, erf_stub);
DECLARE_DISPATCH(unary_fn, erfc_stub);
DECLARE_DISPATCH(unary_fn, erfinv_stub);
DECLARE_DISPATCH(unary_fn, exp_stub);
DECLARE_DISPATCH(unary_fn, exp2_stub);
DECLARE_DISPATCH(unary_fn, expm1_stub);
DECLARE_DISPATCH(unary_fn, floor_stub);
DECLARE_DISPATCH(unary_fn, frac_stub);
DECLARE_DISPATCH(unary_fn, frexp_stub);
DECLARE_DISPATCH(unary_fn, i0_stub);
DECLARE_DISPATCH(unary_fn, special_i0e_stub);
DECLARE_DISPATCH(unary_fn, special_i1_stub);
DECLARE_DISPATCH(unary_fn, special_i1e_stub);
DECLARE_DISPATCH(unary_fn, log_stub);
DECLARE_DISPATCH(unary_fn, log10_stub);
DECLARE_DISPATCH(unary_fn, log1p_stub);
DECLARE_DISPATCH(unary_fn, log2_stub);
DECLARE_DISPATCH(unary_fn, special_ndtri_stub);
DECLARE_DISPATCH(unary_fn, special_log_ndtr_stub);
DECLARE_DISPATCH(unary_fn, neg_stub);
DECLARE_DISPATCH(unary_fn, reciprocal_stub);
DECLARE_DISPATCH(unary_fn, round_stub);
DECLARE_DISPATCH(unary_fn, rsqrt_stub);
DECLARE_DISPATCH(unary_fn, sigmoid_stub);
DECLARE_DISPATCH(unary_fn_with_scalar, logit_stub);
DECLARE_DISPATCH(unary_fn, sign_stub);
DECLARE_DISPATCH(unary_fn, signbit_stub);
DECLARE_DISPATCH(unary_fn, sgn_stub);
DECLARE_DISPATCH(unary_fn, sin_stub);
DECLARE_DISPATCH(unary_fn, sinc_stub);
DECLARE_DISPATCH(unary_fn, sinh_stub);
DECLARE_DISPATCH(unary_fn, sqrt_stub);
DECLARE_DISPATCH(unary_fn, tan_stub);
DECLARE_DISPATCH(unary_fn, tanh_stub);
DECLARE_DISPATCH(unary_fn, trigamma_stub);
DECLARE_DISPATCH(unary_fn, trunc_stub);
DECLARE_DISPATCH(unary_fn, lgamma_stub);
DECLARE_DISPATCH(unary_fn, special_airy_ai_stub);
DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub);
DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub);
DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub);
DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub);
DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub);
DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub);
DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub);
DECLARE_DISPATCH(unary_fn, special_modified_bessel_k1_stub);
DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k0_stub);
DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k1_stub);
DECLARE_DISPATCH(unary_fn, special_spherical_bessel_j0_stub);
// NB: these are actually defined in Distribution
DECLARE_DISPATCH(void(*)(const TensorBase&, const TensorBase&, std::optional<Generator>), bernoulli_tensor_stub);
DECLARE_DISPATCH(void(*)(const TensorBase&, const double, std::optional<Generator>), bernoulli_scalar_stub);
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), cauchy_stub);
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional<Generator>), exponential_stub);
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional<Generator>), geometric_stub);
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), log_normal_stub);
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), uniform_stub);
DECLARE_DISPATCH(void(*)(const TensorBase&, const double, const double, std::optional<Generator>), normal_stub);
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const uint64_t, const int64_t, std::optional<Generator>), random_from_to_stub);
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional<Generator>), random_full_64_bits_range_stub);
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional<Generator>), random_stub);
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t, const double), kaiser_window_stub);
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t), polygamma_stub);
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const Scalar& a, const Scalar& b), clamp_stub);
DECLARE_DISPATCH(
void (*)(Tensor&, const Tensor&, int64_t, std::optional<Generator>),
multinomial_with_replacement_stub);
DECLARE_DISPATCH(
void (*)(
TensorIteratorBase&,
std::optional<double>,
std::optional<double>,
std::optional<double>),
nan_to_num_stub);
DECLARE_DISPATCH(void (*)(TensorIteratorBase&, int64_t), round_decimals_stub);
// Missing unary functions
// digamma
// lgamma
// erfinv
// clone
// contiguous
// zero
} // namespace at::native

View File

@ -0,0 +1,48 @@
#pragma once
#include <ATen/native/DispatchStub.h>
#include <c10/core/ScalarType.h>
#include <cstdint>
namespace at::native {
using unfold2d_copy_fn = void (*)(
ScalarType dtype,
void *finput,
const void *input,
int64_t kH,
int64_t kW,
int64_t dH,
int64_t dW,
int64_t padH,
int64_t padW,
int64_t n_input_plane,
int64_t input_height,
int64_t input_width,
int64_t output_height,
int64_t output_width,
bool is_channels_last
);
using unfold2d_acc_fn = void (*)(
ScalarType dtype,
void *finput,
void *input,
int64_t kH,
int64_t kW,
int64_t dH,
int64_t dW,
int64_t padH,
int64_t padW,
int64_t n_input_plane,
int64_t input_height,
int64_t input_width,
int64_t output_height,
int64_t output_width,
bool is_channels_last
);
DECLARE_DISPATCH(unfold2d_copy_fn, unfolded2d_copy_stub);
DECLARE_DISPATCH(unfold2d_acc_fn, unfolded2d_acc_stub);
} // namespace at::native

View File

@ -0,0 +1,49 @@
#pragma once
#include <c10/core/ScalarType.h>
namespace at::native {
void Unfold3dCopyCPU(
ScalarType dtype,
const void *src,
int64_t C,
int64_t X_D,
int64_t X_H,
int64_t X_W,
int64_t Y_D,
int64_t Y_H,
int64_t Y_W,
int64_t kernel_d,
int64_t kernel_h,
int64_t kernel_w,
int64_t stride_d,
int64_t stride_h,
int64_t stride_w,
int64_t pad_d,
int64_t pad_h,
int64_t pad_w,
void* dst);
void Unfold3dAccCPU(
ScalarType dtype,
const void *src,
int64_t C,
int64_t X_D,
int64_t X_H,
int64_t X_W,
int64_t Y_D,
int64_t Y_H,
int64_t Y_W,
int64_t kernel_d,
int64_t kernel_h,
int64_t kernel_w,
int64_t stride_d,
int64_t stride_h,
int64_t stride_w,
int64_t pad_d,
int64_t pad_h,
int64_t pad_w,
void *dst);
} // namespace at::native

View File

@ -0,0 +1,112 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/TensorIterator.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/NonEmptyUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/arange.h>
#endif
namespace at::native {
using unfold_backward_fn = void (*)(
Tensor& grad_in,
const Tensor& grad,
int64_t dim,
int64_t size,
int64_t step
);
DECLARE_DISPATCH(unfold_backward_fn, unfold_backward_stub);
namespace {
// Note on naming: it is unconventional.
// grad_in does not mean that it is a gradient wrt to input,
// grad_in/grad_out is just an input/output of unfold_backward kernel.
static C10_UNUSED TensorIterator _make_unfold_backward_iter_over_grad_out(
Tensor& grad_out,
const Tensor& grad_in,
int64_t dim,
int64_t size,
int64_t step
) {
dim = maybe_wrap_dim(dim, grad_out.dim());
// last dim stores the folds
auto grad_out_dim_size = ensure_nonempty_size(grad_out, dim);
auto grad_in_dim_size = ensure_nonempty_size(grad_in, dim);
// dictates the number of elements to iterate over
// in dimension `dim`
auto iter_dim_size = std::min(
grad_out_dim_size,
(grad_in_dim_size - 1) * step + size
);
/* prepare grad_out for TensorIterator { */
auto grad_out_strides = ensure_nonempty_vec(grad_out.strides().vec());
auto grad_out_sizes = ensure_nonempty_vec(grad_out.sizes().vec());
grad_out_sizes[dim] = iter_dim_size;
auto grad_out_restrided = grad_out.as_strided(
grad_out_sizes, grad_out_strides
);
/* } */
/* prepare grad_in for TensorIterator { */
auto grad_in_strides = ensure_nonempty_vec(grad_in.strides().vec());
auto grad_in_sizes = ensure_nonempty_vec(grad_in.sizes().vec());
// set strides for dim to 0
// and size to 1 because
// this dimension is indexed inside the kernel
grad_in_strides[dim] = 0;
grad_in_sizes[dim] = 1;
grad_in_strides.pop_back();
grad_in_sizes.pop_back();
auto grad_in_restrided = grad_in.squeeze(-1).as_strided(
grad_in_sizes, grad_in_strides
);
/* } */
// During the TensorIterator iteration we have to know
// i_dim in grad_out[i_1,...,i_dim,...i_n],
// idx_dim stores this information
/* prepare idx_dim for TensorIterator { */
auto idx_dim = at::arange(
0, iter_dim_size, grad_in.options().dtype(at::kLong)
);
auto grad_out_dim = ensure_nonempty_dim(grad_out.dim());
auto idx_dim_strides = std::vector<int64_t>(grad_out_dim, 0);
auto idx_dim_sizes = std::vector<int64_t>(grad_out_dim, 1);
idx_dim_strides[dim] = 1;
idx_dim_sizes[dim] = iter_dim_size;
// idx_dim size will broadcast over determined by grad_out sizes in TensorIterator
auto idx_dim_restrided = idx_dim.as_strided(idx_dim_sizes, idx_dim_strides);
/* } */
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.add_owned_output(grad_out_restrided)
.add_owned_const_input(grad_in_restrided)
.add_owned_const_input(idx_dim_restrided)
.build();
return iter;
}
}
} // namespace at::native

View File

@ -0,0 +1,505 @@
#pragma once
#include <math.h>
#include <ATen/OpMathType.h>
#include <ATen/TensorUtils.h>
#include <ATen/OpMathType.h>
#include <ATen/core/Tensor.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/cpu/utils.h>
/**
* Note [compute_scales_value]
* Note [area_pixel_compute_scale]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* Interpolate with scale_factor can have different behaviors
* depending on the value of recompute_scale_factor:
*
* - With recompute_scale_factor = True (current default behavior):
* the scale_factor, when provided by the user, are used to calculate
* the output size. The input size and the computed output_size
* are then used to infer new values for the scales which are
* used in the interpolation. Because floating-point math is not exact,
* this may be a different value from the user-supplied scales.
*
* - With recompute_scale_factor = False (which will be the default
* behavior starting 1.5.0):
* the behavior follows opencv logic, and the scales provided by
* the user are the ones used in the interpolation calculations.
*
* If the scales are not provided or if they are provided but
* recompute_scale_factor is set to True (default behavior), the scales
* are computed from the input and the output size;
*
*
* When the scales are inferred from the input and output sizes,
* we view each pixel as an area, idx + 0.5 as its center index.
* Here is an example formula in 1D case.
* if align_corners: center of two corner pixel areas are preserved,
* (0.5, 0.5) -> (0.5, 0.5),
* (input_size - 0.5, 0.5) -> (output_size - 0.5)
* scale = (input_size - 0.5 - 0.5) / (output_size - 0.5 - 0.5)
* src_index + 0.5 - 0.5 = scale * (dst_index + 0.5 - 0.5)
* if not align_corners: the whole range is scaled accordingly
* scale = input_size / output_size
* src_idx + 0.5 = scale * (dst_index + 0.5)
*/
namespace at::native {
namespace upsample {
TORCH_API c10::SmallVector<int64_t, 3> compute_output_size(
c10::IntArrayRef input_size, // Full input tensor size.
at::OptionalIntArrayRef output_size,
std::optional<c10::ArrayRef<double>> scale_factors);
inline std::optional<double> get_scale_value(std::optional<c10::ArrayRef<double>> scales, int idx) {
if (!scales) {
return std::nullopt;
}
return scales->at(idx);
}
} // namespace upsample
using scale_t = std::optional<double>;
using upsampling_nearest1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
using _upsampling_nearest_exact1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
using upsampling_nearest2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
using _upsampling_nearest_exact2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
using upsampling_nearest3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
using _upsampling_nearest_exact3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
using upsampling_linear1d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_w);
using upsampling_bilinear2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
using _upsampling_bilinear2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
using upsampling_trilinear3d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_d, scale_t scales_h, scale_t scales_w);
using upsampling_bicubic2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
using _upsampling_bicubic2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_kernel);
DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_kernel);
DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_kernel);
DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_kernel);
DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_kernel);
DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_kernel);
DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_backward_kernel);
DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_backward_kernel);
DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_backward_kernel);
DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_backward_kernel);
DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_backward_kernel);
DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_backward_kernel);
DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_kernel);
DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_kernel);
DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_kernel);
DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_kernel);
DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_backward_kernel);
DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_backward_kernel);
DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_backward_kernel);
DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_backward_kernel);
DECLARE_DISPATCH(upsampling_bicubic2d, upsample_bicubic2d_kernel);
DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_kernel);
DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_backward_kernel);
inline C10_UNUSED std::array<int64_t, 3> upsample_1d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
TORCH_CHECK(
output_size.size() == 1,
"It is expected output_size equals to 1, but got size ",
output_size.size());
TORCH_CHECK(
input_size.size() == 3,
"It is expected input_size equals to 3, but got size ",
input_size.size());
int64_t output_width = output_size[0];
int64_t nbatch = input_size[0];
int64_t channels = input_size[1];
int64_t input_width = input_size[2];
TORCH_CHECK(
input_width > 0 && output_width > 0,
"Input and output sizes should be greater than 0, but got input (W: ",
input_width,
") and output (W: ",
output_width,
")");
return {nbatch, channels, output_width};
}
inline C10_UNUSED std::array<int64_t, 4> upsample_2d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
TORCH_CHECK(
output_size.size() == 2,
"It is expected output_size equals to 2, but got size ",
output_size.size());
TORCH_CHECK(
input_size.size() == 4,
"It is expected input_size equals to 4, but got size ",
input_size.size());
int64_t output_height = output_size[0];
int64_t output_width = output_size[1];
int64_t nbatch = input_size[0];
int64_t channels = input_size[1];
int64_t input_height = input_size[2];
int64_t input_width = input_size[3];
TORCH_CHECK(
input_height > 0 && input_width > 0 && output_height > 0 &&
output_width > 0,
"Input and output sizes should be greater than 0,"
" but got input (H: ",
input_height,
", W: ",
input_width,
") output (H: ",
output_height,
", W: ",
output_width,
")");
return {nbatch, channels, output_height, output_width};
}
inline C10_UNUSED
std::array<int64_t, 5> upsample_3d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
TORCH_CHECK(
output_size.size() == 3,
"It is expected output_size equals to 3, but got size ",
output_size.size());
TORCH_CHECK(
input_size.size() == 5,
"It is expected input_size equals to 5, but got size ",
input_size.size());
int64_t output_depth = output_size[0];
int64_t output_height = output_size[1];
int64_t output_width = output_size[2];
int64_t nbatch = input_size[0];
int64_t channels = input_size[1];
int64_t input_depth = input_size[2];
int64_t input_height = input_size[3];
int64_t input_width = input_size[4];
TORCH_CHECK(
input_depth > 0 && input_height > 0 && input_width > 0 &&
output_depth > 0 && output_height > 0 && output_width > 0,
"Input and output sizes should be greater than 0, but got input (D: ",
input_depth,
", H: ",
input_height,
", W: ",
input_width,
") output (D: ",
output_depth,
", H: ",
output_height,
", W: ",
output_width,
")");
return {nbatch, channels, output_depth, output_height, output_width};
}
inline void upsample_2d_shape_check(
const Tensor& input,
const Tensor& grad_output,
int64_t nbatch,
int64_t nchannels,
int64_t input_height,
int64_t input_width,
int64_t output_height,
int64_t output_width) {
TORCH_CHECK(
input_height > 0 && input_width > 0 && output_height > 0 &&
output_width > 0,
"Input and output sizes should be greater than 0,"
" but got input (H: ",
input_height,
", W: ",
input_width,
") output (H: ",
output_height,
", W: ",
output_width,
")");
if (input.defined()) {
// Allow for empty batch size but not other dimensions
TORCH_CHECK(
(input.numel() != 0 ||
(input.size(1) != 0 && input.size(2) != 0 && input.size(3) != 0)
) &&
input.dim() == 4,
"Non-empty 4D data tensor expected but got a tensor with sizes ",
input.sizes());
} else if (grad_output.defined()) {
check_dim_size(grad_output, 4, 0, nbatch);
check_dim_size(grad_output, 4, 1, nchannels);
check_dim_size(grad_output, 4, 2, output_height);
check_dim_size(grad_output, 4, 3, output_width);
}
}
template <typename scalar_t>
inline scalar_t compute_scales_value(
const std::optional<double> scale,
int64_t input_size,
int64_t output_size) {
// see Note [compute_scales_value]
// FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults.
return (scale.has_value() && scale.value() > 0.)
? static_cast<scalar_t>(1.0 / scale.value())
: (static_cast<scalar_t>(input_size) / output_size);
}
template <typename scalar_t>
inline scalar_t area_pixel_compute_scale(
int64_t input_size,
int64_t output_size,
bool align_corners,
const std::optional<double> scale) {
// see Note [area_pixel_compute_scale]
if(align_corners) {
if(output_size > 1) {
return static_cast<scalar_t>(input_size - 1) / (output_size - 1);
} else {
return static_cast<scalar_t>(0);
}
} else {
return compute_scales_value<scalar_t>(scale, input_size, output_size);
}
}
template <typename scalar_t>
inline scalar_t area_pixel_compute_source_index(
scalar_t scale,
int64_t dst_index,
bool align_corners,
bool cubic) {
if (align_corners) {
return scale * dst_index;
} else {
scalar_t src_idx = scale * (dst_index + static_cast<scalar_t>(0.5)) -
static_cast<scalar_t>(0.5);
// [Note] Follow Opencv resize logic:
// We allow negative src_idx here and later will use
// dx = src_idx - floorf(src_idx)
// to compute the "distance"(which affects weights).
// For linear modes, weight distribution doesn't matter
// for negative indices as they use 2 pixels to interpolate.
// For example, [-1, 0], they both use pixel 0 value so it
// doesn't affect if we bound the src_idx to 0 or not.
// TODO: Our current linear mode impls use unbound indices
// where we should and then remove this cubic flag.
// This matters in cubic mode, as we might need [-1, 0, 1, 2]
// to interpolate and the weights can be affected.
return (!cubic && src_idx < static_cast<scalar_t>(0)) ? scalar_t(0)
: src_idx;
}
}
inline int64_t nearest_neighbor_compute_source_index(
const float scale,
int64_t dst_index,
int64_t input_size) {
// Index computation matching OpenCV INTER_NEAREST
// which is buggy and kept for BC
const int64_t src_index =
std::min(static_cast<int64_t>(floorf(dst_index * scale)), input_size - 1);
return src_index;
}
inline int64_t nearest_neighbor_exact_compute_source_index(
const float scale,
int64_t dst_index,
int64_t input_size) {
// index_f32 = (output_index + 0.5) * scale - 0.5
// input_index = round(index_f32)
// Same as Pillow and Scikit-Image/Scipy ndi.zoom
const int64_t src_index =
std::min(static_cast<int64_t>(floorf((dst_index + 0.5) * scale)), input_size - 1);
return src_index;
}
inline int64_t nearest_idx(
int64_t output_index,
int64_t input_size,
int64_t output_size,
std::optional<double> scales) {
// This method specificly treats cases: output_size == input_size or
// output_size == 2 * input_size, that we would like to get rid of
// We keep this method for BC and consider as deprecated.
// See nearest_exact_idx as replacement
if (output_size == input_size) {
// scale_factor = 1, simply copy
return output_index;
} else if (output_size == 2 * input_size) {
// scale_factor = 2, shift input index
return output_index >> 1;
} else {
float scale = compute_scales_value<float>(scales, input_size, output_size);
return nearest_neighbor_compute_source_index(scale, output_index, input_size);
}
}
inline int64_t nearest_exact_idx(
int64_t output_index,
int64_t input_size,
int64_t output_size,
std::optional<double> scales) {
float scale = compute_scales_value<float>(scales, input_size, output_size);
return nearest_neighbor_exact_compute_source_index(scale, output_index, input_size);
}
// Define a typedef to dispatch to nearest_idx or nearest_exact_idx
typedef int64_t (*nearest_idx_fn_t)(int64_t, int64_t, int64_t, std::optional<double>);
template <typename scalar_t>
scalar_t upsample_get_value_bounded(
scalar_t* data,
int64_t width,
int64_t height,
int64_t x,
int64_t y) {
int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
return data[access_y * width + access_x];
}
template <typename scalar_t>
void upsample_increment_value_bounded(
scalar_t* data,
int64_t width,
int64_t height,
int64_t x,
int64_t y,
scalar_t value) {
int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
data[access_y * width + access_x] += value;
}
// Based on
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
template <typename scalar_t>
scalar_t cubic_convolution1(scalar_t x, scalar_t A) {
return ((A + 2) * x - (A + 3)) * x * x + 1;
}
template <typename scalar_t>
scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
}
template <typename scalar_t>
void get_cubic_upsample_coefficients(
scalar_t coeffs[4],
scalar_t t) {
scalar_t A = -0.75;
scalar_t x1 = t;
coeffs[0] = cubic_convolution2<scalar_t>(x1 + 1.0, A);
coeffs[1] = cubic_convolution1<scalar_t>(x1, A);
// opposite coefficients
scalar_t x2 = 1.0 - t;
coeffs[2] = cubic_convolution1<scalar_t>(x2, A);
coeffs[3] = cubic_convolution2<scalar_t>(x2 + 1.0, A);
}
template <typename scalar_t>
inline scalar_t cubic_interp1d(
scalar_t x0,
scalar_t x1,
scalar_t x2,
scalar_t x3,
scalar_t t) {
scalar_t coeffs[4];
get_cubic_upsample_coefficients<scalar_t>(coeffs, t);
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}
// when `real_input_index` becomes larger than the range the floating point
// type can accurately represent, the type casting to `int64_t` might exceed
// `input_size`, causing overflow. So we guard it with `std::min` below.
template<typename scalar_t, typename opmath_t>
inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) {
input_index = std::min(static_cast<int64_t>(floorf(real_input_index)), input_size - 1);
lambda = std::min(
std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1)
);
}
template<typename scalar_t, typename opmath_t>
inline void compute_source_index_and_lambda(
int64_t& input_index0,
int64_t& input_index1,
scalar_t& lambda0,
scalar_t& lambda1,
opmath_t ratio,
int64_t output_index,
int64_t input_size,
int64_t output_size,
bool align_corners) {
if (output_size == input_size) {
// scale_factor = 1, simply copy
input_index0 = output_index;
input_index1 = output_index;
lambda0 = static_cast<scalar_t>(1);
lambda1 = static_cast<scalar_t>(0);
} else {
const auto real_input_index =
area_pixel_compute_source_index<opmath_t>(
ratio, output_index, align_corners, /*cubic=*/false);
guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1);
int64_t offset = (input_index0 < input_size - 1) ? 1 : 0;
input_index1 = input_index0 + offset;
lambda0 = static_cast<scalar_t>(1.) - lambda1;
}
}
// It will not be used by data types other than BFloat16 and Half.
template <typename scalar_in, typename scalar_out,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_out> || !std::is_same<scalar_in, float>::value, int> = 0>
void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
TORCH_CHECK((is_reduced_floating_point_v<scalar_out>),
"Upsample backward only support BFloat16 and Half in the lower precision data types on CPU.")
TORCH_CHECK((std::is_same<scalar_in, float>::value),
"Upsample backward should use float as acc buffer for BFloat16 and Half grad input on CPU.")
return;
}
template <typename scalar_in, typename scalar_out,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_out> && std::is_same<scalar_in, float>::value, int> = 0>
void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
using bVec = Vectorized<scalar_out>;
using fVec = Vectorized<float>;
int64_t d = 0;
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec gin_bvec = bVec::loadu(gin + d);
auto [gin_fvec0, gin_fvec1] = convert_to_float<scalar_out>(gin_bvec);
gin_fvec0 += fVec::loadu(buffer_ptr + d);
gin_fvec1 += fVec::loadu(buffer_ptr + d + fVec::size());
fVec(0).store(buffer_ptr + d);
fVec(0).store(buffer_ptr + d + fVec::size());
convert_from_float<scalar_out>(gin_fvec0, gin_fvec1).store(gin + d);
}
for (; d < size; d++) {
gin[d] += buffer_ptr[d];
buffer_ptr[d] = 0;
}
}
} // namespace at::native

View File

@ -0,0 +1,38 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
namespace at::native {
using batch_norm_fn = void (*)(Tensor&, const Tensor&, const Tensor&,
const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double);
using batch_norm_collect_stats_fn = void (*)(Tensor&, Tensor&, const Tensor&);
using batch_norm_backward_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&,
const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double);
DECLARE_DISPATCH(batch_norm_fn, batch_norm_cpu_stub);
DECLARE_DISPATCH(batch_norm_collect_stats_fn, batch_norm_cpu_collect_stats_stub);
DECLARE_DISPATCH(batch_norm_backward_fn, batch_norm_cpu_backward_stub);
// TensorAccessor when it is defined to work around undefined...
template <typename scalar_t>
static TensorAccessor<scalar_t, 1> conditional_accessor_1d(const Tensor& t) {
if (! t.defined()) {
return TensorAccessor<scalar_t, 1>(nullptr, nullptr, nullptr);
}
return t.accessor<scalar_t, 1>();
}
template <typename scalar_t>
static scalar_t* conditional_data_ptr(const Tensor& t) {
if constexpr (std::is_const_v<scalar_t>) {
return t.defined() ? t.contiguous().const_data_ptr<scalar_t>()
: nullptr;
} else {
return t.defined() ? t.contiguous().data_ptr<scalar_t>()
: nullptr;
}
}
} // namespace at::native

View File

@ -0,0 +1,37 @@
#ifndef ATOMIC_ADD_FLOAT
#define ATOMIC_ADD_FLOAT
#if (defined(__x86_64__) || defined(__i386__) || defined(__aarch64__))
#include <ATen/native/cpu/Intrinsics.h>
#else
#define _mm_pause()
#endif
#include <atomic>
static inline void cpu_atomic_add_float(float* dst, float fvalue)
{
typedef union {
unsigned intV;
float floatV;
} uf32_t;
uf32_t new_value, old_value;
std::atomic<unsigned>* dst_intV = (std::atomic<unsigned>*)(dst);
old_value.floatV = *dst;
new_value.floatV = old_value.floatV + fvalue;
unsigned* old_intV = (unsigned*)(&old_value.intV);
while (!std::atomic_compare_exchange_strong(dst_intV, old_intV, new_value.intV)) {
#ifdef __aarch64__
__asm__ __volatile__("yield;" : : : "memory");
#else
_mm_pause();
#endif
old_value.floatV = *dst;
new_value.floatV = old_value.floatV + fvalue;
}
}
#endif

View File

@ -0,0 +1,12 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/core/IListRef.h>
namespace at::native {
using cat_serial_fn = void(*)(const Tensor &, const MaterializedITensorListRef&, int64_t);
DECLARE_DISPATCH(cat_serial_fn, cat_serial_stub);
} // namespace at::native

View File

@ -0,0 +1,14 @@
#pragma once
#include <ATen/native/DispatchStub.h>
#include <cstdint>
namespace at {
class TensorBase;
}
namespace at::native {
using channel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t);
DECLARE_DISPATCH(channel_shuffle_fn, channel_shuffle_kernel);
} // at::native

View File

@ -0,0 +1,14 @@
#pragma once
#include <ATen/native/TensorIterator.h>
namespace at {
struct TensorIteratorBase;
namespace native {
inline namespace CPU_CAPABILITY {
void direct_copy_kernel(TensorIteratorBase &iter);
void copy_kernel(TensorIterator& iter, bool /*non_blocking*/);
}}} // namespace at::native::CPU_CAPABILITY

View File

@ -0,0 +1,21 @@
#pragma once
#include <ATen/native/DispatchStub.h>
#include <c10/util/ArrayRef.h>
/*
Depthwise 3x3 Winograd convolution operator
*/
namespace at {
class Tensor;
namespace native {
using convolution_depthwise3x3_winograd_fn =
Tensor (*)(const Tensor &, const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, int64_t);
DECLARE_DISPATCH(convolution_depthwise3x3_winograd_fn, convolution_depthwise3x3_winograd_stub);
} // namespace native
} // namespace at

View File

@ -0,0 +1,425 @@
#pragma once
#include <ATen/CPUApplyUtils.h>
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/ExpandBase.h>
#include <ATen/core/DistributionsHelper.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
#include <mutex>
#ifdef CPU_CAPABILITY_AVX2
#include <ATen/native/cpu/avx_mathfun.h>
#include <c10/util/irange.h>
#endif
namespace at::native::templates::cpu {
namespace {
// ==================================================== Random ========================================================
template<typename RNG>
void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG generator) {
AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cpu", AT_WRAP([&] {
std::lock_guard<std::mutex> lock(generator->mutex_);
cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
uniform_int_from_to_distribution<scalar_t> random(range, base);
return random(generator);
});
}), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}
// This is the special kernel to handle single specific case:
// from(inclusive) = std::numeric_limits<int64_t>::lowest()
// to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
template<typename RNG>
void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG generator) {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cpu", [&] {
if constexpr (std::is_same_v<scalar_t, int64_t> ||
std::is_same_v<scalar_t, double> ||
std::is_same_v<scalar_t, float> ||
std::is_same_v<scalar_t, at::BFloat16>) {
std::lock_guard<std::mutex> lock(generator->mutex_);
cpu_serial_kernel(iter, [generator]() -> scalar_t {
uniform_int_full_range_distribution<scalar_t> random;
return random(generator);
});
} else {
TORCH_CHECK(false, "random_full_64_bits_range_kernel_cpu handles only int64, double, float and bfloat16");
}
});
}
template<typename RNG>
struct RandomFromToKernel {
void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional<Generator> gen) {
random_from_to_kernel(iter, range, base, check_generator<RNG>(gen));
}
void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
random_full_64_bits_range_kernel(iter, check_generator<RNG>(gen));
}
};
template<typename RNG>
void random_kernel(TensorIteratorBase& iter, RNG generator) {
std::lock_guard<std::mutex> lock(generator->mutex_);
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cpu", [&] {
cpu_serial_kernel(iter, [generator]() -> scalar_t {
uniform_int_distribution<scalar_t> random;
return random(generator);
});
});
}
template<typename RNG>
struct RandomKernel {
void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
random_kernel(iter, check_generator<RNG>(gen));
}
};
// ==================================================== Normal ========================================================
#ifdef CPU_CAPABILITY_AVX2
static void normal_fill_16_AVX2(float *data,
const __m256* two_pi,
const __m256* one,
const __m256* minus_two,
const __m256* mean,
const __m256* std_v) {
const __m256 u1 = _mm256_sub_ps(*one, _mm256_loadu_ps(data));
const __m256 u2 = _mm256_loadu_ps(data + 8);
// sincos256_ps and log256_ps are from avx_mathfun.h
const __m256 radius = _mm256_sqrt_ps(_mm256_mul_ps(*minus_two, log256_ps(u1)));
const __m256 theta = _mm256_mul_ps(*two_pi, u2);
__m256 sintheta, costheta;
sincos256_ps(theta, &sintheta, &costheta);
const __m256 n1 = _mm256_mul_ps(radius, costheta);
const __m256 n2 = _mm256_mul_ps(radius, sintheta);
_mm256_storeu_ps(data, _mm256_fmadd_ps(n1, *std_v, *mean));
_mm256_storeu_ps(data + 8, _mm256_fmadd_ps(n2, *std_v, *mean));
}
template<typename RNG>
void normal_fill_AVX2(const TensorBase &self, const float mean, const float std, RNG generator) {
float *data = self.data_ptr<float>();
auto size = self.numel();
std::lock_guard<std::mutex> lock(generator->mutex_);
for (const auto i : c10::irange(size)) {
at::uniform_real_distribution<float> uniform(0, 1);
data[i] = uniform(generator);
}
const __m256 two_pi = _mm256_set1_ps(2.0f * c10::pi<double>);
const __m256 one = _mm256_set1_ps(1.0f);
const __m256 minus_two = _mm256_set1_ps(-2.0f);
const __m256 mean_v = _mm256_set1_ps(mean);
const __m256 std_v = _mm256_set1_ps(std);
for (int64_t i = 0; i < size - 15; i += 16) {
normal_fill_16_AVX2(data + i, &two_pi, &one, &minus_two, &mean_v, &std_v);
}
if (size % 16 != 0) {
// Recompute the last 16 values.
data = data + size - 16;
for (const auto i : c10::irange(16)) {
at::uniform_real_distribution<float> uniform(0, 1);
data[i] = uniform(generator);
}
normal_fill_16_AVX2(data, &two_pi, &one, &minus_two, &mean_v, &std_v);
}
}
#endif
template <typename scalar_t>
static void normal_fill_16(scalar_t *data, const scalar_t mean, const scalar_t std) {
for (const auto j : c10::irange(8)) {
const scalar_t u1 = 1 - data[j]; // [0, 1) -> (0, 1] for log.
const scalar_t u2 = data[j + 8];
const scalar_t radius = std::sqrt(-2 * std::log(u1));
const scalar_t theta = 2.0f * c10::pi<double> * u2;
data[j] = radius * std::cos(theta) * std + mean;
data[j + 8] = radius * std::sin(theta) * std + mean;
}
}
#if defined(__VSX__) || defined(CPU_CAPABILITY_VSX)
static void normal_fill_16_VSX(float *data,const Vectorized<float> &two_pi,const Vectorized<float> &one,const Vectorized<float> &minus_two,const Vectorized<float> &mean,const Vectorized<float> &std) {
using Vec = Vectorized<float>;
Vec u1=one-Vec::loadu(data);
Vec u2=Vec::loadu(data+8);
Vec radius=(minus_two * u1.log());
radius=radius.sqrt();
Vec theta=two_pi * u2;
Vec output_vec=radius * theta.cos() * std + mean;
Vec output_vec2=radius * theta.sin() * std + mean;
output_vec.store(data);
output_vec2.store(data+8);
}
template <typename scalar_t, typename RNG>
void normal_fill_VSX(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) {
float *data = self.data_ptr<float>();
auto size = self.numel();
std::lock_guard<std::mutex> lock(generator->mutex_);
for (const auto i : c10::irange(size)) {
at::uniform_real_distribution<scalar_t> uniform(0, 1);
data[i] = uniform(generator);
}
using Vec = Vectorized<float>;
const Vec two_pi = Vec(2.0f * c10::pi<double>);
const Vec one = Vec(1.0f);
const Vec minus_two = Vec(-2.0f);
const Vec var_vec = Vec(std);
const Vec mean_vec = Vec(mean);
for (int64_t i = 0; i < size - 15; i += 16) {
if(Vec::size()==8) {
normal_fill_16_VSX(data + i, two_pi, one, minus_two, mean_vec, var_vec);
}
else{
normal_fill_16<scalar_t>(data + i, mean, std);
}
}
if (size % 16 != 0) {
// Recompute the last 16 values.
data = data + size - 16;
for (const auto i : c10::irange(16)) {
at::uniform_real_distribution<scalar_t> uniform(0, 1);
data[i] = uniform(generator);
}
if(Vec::size()==8){
normal_fill_16_VSX(data, two_pi, one, minus_two, mean_vec, var_vec);
}
else{
normal_fill_16<scalar_t>(data, mean, std);
}
}
}
#endif //VSX
template <typename scalar_t, typename RNG>
void normal_fill(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) {
scalar_t *data = self.data_ptr<scalar_t>();
auto size = self.numel();
std::lock_guard<std::mutex> lock(generator->mutex_);
for (const auto i : c10::irange(size)) {
at::uniform_real_distribution<scalar_t> uniform(0, 1);
data[i] = uniform(generator);
}
for (int64_t i = 0; i < size - 15; i += 16) {
normal_fill_16<scalar_t>(data + i, mean, std);
}
if (size % 16 != 0) {
// Recompute the last 16 values.
data = data + size - 16;
for (const auto i : c10::irange(16)) {
at::uniform_real_distribution<scalar_t> uniform(0, 1);
data[i] = uniform(generator);
}
normal_fill_16<scalar_t>(data, mean, std);
}
}
template<typename RNG>
void normal_kernel(const TensorBase &self, double mean, double std, RNG generator) {
auto size = self.numel();
if (self.scalar_type() == ScalarType::Float && size >= 16 && self.is_contiguous()) {
#ifdef CPU_CAPABILITY_AVX2
normal_fill_AVX2(self, static_cast<float>(mean), static_cast<float>(std), generator);
#elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX)
normal_fill_VSX(self, static_cast<float>(mean), static_cast<float>(std), generator);
#else
normal_fill(self, static_cast<float>(mean), static_cast<float>(std), generator);
#endif
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] {
if (size >= 16 && self.is_contiguous()) {
normal_fill<scalar_t>(self, static_cast<scalar_t>(mean), static_cast<scalar_t>(std), generator);
} else {
auto iter = TensorIterator::borrowing_nullary_op(self);
std::lock_guard<std::mutex> lock(generator->mutex_);
cpu_serial_kernel(iter, [mean, std, generator]() -> scalar_t {
at::normal_distribution<double> normal(mean, std);
return static_cast<scalar_t>(normal(generator));
});
}
});
}
}
template<typename RNG>
struct NormalKernel {
void operator()(Tensor& self, double mean, double std, std::optional<Generator> gen) {
normal_kernel(self, mean, std, check_generator<RNG>(gen));
}
};
// ==================================================== Uniform =======================================================
template<typename RNG>
void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG generator) {
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() {
std::lock_guard<std::mutex> lock(generator->mutex_);
auto from = static_cast<scalar_t>(from_);
auto to = static_cast<scalar_t>(to_);
at::uniform_real_distribution<scalar_t> uniform(from, to);
cpu_serial_kernel(iter, [&uniform, generator]() -> scalar_t {
return static_cast<scalar_t>(uniform(generator));
});
});
}
template<typename RNG>
struct UniformKernel {
void operator()(TensorIteratorBase& iter, double from, double to, std::optional<Generator> gen) {
uniform_kernel(iter, from, to, check_generator<RNG>(gen));
}
};
// ==================================================== Cauchy ========================================================
template<typename RNG>
void cauchy_kernel(TensorIteratorBase& iter, double median, double sigma, RNG generator) {
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "cauchy_cpu", [&]() {
std::lock_guard<std::mutex> lock(generator->mutex_);
at::cauchy_distribution<double> cauchy(median, sigma);
cpu_serial_kernel(iter, [&cauchy, generator]() -> scalar_t {
return static_cast<scalar_t>(cauchy(generator));
});
});
}
template<typename RNG>
struct CauchyKernel {
void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional<Generator> gen) {
cauchy_kernel(iter, median, sigma, check_generator<RNG>(gen));
}
};
// ================================================== LogNormal =======================================================
template<typename RNG>
void log_normal_kernel(TensorIteratorBase& iter, double mean, double std, RNG generator) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cpu", [&]() {
std::lock_guard<std::mutex> lock(generator->mutex_);
at::lognormal_distribution<double> logNormal(mean, std);
cpu_serial_kernel(iter, [&logNormal, generator]() -> scalar_t {
return static_cast<scalar_t>(logNormal(generator));
});
});
}
template<typename RNG>
struct LogNormalKernel {
void operator()(TensorIteratorBase& iter, double mean, double std, std::optional<Generator> gen) {
log_normal_kernel(iter, mean, std, check_generator<RNG>(gen));
}
};
// =================================================== Geometric ======================================================
template<typename RNG>
void geometric_kernel(TensorIteratorBase& iter, double p, RNG generator) {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cpu", [&]() {
std::lock_guard<std::mutex> lock(generator->mutex_);
at::geometric_distribution<double> geometric(p);
cpu_serial_kernel(iter, [&geometric, generator]() -> scalar_t {
return static_cast<scalar_t>(geometric(generator));
});
});
}
template<typename RNG>
struct GeometricKernel {
void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
geometric_kernel(iter, p, check_generator<RNG>(gen));
}
};
// ================================================== Exponential =====================================================
template<typename RNG>
void exponential_kernel(TensorIteratorBase& iter, double lambda, RNG generator) {
TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cpu", [&]() {
std::lock_guard<std::mutex> lock(generator->mutex_);
at::exponential_distribution<double> exponential(lambda);
cpu_serial_kernel(iter, [&exponential, generator]() -> scalar_t {
return static_cast<scalar_t>(exponential(generator));
});
});
}
template<typename RNG>
struct ExponentialKernel {
void operator()(TensorIteratorBase& iter, double lambda, std::optional<Generator> gen) {
exponential_kernel(iter, lambda, check_generator<RNG>(gen));
}
};
// ================================================== Bernoulli =======================================================
template<typename RNG>
void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG generator) {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(generator->mutex_);
using self_t = scalar_t;
auto p_cpu = p_.to(kCPU);
auto p = expand_inplace(self, p_cpu);
auto iter = TensorIteratorConfig()
.add_output(self)
.add_const_input(*p)
.check_all_same_dtype(false)
.build();
if (p->scalar_type() == kDouble) {
cpu_serial_kernel(iter, [&](const double p_val) -> self_t {
at::bernoulli_distribution<double> bernoulli(p_val);
return static_cast<self_t>(bernoulli(generator));
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
p->scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
using p_t = scalar_t;
cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t {
at::bernoulli_distribution<float> bernoulli(p_val);
return static_cast<self_t>(bernoulli(generator));
});
});
}
});
}
template<typename RNG>
void bernoulli_kernel(const TensorBase &self, double p, RNG generator) {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(generator->mutex_);
auto iter = TensorIterator::borrowing_nullary_op(self);
cpu_serial_kernel(iter, [p, generator]() -> scalar_t {
at::bernoulli_distribution<double> bernoulli(p);
return static_cast<scalar_t>(bernoulli(generator));
});
});
}
template<typename RNG>
struct BernoulliKernel {
void operator()(const TensorBase &self, double p, std::optional<Generator> gen) {
bernoulli_kernel(self, p, check_generator<RNG>(gen));
}
void operator()(const TensorBase &self, const TensorBase &p_, std::optional<Generator> gen) {
bernoulli_kernel(self, p_, check_generator<RNG>(gen));
}
};
}}

View File

@ -0,0 +1,34 @@
#pragma once
#include <ATen/native/DispatchStub.h>
#include <array>
#include <cstdint>
namespace at {
class TensorBase;
}
namespace at::native {
using forward_2d_fn = void (*) (
const TensorBase &output,
const TensorBase &input,
const TensorBase &grid,
int64_t interpolation_mode,
int64_t padding_mode,
bool align_corners);
using backward_2d_fn = void (*) (
const TensorBase &grad_input,
const TensorBase &grad_grid,
const TensorBase &grad_output,
const TensorBase &input,
const TensorBase &grid,
int64_t interpolation_mode,
int64_t padding_mode,
bool align_corners,
std::array<bool, 2> output_mask);
DECLARE_DISPATCH(forward_2d_fn, grid_sampler_2d_cpu_kernel);
DECLARE_DISPATCH(backward_2d_fn, grid_sampler_2d_backward_cpu_kernel);
} // namespace at::native

View File

@ -0,0 +1,87 @@
#pragma once
#include <ATen/native/TensorIterator.h>
#include <c10/util/irange.h>
namespace at::native {
namespace {
static bool is_constant_index(int ntensor, const int64_t* strides) {
AT_ASSERT(ntensor >= 3);
for (const auto arg : c10::irange(2, ntensor)) {
if (strides[arg] != 0) {
return false;
}
}
return true;
}
struct Indexer {
Indexer(int64_t num_indexers, char** indexers, const int64_t* indexer_strides,
IntArrayRef original_sizes, IntArrayRef original_strides)
: num_indexers(num_indexers)
, indexers(indexers)
, indexer_strides(indexer_strides)
, original_strides(original_strides.data())
, original_sizes(original_sizes.data()) {
AT_ASSERT(static_cast<int64_t>(original_strides.size()) == num_indexers);
AT_ASSERT(static_cast<int64_t>(original_sizes.size()) == num_indexers);
}
int64_t num_indexers;
char** indexers;
const int64_t* indexer_strides;
const int64_t* original_strides;
const int64_t* original_sizes;
int64_t get(int64_t idx) {
int64_t offset = 0;
for (const auto j : c10::irange(num_indexers)) {
int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]];
int64_t size = original_sizes[j];
TORCH_CHECK_INDEX(value >= -size && value < size,
"index ", value, " is out of bounds for dimension ", j, " with size ", size);
if (value < 0) {
value += size;
}
offset += value * original_strides[j];
}
return offset;
}
};
} // anonymous namespace
template <typename scalar_t, typename func_t>
void cpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride,
const func_t& f, bool serial_execution=false)
{
int ntensor = iter.ntensors();
// When launch the index parallel version, set a relative small grain size less than the INTERNAL::GRAIN_SIZE
// to make the whole available thread numbers get more balanced work load and a better cache location.
// The grain size here is chosen by the op benchmark to overcome the thread launch overhead
const int index_parallel_grain_size = 3000;
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
auto indexer = Indexer(ntensor - 2, &data[2], &strides[2], index_size, index_stride);
char* dst = data[0];
char* src = data[1];
if (is_constant_index(ntensor, strides)) {
// specialization for when every element uses the same index
int64_t offset = indexer.get(0);
for (const auto i : c10::irange(n)) {
f(dst + strides[0] * i, src + strides[1] * i, offset);
}
} else {
for (const auto i : c10::irange(n)) {
int64_t offset = indexer.get(i);
f(dst + strides[0] * i, src + strides[1] * i, offset);
}
}
};
if (serial_execution) {
iter.serial_for_each(loop, {0, iter.numel()});
} else {
iter.for_each(loop, index_parallel_grain_size);
}
}
} // at
// native

View File

@ -0,0 +1,33 @@
#pragma once
#if defined(__clang__) && (defined(__x86_64__) || defined(__i386__))
/* Clang-compatible compiler, targeting x86/x86-64 */
#include <x86intrin.h>
#elif defined(_MSC_VER)
/* Microsoft C/C++-compatible compiler */
#include <intrin.h>
#if _MSC_VER <= 1900
#define _mm256_extract_epi64(X, Y) (((uint64_t*)&X)[Y])
#endif
#elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
/* GCC-compatible compiler, targeting x86/x86-64 */
#include <x86intrin.h>
#elif defined(__GNUC__) && defined(__ARM_NEON__)
/* GCC-compatible compiler, targeting ARM with NEON */
#include <arm_neon.h>
#elif defined(__GNUC__) && defined(__IWMMXT__)
/* GCC-compatible compiler, targeting ARM with WMMX */
#include <mmintrin.h>
#elif (defined(__GNUC__) || defined(__xlC__)) && \
(defined(__VEC__) || defined(__ALTIVEC__))
/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
#include <altivec.h>
/* We need to undef those tokens defined by <altivec.h> to avoid conflicts
with the C++ types. => Can still use __bool/__vector */
#undef bool
#undef vector
#undef pixel
#elif defined(__GNUC__) && defined(__SPE__)
/* GCC-compatible compiler, targeting PowerPC with SPE */
#include <spe.h>
#endif

View File

@ -0,0 +1,62 @@
#pragma once
namespace at::native { inline namespace CPU_CAPABILITY {
// n: number of function arguments (arity)
// traits: function_traits (see FunctionTraits.h)
// s: index of scalar argument or -1
template <int n, int stride_index, typename traits, int s=-1>
struct IsContiguous {
static bool eval(const int64_t* strides) {
using type = typename traits::template arg<n - 1>::type;
return strides[stride_index] == (s == n ? 0 : sizeof(type)) &&
IsContiguous<n - 1, stride_index - 1, traits, s>::eval(strides);
}
};
// will be called when there is an output exists
template <typename traits, int s>
struct IsContiguous<0, 0, traits, s> {
static bool eval(const int64_t* strides) {
return strides[0] == sizeof(typename traits::result_type);
}
};
// will be called when there is no output
template <typename traits, int s>
struct IsContiguous<0, -1, traits, s> {
static bool eval(const int64_t* /*strides*/) {
return true;
}
};
// output and all inputs are contiguous
template <typename traits,
typename std::enable_if<std::is_void<typename traits::result_type>::value>::type* = nullptr>
static inline bool is_contiguous(const int64_t* strides) {
return IsContiguous<traits::arity, traits::arity - 1, traits>::eval(strides);
}
template <typename traits,
typename std::enable_if<!std::is_void<typename traits::result_type>::value>::type* = nullptr>
static inline bool is_contiguous(const int64_t* strides) {
return IsContiguous<traits::arity, traits::arity, traits>::eval(strides);
}
// input at `s` is scalar (stride 0); output and other inputs are contiguous
// NB: output is typically at strides[0] so first input corresponds to s=1
template <typename traits, int s,
typename std::enable_if<std::is_void<typename traits::result_type>::value>::type* = nullptr>
static inline bool is_contiguous_scalar(const int64_t* strides) {
static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
return IsContiguous<traits::arity, traits::arity - 1, traits, s>::eval(strides);
}
template <typename traits, int s,
typename std::enable_if<!std::is_void<typename traits::result_type>::value>::type* = nullptr>
static inline bool is_contiguous_scalar(const int64_t* strides) {
static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
return IsContiguous<traits::arity, traits::arity, traits, s>::eval(strides);
}
}}

View File

@ -0,0 +1,61 @@
#pragma once
#include <c10/util/complex.h>
#include <ATen/NumericUtils.h>
namespace at::native {
inline namespace CPU_CAPABILITY {
// custom min and max to be used in logcumsumexp for complex arguments
template <typename scalar_t>
std::pair<c10::complex<scalar_t>, c10::complex<scalar_t>> _logcumsumexp_minmax(c10::complex<scalar_t> x, c10::complex<scalar_t> y) {
if (at::_isnan(y)) { // either real is nan or imag is nan
return std::make_pair(y, y);
} else if (at::_isnan(x)) { // either real is nan or imag is nan
return std::make_pair(x, x);
} else {
return (x.real() < y.real()) ? std::make_pair(x, y) : std::make_pair(y, x);
}
}
template <typename scalar_t>
scalar_t _log_add_exp_helper(scalar_t x, scalar_t y) {
// Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp
scalar_t min = at::_isnan(y) ? y : std::min(x, y); // std::min returns first arg if one of the args is nan
scalar_t max = at::_isnan(y) ? y : std::max(x, y); // std::max returns first arg if one of the args is nan
if (min != max || std::isfinite(min)) {
// nan will be propagated here
return std::log1p(std::exp(min - max)) + max;
} else {
// special case to correctly handle infinite cases
return x;
}
}
template <typename scalar_t>
c10::complex<scalar_t> _log_add_exp_helper(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
auto [min, max] = _logcumsumexp_minmax<scalar_t>(x, y);
auto min_real = std::real(min);
auto max_real = std::real(max);
if (at::_isnan(min)) { // either real is nan or imag is nan
// handling the "infectious" NaNs
return {std::numeric_limits<scalar_t>::quiet_NaN(), std::numeric_limits<scalar_t>::quiet_NaN()};
} else if (!std::isfinite(min_real) && (min_real == max_real)) {
if (min_real < 0) {
// handle the -inf case, the imaginary part here does not really matter as the exp(value)
// will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined.
// It does not matter if we're taking the exp of this value
return min;
} else {
// handle the +inf case, we don't need the special precision for log1p for small values
// and to avoid producing nan in case of real(max) == real(min) == +inf
return std::log(std::exp(min) + std::exp(max));
}
} else {
return std::log1p(std::exp(min - max)) + max;
}
}
} // end namespace
} //end at::native

Some files were not shown because too many files have changed in this diff Show More