I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,14 @@
/// Flush-To-Zero and Denormals-Are-Zero mode
///
/// Flush-To-Zero (FTZ) and Denormals-Are-Zero (DAZ) are modes that bypass
/// IEEE 754 methods of dealing with denormal floating-point numbers on x86-64
/// and some x86 CPUs. They result in reduced precision for values near zero,
/// but increased performance.
///
/// See https://software.intel.com/en-us/articles/x87-and-sse-floating-point-assists-in-ia-32-flush-to-zero-ftz-and-denormals-are-zero-daz
namespace at::cpu {
bool set_flush_denormal(bool on);
} // namespace at::cpu

View File

@ -0,0 +1,30 @@
#pragma once
#include <cstdint>
#include <c10/macros/Export.h>
namespace at::cpu {
TORCH_API bool is_avx2_supported();
TORCH_API bool is_avx512_supported();
// Detect if CPU support Vector Neural Network Instruction.
TORCH_API bool is_avx512_vnni_supported();
// Detect if CPU supports AVX512_BF16 ISA
TORCH_API bool is_avx512_bf16_supported();
// Detect if CPU support Advanced Matrix Extension.
TORCH_API bool is_amx_tile_supported();
// Enable the system to use AMX instructions.
TORCH_API bool init_amx();
// Get the L1 cache size per core in Byte
TORCH_API uint32_t L1d_cache_size();
// Get the L2 cache size per core in Byte
TORCH_API uint32_t L2_cache_size();
} // namespace at::cpu

View File

@ -0,0 +1,4 @@
#pragma once
#include <ATen/cpu/vec/functional_base.h>
#include <ATen/cpu/vec/functional_bfloat16.h>

View File

@ -0,0 +1,358 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/vec.h>
#include <c10/util/irange.h>
namespace at::vec {
// slow path
template <typename scalar_t, typename Op>
inline scalar_t vec_reduce_all(
const Op& vec_fun,
vec::Vectorized<scalar_t> acc_vec,
int64_t size) {
using Vec = vec::Vectorized<scalar_t>;
scalar_t acc_arr[Vec::size()];
acc_vec.store(acc_arr);
for (const auto i : c10::irange(1, size)) {
std::array<scalar_t, Vec::size()> acc_arr_next = {0};
acc_arr_next[0] = acc_arr[i];
Vec acc_vec_next = Vec::loadu(acc_arr_next.data());
acc_vec = vec_fun(acc_vec, acc_vec_next);
}
acc_vec.store(acc_arr);
return acc_arr[0];
}
template <typename scalar_t, typename Op>
struct VecReduceAllSIMD {
static inline scalar_t apply(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) {
return vec_reduce_all(vec_fun, acc_vec, Vectorized<scalar_t>::size());
}
};
#if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
#if defined(CPU_CAPABILITY_AVX2)
template <typename Op>
struct VecReduceAllSIMD<float, Op> {
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
using Vec = Vectorized<float>;
Vec v = acc_vec;
// 128-bit shuffle
Vec v1 = _mm256_permute2f128_ps(v, v, 0x1);
v = vec_fun(v, v1);
// 64-bit shuffle
v1 = _mm256_shuffle_ps(v, v, 0x4E);
v = vec_fun(v, v1);
// 32-bit shuffle
v1 = _mm256_shuffle_ps(v, v, 0xB1);
v = vec_fun(v, v1);
return _mm256_cvtss_f32(v);
}
};
#endif // defined(CPU_CAPABILITY_AVX2)
#if defined(CPU_CAPABILITY_AVX512)
template <typename Op>
struct VecReduceAllSIMD<float, Op> {
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
using Vec = Vectorized<float>;
Vec v = acc_vec;
// 256-bit shuffle
Vec v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
v = vec_fun(v, v1);
// 128-bit shuffle
v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
v = vec_fun(v, v1);
// 64-bit shuffle
v1 = _mm512_shuffle_ps(v, v, 0x4E);
v = vec_fun(v, v1);
// 32-bit shuffle
v1 = _mm512_shuffle_ps(v, v, 0xB1);
v = vec_fun(v, v1);
return _mm512_cvtss_f32(v);
}
};
#endif // defined(CPU_CAPABILITY_AVX512)
#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
template <typename Op>
struct VecReduceAllSIMD<float, Op> {
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
using Vec = Vectorized<float>;
Vec v = acc_vec;
// 128-bit shuffle: [a1, a2, a3, a4, a5, a6, a7, a8] -> [a5, a6, a7, a8, a1, a2, a3, a4]
Vec v1 = {v.get_high(), v.get_low()};
// [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] ('+' stands for the reduction function. Note that the last 4 elements are not required)
v = vec_fun(v, v1);
// 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, a4+a8, a1+a5, a2+a6, -, -, -, -]
float32x4_t v1_1 = vextq_f32(v.get_low(), v.get_low(), 2);
v1 = {v1_1, v1_1};
// [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -]
v = vec_fun(v, v1);
// 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, -]
v1_1 = vrev64q_f32(v.get_low());
v1 = {v1_1, v1_1};
// [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -]
v = vec_fun(v, v1);
return v.get_low()[0];
}
};
#endif // defined(__aarch64__)
template <typename scalar_t, typename Op>
inline scalar_t vec_reduce_all(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) {
return VecReduceAllSIMD<scalar_t, Op>::apply(vec_fun, acc_vec);
}
template <typename scalar_t, typename Op,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
using Vec = vec::Vectorized<scalar_t>;
if (size < Vec::size())
return vec_reduce_all(vec_fun, Vec::loadu(data, size), size);
int64_t d = Vec::size();
Vec acc_vec = Vec::loadu(data);
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec data_vec = Vec::loadu(data + d);
acc_vec = vec_fun(acc_vec, data_vec);
}
if (size - d > 0) {
Vec data_vec = Vec::loadu(data + d, size - d);
acc_vec = Vec::set(acc_vec, vec_fun(acc_vec, data_vec), size - d);
}
return vec_reduce_all(vec_fun, acc_vec);
}
// similar to reduce_all, but reduces into two outputs
template <typename scalar_t, typename Op1, typename Op2,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
const scalar_t* data, int64_t size) {
using Vec = vec::Vectorized<scalar_t>;
if (size < Vec::size()) {
auto loaded_data = Vec::loadu(data, size);
return std::pair<scalar_t, scalar_t>(
vec_reduce_all(vec_fun1, loaded_data, size),
vec_reduce_all(vec_fun2, loaded_data, size));
}
int64_t d = Vec::size();
Vec acc_vec1 = Vec::loadu(data);
Vec acc_vec2 = Vec::loadu(data);
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec data_vec = Vec::loadu(data + d);
acc_vec1 = vec_fun1(acc_vec1, data_vec);
acc_vec2 = vec_fun2(acc_vec2, data_vec);
}
if (size - d > 0) {
Vec data_vec = Vec::loadu(data + d, size - d);
acc_vec1 = Vec::set(acc_vec1, vec_fun1(acc_vec1, data_vec), size - d);
acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d);
}
return std::pair<scalar_t, scalar_t>(
vec_reduce_all(vec_fun1, acc_vec1),
vec_reduce_all(vec_fun2, acc_vec2));
}
template <typename scalar_t, typename MapOp, typename ReduceOp,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline scalar_t map_reduce_all(
const MapOp& map_fun,
const ReduceOp& red_fun,
const scalar_t* data,
int64_t size) {
using Vec = vec::Vectorized<scalar_t>;
if (size < Vec::size())
return vec_reduce_all(red_fun, map_fun(Vec::loadu(data, size)), size);
int64_t d = Vec::size();
Vec acc_vec = map_fun(Vec::loadu(data));
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec data_vec = Vec::loadu(data + d);
data_vec = map_fun(data_vec);
acc_vec = red_fun(acc_vec, data_vec);
}
if (size - d > 0) {
Vec data_vec = Vec::loadu(data + d, size - d);
data_vec = map_fun(data_vec);
acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
}
return vec_reduce_all(red_fun, acc_vec);
}
template <typename scalar_t, typename MapOp, typename ReduceOp,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline scalar_t map2_reduce_all(
const MapOp& map_fun,
const ReduceOp& red_fun,
const scalar_t* data,
const scalar_t* data2,
int64_t size) {
using Vec = vec::Vectorized<scalar_t>;
if (size < Vec::size()) {
Vec data_vec = Vec::loadu(data, size);
Vec data2_vec = Vec::loadu(data2, size);
data_vec = map_fun(data_vec, data2_vec);
return vec_reduce_all(red_fun, data_vec, size);
}
int64_t d = Vec::size();
Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2));
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec data_vec = Vec::loadu(data + d);
Vec data2_vec = Vec::loadu(data2 + d);
data_vec = map_fun(data_vec, data2_vec);
acc_vec = red_fun(acc_vec, data_vec);
}
if (size - d > 0) {
Vec data_vec = Vec::loadu(data + d, size - d);
Vec data2_vec = Vec::loadu(data2 + d, size - d);
data_vec = map_fun(data_vec, data2_vec);
acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
}
return vec_reduce_all(red_fun, acc_vec);
}
template <typename scalar_t, typename MapOp, typename ReduceOp,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline scalar_t map3_reduce_all(
const MapOp& map_fun,
const ReduceOp& red_fun,
const scalar_t* data,
const scalar_t* data2,
const scalar_t* data3,
int64_t size) {
using Vec = vec::Vectorized<scalar_t>;
if (size < Vec::size()) {
Vec data_vec = Vec::loadu(data, size);
Vec data2_vec = Vec::loadu(data2, size);
Vec data3_vec = Vec::loadu(data3, size);
data_vec = map_fun(data_vec, data2_vec, data3_vec);
return vec_reduce_all(red_fun, data_vec, size);
}
int64_t d = Vec::size();
Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2), Vec::loadu(data3));
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec data_vec = Vec::loadu(data + d);
Vec data2_vec = Vec::loadu(data2 + d);
Vec data3_vec = Vec::loadu(data3 + d);
data_vec = map_fun(data_vec, data2_vec, data3_vec);
acc_vec = red_fun(acc_vec, data_vec);
}
if (size - d > 0) {
Vec data_vec = Vec::loadu(data + d, size - d);
Vec data2_vec = Vec::loadu(data2 + d, size - d);
Vec data3_vec = Vec::loadu(data3 + d, size - d);
data_vec = map_fun(data_vec, data2_vec, data3_vec);
acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
}
return vec_reduce_all(red_fun, acc_vec);
}
template <typename scalar_t, typename Op,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map(
const Op& vec_fun,
scalar_t* output_data,
const scalar_t* input_data,
int64_t size) {
using Vec = vec::Vectorized<scalar_t>;
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec output_vec = vec_fun(Vec::loadu(input_data + d));
output_vec.store(output_data + d);
}
if (size - d > 0) {
Vec output_vec = vec_fun(Vec::loadu(input_data + d, size - d));
output_vec.store(output_data + d, size - d);
}
}
template <typename scalar_t, typename Op,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map2(
const Op& vec_fun,
scalar_t* output_data,
const scalar_t* input_data,
const scalar_t* input_data2,
int64_t size) {
using Vec = vec::Vectorized<scalar_t>;
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec data_vec = Vec::loadu(input_data + d);
Vec data_vec2 = Vec::loadu(input_data2 + d);
Vec output_vec = vec_fun(data_vec, data_vec2);
output_vec.store(output_data + d);
}
if (size - d > 0) {
Vec data_vec = Vec::loadu(input_data + d, size - d);
Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
Vec output_vec = vec_fun(data_vec, data_vec2);
output_vec.store(output_data + d, size - d);
}
}
template <typename scalar_t, typename Op,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map3(
const Op& vec_fun,
scalar_t* output_data,
const scalar_t* input_data1,
const scalar_t* input_data2,
const scalar_t* input_data3,
int64_t size) {
using Vec = vec::Vectorized<scalar_t>;
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec data_vec1 = Vec::loadu(input_data1 + d);
Vec data_vec2 = Vec::loadu(input_data2 + d);
Vec data_vec3 = Vec::loadu(input_data3 + d);
Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
output_vec.store(output_data + d);
}
if (size - d > 0) {
Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
output_vec.store(output_data + d, size - d);
}
}
template <typename scalar_t, typename Op,
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map4(
const Op& vec_fun,
scalar_t* output_data,
const scalar_t* input_data1,
const scalar_t* input_data2,
const scalar_t* input_data3,
const scalar_t* input_data4,
int64_t size) {
using Vec = vec::Vectorized<scalar_t>;
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec data_vec1 = Vec::loadu(input_data1 + d);
Vec data_vec2 = Vec::loadu(input_data2 + d);
Vec data_vec3 = Vec::loadu(input_data3 + d);
Vec data_vec4 = Vec::loadu(input_data4 + d);
Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
output_vec.store(output_data + d);
}
if (size - d > 0) {
Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
Vec data_vec4 = Vec::loadu(input_data4 + d, size - d);
Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
output_vec.store(output_data + d, size - d);
}
}
} // namespace at::vec

View File

@ -0,0 +1,549 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/vec.h>
namespace at::vec {
// BFloat16 specification
template <typename scalar_t> struct VecScalarType { using type = scalar_t; };
template <> struct VecScalarType<BFloat16> { using type = float; };
template <> struct VecScalarType<Half> { using type = float; };
// This is different from at::acc_type since we only need to specialize BFloat16
template <typename scalar_t>
using vec_scalar_t = typename VecScalarType<scalar_t>::type;
// Vector conversion between float and bfloat16/half
template <typename scalar_t,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float(const Vectorized<scalar_t>&);
template <>
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<BFloat16> (const Vectorized<BFloat16>& a) {
return convert_bfloat16_float(a);
}
template <>
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<Half> (const Vectorized<Half>& a) {
return convert_half_float(a);
}
template <typename scalar_t,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline Vectorized<scalar_t> convert_from_float(const Vectorized<float>&, const Vectorized<float>&);
template <>
inline Vectorized<BFloat16> convert_from_float<BFloat16>(const Vectorized<float>& a, const Vectorized<float>& b) {
return convert_float_bfloat16(a, b);
}
template <>
inline Vectorized<Half> convert_from_float<Half>(const Vectorized<float>& a, const Vectorized<float>& b) {
return convert_float_half(a, b);
}
template <typename scalar_t,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void load_to_float(const scalar_t *data, Vectorized<float> &out1, Vectorized<float> &out2);
template <>
inline void load_to_float<BFloat16> (const BFloat16 *data, Vectorized<float> &out1, Vectorized<float> &out2) {
load_fp32_from_bf16(data, out1, out2);
}
template <>
inline void load_to_float<Half> (const Half *data, Vectorized<float> &out1, Vectorized<float> &out2) {
load_fp32_from_fp16(data, out1, out2);
}
template <typename scalar_t,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void load_to_float(const scalar_t *data, Vectorized<float> &out);
template <>
inline void load_to_float<BFloat16> (const BFloat16 *data, Vectorized<float> &out) {
load_fp32_from_bf16(data, out);
}
template <>
inline void load_to_float<Half> (const Half *data, Vectorized<float> &out) {
load_fp32_from_fp16(data, out);
}
// Note that we already have specialized member of Vectorized<scalar_t> for BFloat16
// so the following functions would run smoothly:
// using Vec = Vectorized<BFloat16>;
// Vec one = Vec(BFloat16(1));
// vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N);
//
// Then why we still need to specialize "functional"?
// If we do specialization at Vectorized<> level, the above example would need 3 pairs of
// conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/".
// If we do specialization at vec::map<>() level, we have only 1 pair of conversion
// of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only.
//
// The following BFloat16 functionality will only do data type conversion for input
// and output vector (reduce functionality will only convert the final scalar back to bf16).
// Compared to Vectorized<> specialization,
// 1. better performance since we have less data type conversion;
// 2. less rounding error since immediate results are kept in fp32;
// 3. accumulation done on data type of fp32.
//
// If you plan to extend this file, please ensure adding unit tests at
// aten/src/ATen/test/vec_test_all_types.cpp
//
template <typename scalar_t, typename Op,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
if (size < bVec::size()) {
bVec data_bvec = bVec::loadu(data, size);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
if (size > fVec::size()) {
data_fvec0 = fVec::set(data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size());
return vec_reduce_all<float>(vec_fun, data_fvec0, fVec::size());
} else {
return vec_reduce_all<float>(vec_fun, data_fvec0, size);
}
}
int64_t d = bVec::size();
bVec acc_bvec = bVec::loadu(data);
auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec);
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec data_bvec = bVec::loadu(data + d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
acc_fvec0 = vec_fun(acc_fvec0, data_fvec0);
acc_fvec1 = vec_fun(acc_fvec1, data_fvec1);
}
if (size - d > 0) {
bVec data_bvec = bVec::loadu(data + d, size - d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
if (size - d > fVec::size()) {
acc_fvec0 = vec_fun(acc_fvec0, data_fvec0);
acc_fvec1 = fVec::set(acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
} else {
acc_fvec0 = fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d);
}
}
acc_fvec0 = vec_fun(acc_fvec0, acc_fvec1);
return vec_reduce_all<float>(vec_fun, acc_fvec0);
}
template <typename scalar_t, typename Op1, typename Op2,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline std::pair<float, float> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
const scalar_t* data, int64_t size) {
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
if (size < bVec::size()) {
bVec data_bvec = bVec::loadu(data, size);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
if (size > fVec::size()) {
fVec acc1_fvec = fVec::set(data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size());
fVec acc2_fvec = fVec::set(data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size());
return std::pair<scalar_t, scalar_t>(
vec_reduce_all<float>(vec_fun1, acc1_fvec, fVec::size()),
vec_reduce_all<float>(vec_fun2, acc2_fvec, fVec::size()));
} else {
return std::pair<scalar_t, scalar_t>(
vec_reduce_all<float>(vec_fun1, data_fvec0, size),
vec_reduce_all<float>(vec_fun2, data_fvec0, size));
}
}
int64_t d = bVec::size();
bVec acc_bvec = bVec::loadu(data);
auto [acc1_fvec0, acc1_fvec1] = convert_to_float<scalar_t>(acc_bvec);
auto [acc2_fvec0, acc2_fvec1] = convert_to_float<scalar_t>(acc_bvec);
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec data_bvec = bVec::loadu(data + d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0);
acc1_fvec1 = vec_fun1(acc1_fvec1, data_fvec1);
acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0);
acc2_fvec1 = vec_fun2(acc2_fvec1, data_fvec1);
}
if (size - d > 0) {
bVec data_bvec = bVec::loadu(data + d, size - d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
if (size - d > fVec::size()) {
acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0);
acc1_fvec1 = fVec::set(acc1_fvec1, vec_fun1(acc1_fvec1, data_fvec1), size - d - fVec::size());
acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0);
acc2_fvec1 = fVec::set(acc2_fvec1, vec_fun2(acc2_fvec1, data_fvec1), size - d - fVec::size());
} else {
acc1_fvec0 = fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d);
acc2_fvec0 = fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d);
}
}
acc1_fvec0 = vec_fun1(acc1_fvec0, acc1_fvec1);
acc2_fvec0 = vec_fun2(acc2_fvec0, acc2_fvec1);
return std::pair<scalar_t, scalar_t>(
vec_reduce_all<float>(vec_fun1, acc1_fvec0),
vec_reduce_all<float>(vec_fun2, acc2_fvec0));
}
template <typename scalar_t, typename MapOp, typename ReduceOp,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline float map_reduce_all(
const MapOp& map_fun,
const ReduceOp& red_fun,
const scalar_t* data,
int64_t size) {
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
if (size < bVec::size()) {
bVec data_bvec = bVec::loadu(data, size);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
if (size > fVec::size()) {
data_fvec0 = map_fun(data_fvec0);
data_fvec1 = map_fun(data_fvec1);
data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
} else {
data_fvec0 = map_fun(data_fvec0);
return vec_reduce_all<float>(red_fun, data_fvec0, size);
}
}
int64_t d = bVec::size();
bVec acc_bvec = bVec::loadu(data);
auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec);
acc_fvec0 = map_fun(acc_fvec0);
acc_fvec1 = map_fun(acc_fvec1);
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec data_bvec = bVec::loadu(data + d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
data_fvec0 = map_fun(data_fvec0);
data_fvec1 = map_fun(data_fvec1);
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
}
if (size - d > 0) {
bVec data_bvec = bVec::loadu(data + d, size - d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
if (size - d > fVec::size()) {
data_fvec0 = map_fun(data_fvec0);
data_fvec1 = map_fun(data_fvec1);
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
} else {
data_fvec0 = map_fun(data_fvec0);
acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
}
}
acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
return vec_reduce_all<float>(red_fun, acc_fvec0);
}
template <typename scalar_t, typename MapOp, typename ReduceOp,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline float map2_reduce_all(
const MapOp& map_fun,
const ReduceOp& red_fun,
const scalar_t* data,
const scalar_t* data2,
int64_t size) {
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
if (size < bVec::size()) {
bVec data_bvec = bVec::loadu(data, size);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
bVec data2_bvec = bVec::loadu(data2, size);
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
if (size > fVec::size()) {
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
data_fvec1 = map_fun(data_fvec1, data2_fvec1);
data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
} else {
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
return vec_reduce_all<float>(red_fun, data_fvec0, size);
}
}
int64_t d = bVec::size();
bVec acc_bvec = bVec::loadu(data);
auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec);
bVec acc2_bvec = bVec::loadu(data2);
auto [acc2_fvec0, acc2_fvec1] = convert_to_float<scalar_t>(acc2_bvec);
acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0);
acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1);
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec data_bvec = bVec::loadu(data + d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
bVec data2_bvec = bVec::loadu(data2 + d);
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
data_fvec1 = map_fun(data_fvec1, data2_fvec1);
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
}
if (size - d > 0) {
bVec data_bvec = bVec::loadu(data + d, size - d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
bVec data2_bvec = bVec::loadu(data2 + d, size - d);
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
if (size - d > fVec::size()) {
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
data_fvec1 = map_fun(data_fvec1, data2_fvec1);
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
} else {
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
}
}
acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
return vec_reduce_all<float>(red_fun, acc_fvec0);
}
template <typename scalar_t, typename MapOp, typename ReduceOp,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline float map3_reduce_all(
const MapOp& map_fun,
const ReduceOp& red_fun,
const scalar_t* data,
const scalar_t* data2,
const scalar_t* data3,
int64_t size) {
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
if (size < bVec::size()) {
bVec data_bvec = bVec::loadu(data, size);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
bVec data2_bvec = bVec::loadu(data2, size);
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
bVec data3_bvec = bVec::loadu(data3, size);
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
if (size > fVec::size()) {
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
} else {
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
return vec_reduce_all<float>(red_fun, data_fvec0, size);
}
}
int64_t d = bVec::size();
bVec acc_bvec = bVec::loadu(data);
auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec);
bVec acc2_bvec = bVec::loadu(data2);
auto [acc2_fvec0, acc2_fvec1] = convert_to_float<scalar_t>(acc2_bvec);
bVec acc3_bvec = bVec::loadu(data3);
auto [acc3_fvec0, acc3_fvec1] = convert_to_float<scalar_t>(acc3_bvec);
acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0, acc3_fvec0);
acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1, acc3_fvec1);
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec data_bvec = bVec::loadu(data + d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
bVec data2_bvec = bVec::loadu(data2 + d);
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
bVec data3_bvec = bVec::loadu(data3 + d);
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
}
if (size - d > 0) {
bVec data_bvec = bVec::loadu(data + d, size - d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
bVec data2_bvec = bVec::loadu(data2 + d, size - d);
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
bVec data3_bvec = bVec::loadu(data3 + d, size - d);
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
if (size - d > fVec::size()) {
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
} else {
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
}
}
acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
return vec_reduce_all<float>(red_fun, acc_fvec0);
}
template <typename scalar_t, typename Op,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map(
const Op& vec_fun,
scalar_t* output_data,
const scalar_t* input_data,
int64_t size) {
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
int64_t d = 0;
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec data_bvec = bVec::loadu(input_data + d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
fVec output_fvec0 = vec_fun(data_fvec0);
fVec output_fvec1 = vec_fun(data_fvec1);
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
output_bvec.store(output_data + d);
}
if (size - d > 0) {
bVec data_bvec = bVec::loadu(input_data + d, size - d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
fVec output_fvec0 = vec_fun(data_fvec0);
fVec output_fvec1 = vec_fun(data_fvec1);
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
output_bvec.store(output_data + d, size - d);
}
}
template <typename scalar_t, typename Op,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map(
const Op& vec_fun,
scalar_t* output_data,
const float* input_data,
int64_t size) {
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
int64_t d = 0;
for (; d < size - (size % bVec::size()); d += bVec::size()) {
fVec data_fvec0 = fVec::loadu(input_data + d);
fVec data_fvec1 = fVec::loadu(input_data + d + fVec::size());
fVec output_fvec0 = vec_fun(data_fvec0);
fVec output_fvec1 = vec_fun(data_fvec1);
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
output_bvec.store(output_data + d);
}
if (size - d > 0) {
fVec data_fvec0, data_fvec1;
if (size - d > fVec::size()) {
data_fvec0 = fVec::loadu(input_data + d);
data_fvec1 = fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size());
} else {
// choose to align with behaviour of bVec::loadu(ptr, size),
// which leaves data_fvec1 uninitialized
data_fvec0 = fVec::loadu(input_data + d, size - d);
}
fVec output_fvec0 = vec_fun(data_fvec0);
fVec output_fvec1 = vec_fun(data_fvec1);
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
output_bvec.store(output_data + d, size - d);
}
}
template <typename scalar_t, typename Op,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map2(
const Op& vec_fun,
scalar_t* output_data,
const scalar_t* input_data,
const scalar_t* input_data2,
int64_t size) {
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
int64_t d = 0;
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec data_bvec = bVec::loadu(input_data + d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
bVec data2_bvec = bVec::loadu(input_data2 + d);
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0);
fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1);
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
output_bvec.store(output_data + d);
}
if (size - d > 0) {
bVec data_bvec = bVec::loadu(input_data + d, size - d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0);
fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1);
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
output_bvec.store(output_data + d, size - d);
}
}
template <typename scalar_t, typename Op,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map3(
const Op& vec_fun,
scalar_t* output_data,
const scalar_t* input_data1,
const scalar_t* input_data2,
const scalar_t* input_data3,
int64_t size) {
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
int64_t d = 0;
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec data1_bvec = bVec::loadu(input_data1 + d);
auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec);
bVec data2_bvec = bVec::loadu(input_data2 + d);
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
bVec data3_bvec = bVec::loadu(input_data3 + d);
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0);
fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1);
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
output_bvec.store(output_data + d);
}
if (size - d > 0) {
bVec data1_bvec = bVec::loadu(input_data1 + d, size - d);
auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec);
bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
bVec data3_bvec = bVec::loadu(input_data3 + d, size - d);
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0);
fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1);
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
output_bvec.store(output_data + d, size - d);
}
}
template <typename scalar_t, typename Op,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline void map4(
const Op& vec_fun,
scalar_t* output_data,
const scalar_t* input_data1,
const scalar_t* input_data2,
const scalar_t* input_data3,
const scalar_t* input_data4,
int64_t size) {
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
int64_t d = 0;
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec data1_bvec = bVec::loadu(input_data1 + d);
auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec);
bVec data2_bvec = bVec::loadu(input_data2 + d);
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
bVec data3_bvec = bVec::loadu(input_data3 + d);
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
bVec data4_bvec = bVec::loadu(input_data4 + d);
auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec);
fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
output_bvec.store(output_data + d);
}
if (size - d > 0) {
bVec data1_bvec = bVec::loadu(input_data1 + d, size - d);
auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec);
bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
bVec data3_bvec = bVec::loadu(input_data3 + d, size - d);
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
bVec data4_bvec = bVec::loadu(input_data4 + d, size - d);
auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec);
fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
output_bvec.store(output_data + d, size - d);
}
}
} // namespace at::vec

View File

@ -0,0 +1,43 @@
#pragma once
#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
/* GCC or clang-compatible compiler, targeting x86/x86-64 */
#include <x86intrin.h>
#elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__))
/* Clang-compatible compiler, targeting arm neon */
#include <arm_neon.h>
#elif defined(_MSC_VER)
/* Microsoft C/C++-compatible compiler */
#include <intrin.h>
#if _MSC_VER <= 1900
#define _mm256_extract_epi64(X, Y) (_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2))
#define _mm256_extract_epi32(X, Y) (_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4))
#define _mm256_extract_epi16(X, Y) (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8))
#define _mm256_extract_epi8(X, Y) (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16))
#endif
#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__))
/* GCC-compatible compiler, targeting ARM with NEON */
#include <arm_neon.h>
#if defined (MISSING_ARM_VLD1)
#include <ATen/cpu/vec/vec256/missing_vld1_neon.h>
#elif defined (MISSING_ARM_VST1)
#include <ATen/cpu/vec/vec256/missing_vst1_neon.h>
#endif
#elif defined(__GNUC__) && defined(__IWMMXT__)
/* GCC-compatible compiler, targeting ARM with WMMX */
#include <mmintrin.h>
#elif defined(__s390x__)
// targets Z/architecture
// we will include vecintrin later
#elif (defined(__GNUC__) || defined(__xlC__)) && \
(defined(__VEC__) || defined(__ALTIVEC__))
/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
#include <altivec.h>
/* We need to undef those tokens defined by <altivec.h> to avoid conflicts
with the C++ types. => Can still use __bool/__vector */
#undef bool
#undef vector
#undef pixel
#elif defined(__GNUC__) && defined(__SPE__)
/* GCC-compatible compiler, targeting PowerPC with SPE */
#include <spe.h>
#endif

View File

@ -0,0 +1,47 @@
#pragma once
#if defined(CPU_CAPABILITY_AVX512)
#include <ATen/cpu/vec/vec512/vec512.h>
#else
#include <ATen/cpu/vec/vec256/vec256.h>
#endif
namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
inline Vectorized<bool> convert_to_bool(Vectorized<int8_t> x) {
__at_align__ bool buffer[x.size()];
x.ne(Vectorized<int8_t>(0)).store(buffer);
Vectorized<bool> ret;
static_assert(x.size() == ret.size());
std::memcpy(ret, buffer, ret.size() * sizeof(bool));
return ret;
}
template <>
inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr) {
// See NOTE [Loading boolean values]
return convert_to_bool(Vectorized<int8_t>::loadu(ptr));
}
template <>
inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr, int64_t count) {
// See NOTE [Loading boolean values]
return convert_to_bool(Vectorized<int8_t>::loadu(ptr, count));
}
template <typename VT>
struct VecHoldType { using hold_type = typename VT::value_type; };
template <>
struct VecHoldType<Vectorized<BFloat16>> { using hold_type = BFloat16; };
template <>
struct VecHoldType<Vectorized<Half>> {using hold_type = Half; };
template <typename VT>
using vechold_type = typename VecHoldType<VT>::hold_type;
}} // namespace at::vec::CPU_CAPABILITY

View File

@ -0,0 +1,452 @@
/* Workaround for missing vld1_*_x2 and vst1_*_x2 intrinsics in gcc-7. */
__extension__ extern __inline uint8x8x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1_u8_x2 (const uint8_t *__a)
{
uint8x8x2_t ret;
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int8x8x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1_s8_x2 (const int8_t *__a)
{
int8x8x2_t ret;
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint16x4x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1_u16_x2 (const uint16_t *__a)
{
uint16x4x2_t ret;
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int16x4x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1_s16_x2 (const int16_t *__a)
{
int16x4x2_t ret;
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint32x2x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1_u32_x2 (const uint32_t *__a)
{
uint32x2x2_t ret;
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int32x2x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1_s32_x2 (const int32_t *__a)
{
int32x2x2_t ret;
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint64x1x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1_u64_x2 (const uint64_t *__a)
{
uint64x1x2_t ret;
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int64x1x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1_s64_x2 (const int64_t *__a)
{
int64x1x2_t ret;
__builtin_aarch64_simd_oi __o;
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float16x4x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1_f16_x2 (const float16_t *__a)
{
float16x4x2_t ret;
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float32x2x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1_f32_x2 (const float32_t *__a)
{
float32x2x2_t ret;
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float64x1x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1_f64_x2 (const float64_t *__a)
{
float64x1x2_t ret;
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly8x8x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1_p8_x2 (const poly8_t *__a)
{
poly8x8x2_t ret;
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly16x4x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1_p16_x2 (const poly16_t *__a)
{
poly16x4x2_t ret;
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly64x1x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1_p64_x2 (const poly64_t *__a)
{
poly64x1x2_t ret;
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint8x16x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1q_u8_x2 (const uint8_t *__a)
{
uint8x16x2_t ret;
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int8x16x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1q_s8_x2 (const int8_t *__a)
{
int8x16x2_t ret;
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint16x8x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1q_u16_x2 (const uint16_t *__a)
{
uint16x8x2_t ret;
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int16x8x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1q_s16_x2 (const int16_t *__a)
{
int16x8x2_t ret;
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint32x4x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1q_u32_x2 (const uint32_t *__a)
{
uint32x4x2_t ret;
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int32x4x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1q_s32_x2 (const int32_t *__a)
{
int32x4x2_t ret;
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline uint64x2x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1q_u64_x2 (const uint64_t *__a)
{
uint64x2x2_t ret;
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline int64x2x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1q_s64_x2 (const int64_t *__a)
{
int64x2x2_t ret;
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float16x8x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1q_f16_x2 (const float16_t *__a)
{
float16x8x2_t ret;
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float32x4x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1q_f32_x2 (const float32_t *__a)
{
float32x4x2_t ret;
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline float64x2x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1q_f64_x2 (const float64_t *__a)
{
float64x2x2_t ret;
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly8x16x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1q_p8_x2 (const poly8_t *__a)
{
poly8x16x2_t ret;
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly16x8x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1q_p16_x2 (const poly16_t *__a)
{
poly16x8x2_t ret;
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
__extension__ extern __inline poly64x2x2_t
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vld1q_p64_x2 (const poly64_t *__a)
{
poly64x2x2_t ret;
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w" (ret) : "Q"(*__a));
return ret;
}
/* vst1x2 */
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1_s64_x2 (int64_t * __a, int64x1x2_t val)
{
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1_u64_x2 (uint64_t * __a, uint64x1x2_t val)
{
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1_f64_x2 (float64_t * __a, float64x1x2_t val)
{
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1_s8_x2 (int8_t * __a, int8x8x2_t val)
{
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1_p8_x2 (poly8_t * __a, poly8x8x2_t val)
{
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1_s16_x2 (int16_t * __a, int16x4x2_t val)
{
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1_p16_x2 (poly16_t * __a, poly16x4x2_t val)
{
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1_s32_x2 (int32_t * __a, int32x2x2_t val)
{
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1_u8_x2 (uint8_t * __a, uint8x8x2_t val)
{
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1_u16_x2 (uint16_t * __a, uint16x4x2_t val)
{
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1_u32_x2 (uint32_t * __a, uint32x2x2_t val)
{
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1_f16_x2 (float16_t * __a, float16x4x2_t val)
{
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1_f32_x2 (float32_t * __a, float32x2x2_t val)
{
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1_p64_x2 (poly64_t * __a, poly64x1x2_t val)
{
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1q_s8_x2 (int8_t * __a, int8x16x2_t val)
{
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1q_p8_x2 (poly8_t * __a, poly8x16x2_t val)
{
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1q_s16_x2 (int16_t * __a, int16x8x2_t val)
{
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1q_p16_x2 (poly16_t * __a, poly16x8x2_t val)
{
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1q_s32_x2 (int32_t * __a, int32x4x2_t val)
{
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1q_s64_x2 (int64_t * __a, int64x2x2_t val)
{
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1q_u8_x2 (uint8_t * __a, uint8x16x2_t val)
{
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1q_u16_x2 (uint16_t * __a, uint16x8x2_t val)
{
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1q_u32_x2 (uint32_t * __a, uint32x4x2_t val)
{
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1q_u64_x2 (uint64_t * __a, uint64x2x2_t val)
{
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1q_f16_x2 (float16_t * __a, float16x8x2_t val)
{
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1q_f32_x2 (float32_t * __a, float32x4x2_t val)
{
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1q_f64_x2 (float64_t * __a, float64x2x2_t val)
{
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q" (*__a) : "w" (val));
}
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1q_p64_x2 (poly64_t * __a, poly64x2x2_t val)
{
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q" (*__a) : "w" (val));
}

View File

@ -0,0 +1,8 @@
/* Workaround for missing vst1q_f32_x2 in gcc-8. */
__extension__ extern __inline void
__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
vst1q_f32_x2 (float32_t * __a, float32x4x2_t val)
{
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q" (*__a) : "w" (val));
}

View File

@ -0,0 +1,330 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#if !(defined(__VSX__) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR))
#include <ATen/cpu/vec/vec256/vec256_float.h>
#include <ATen/cpu/vec/vec256/vec256_float_neon.h>
#include <ATen/cpu/vec/vec256/vec256_half_neon.h>
#include <ATen/cpu/vec/vec256/vec256_bfloat16.h>
#include <ATen/cpu/vec/vec256/vec256_double.h>
#include <ATen/cpu/vec/vec256/vec256_int.h>
#include <ATen/cpu/vec/vec256/vec256_qint.h>
#include <ATen/cpu/vec/vec256/vec256_complex_float.h>
#include <ATen/cpu/vec/vec256/vec256_complex_double.h>
#elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX)
#include <ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h>
#else
#include <ATen/cpu/vec/vec256/zarch/vec256_zarch.h>
#include <ATen/cpu/vec/vec256/vec256_bfloat16.h>
#endif
#include <ATen/cpu/vec/vec256/vec256_convert.h>
#include <ATen/cpu/vec/vec256/vec256_mask.h>
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <ostream>
namespace at::vec {
// Note [CPU_CAPABILITY namespace]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This header, and all of its subheaders, will be compiled with
// different architecture flags for each supported set of vector
// intrinsics. So we need to make sure they aren't inadvertently
// linked together. We do this by declaring objects in an `inline
// namespace` which changes the name mangling, but can still be
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) {
stream << val.val_;
return stream;
}
inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) {
stream << static_cast<int>(val.val_);
return stream;
}
inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) {
stream << static_cast<unsigned int>(val.val_);
return stream;
}
template <typename T>
std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
T buf[Vectorized<T>::size()];
vec.store(buf);
stream << "vec[";
for (int i = 0; i != Vectorized<T>::size(); i++) {
if (i != 0) {
stream << ", ";
}
stream << buf[i];
}
stream << "]";
return stream;
}
#if defined(CPU_CAPABILITY_AVX2)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<>
inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) {
return _mm256_castpd_ps(src);
}
template<>
inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) {
return _mm256_castps_pd(src);
}
template<>
inline Vectorized<float> cast<float, int32_t>(const Vectorized<int32_t>& src) {
return _mm256_castsi256_ps(src);
}
template<>
inline Vectorized<double> cast<double, int64_t>(const Vectorized<int64_t>& src) {
return _mm256_castsi256_pd(src);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#ifndef _MSC_VER
// MSVC is not working well on complex function overload.
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline gather(const double* base_addr, const Vectorized<int64_t>& vindex) {
return _mm256_i64gather_pd(base_addr, vindex, scale);
}
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
inline gather(const float* base_addr, const Vectorized<int32_t>& vindex) {
return _mm256_i32gather_ps(base_addr, vindex, scale);
}
#endif
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#ifndef _MSC_VER
// MSVC is not working well on complex function overload.
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline mask_gather(const Vectorized<double>& src, const double* base_addr,
const Vectorized<int64_t>& vindex, Vectorized<double>& mask) {
return _mm256_mask_i64gather_pd(src, base_addr, vindex, mask, scale);
}
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
inline mask_gather(const Vectorized<float>& src, const float* base_addr,
const Vectorized<int32_t>& vindex, Vectorized<float>& mask) {
return _mm256_mask_i32gather_ps(src, base_addr, vindex, mask, scale);
}
#endif
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Only works for inputs in the range: [-2^51, 2^51]
// From: https://stackoverflow.com/a/41148578
template<>
Vectorized<int64_t>
inline convert_to_int_of_same_size<double>(const Vectorized<double> &src) {
auto x = _mm256_add_pd(src, _mm256_set1_pd(0x0018000000000000));
return _mm256_sub_epi64(
_mm256_castpd_si256(x),
_mm256_castpd_si256(_mm256_set1_pd(0x0018000000000000))
);
}
template<>
Vectorized<int32_t>
inline convert_to_int_of_same_size<float>(const Vectorized<float> &src) {
return _mm256_cvttps_epi32(src);
}
// From: https://stackoverflow.com/a/41148578
template<>
Vectorized<double>
inline convert_to_fp_of_same_size<double>(const Vectorized<int64_t> &src) {
__m256i magic_i_lo = _mm256_set1_epi64x(0x4330000000000000); /* 2^52 */
__m256i magic_i_hi32 = _mm256_set1_epi64x(0x4530000080000000); /* 2^84 + 2^63 */
__m256i magic_i_all = _mm256_set1_epi64x(0x4530000080100000); /* 2^84 + 2^63 + 2^52 */
__m256d magic_d_all = _mm256_castsi256_pd(magic_i_all);
__m256i v_lo = _mm256_blend_epi32(magic_i_lo, src, 0b01010101); /* v_low = low32 + 2^52 */
__m256i v_hi = _mm256_srli_epi64(src, 32);
v_hi = _mm256_xor_si256(v_hi, magic_i_hi32); /* v_hi = high32*2^32 + 2^84 + 2^63 */
/* int64 = low32 + high32*2^32 = v_hi + v_lo - 2^52 - 2^63 - 2^84 */
__m256d v_hi_dbl = _mm256_sub_pd(_mm256_castsi256_pd(v_hi), magic_d_all);
__m256d result = _mm256_add_pd(v_hi_dbl, _mm256_castsi256_pd(v_lo));
return result;
}
template<>
Vectorized<float>
inline convert_to_fp_of_same_size<float>(const Vectorized<int32_t> &src) {
return _mm256_cvtepi32_ps(src);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <>
std::pair<Vectorized<double>, Vectorized<double>>
inline interleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
// inputs:
// a = {a0, a1, a3, a3}
// b = {b0, b1, b2, b3}
// swap lanes:
// a_swapped = {a0, a1, b0, b1}
// b_swapped = {a2, a3, b2, b3}
auto a_swapped = _mm256_permute2f128_pd(a, b, 0b0100000); // 0, 2. 4 bits apart
auto b_swapped = _mm256_permute2f128_pd(a, b, 0b0110001); // 1, 3. 4 bits apart
// group cols crossing lanes:
// return {a0, b0, a1, b1}
// {a2, b2, a3, b3}
return std::make_pair(_mm256_permute4x64_pd(a_swapped, 0b11011000), // 0, 2, 1, 3
_mm256_permute4x64_pd(b_swapped, 0b11011000)); // 0, 2, 1, 3
}
template <>
std::pair<Vectorized<float>, Vectorized<float>>
inline interleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
// inputs:
// a = {a0, a1, a2, a3, a4, a5, a6, a7}
// b = {b0, b1, b2, b3, b4, b5, b6, b7}
// swap lanes:
// a_swapped = {a0, a1, a2, a3, b0, b1, b2, b3}
// b_swapped = {a4, a5, a6, a7, b4, b5, b6, b7}
// TODO: can we support caching this?
auto a_swapped = _mm256_permute2f128_ps(a, b, 0b0100000); // 0, 2. 4 bits apart
auto b_swapped = _mm256_permute2f128_ps(a, b, 0b0110001); // 1, 3. 4 bits apart
// group cols crossing lanes:
// return {a0, b0, a1, b1, a2, b2, a3, b3}
// {a4, b4, a5, b5, a6, b6, a7, b7}
const __m256i group_ctrl = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
return std::make_pair(_mm256_permutevar8x32_ps(a_swapped, group_ctrl),
_mm256_permutevar8x32_ps(b_swapped, group_ctrl));
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <>
std::pair<Vectorized<double>, Vectorized<double>>
inline deinterleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
// inputs:
// a = {a0, b0, a1, b1}
// b = {a2, b2, a3, b3}
// group cols crossing lanes:
// a_grouped = {a0, a1, b0, b1}
// b_grouped = {a2, a3, b2, b3}
auto a_grouped = _mm256_permute4x64_pd(a, 0b11011000); // 0, 2, 1, 3
auto b_grouped = _mm256_permute4x64_pd(b, 0b11011000); // 0, 2, 1, 3
// swap lanes:
// return {a0, a1, a2, a3}
// {b0, b1, b2, b3}
return std::make_pair(_mm256_permute2f128_pd(a_grouped, b_grouped, 0b0100000), // 0, 2. 4 bits apart
_mm256_permute2f128_pd(a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart
}
template <>
std::pair<Vectorized<float>, Vectorized<float>>
inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
// inputs:
// a = {a0, b0, a1, b1, a2, b2, a3, b3}
// b = {a4, b4, a5, b5, a6, b6, a7, b7}
// group cols crossing lanes:
// a_grouped = {a0, a1, a2, a3, b0, b1, b2, b3}
// b_grouped = {a4, a5, a6, a7, b4, b5, b6, b7}
// TODO: can we support caching this?
const __m256i group_ctrl = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
auto a_grouped = _mm256_permutevar8x32_ps(a, group_ctrl);
auto b_grouped = _mm256_permutevar8x32_ps(b, group_ctrl);
// swap lanes:
// return {a0, a1, a2, a3, a4, a5, a6, a7}
// {b0, b1, b2, b3, b4, b5, b6, b7}
return std::make_pair(_mm256_permute2f128_ps(a_grouped, b_grouped, 0b0100000), // 0, 2. 4 bits apart
_mm256_permute2f128_ps(a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<>
inline Vectorized<float> flip(const Vectorized<float> & v) {
const __m256i mask_float = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7);
return _mm256_permutevar8x32_ps(v, mask_float);
}
template<>
inline Vectorized<double> flip(const Vectorized<double> & v) {
return _mm256_permute4x64_pd(v, 27); // 27 == _MM_SHUFFLE(0, 1, 2, 3)
}
template<>
inline Vectorized<int64_t> flip(const Vectorized<int64_t> & v) {
return _mm256_permute4x64_epi64(v, 27); // 27 == _MM_SHUFFLE(0, 1, 2, 3)
}
template<>
inline Vectorized<int32_t> flip(const Vectorized<int32_t> & v) {
const __m256i mask_int32 = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7);
return _mm256_permutevar8x32_epi32(v, mask_int32);
}
template<>
inline Vectorized<int16_t> flip(const Vectorized<int16_t> & v) {
const __m256i mask = _mm256_set_epi8(
1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14,
1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14
);
auto reversed = _mm256_shuffle_epi8(v, mask);
return _mm256_permute2x128_si256(reversed, reversed, 1);
}
inline __m256i flip8(const __m256i & v) {
const __m256i mask_int8 = _mm256_set_epi8(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
);
auto reversed = _mm256_shuffle_epi8(v, mask_int8);
return _mm256_permute2x128_si256(reversed, reversed, 1);
}
template<>
inline Vectorized<int8_t> flip(const Vectorized<int8_t> & v) {
return flip8(v);
}
template<>
inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
return flip8(v);
}
inline Vectorized<bool> operator&&(
const Vectorized<bool>& self,
const Vectorized<bool>& other) {
const __m256i* self_ = reinterpret_cast<const __m256i*>(self.as_bytes());
const __m256i* other_ = reinterpret_cast<const __m256i*>(other.as_bytes());
__m256i out = _mm256_and_si256(*self_, *other_);
Vectorized<bool> ret;
std::memcpy(ret, &out, ret.size() * sizeof(bool));
return ret;
}
#endif // (defined(CPU_CAPABILITY_AVX2)
}} // namepsace at::vec::CPU_CAPABILITY

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,432 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <c10/util/complex.h>
#include <c10/util/irange.h>
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#if defined(CPU_CAPABILITY_AVX2)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX2)
template <> class Vectorized<c10::complex<double>> {
private:
__m256d values;
public:
using value_type = c10::complex<double>;
using size_type = int;
static constexpr size_type size() {
return 2;
}
Vectorized() {}
Vectorized(__m256d v) : values(v) {}
Vectorized(c10::complex<double> val) {
double real_value = val.real();
double imag_value = val.imag();
values = _mm256_setr_pd(real_value, imag_value,
real_value, imag_value);
}
Vectorized(c10::complex<double> val1, c10::complex<double> val2) {
values = _mm256_setr_pd(val1.real(), val1.imag(),
val2.real(), val2.imag());
}
operator __m256d() const {
return values;
}
template <int64_t mask>
static Vectorized<c10::complex<double>> blend(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) {
// convert c10::complex<V> index mask to V index mask: xy -> xxyy
static_assert (mask > -1 && mask < 4, "Unexpected mask value");
switch (mask) {
case 0:
return a;
case 1:
return _mm256_blend_pd(a.values, b.values, 0x03);
case 2:
return _mm256_blend_pd(a.values, b.values, 0x0c);
case 3: break;
}
return b;
}
static Vectorized<c10::complex<double>> blendv(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b,
const Vectorized<c10::complex<double>>& mask) {
// convert c10::complex<V> index mask to V index mask: xy -> xxyy
auto mask_ = _mm256_unpacklo_pd(mask.values, mask.values);
return _mm256_blendv_pd(a.values, b.values, mask_);
}
template<typename step_t>
static Vectorized<c10::complex<double>> arange(c10::complex<double> base = 0., step_t step = static_cast<step_t>(1)) {
return Vectorized<c10::complex<double>>(base,
base + step);
}
static Vectorized<c10::complex<double>> set(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b,
int64_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
}
return b;
}
static Vectorized<c10::complex<double>> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
__at_align__ double tmp_values[2*size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction.
for (const auto i : c10::irange(2*size())) {
tmp_values[i] = 0.0;
}
std::memcpy(
tmp_values,
reinterpret_cast<const double*>(ptr),
count * sizeof(c10::complex<double>));
return _mm256_load_pd(tmp_values);
}
void store(void* ptr, int count = size()) const {
if (count == size()) {
_mm256_storeu_pd(reinterpret_cast<double*>(ptr), values);
} else if (count > 0) {
double tmp_values[2*size()];
_mm256_storeu_pd(reinterpret_cast<double*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(c10::complex<double>));
}
}
const c10::complex<double>& operator[](int idx) const = delete;
c10::complex<double>& operator[](int idx) = delete;
Vectorized<c10::complex<double>> map(c10::complex<double> (*const f)(const c10::complex<double> &)) const {
__at_align__ c10::complex<double> tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
__m256d abs_2_() const {
auto val_2 = _mm256_mul_pd(values, values); // a*a b*b
return _mm256_hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b
}
__m256d abs_() const {
auto real = _mm256_movedup_pd(values); // real real
// movehdup_pd does not exist...
auto imag = _mm256_permute_pd(values, 0xf); // imag imag
return Sleef_hypotd4_u05(real, imag); // abs abs
}
Vectorized<c10::complex<double>> abs() const {
const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
return _mm256_and_pd(abs_(), real_mask); // abs 0
}
__m256d angle_() const {
//angle = atan2(b/a)
auto b_a = _mm256_permute_pd(values, 0x05); // b a
return Sleef_atan2d4_u10(values, b_a); // 90-angle angle
}
Vectorized<c10::complex<double>> angle() const {
const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
auto angle = _mm256_permute_pd(angle_(), 0x05); // angle 90-angle
return _mm256_and_pd(angle, real_mask); // angle 0
}
Vectorized<c10::complex<double>> sgn() const {
auto abs = abs_();
auto zero = _mm256_setzero_pd();
auto mask = _mm256_cmp_pd(abs, zero, _CMP_EQ_OQ);
auto div = _mm256_div_pd(values, abs);
return _mm256_blendv_pd(div, zero, mask);
}
__m256d real_() const {
const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
return _mm256_and_pd(values, real_mask);
}
Vectorized<c10::complex<double>> real() const {
return real_();
}
__m256d imag_() const {
const __m256d imag_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0x0000000000000000, 0xFFFFFFFFFFFFFFFF,
0x0000000000000000, 0xFFFFFFFFFFFFFFFF));
return _mm256_and_pd(values, imag_mask);
}
Vectorized<c10::complex<double>> imag() const {
return _mm256_permute_pd(imag_(), 0x05); //b a
}
__m256d conj_() const {
const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0);
return _mm256_xor_pd(values, sign_mask); // a -b
}
Vectorized<c10::complex<double>> conj() const {
return conj_();
}
Vectorized<c10::complex<double>> log() const {
// Most trigonomic ops use the log() op to improve complex number performance.
return map(std::log);
}
Vectorized<c10::complex<double>> log2() const {
const __m256d log2_ = _mm256_set1_pd(std::log(2));
return _mm256_div_pd(log(), log2_);
}
Vectorized<c10::complex<double>> log10() const {
const __m256d log10_ = _mm256_set1_pd(std::log(10));
return _mm256_div_pd(log(), log10_);
}
Vectorized<c10::complex<double>> log1p() const {
return map(std::log1p);
}
Vectorized<c10::complex<double>> asin() const {
// asin(x)
// = -i*ln(iz + sqrt(1 -z^2))
// = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
// = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
const __m256d one = _mm256_set1_pd(1);
auto conj = conj_();
auto b_a = _mm256_permute_pd(conj, 0x05); //-b a
auto ab = _mm256_mul_pd(conj, b_a); //-ab -ab
auto im = _mm256_add_pd(ab, ab); //-2ab -2ab
auto val_2 = _mm256_mul_pd(values, values); // a*a b*b
auto re = _mm256_hsub_pd(val_2, _mm256_permute_pd(val_2, 0x05)); // a*a-b*b b*b-a*a
re = _mm256_sub_pd(one, re);
auto root = Vectorized(_mm256_blend_pd(re, im, 0x0A)).sqrt(); //sqrt(re + i*im)
auto ln = Vectorized(_mm256_add_pd(b_a, root)).log(); //ln(iz + sqrt())
return Vectorized(_mm256_permute_pd(ln.values, 0x05)).conj(); //-i*ln()
}
Vectorized<c10::complex<double>> acos() const {
// acos(x) = pi/2 - asin(x)
constexpr auto pi_2d = c10::pi<double> / 2;
const __m256d pi_2 = _mm256_setr_pd(pi_2d, 0.0, pi_2d, 0.0);
return _mm256_sub_pd(pi_2, asin());
}
Vectorized<c10::complex<double>> atan() const;
Vectorized<c10::complex<double>> atanh() const {
return map(std::atanh);
}
Vectorized<c10::complex<double>> exp() const {
//exp(a + bi)
// = exp(a)*(cos(b) + sin(b)i)
auto exp = Sleef_expd4_u10(values); //exp(a) exp(b)
exp = _mm256_blend_pd(exp, _mm256_permute_pd(exp, 0x05), 0x0A); //exp(a) exp(a)
auto sin_cos = Sleef_sincosd4_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)]
auto cos_sin = _mm256_blend_pd(_mm256_permute_pd(sin_cos.y, 0x05),
sin_cos.x, 0x0A); //cos(b) sin(b)
return _mm256_mul_pd(exp, cos_sin);
}
Vectorized<c10::complex<double>> exp2() const {
// Use identity 2**x = exp(log(2) * x)
const __m256d ln_2 = _mm256_set1_pd(c10::ln_2<double>);
Vectorized<c10::complex<double>> scaled_values = _mm256_mul_pd(values, ln_2);
return scaled_values.exp();
}
Vectorized<c10::complex<double>> expm1() const {
return map(std::expm1);
}
Vectorized<c10::complex<double>> sin() const {
return map(std::sin);
}
Vectorized<c10::complex<double>> sinh() const {
return map(std::sinh);
}
Vectorized<c10::complex<double>> cos() const {
return map(std::cos);
}
Vectorized<c10::complex<double>> cosh() const {
return map(std::cosh);
}
Vectorized<c10::complex<double>> ceil() const {
return _mm256_ceil_pd(values);
}
Vectorized<c10::complex<double>> floor() const {
return _mm256_floor_pd(values);
}
Vectorized<c10::complex<double>> neg() const {
auto zero = _mm256_setzero_pd();
return _mm256_sub_pd(zero, values);
}
Vectorized<c10::complex<double>> round() const {
return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
Vectorized<c10::complex<double>> tan() const {
return map(std::tan);
}
Vectorized<c10::complex<double>> tanh() const {
return map(std::tanh);
}
Vectorized<c10::complex<double>> trunc() const {
return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
}
Vectorized<c10::complex<double>> sqrt() const {
return map(std::sqrt);
}
Vectorized<c10::complex<double>> reciprocal() const;
Vectorized<c10::complex<double>> rsqrt() const {
return sqrt().reciprocal();
}
Vectorized<c10::complex<double>> pow(const Vectorized<c10::complex<double>> &exp) const {
__at_align__ c10::complex<double> x_tmp[size()];
__at_align__ c10::complex<double> y_tmp[size()];
store(x_tmp);
exp.store(y_tmp);
for (const auto i : c10::irange(size())) {
x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
}
return loadu(x_tmp);
}
// Comparison using the _CMP_**_OQ predicate.
// `O`: get false if an operand is NaN
// `Q`: do not raise if an operand is NaN
Vectorized<c10::complex<double>> operator==(const Vectorized<c10::complex<double>>& other) const {
return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ);
}
Vectorized<c10::complex<double>> operator!=(const Vectorized<c10::complex<double>>& other) const {
return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ);
}
Vectorized<c10::complex<double>> operator<(const Vectorized<c10::complex<double>>&) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> operator<=(const Vectorized<c10::complex<double>>&) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> operator>(const Vectorized<c10::complex<double>>&) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> operator>=(const Vectorized<c10::complex<double>>&) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> eq(const Vectorized<c10::complex<double>>& other) const;
Vectorized<c10::complex<double>> ne(const Vectorized<c10::complex<double>>& other) const;
};
template <> Vectorized<c10::complex<double>> inline operator+(const Vectorized<c10::complex<double>> &a, const Vectorized<c10::complex<double>> &b) {
return _mm256_add_pd(a, b);
}
template <> Vectorized<c10::complex<double>> inline operator-(const Vectorized<c10::complex<double>> &a, const Vectorized<c10::complex<double>> &b) {
return _mm256_sub_pd(a, b);
}
template <> Vectorized<c10::complex<double>> inline operator*(const Vectorized<c10::complex<double>> &a, const Vectorized<c10::complex<double>> &b) {
//(a + bi) * (c + di) = (ac - bd) + (ad + bc)i
const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0);
auto ac_bd = _mm256_mul_pd(a, b); //ac bd
auto d_c = _mm256_permute_pd(b, 0x05); //d c
d_c = _mm256_xor_pd(sign_mask, d_c); //d -c
auto ad_bc = _mm256_mul_pd(a, d_c); //ad -bc
auto ret = _mm256_hsub_pd(ac_bd, ad_bc); //ac - bd ad + bc
return ret;
}
template <> Vectorized<c10::complex<double>> inline operator/(const Vectorized<c10::complex<double>> &a, const Vectorized<c10::complex<double>> &b) {
//re + im*i = (a + bi) / (c + di)
auto mask = _mm256_set1_pd(-0.f);
auto fabs_cd = _mm256_andnot_pd(mask, b); // |c| |d|
auto fabs_dc = _mm256_permute_pd(fabs_cd, 0x05); // |d| |c|
auto scale = _mm256_div_pd(_mm256_set1_pd(1.0f), _mm256_max_pd(fabs_cd, fabs_dc)); // 1/sc 1/sc
auto a2 = _mm256_mul_pd(a, scale); // a/sc b/sc
auto b2 = _mm256_mul_pd(b, scale); // c/sc d/sc
auto acbd2 = _mm256_mul_pd(a2, b2);
const __m256d sign_mask = _mm256_setr_pd(-0.0, 0.0, -0.0, 0.0);
auto dc2 = _mm256_permute_pd(b2, 0x05); // d/sc c/sc
dc2 = _mm256_xor_pd(sign_mask, dc2); // -d/|c,d| c/sc
auto adbc2 = _mm256_mul_pd(a2, dc2); //-ad/sc^2 bc/sc^2
auto res2 = _mm256_hadd_pd(acbd2, adbc2); //(ac+bd)/sc^2 (bc-ad)/sc^2
// get the denominator
auto denom2 = Vectorized<c10::complex<double>>(b2).abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2
res2 = _mm256_div_pd(res2, denom2);
return res2;
}
// reciprocal. Implement this here so we can use multiplication.
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::reciprocal() const{
//re + im*i = (a + bi) / (c + di)
//re = (ac + bd)/abs_2() = c/abs_2()
//im = (bc - ad)/abs_2() = d/abs_2()
const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0);
auto c_d = _mm256_xor_pd(sign_mask, values); //c -d
return _mm256_div_pd(c_d, abs_2_());
}
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::atan() const {
// atan(x) = i/2 * ln((i + z)/(i - z))
const __m256d i = _mm256_setr_pd(0.0, 1.0, 0.0, 1.0);
const Vectorized i_half = _mm256_setr_pd(0.0, 0.5, 0.0, 0.5);
auto sum = Vectorized(_mm256_add_pd(i, values)); // a 1+b
auto sub = Vectorized(_mm256_sub_pd(i, values)); // -a 1-b
auto ln = (sum/sub).log(); // ln((i + z)/(i - z))
return i_half*ln; // i/2*ln()
}
template <>
Vectorized<c10::complex<double>> inline maximum(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) {
auto abs_a = a.abs_2_();
auto abs_b = b.abs_2_();
auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_LT_OQ);
auto max = _mm256_blendv_pd(a, b, mask);
// Exploit the fact that all-ones is a NaN.
auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q);
return _mm256_or_pd(max, isnan);
}
template <>
Vectorized<c10::complex<double>> inline minimum(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) {
auto abs_a = a.abs_2_();
auto abs_b = b.abs_2_();
auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_GT_OQ);
auto min = _mm256_blendv_pd(a, b, mask);
// Exploit the fact that all-ones is a NaN.
auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q);
return _mm256_or_pd(min, isnan);
}
template <>
Vectorized<c10::complex<double>> inline operator&(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) {
return _mm256_and_pd(a, b);
}
template <>
Vectorized<c10::complex<double>> inline operator|(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) {
return _mm256_or_pd(a, b);
}
template <>
Vectorized<c10::complex<double>> inline operator^(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) {
return _mm256_xor_pd(a, b);
}
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::eq(const Vectorized<c10::complex<double>>& other) const {
auto eq = (*this == other); // compares real and imag individually
// If both real numbers and imag numbers are equal, then the complex numbers are equal
return (eq.real() & eq.imag()) & Vectorized<c10::complex<double>>(_mm256_set1_pd(1.0));
}
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::ne(const Vectorized<c10::complex<double>>& other) const {
auto ne = (*this != other); // compares real and imag individually
// If either real numbers or imag numbers are not equal, then the complex numbers are not equal
return (ne.real() | ne.imag()) & Vectorized<c10::complex<double>>(_mm256_set1_pd(1.0));
}
#endif
}} // namespace at::vec::CPU_CAPABILITY

View File

@ -0,0 +1,469 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <c10/util/complex.h>
#include <c10/util/irange.h>
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#if defined(CPU_CAPABILITY_AVX2)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX2)
template <> class Vectorized<c10::complex<float>> {
private:
__m256 values;
public:
using value_type = c10::complex<float>;
using size_type = int;
static constexpr size_type size() {
return 4;
}
Vectorized() {}
Vectorized(__m256 v) : values(v) {}
Vectorized(c10::complex<float> val) {
float real_value = val.real();
float imag_value = val.imag();
values = _mm256_setr_ps(real_value, imag_value,
real_value, imag_value,
real_value, imag_value,
real_value, imag_value
);
}
Vectorized(c10::complex<float> val1, c10::complex<float> val2, c10::complex<float> val3, c10::complex<float> val4) {
values = _mm256_setr_ps(val1.real(), val1.imag(),
val2.real(), val2.imag(),
val3.real(), val3.imag(),
val4.real(), val4.imag()
);
}
operator __m256() const {
return values;
}
template <int64_t mask>
static Vectorized<c10::complex<float>> blend(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) {
// convert c10::complex<V> index mask to V index mask: xy -> xxyy
static_assert(mask > -1 && mask < 16, "Unexpected mask range");
switch (mask) {
case 0:
return a;
case 1:
return _mm256_blend_ps(a.values, b.values, 0x03); //b0000 0001 = b0000 0011
case 2:
return _mm256_blend_ps(a.values, b.values, 0x0C); //b0000 0010 = b0000 1100
case 3:
return _mm256_blend_ps(a.values, b.values, 0x0F); //b0000 0011 = b0000 1111
case 4:
return _mm256_blend_ps(a.values, b.values, 0x30); //b0000 0100 = b0011 0000
case 5:
return _mm256_blend_ps(a.values, b.values, 0x33); //b0000 0101 = b0011 0011
case 6:
return _mm256_blend_ps(a.values, b.values, 0x3C); //b0000 0110 = b0011 1100
case 7:
return _mm256_blend_ps(a.values, b.values, 0x3F); //b0000 0111 = b0011 1111
case 8:
return _mm256_blend_ps(a.values, b.values, 0xC0); //b0000 1000 = b1100 0000
case 9:
return _mm256_blend_ps(a.values, b.values, 0xC3); //b0000 1001 = b1100 0011
case 10:
return _mm256_blend_ps(a.values, b.values, 0xCC); //b0000 1010 = b1100 1100
case 11:
return _mm256_blend_ps(a.values, b.values, 0xCF); //b0000 1011 = b1100 1111
case 12:
return _mm256_blend_ps(a.values, b.values, 0xF0); //b0000 1100 = b1111 0000
case 13:
return _mm256_blend_ps(a.values, b.values, 0xF3); //b0000 1101 = b1111 0011
case 14:
return _mm256_blend_ps(a.values, b.values, 0xFC); //b0000 1110 = b1111 1100
default: break;
}
return b;
}
static Vectorized<c10::complex<float>> blendv(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b,
const Vectorized<c10::complex<float>>& mask) {
// convert c10::complex<V> index mask to V index mask: xy -> xxyy
auto mask_ = _mm256_unpacklo_ps(mask.values, mask.values);
return _mm256_blendv_ps(a.values, b.values, mask_);
}
template<typename step_t>
static Vectorized<c10::complex<float>> arange(c10::complex<float> base = 0., step_t step = static_cast<step_t>(1)) {
return Vectorized<c10::complex<float>>(base,
base + step,
base + c10::complex<float>(2)*step,
base + c10::complex<float>(3)*step);
}
static Vectorized<c10::complex<float>> set(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b,
int64_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
}
return b;
}
static Vectorized<c10::complex<float>> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return _mm256_loadu_ps(reinterpret_cast<const float*>(ptr));
__at_align__ float tmp_values[2*size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction.
for (const auto i : c10::irange(2*size())) {
tmp_values[i] = 0.0;
}
std::memcpy(
tmp_values,
reinterpret_cast<const float*>(ptr),
count * sizeof(c10::complex<float>));
return _mm256_load_ps(tmp_values);
}
void store(void* ptr, int count = size()) const {
if (count == size()) {
_mm256_storeu_ps(reinterpret_cast<float*>(ptr), values);
} else if (count > 0) {
float tmp_values[2*size()];
_mm256_storeu_ps(reinterpret_cast<float*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(c10::complex<float>));
}
}
const c10::complex<float>& operator[](int idx) const = delete;
c10::complex<float>& operator[](int idx) = delete;
Vectorized<c10::complex<float>> map(c10::complex<float> (*const f)(const c10::complex<float> &)) const {
__at_align__ c10::complex<float> tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
__m256 abs_2_() const {
auto val_2 = _mm256_mul_ps(values, values); // a*a b*b
auto ret = _mm256_hadd_ps(val_2, val_2); // a*a+b*b a*a+b*b
return _mm256_permute_ps(ret, 0xD8);
}
__m256 abs_() const {
auto real = _mm256_moveldup_ps(values); // real real
auto imag = _mm256_movehdup_ps(values); // imag imag
return Sleef_hypotf8_u05(real, imag); // abs abs
}
Vectorized<c10::complex<float>> abs() const {
const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000));
return _mm256_and_ps(abs_(), real_mask); // abs 0
}
__m256 angle_() const {
//angle = atan2(b/a)
auto b_a = _mm256_permute_ps(values, 0xB1); // b a
return Sleef_atan2f8_u10(values, b_a); // 90-angle angle
}
Vectorized<c10::complex<float>> angle() const {
const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000));
auto angle = _mm256_permute_ps(angle_(), 0xB1); // angle 90-angle
return _mm256_and_ps(angle, real_mask); // angle 0
}
Vectorized<c10::complex<float>> sgn() const {
auto abs = abs_();
auto zero = _mm256_setzero_ps();
auto mask = _mm256_cmp_ps(abs, zero, _CMP_EQ_OQ);
auto div = _mm256_div_ps(values, abs);
return _mm256_blendv_ps(div, zero, mask);
}
__m256 real_() const {
const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000));
return _mm256_and_ps(values, real_mask);
}
Vectorized<c10::complex<float>> real() const {
return real_();
}
__m256 imag_() const {
const __m256 imag_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF,
0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF));
return _mm256_and_ps(values, imag_mask);
}
Vectorized<c10::complex<float>> imag() const {
return _mm256_permute_ps(imag_(), 0xB1); //b a
}
__m256 conj_() const {
const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
return _mm256_xor_ps(values, sign_mask); // a -b
}
Vectorized<c10::complex<float>> conj() const {
return conj_();
}
Vectorized<c10::complex<float>> log() const {
// Most trigonomic ops use the log() op to improve complex number performance.
return map(std::log);
}
Vectorized<c10::complex<float>> log2() const {
const __m256 log2_ = _mm256_set1_ps(std::log(2));
return _mm256_div_ps(log(), log2_);
}
Vectorized<c10::complex<float>> log10() const {
const __m256 log10_ = _mm256_set1_ps(std::log(10));
return _mm256_div_ps(log(), log10_);
}
Vectorized<c10::complex<float>> log1p() const {
return map(std::log1p);
}
Vectorized<c10::complex<float>> asin() const {
// asin(x)
// = -i*ln(iz + sqrt(1 -z^2))
// = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
// = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
const __m256 one = _mm256_set1_ps(1);
auto conj = conj_();
auto b_a = _mm256_permute_ps(conj, 0xB1); //-b a
auto ab = _mm256_mul_ps(conj, b_a); //-ab -ab
auto im = _mm256_add_ps(ab, ab); //-2ab -2ab
auto val_2 = _mm256_mul_ps(values, values); // a*a b*b
auto re = _mm256_hsub_ps(val_2, _mm256_permute_ps(val_2, 0xB1)); // a*a-b*b b*b-a*a
re = _mm256_permute_ps(re, 0xD8);
re = _mm256_sub_ps(one, re);
auto root = Vectorized(_mm256_blend_ps(re, im, 0xAA)).sqrt(); //sqrt(re + i*im)
auto ln = Vectorized(_mm256_add_ps(b_a, root)).log(); //ln(iz + sqrt())
return Vectorized(_mm256_permute_ps(ln.values, 0xB1)).conj(); //-i*ln()
}
Vectorized<c10::complex<float>> acos() const {
return map(std::acos);
}
Vectorized<c10::complex<float>> atan() const;
Vectorized<c10::complex<float>> atanh() const {
return map(std::atanh);
}
Vectorized<c10::complex<float>> exp() const {
//exp(a + bi)
// = exp(a)*(cos(b) + sin(b)i)
auto exp = Sleef_expf8_u10(values); //exp(a) exp(b)
exp = _mm256_blend_ps(exp, _mm256_permute_ps(exp, 0xB1), 0xAA); //exp(a) exp(a)
auto sin_cos = Sleef_sincosf8_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)]
auto cos_sin = _mm256_blend_ps(_mm256_permute_ps(sin_cos.y, 0xB1),
sin_cos.x, 0xAA); //cos(b) sin(b)
return _mm256_mul_ps(exp, cos_sin);
}
Vectorized<c10::complex<float>> exp2() const {
// Use identity 2**x = exp(log(2) * x)
const __m256 ln_2 = _mm256_set1_ps(c10::ln_2<float>);
Vectorized<c10::complex<float>> scaled_values = _mm256_mul_ps(values, ln_2);
return scaled_values.exp();
}
Vectorized<c10::complex<float>> expm1() const {
return map(std::expm1);
}
Vectorized<c10::complex<float>> sin() const {
return map(std::sin);
}
Vectorized<c10::complex<float>> sinh() const {
return map(std::sinh);
}
Vectorized<c10::complex<float>> cos() const {
return map(std::cos);
}
Vectorized<c10::complex<float>> cosh() const {
return map(std::cosh);
}
Vectorized<c10::complex<float>> ceil() const {
return _mm256_ceil_ps(values);
}
Vectorized<c10::complex<float>> floor() const {
return _mm256_floor_ps(values);
}
Vectorized<c10::complex<float>> neg() const {
auto zero = _mm256_setzero_ps();
return _mm256_sub_ps(zero, values);
}
Vectorized<c10::complex<float>> round() const {
return _mm256_round_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
Vectorized<c10::complex<float>> tan() const {
return map(std::tan);
}
Vectorized<c10::complex<float>> tanh() const {
return map(std::tanh);
}
Vectorized<c10::complex<float>> trunc() const {
return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
}
Vectorized<c10::complex<float>> sqrt() const {
return map(std::sqrt);
}
Vectorized<c10::complex<float>> reciprocal() const;
Vectorized<c10::complex<float>> rsqrt() const {
return sqrt().reciprocal();
}
Vectorized<c10::complex<float>> pow(const Vectorized<c10::complex<float>> &exp) const {
__at_align__ c10::complex<float> x_tmp[size()];
__at_align__ c10::complex<float> y_tmp[size()];
store(x_tmp);
exp.store(y_tmp);
for (const auto i : c10::irange(size())) {
x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
}
return loadu(x_tmp);
}
// Comparison using the _CMP_**_OQ predicate.
// `O`: get false if an operand is NaN
// `Q`: do not raise if an operand is NaN
Vectorized<c10::complex<float>> operator==(const Vectorized<c10::complex<float>>& other) const {
return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ);
}
Vectorized<c10::complex<float>> operator!=(const Vectorized<c10::complex<float>>& other) const {
return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ);
}
Vectorized<c10::complex<float>> operator<(const Vectorized<c10::complex<float>>& /*other*/) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<float>> operator<=(const Vectorized<c10::complex<float>>& /*other*/) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<float>> operator>(const Vectorized<c10::complex<float>>& /*other*/) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<float>> operator>=(const Vectorized<c10::complex<float>>& /*other*/) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<float>> eq(const Vectorized<c10::complex<float>>& other) const;
Vectorized<c10::complex<float>> ne(const Vectorized<c10::complex<float>>& other) const;
};
template <> Vectorized<c10::complex<float>> inline operator+(const Vectorized<c10::complex<float>> &a, const Vectorized<c10::complex<float>> &b) {
return _mm256_add_ps(a, b);
}
template <> Vectorized<c10::complex<float>> inline operator-(const Vectorized<c10::complex<float>> &a, const Vectorized<c10::complex<float>> &b) {
return _mm256_sub_ps(a, b);
}
template <> Vectorized<c10::complex<float>> inline operator*(const Vectorized<c10::complex<float>> &a, const Vectorized<c10::complex<float>> &b) {
//(a + bi) * (c + di) = (ac - bd) + (ad + bc)i
const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
auto ac_bd = _mm256_mul_ps(a, b); //ac bd
auto d_c = _mm256_permute_ps(b, 0xB1); //d c
d_c = _mm256_xor_ps(sign_mask, d_c); //d -c
auto ad_bc = _mm256_mul_ps(a, d_c); //ad -bc
auto ret = _mm256_hsub_ps(ac_bd, ad_bc); //ac - bd ad + bc
ret = _mm256_permute_ps(ret, 0xD8);
return ret;
}
template <> Vectorized<c10::complex<float>> inline operator/(const Vectorized<c10::complex<float>> &a, const Vectorized<c10::complex<float>> &b) {
//re + im*i = (a + bi) / (c + di)
auto mask = _mm256_set1_ps(-0.f);
auto fabs_cd = _mm256_andnot_ps(mask, b); // |c| |d|
auto fabs_dc = _mm256_permute_ps(fabs_cd, 0xB1); // |d| |c|
auto scale = _mm256_rcp_ps(_mm256_max_ps(fabs_cd, fabs_dc)); // 1/sc 1/sc
auto a2 = _mm256_mul_ps(a, scale); // a/sc b/sc
auto b2 = _mm256_mul_ps(b, scale); // c/sc d/sc
auto acbd2 = _mm256_mul_ps(a2, b2);
const __m256 sign_mask = _mm256_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
auto dc2 = _mm256_permute_ps(b2, 0xB1); // d/sc c/sc
dc2 = _mm256_xor_ps(sign_mask, dc2); // -d/|c,d| c/sc
auto adbc2 = _mm256_mul_ps(a2, dc2); //-ad/sc^2 bc/sc^2
auto res2 = _mm256_hadd_ps(acbd2, adbc2); //(ac+bd)/sc^2 (bc-ad)/sc^2
res2 = _mm256_permute_ps(res2, 0xD8);
// get the denominator
auto denom2 = Vectorized<c10::complex<float>>(b2).abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2
res2 = _mm256_div_ps(res2, denom2);
return res2;
}
// reciprocal. Implement this here so we can use multiplication.
inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::reciprocal() const {
//re + im*i = (a + bi) / (c + di)
//re = (ac + bd)/abs_2() = c/abs_2()
//im = (bc - ad)/abs_2() = d/abs_2()
const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
auto c_d = _mm256_xor_ps(sign_mask, values); //c -d
return _mm256_div_ps(c_d, abs_2_());
}
inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::atan() const {
// atan(x) = i/2 * ln((i + z)/(i - z))
const __m256 i = _mm256_setr_ps(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0);
const Vectorized i_half = _mm256_setr_ps(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5);
auto sum = Vectorized(_mm256_add_ps(i, values)); // a 1+b
auto sub = Vectorized(_mm256_sub_ps(i, values)); // -a 1-b
auto ln = (sum/sub).log(); // ln((i + z)/(i - z))
return i_half*ln; // i/2*ln()
}
template <>
Vectorized<c10::complex<float>> inline maximum(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) {
auto abs_a = a.abs_2_();
auto abs_b = b.abs_2_();
auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ);
auto max = _mm256_blendv_ps(a, b, mask);
// Exploit the fact that all-ones is a NaN.
auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
return _mm256_or_ps(max, isnan);
}
template <>
Vectorized<c10::complex<float>> inline minimum(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) {
auto abs_a = a.abs_2_();
auto abs_b = b.abs_2_();
auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ);
auto min = _mm256_blendv_ps(a, b, mask);
// Exploit the fact that all-ones is a NaN.
auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
return _mm256_or_ps(min, isnan);
}
template <>
Vectorized<c10::complex<float>> inline operator&(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) {
return _mm256_and_ps(a, b);
}
template <>
Vectorized<c10::complex<float>> inline operator|(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) {
return _mm256_or_ps(a, b);
}
template <>
Vectorized<c10::complex<float>> inline operator^(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) {
return _mm256_xor_ps(a, b);
}
inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::eq(
const Vectorized<c10::complex<float>>& other) const {
auto eq = (*this == other); // compares real and imag individually
// If both real numbers and imag numbers are equal, then the complex numbers are equal
return (eq.real() & eq.imag()) & Vectorized<c10::complex<float>>(_mm256_set1_ps(1.0f));
}
inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::ne(
const Vectorized<c10::complex<float>>& other) const {
auto ne = (*this != other); // compares real and imag individually
// If either real numbers or imag numbers are not equal, then the complex numbers are not equal
return (ne.real() | ne.imag()) & Vectorized<c10::complex<float>>(_mm256_set1_ps(1.0f));
}
#endif
}} // namespace at::vec::CPU_CAPABILITY

View File

@ -0,0 +1,308 @@
#pragma once
#include <ATen/cpu/vec/functional_bfloat16.h>
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec_convert.h>
namespace at::vec {
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
template <>
struct VecConvert<float, 1, BFloat16, 1> {
static inline VectorizedN<float, 1> apply(
const VectorizedN<BFloat16, 1>& src) {
VectorizedN<float, 1> result;
__m256 value;
cvtbf16_fp32(_mm256_castsi256_si128(src[0]), value);
result[0] = value;
return result;
}
};
template <>
struct VecConvert<float, 1, Half, 1> {
static inline VectorizedN<float, 1> apply(const VectorizedN<Half, 1>& src) {
VectorizedN<float, 1> result;
__m256 value;
cvtfp16_fp32(_mm256_castsi256_si128(src[0]), value);
result[0] = value;
return result;
}
};
template <>
struct VecConvert<BFloat16, 1, float, 1> {
static inline VectorizedN<BFloat16, 1> apply(
const VectorizedN<float, 1>& src) {
VectorizedN<BFloat16, 1> result;
result[0] = _mm256_castsi128_si256(cvtfp32_bf16(src[0]));
return result;
}
};
template <>
struct VecConvert<BFloat16, 1, float, 2> {
static inline VectorizedN<BFloat16, 1> apply(
const VectorizedN<float, 2>& src) {
VectorizedN<BFloat16, 1> result;
result[0] = convert_float_bfloat16(src[0], src[1]);
return result;
}
};
template <>
struct VecConvert<float, 2, BFloat16, 1> {
static inline VectorizedN<float, 2> apply(
const VectorizedN<BFloat16, 1>& src) {
VectorizedN<float, 2> result;
std::tie(result[0], result[1]) = convert_bfloat16_float(src[0]);
return result;
}
};
template <>
struct VecConvert<Half, 1, float, 1> {
static inline VectorizedN<Half, 1> apply(const VectorizedN<float, 1>& src) {
VectorizedN<Half, 1> result;
result[0] = _mm256_castsi128_si256(cvtfp32_fp16(src[0]));
return result;
}
};
template <>
struct VecConvert<Half, 1, float, 2> {
static inline VectorizedN<Half, 1> apply(const VectorizedN<float, 2>& src) {
VectorizedN<Half, 1> result;
result[0] = convert_float_half(src[0], src[1]);
return result;
}
};
template <>
struct VecConvert<float, 2, Half, 1> {
static inline VectorizedN<float, 2> apply(const VectorizedN<Half, 1>& src) {
VectorizedN<float, 2> result;
std::tie(result[0], result[1]) = convert_half_float(src[0]);
return result;
}
};
template <>
inline Vectorized<double> convert_to_fp_of_same_size<double>(
const Vectorized<int64_t>& src);
template <>
struct VecConvert<float, 1, int64_t, 2> {
static inline VectorizedN<float, 1> apply(
const VectorizedN<int64_t, 2>& src) {
auto low_double = at::vec::convert_to_fp_of_same_size<double>(src[0]);
auto low = _mm256_cvtpd_ps(low_double);
auto high_double = at::vec::convert_to_fp_of_same_size<double>(src[1]);
auto high = _mm256_cvtpd_ps(high_double);
return Vectorized<float>(
_mm256_insertf128_ps(_mm256_castps128_ps256(low), high, 1));
}
};
template <>
struct VecConvert<int64_t, 2, float, 1> {
static inline VectorizedN<int64_t, 2> apply(
const VectorizedN<float, 1>& src) {
// Scalarization is the most reliable way of converting fp to int64 on AVX2.
// Check: https://stackoverflow.com/questions/41144668
float buffer[8];
src.store(buffer);
at::vec::VectorizedN<int64_t, 2> result;
result[0] = Vectorized<int64_t>(
static_cast<int64_t>(buffer[0]),
static_cast<int64_t>(buffer[1]),
static_cast<int64_t>(buffer[2]),
static_cast<int64_t>(buffer[3]));
result[1] = Vectorized<int64_t>(
static_cast<int64_t>(buffer[4]),
static_cast<int64_t>(buffer[5]),
static_cast<int64_t>(buffer[6]),
static_cast<int64_t>(buffer[7]));
return result;
}
};
template <>
struct VecConvert<int32_t, 1, int64_t, 2> {
static inline VectorizedN<int32_t, 1> apply(
const VectorizedN<int64_t, 2>& src) {
auto low = _mm256_shuffle_epi32(src[0], _MM_SHUFFLE(2, 0, 2, 0));
auto high = _mm256_shuffle_epi32(src[1], _MM_SHUFFLE(2, 0, 2, 0));
auto low_perm = _mm256_permute4x64_epi64(low, _MM_SHUFFLE(3, 1, 2, 0));
auto high_perm = _mm256_permute4x64_epi64(high, _MM_SHUFFLE(3, 1, 2, 0));
return Vectorized<int32_t>(_mm256_blend_epi32(low_perm, high_perm, 0xF0));
}
};
template <>
struct VecConvert<int64_t, 2, int32_t, 1> {
static inline VectorizedN<int64_t, 2> apply(
const VectorizedN<int32_t, 1>& src) {
at::vec::VectorizedN<int64_t, 2> result;
result[0] = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(src[0]));
result[1] = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(src[0], 1));
return result;
}
};
template <>
struct VecConvert<int32_t, 1, int8_t, 1> {
static inline VectorizedN<int32_t, 1> apply(
const VectorizedN<int8_t, 1>& src) {
auto src128 = _mm256_castsi256_si128(src[0]);
return Vectorized<int32_t>(_mm256_cvtepi8_epi32(src128));
}
};
template <>
struct VecConvert<int32_t, 1, uint8_t, 1> {
static inline VectorizedN<int32_t, 1> apply(
const VectorizedN<uint8_t, 1>& src) {
auto src128 = _mm256_castsi256_si128(src[0]);
return Vectorized<int32_t>(_mm256_cvtepu8_epi32(src128));
}
};
template <>
struct VecConvert<int32_t, 1, float, 1> {
static inline VectorizedN<int32_t, 1> apply(
const VectorizedN<float, 1>& src) {
return Vectorized<int32_t>(_mm256_cvttps_epi32(src[0]));
}
};
template <>
struct VecConvert<float, 1, int32_t, 1> {
static inline VectorizedN<float, 1> apply(
const VectorizedN<int32_t, 1>& src) {
return Vectorized<float>(_mm256_cvtepi32_ps(src[0]));
}
};
template <>
struct VecConvert<int16_t, 1, uint8_t, 1> {
static inline VectorizedN<int16_t, 1> apply(
const VectorizedN<uint8_t, 1>& src) {
auto src128 = _mm256_castsi256_si128(src[0]);
return Vectorized<int16_t>(_mm256_cvtepu8_epi16(src128));
}
};
template <typename dst_t, typename src_t>
struct VecConvert<
dst_t,
1,
src_t,
1,
typename std::enable_if_t<
(is_reduced_floating_point_v<dst_t> && is_8bit_integer_v<src_t>) ||
(is_reduced_floating_point_v<src_t> && is_8bit_integer_v<dst_t>),
void>> {
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<src_t, 1>& src) {
VectorizedN<float, 1> tmp_fp32 = VecConvert<float, 1, src_t, 1>::apply(src);
return VecConvert<dst_t, 1, float, 1>::apply(tmp_fp32);
}
};
template <typename dst_t>
struct VecConvert<
dst_t,
1,
float,
1,
typename std::enable_if_t<is_8bit_integer_v<dst_t>,
void>> {
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<float, 1>& src) {
return convert_float_to_int8<dst_t>(src[0]);
}
};
template <typename dst_t>
struct VecConvert<
dst_t,
1,
int64_t,
2,
typename std::enable_if<
std::is_same_v<dst_t, int8_t> ||
std::is_same_v<dst_t, uint8_t>>::type> {
static inline VectorizedN<dst_t, 1> apply(
const VectorizedN<int64_t, 2>& src) {
return VecConvert<dst_t, 1, int32_t, 1>::apply(
VecConvert<int32_t, 1, int64_t, 2>::apply(src));
}
};
#endif /* defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) */
#if (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)) || defined(CPU_CAPABILITY_NEON)
template <typename src_t>
struct VecConvert<
float,
1,
src_t,
1,
typename std::enable_if_t<is_8bit_integer_v<src_t>,
void>> {
static inline VectorizedN<float, 1> apply(const VectorizedN<src_t, 1>& src) {
return convert_int8_to_float<src_t>(src[0]);
}
};
#endif
#if defined(CPU_CAPABILITY_NEON)
template <>
struct VecConvert<float, 1, BFloat16, 1> {
static inline VectorizedN<float, 1> apply(
const VectorizedN<BFloat16, 1>& src) {
VectorizedN<float, 1> result;
uint16x8_t u16_8 = vld1q_u16(reinterpret_cast<const uint16_t*>(&src[0]));
int32x4_t shift = vdupq_n_s32(16);
auto u16_low1 = vget_low_u16(u16_8);
auto u16_high1 = vget_high_u16(u16_8);
float32x4_t f32x4_0 = vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16_low1), shift));
float32x4_t f32x4_1 = vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16_high1), shift));
result[0] = {f32x4_0, f32x4_1};
return result;
}
};
#endif
template <typename src_t>
struct VecConvert<
float,
1,
src_t,
1,
typename std::enable_if_t<is_reduced_floating_point_v<src_t>, void>> {
static inline VectorizedN<float, 1> apply(const VectorizedN<src_t, 1>& src) {
auto [res_vec1, res_vec2] = convert_to_float<src_t>(src[0]);
return res_vec1;
}
};
template <typename dst_t>
struct VecConvert<
dst_t,
1,
float,
1,
typename std::enable_if_t<is_reduced_floating_point_v<dst_t>, void>> {
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<float, 1>& src) {
return convert_from_float<dst_t>(src[0], src[0]);
}
};
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -0,0 +1,447 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/irange.h>
#if defined(CPU_CAPABILITY_AVX2)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX2)
template <> class Vectorized<double> {
private:
__m256d values;
public:
using value_type = double;
using size_type = int;
static constexpr size_type size() {
return 4;
}
Vectorized() {}
Vectorized(__m256d v) : values(v) {}
Vectorized(double val) {
values = _mm256_set1_pd(val);
}
Vectorized(double val1, double val2, double val3, double val4) {
values = _mm256_setr_pd(val1, val2, val3, val4);
}
operator __m256d() const {
return values;
}
template <int64_t mask>
static Vectorized<double> blend(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm256_blend_pd(a.values, b.values, mask);
}
static Vectorized<double> blendv(const Vectorized<double>& a, const Vectorized<double>& b,
const Vectorized<double>& mask) {
return _mm256_blendv_pd(a.values, b.values, mask.values);
}
template<typename step_t>
static Vectorized<double> arange(double base = 0., step_t step = static_cast<step_t>(1)) {
return Vectorized<double>(base, base + step, base + 2 * step, base + 3 * step);
}
static Vectorized<double> set(const Vectorized<double>& a, const Vectorized<double>& b,
int64_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
}
return b;
}
static Vectorized<double> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
__at_align__ double tmp_values[size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction.
for (const auto i : c10::irange(size())) {
tmp_values[i] = 0.0;
}
std::memcpy(
tmp_values,
reinterpret_cast<const double*>(ptr),
count * sizeof(double));
return _mm256_load_pd(tmp_values);
}
void store(void* ptr, int count = size()) const {
if (count == size()) {
_mm256_storeu_pd(reinterpret_cast<double*>(ptr), values);
} else if (count > 0) {
double tmp_values[size()];
_mm256_storeu_pd(reinterpret_cast<double*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(double));
}
}
const double& operator[](int idx) const = delete;
double& operator[](int idx) = delete;
int zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
__m256d cmp = _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_EQ_OQ);
return _mm256_movemask_pd(cmp);
}
Vectorized<double> isnan() const {
return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q);
}
bool has_inf_nan() const {
__m256d self_sub = _mm256_sub_pd(values, values);
return (_mm256_movemask_epi8(_mm256_castpd_si256(self_sub)) & 0x77777777) != 0;
}
Vectorized<double> map(double (*const f)(double)) const {
__at_align__ double tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
Vectorized<double> abs() const {
auto mask = _mm256_set1_pd(-0.f);
return _mm256_andnot_pd(mask, values);
}
Vectorized<double> angle() const {
const auto zero_vec = _mm256_set1_pd(0.f);
const auto nan_vec = _mm256_set1_pd(NAN);
const auto not_nan_mask = _mm256_cmp_pd(values, values, _CMP_EQ_OQ);
const auto nan_mask = _mm256_cmp_pd(not_nan_mask, zero_vec, _CMP_EQ_OQ);
const auto pi = _mm256_set1_pd(c10::pi<double>);
const auto neg_mask = _mm256_cmp_pd(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm256_blendv_pd(zero_vec, pi, neg_mask);
angle = _mm256_blendv_pd(angle, nan_vec, nan_mask);
return angle;
}
Vectorized<double> real() const {
return *this;
}
Vectorized<double> imag() const {
return _mm256_set1_pd(0);
}
Vectorized<double> conj() const {
return *this;
}
Vectorized<double> acos() const {
return Vectorized<double>(Sleef_acosd4_u10(values));
}
Vectorized<double> acosh() const {
return Vectorized<double>(Sleef_acoshd4_u10(values));
}
Vectorized<double> asin() const {
return Vectorized<double>(Sleef_asind4_u10(values));
}
Vectorized<double> atan() const {
return Vectorized<double>(Sleef_atand4_u10(values));
}
Vectorized<double> atanh() const {
return Vectorized<double>(Sleef_atanhd4_u10(values));
}
Vectorized<double> atan2(const Vectorized<double> &b) const {
return Vectorized<double>(Sleef_atan2d4_u10(values, b));
}
Vectorized<double> copysign(const Vectorized<double> &sign) const {
return Vectorized<double>(Sleef_copysignd4(values, sign));
}
Vectorized<double> erf() const {
return Vectorized<double>(Sleef_erfd4_u10(values));
}
Vectorized<double> erfc() const {
return Vectorized<double>(Sleef_erfcd4_u15(values));
}
Vectorized<double> erfinv() const {
return map(calc_erfinv);
}
Vectorized<double> exp() const {
return Vectorized<double>(Sleef_expd4_u10(values));
}
Vectorized<double> exp2() const {
return Vectorized<double>(Sleef_exp2d4_u10(values));
}
Vectorized<double> expm1() const {
return Vectorized<double>(Sleef_expm1d4_u10(values));
}
Vectorized<double> exp_u20() const {
return exp();
}
Vectorized<double> fmod(const Vectorized<double>& q) const {
return Vectorized<double>(Sleef_fmodd4(values, q));
}
Vectorized<double> hypot(const Vectorized<double> &b) const {
return Vectorized<double>(Sleef_hypotd4_u05(values, b));
}
Vectorized<double> i0() const {
return map(calc_i0);
}
Vectorized<double> i0e() const {
return map(calc_i0e);
}
Vectorized<double> digamma() const {
return map(calc_digamma);
}
Vectorized<double> igamma(const Vectorized<double> &x) const {
__at_align__ double tmp[size()];
__at_align__ double tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (const auto i : c10::irange(size())) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<double> igammac(const Vectorized<double> &x) const {
__at_align__ double tmp[size()];
__at_align__ double tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (const auto i : c10::irange(size())) {
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<double> log() const {
return Vectorized<double>(Sleef_logd4_u10(values));
}
Vectorized<double> log2() const {
return Vectorized<double>(Sleef_log2d4_u10(values));
}
Vectorized<double> log10() const {
return Vectorized<double>(Sleef_log10d4_u10(values));
}
Vectorized<double> log1p() const {
return Vectorized<double>(Sleef_log1pd4_u10(values));
}
Vectorized<double> sin() const {
return Vectorized<double>(Sleef_sind4_u10(values));
}
Vectorized<double> sinh() const {
return Vectorized<double>(Sleef_sinhd4_u10(values));
}
Vectorized<double> cos() const {
return Vectorized<double>(Sleef_cosd4_u10(values));
}
Vectorized<double> cosh() const {
return Vectorized<double>(Sleef_coshd4_u10(values));
}
Vectorized<double> ceil() const {
return _mm256_ceil_pd(values);
}
Vectorized<double> floor() const {
return _mm256_floor_pd(values);
}
Vectorized<double> frac() const;
Vectorized<double> neg() const {
return _mm256_xor_pd(_mm256_set1_pd(-0.), values);
}
Vectorized<double> nextafter(const Vectorized<double> &b) const {
return Vectorized<double>(Sleef_nextafterd4(values, b));
}
Vectorized<double> round() const {
return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
Vectorized<double> tan() const {
return Vectorized<double>(Sleef_tand4_u10(values));
}
Vectorized<double> tanh() const {
return Vectorized<double>(Sleef_tanhd4_u10(values));
}
Vectorized<double> trunc() const {
return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
}
Vectorized<double> lgamma() const {
return Vectorized<double>(Sleef_lgammad4_u10(values));
}
Vectorized<double> sqrt() const {
return _mm256_sqrt_pd(values);
}
Vectorized<double> reciprocal() const {
return _mm256_div_pd(_mm256_set1_pd(1), values);
}
Vectorized<double> rsqrt() const {
return _mm256_div_pd(_mm256_set1_pd(1), _mm256_sqrt_pd(values));
}
Vectorized<double> pow(const Vectorized<double> &b) const {
return Vectorized<double>(Sleef_powd4_u10(values, b));
}
// Comparison using the _CMP_**_OQ predicate.
// `O`: get false if an operand is NaN
// `Q`: do not raise if an operand is NaN
Vectorized<double> operator==(const Vectorized<double>& other) const {
return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ);
}
Vectorized<double> operator!=(const Vectorized<double>& other) const {
return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ);
}
Vectorized<double> operator<(const Vectorized<double>& other) const {
return _mm256_cmp_pd(values, other.values, _CMP_LT_OQ);
}
Vectorized<double> operator<=(const Vectorized<double>& other) const {
return _mm256_cmp_pd(values, other.values, _CMP_LE_OQ);
}
Vectorized<double> operator>(const Vectorized<double>& other) const {
return _mm256_cmp_pd(values, other.values, _CMP_GT_OQ);
}
Vectorized<double> operator>=(const Vectorized<double>& other) const {
return _mm256_cmp_pd(values, other.values, _CMP_GE_OQ);
}
Vectorized<double> eq(const Vectorized<double>& other) const;
Vectorized<double> ne(const Vectorized<double>& other) const;
Vectorized<double> lt(const Vectorized<double>& other) const;
Vectorized<double> le(const Vectorized<double>& other) const;
Vectorized<double> gt(const Vectorized<double>& other) const;
Vectorized<double> ge(const Vectorized<double>& other) const;
};
template <>
Vectorized<double> inline operator+(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm256_add_pd(a, b);
}
template <>
Vectorized<double> inline operator-(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm256_sub_pd(a, b);
}
template <>
Vectorized<double> inline operator*(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm256_mul_pd(a, b);
}
template <>
Vectorized<double> inline operator/(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm256_div_pd(a, b);
}
// frac. Implement this here so we can use subtraction.
inline Vectorized<double> Vectorized<double>::frac() const {
return *this - this->trunc();
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<double> inline maximum(const Vectorized<double>& a, const Vectorized<double>& b) {
Vectorized<double> max = _mm256_max_pd(a, b);
Vectorized<double> isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q);
// Exploit the fact that all-ones is a NaN.
return _mm256_or_pd(max, isnan);
}
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<double> inline minimum(const Vectorized<double>& a, const Vectorized<double>& b) {
Vectorized<double> min = _mm256_min_pd(a, b);
Vectorized<double> isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q);
// Exploit the fact that all-ones is a NaN.
return _mm256_or_pd(min, isnan);
}
template <>
Vectorized<double> inline clamp(const Vectorized<double>& a, const Vectorized<double>& min, const Vectorized<double>& max) {
return _mm256_min_pd(max, _mm256_max_pd(min, a));
}
template <>
Vectorized<double> inline clamp_min(const Vectorized<double>& a, const Vectorized<double>& min) {
return _mm256_max_pd(min, a);
}
template <>
Vectorized<double> inline clamp_max(const Vectorized<double>& a, const Vectorized<double>& max) {
return _mm256_min_pd(max, a);
}
template <>
Vectorized<double> inline operator&(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm256_and_pd(a, b);
}
template <>
Vectorized<double> inline operator|(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm256_or_pd(a, b);
}
template <>
Vectorized<double> inline operator^(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm256_xor_pd(a, b);
}
inline Vectorized<double> Vectorized<double>::eq(const Vectorized<double>& other) const {
return (*this == other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::ne(const Vectorized<double>& other) const {
return (*this != other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::gt(const Vectorized<double>& other) const {
return (*this > other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::ge(const Vectorized<double>& other) const {
return (*this >= other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::lt(const Vectorized<double>& other) const {
return (*this < other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::le(const Vectorized<double>& other) const {
return (*this <= other) & Vectorized<double>(1.0);
}
template <>
inline void convert(const double* src, double* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<double>::size()); i += Vectorized<double>::size()) {
_mm256_storeu_pd(dst + i, _mm256_loadu_pd(src + i));
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = src[i];
}
}
#ifdef CPU_CAPABILITY_AVX2
template <>
Vectorized<double> inline fmadd(const Vectorized<double>& a, const Vectorized<double>& b, const Vectorized<double>& c) {
return _mm256_fmadd_pd(a, b, c);
}
template <>
Vectorized<double> inline fmsub(const Vectorized<double>& a, const Vectorized<double>& b, const Vectorized<double>& c) {
return _mm256_fmsub_pd(a, b, c);
}
#endif
#endif
}} // namespace at::vec::CPU_CAPABILITY

View File

@ -0,0 +1,656 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/irange.h>
#if defined(CPU_CAPABILITY_AVX2)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX2)
template <> class Vectorized<float> {
private:
__m256 values;
public:
using value_type = float;
using size_type = int;
static constexpr size_type size() {
return 8;
}
Vectorized() {}
Vectorized(__m256 v) : values(v) {}
Vectorized(float val) {
values = _mm256_set1_ps(val);
}
Vectorized(float val1, float val2, float val3, float val4,
float val5, float val6, float val7, float val8) {
values = _mm256_setr_ps(val1, val2, val3, val4, val5, val6, val7, val8);
}
operator __m256() const {
return values;
}
template <int64_t mask>
static Vectorized<float> blend(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm256_blend_ps(a.values, b.values, mask);
}
static Vectorized<float> blendv(const Vectorized<float>& a, const Vectorized<float>& b,
const Vectorized<float>& mask) {
return _mm256_blendv_ps(a.values, b.values, mask.values);
}
template<typename step_t>
static Vectorized<float> arange(float base = 0.f, step_t step = static_cast<step_t>(1)) {
return Vectorized<float>(
base, base + step, base + 2 * step, base + 3 * step,
base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step);
}
static Vectorized<float> set(const Vectorized<float>& a, const Vectorized<float>& b,
int64_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
case 4:
return blend<15>(a, b);
case 5:
return blend<31>(a, b);
case 6:
return blend<63>(a, b);
case 7:
return blend<127>(a, b);
}
return b;
}
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return _mm256_loadu_ps(reinterpret_cast<const float*>(ptr));
__at_align__ float tmp_values[size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction.
for (const auto i : c10::irange(size())) {
tmp_values[i] = 0.0;
}
std::memcpy(
tmp_values, reinterpret_cast<const float*>(ptr), count * sizeof(float));
return _mm256_loadu_ps(tmp_values);
}
void store(void* ptr, int64_t count = size()) const {
if (count == size()) {
_mm256_storeu_ps(reinterpret_cast<float*>(ptr), values);
} else if (count > 0) {
float tmp_values[size()];
_mm256_storeu_ps(reinterpret_cast<float*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(float));
}
}
const float& operator[](int idx) const = delete;
float& operator[](int idx) = delete;
int zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
__m256 cmp = _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_EQ_OQ);
return _mm256_movemask_ps(cmp);
}
Vectorized<float> isnan() const {
return _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_UNORD_Q);
}
bool has_inf_nan() const {
__m256 self_sub = _mm256_sub_ps(values, values);
return (_mm256_movemask_epi8(_mm256_castps_si256(self_sub)) & 0x77777777) != 0;
}
Vectorized<float> map(float (*const f)(float)) const {
__at_align__ float tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
Vectorized<float> abs() const {
auto mask = _mm256_set1_ps(-0.f);
return _mm256_andnot_ps(mask, values);
}
Vectorized<float> angle() const {
const auto zero_vec = _mm256_set1_ps(0.f);
const auto nan_vec = _mm256_set1_ps(NAN);
const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ);
const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ);
const auto pi = _mm256_set1_ps(c10::pi<float>);
const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask);
angle = _mm256_blendv_ps(angle, nan_vec, nan_mask);
return angle;
}
Vectorized<float> real() const {
return *this;
}
Vectorized<float> imag() const {
return _mm256_set1_ps(0);
}
Vectorized<float> conj() const {
return *this;
}
Vectorized<float> acos() const {
return Vectorized<float>(Sleef_acosf8_u10(values));
}
Vectorized<float> acosh() const {
return Vectorized<float>(Sleef_acoshf8_u10(values));
}
Vectorized<float> asin() const {
return Vectorized<float>(Sleef_asinf8_u10(values));
}
Vectorized<float> atan() const {
return Vectorized<float>(Sleef_atanf8_u10(values));
}
Vectorized<float> atanh() const {
return Vectorized<float>(Sleef_atanhf8_u10(values));
}
Vectorized<float> atan2(const Vectorized<float> &b) const {
return Vectorized<float>(Sleef_atan2f8_u10(values, b));
}
Vectorized<float> copysign(const Vectorized<float> &sign) const {
return Vectorized<float>(Sleef_copysignf8(values, sign));
}
Vectorized<float> erf() const {
// constants
const auto neg_zero_vec = _mm256_set1_ps(-0.f);
const auto one_vec = _mm256_set1_ps(1.0f);
const auto p = _mm256_set1_ps(0.3275911f);
const auto p1 = _mm256_set1_ps(0.254829592f);
const auto p2 = _mm256_set1_ps(-0.284496736f);
const auto p3 = _mm256_set1_ps(1.421413741f);
const auto p4 = _mm256_set1_ps(-1.453152027f);
const auto p5 = _mm256_set1_ps(1.061405429f);
// sign(x)
auto sign_mask = _mm256_and_ps(neg_zero_vec, values);
auto abs_vec = _mm256_xor_ps(sign_mask, values);
// t = 1 / (p * abs(x) + 1)
auto tmp0 = _mm256_fmadd_ps(p, abs_vec, one_vec);
auto t = _mm256_div_ps(one_vec, tmp0);
// r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1
auto tmp1 = _mm256_fmadd_ps(p5, t, p4);
auto tmp2 = _mm256_fmadd_ps(tmp1, t, p3);
auto tmp3 = _mm256_fmadd_ps(tmp2, t, p2);
auto r = _mm256_fmadd_ps(tmp3, t, p1);
// - exp(- x * x)
auto pow_2 = _mm256_mul_ps(values, values);
auto neg_pow_2 = _mm256_xor_ps(neg_zero_vec, pow_2);
// auto tmp4 = exp(neg_pow_2);
auto tmp4 = Vectorized<float>(Sleef_expf8_u10(neg_pow_2));
auto tmp5 = _mm256_xor_ps(neg_zero_vec, tmp4);
// erf(x) = sign(x) * (1 - r * t * exp(- x * x))
auto tmp6 = _mm256_mul_ps(tmp5, t);
auto tmp7 = _mm256_fmadd_ps(tmp6, r, one_vec);
return _mm256_xor_ps(sign_mask, tmp7);
}
Vectorized<float> erfc() const {
return Vectorized<float>(Sleef_erfcf8_u15(values));
}
Vectorized<float> erfinv() const {
return map(calc_erfinv);
}
Vectorized<float> exp() const {
return Vectorized<float>(Sleef_expf8_u10(values));
}
Vectorized<float> exp2() const {
return Vectorized<float>(Sleef_exp2f8_u10(values));
}
Vectorized<float> expm1() const {
return Vectorized<float>(Sleef_expm1f8_u10(values));
}
Vectorized<float> exp_u20() const {
// A faster version of exp with ULP=20
static __m256 vec_factorial_1 =
_mm256_set1_ps(0.999999701f); // 1/factorial(1)
static __m256 vec_factorial_2 =
_mm256_set1_ps(0.499991506f); // 1/factorial(2)
static __m256 vec_factorial_3 =
_mm256_set1_ps(0.166676521f); // 1/factorial(3)
static __m256 vec_factorial_4 =
_mm256_set1_ps(0.0418978221f); // 1/factorial(4)
static __m256 vec_factorial_5 =
_mm256_set1_ps(0.00828929059f); // 1/factorial(5)
static __m256 vec_exp_log2ef =
_mm256_castsi256_ps(_mm256_set1_epi32(0x3fb8aa3b)); // log2(e)
static __m256 vec_half = _mm256_set1_ps(0.5f);
static __m256 vec_one = _mm256_set1_ps(1.f);
static __m256 vec_zero = _mm256_set1_ps(0.f);
static __m256 vec_two = _mm256_set1_ps(2.f);
static __m256 vec_ln2f = _mm256_castsi256_ps(_mm256_set1_epi32(0x3f317218)); // ln(2)
static __m256 vec_ln_flt_min = _mm256_castsi256_ps(_mm256_set1_epi32(0xc2aeac50));
static __m256 vec_ln_flt_max = _mm256_castsi256_ps(_mm256_set1_epi32(0x42b17218));
static __m256i vec_127 = _mm256_set1_epi32(0x0000007f);
static int n_mantissa_bits = 23;
// exp(x) =
// = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem
// = 2^n * exp(r) // simplify the exp(n*ln(2)) expression
auto less_ln_flt_min_mask =
_mm256_cmp_ps(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/);
auto vec_src = _mm256_min_ps(values, vec_ln_flt_max);
vec_src = _mm256_max_ps(vec_src, vec_ln_flt_min);
// fx = floorf(x * log2ef + 0.5)
auto vec_fx = _mm256_fmadd_ps(vec_src, vec_exp_log2ef, vec_half);
vec_fx = _mm256_floor_ps(vec_fx);
// x = x - fx * ln2
auto vec_exp_poly = _mm256_fnmadd_ps(vec_fx, vec_ln2f, vec_src);
// compute polynomial
auto vec_res =
_mm256_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4);
vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3);
vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2);
vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1);
vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_one);
// compute 2^(n-1)
auto vec_exp_number = _mm256_sub_ps(vec_fx, vec_one);
auto vec_exp_number_i = _mm256_cvtps_epi32(vec_exp_number);
auto vec_two_pow_n_i = _mm256_add_epi32(vec_exp_number_i, vec_127);
vec_two_pow_n_i = _mm256_slli_epi32(vec_two_pow_n_i, n_mantissa_bits);
auto vec_two_pow_n = _mm256_castsi256_ps(vec_two_pow_n_i);
vec_two_pow_n =
_mm256_blendv_ps(vec_two_pow_n, vec_zero, less_ln_flt_min_mask);
// y = y * 2^n
vec_res = _mm256_mul_ps(vec_res, vec_two_pow_n);
vec_res = _mm256_mul_ps(vec_res, vec_two);
return vec_res;
}
Vectorized<float> fmod(const Vectorized<float>& q) const {
return Vectorized<float>(Sleef_fmodf8(values, q));
}
Vectorized<float> log() const {
return Vectorized<float>(Sleef_logf8_u10(values));
}
Vectorized<float> log2() const {
return Vectorized<float>(Sleef_log2f8_u10(values));
}
Vectorized<float> log10() const {
return Vectorized<float>(Sleef_log10f8_u10(values));
}
Vectorized<float> log1p() const {
return Vectorized<float>(Sleef_log1pf8_u10(values));
}
Vectorized<float> frac() const;
Vectorized<float> sin() const {
return Vectorized<float>(Sleef_sinf8_u35(values));
}
Vectorized<float> sinh() const {
return Vectorized<float>(Sleef_sinhf8_u10(values));
}
Vectorized<float> cos() const {
return Vectorized<float>(Sleef_cosf8_u35(values));
}
Vectorized<float> cosh() const {
return Vectorized<float>(Sleef_coshf8_u10(values));
}
Vectorized<float> ceil() const {
return _mm256_ceil_ps(values);
}
Vectorized<float> floor() const {
return _mm256_floor_ps(values);
}
Vectorized<float> hypot(const Vectorized<float> &b) const {
return Vectorized<float>(Sleef_hypotf8_u05(values, b));
}
Vectorized<float> i0() const {
return map(calc_i0);
}
Vectorized<float> i0e() const {
return map(calc_i0e);
}
Vectorized<float> digamma() const {
return map(calc_digamma);
}
Vectorized<float> igamma(const Vectorized<float> &x) const {
__at_align__ float tmp[size()];
__at_align__ float tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (const auto i : c10::irange(size())) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<float> igammac(const Vectorized<float> &x) const {
__at_align__ float tmp[size()];
__at_align__ float tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (const auto i : c10::irange(size())) {
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<float> neg() const {
return _mm256_xor_ps(_mm256_set1_ps(-0.f), values);
}
Vectorized<float> nextafter(const Vectorized<float> &b) const {
return Vectorized<float>(Sleef_nextafterf8(values, b));
}
Vectorized<float> round() const {
return _mm256_round_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
Vectorized<float> tan() const {
return Vectorized<float>(Sleef_tanf8_u10(values));
}
Vectorized<float> tanh() const {
return Vectorized<float>(Sleef_tanhf8_u10(values));
}
Vectorized<float> trunc() const {
return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
}
Vectorized<float> lgamma() const {
return Vectorized<float>(Sleef_lgammaf8_u10(values));
}
Vectorized<float> sqrt() const {
return _mm256_sqrt_ps(values);
}
Vectorized<float> reciprocal() const {
return _mm256_div_ps(_mm256_set1_ps(1), values);
}
Vectorized<float> rsqrt() const {
return _mm256_div_ps(_mm256_set1_ps(1), _mm256_sqrt_ps(values));
}
Vectorized<float> pow(const Vectorized<float> &b) const {
return Vectorized<float>(Sleef_powf8_u10(values, b));
}
// Comparison using the _CMP_**_OQ predicate.
// `O`: get false if an operand is NaN
// `Q`: do not raise if an operand is NaN
Vectorized<float> operator==(const Vectorized<float>& other) const {
return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ);
}
Vectorized<float> operator!=(const Vectorized<float>& other) const {
return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ);
}
Vectorized<float> operator<(const Vectorized<float>& other) const {
return _mm256_cmp_ps(values, other.values, _CMP_LT_OQ);
}
Vectorized<float> operator<=(const Vectorized<float>& other) const {
return _mm256_cmp_ps(values, other.values, _CMP_LE_OQ);
}
Vectorized<float> operator>(const Vectorized<float>& other) const {
return _mm256_cmp_ps(values, other.values, _CMP_GT_OQ);
}
Vectorized<float> operator>=(const Vectorized<float>& other) const {
return _mm256_cmp_ps(values, other.values, _CMP_GE_OQ);
}
Vectorized<float> eq(const Vectorized<float>& other) const;
Vectorized<float> ne(const Vectorized<float>& other) const;
Vectorized<float> gt(const Vectorized<float>& other) const;
Vectorized<float> ge(const Vectorized<float>& other) const;
Vectorized<float> lt(const Vectorized<float>& other) const;
Vectorized<float> le(const Vectorized<float>& other) const;
};
template <>
Vectorized<float> inline operator+(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm256_add_ps(a, b);
}
template <>
Vectorized<float> inline operator-(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm256_sub_ps(a, b);
}
template <>
Vectorized<float> inline operator*(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm256_mul_ps(a, b);
}
template <>
Vectorized<float> inline operator/(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm256_div_ps(a, b);
}
// frac. Implement this here so we can use subtraction
inline Vectorized<float> Vectorized<float>::frac() const {
return *this - this->trunc();
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<float> inline maximum(const Vectorized<float>& a, const Vectorized<float>& b) {
Vectorized<float> max = _mm256_max_ps(a, b);
Vectorized<float> isnan = _mm256_cmp_ps(a, b, _CMP_UNORD_Q);
// Exploit the fact that all-ones is a NaN.
return _mm256_or_ps(max, isnan);
}
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<float> inline minimum(const Vectorized<float>& a, const Vectorized<float>& b) {
Vectorized<float> min = _mm256_min_ps(a, b);
Vectorized<float> isnan = _mm256_cmp_ps(a, b, _CMP_UNORD_Q);
// Exploit the fact that all-ones is a NaN.
return _mm256_or_ps(min, isnan);
}
template <>
Vectorized<float> inline clamp(const Vectorized<float>& a, const Vectorized<float>& min, const Vectorized<float>& max) {
return _mm256_min_ps(max, _mm256_max_ps(min, a));
}
template <>
Vectorized<float> inline clamp_max(const Vectorized<float>& a, const Vectorized<float>& max) {
return _mm256_min_ps(max, a);
}
template <>
Vectorized<float> inline clamp_min(const Vectorized<float>& a, const Vectorized<float>& min) {
return _mm256_max_ps(min, a);
}
template <>
Vectorized<float> inline operator&(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm256_and_ps(a, b);
}
template <>
Vectorized<float> inline operator|(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm256_or_ps(a, b);
}
template <>
Vectorized<float> inline operator^(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm256_xor_ps(a, b);
}
inline Vectorized<float> Vectorized<float>::eq(const Vectorized<float>& other) const {
return (*this == other) & Vectorized<float>(1.0f);
}
inline Vectorized<float> Vectorized<float>::ne(const Vectorized<float>& other) const {
return (*this != other) & Vectorized<float>(1.0f);
}
inline Vectorized<float> Vectorized<float>::gt(const Vectorized<float>& other) const {
return (*this > other) & Vectorized<float>(1.0f);
}
inline Vectorized<float> Vectorized<float>::ge(const Vectorized<float>& other) const {
return (*this >= other) & Vectorized<float>(1.0f);
}
inline Vectorized<float> Vectorized<float>::lt(const Vectorized<float>& other) const {
return (*this < other) & Vectorized<float>(1.0f);
}
inline Vectorized<float> Vectorized<float>::le(const Vectorized<float>& other) const {
return (*this <= other) & Vectorized<float>(1.0f);
}
template <>
inline void convert(const float* src, float* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) {
_mm256_storeu_ps(dst + i, _mm256_loadu_ps(src + i));
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = src[i];
}
}
template <>
Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
return _mm256_fmadd_ps(a, b, c);
}
template <>
Vectorized<float> inline fmsub(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
return _mm256_fmsub_ps(a, b, c);
}
// Used by Inductor CPP codegen
template<>
inline void transpose_mxn<float, 8, 8>(
const float* src,
int64_t ld_src,
float* dst,
int64_t ld_dst) {
// load from src to registers
// a: a0 a1 a2 a3 a4 a5 a6 a7
// b: b0 b1 b2 b3 b4 b5 b6 b7
// c: c0 c1 c2 c3 c4 c5 c6 c7
// d: d0 d1 d2 d3 d4 d5 d6 d7
// e: e0 e1 e2 e3 e4 e5 e6 e7
// f: f0 f1 f2 f3 f4 f5 f6 f7
// g: g0 g1 g2 g3 g4 g5 g6 g7
// h: h0 h1 h2 h3 h4 h5 h6 h7
__m256 a = _mm256_loadu_ps(&src[0 * ld_src]);
__m256 b = _mm256_loadu_ps(&src[1 * ld_src]);
__m256 c = _mm256_loadu_ps(&src[2 * ld_src]);
__m256 d = _mm256_loadu_ps(&src[3 * ld_src]);
__m256 e = _mm256_loadu_ps(&src[4 * ld_src]);
__m256 f = _mm256_loadu_ps(&src[5 * ld_src]);
__m256 g = _mm256_loadu_ps(&src[6 * ld_src]);
__m256 h = _mm256_loadu_ps(&src[7 * ld_src]);
__m256 ta, tb, tc, td, te, tf, tg, th;
// unpacking and interleaving 32-bit elements
// a0 b0 a1 b1 a4 b4 a5 b5
// a2 b2 a3 b3 a6 b6 a7 b7
// c0 d0 c1 d1 ...
// c2 d2 c3 d3 ...
// e0 f0 e1 f1 ...
// e2 f2 e3 f3 ...
// g0 h0 g1 h1 ...
// g2 h2 g3 h3 ...
ta = _mm256_unpacklo_ps(a, b);
tb = _mm256_unpackhi_ps(a, b);
tc = _mm256_unpacklo_ps(c, d);
td = _mm256_unpackhi_ps(c, d);
te = _mm256_unpacklo_ps(e, f);
tf = _mm256_unpackhi_ps(e, f);
tg = _mm256_unpacklo_ps(g, h);
th = _mm256_unpackhi_ps(g, h);
// unpacking and interleaving 64-bit elements
// a0 b0 c0 d0 a4 b4 c4 d4
// a1 b1 c1 d1 ...
// a2 b2 c2 d2 ...
// a3 b3 c3 d3 ...
// e0 f0 g0 h0 e4 f4 g4 h4
// e1 f1 g1 h1 ...
// e2 f2 g2 h2 ...
// e3 f3 g3 h3 ...
a = _mm256_castpd_ps(
_mm256_unpacklo_pd(_mm256_castps_pd(ta), _mm256_castps_pd(tc)));
b = _mm256_castpd_ps(
_mm256_unpackhi_pd(_mm256_castps_pd(ta), _mm256_castps_pd(tc)));
c = _mm256_castpd_ps(
_mm256_unpacklo_pd(_mm256_castps_pd(tb), _mm256_castps_pd(td)));
d = _mm256_castpd_ps(
_mm256_unpackhi_pd(_mm256_castps_pd(tb), _mm256_castps_pd(td)));
e = _mm256_castpd_ps(
_mm256_unpacklo_pd(_mm256_castps_pd(te), _mm256_castps_pd(tg)));
f = _mm256_castpd_ps(
_mm256_unpackhi_pd(_mm256_castps_pd(te), _mm256_castps_pd(tg)));
g = _mm256_castpd_ps(
_mm256_unpacklo_pd(_mm256_castps_pd(tf), _mm256_castps_pd(th)));
h = _mm256_castpd_ps(
_mm256_unpackhi_pd(_mm256_castps_pd(tf), _mm256_castps_pd(th)));
// shuffle 128-bits (composed of 4 32-bit elements)
// a0 b0 c0 d0 e0 f0 g0 h0
// a1 b1 c1 d1 ...
// a2 b2 c2 d2 ...
// a3 b3 c3 d3 ...
// a4 b4 c4 d4 ...
// a5 b5 c5 d5 ...
// a6 b6 c6 d6 ...
// a7 b7 c7 d7 ...
ta = _mm256_permute2f128_ps(a, e, 0x20);
tb = _mm256_permute2f128_ps(b, f, 0x20);
tc = _mm256_permute2f128_ps(c, g, 0x20);
td = _mm256_permute2f128_ps(d, h, 0x20);
te = _mm256_permute2f128_ps(a, e, 0x31);
tf = _mm256_permute2f128_ps(b, f, 0x31);
tg = _mm256_permute2f128_ps(c, g, 0x31);
th = _mm256_permute2f128_ps(d, h, 0x31);
// store from registers to dst
_mm256_storeu_ps(&dst[0 * ld_dst], ta);
_mm256_storeu_ps(&dst[1 * ld_dst], tb);
_mm256_storeu_ps(&dst[2 * ld_dst], tc);
_mm256_storeu_ps(&dst[3 * ld_dst], td);
_mm256_storeu_ps(&dst[4 * ld_dst], te);
_mm256_storeu_ps(&dst[5 * ld_dst], tf);
_mm256_storeu_ps(&dst[6 * ld_dst], tg);
_mm256_storeu_ps(&dst[7 * ld_dst], th);
}
template<>
inline void transpose_mxn<float, 16, 16>(
const float* src,
int64_t ld_src,
float* dst,
int64_t ld_dst) {
transpose_mxn<float, 8, 8>(
src , ld_src, dst, ld_dst);
transpose_mxn<float, 8, 8>(
src + 8, ld_src, dst + 8 * ld_dst, ld_dst);
transpose_mxn<float, 8, 8>(
src + 8 * ld_src, ld_src, dst + 8, ld_dst);
transpose_mxn<float, 8, 8>(
src + 8 * ld_src + 8, ld_src, dst + 8 * ld_dst + 8, ld_dst);
}
#endif
}} // namespace at::vec::CPU_CAPABILITY

View File

@ -0,0 +1,909 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/irange.h>
#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF)
#include <sleef.h>
#endif
// Sleef offers vectorized versions of some transcedentals
// such as sin, cos, tan etc..
// However for now opting for STL, since we are not building
// with Sleef for mobile yet.
namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
// Right now contains only aarch64 implementation.
// Due to follow two reasons aarch32 is not currently supported.
// 1. Due to difference in ISA been aarch32 and aarch64, intrinsics
// that work for aarch64 dont work for aarch32.
// 2. Android NDK r21 has problems with compiling aarch32.
// Clang seg faults.
// https://github.com/android/ndk/issues/1248
// https://bugs.llvm.org/show_bug.cgi?id=45824
// Most likely we will do aarch32 support with inline asm.
#if defined(__aarch64__)
#ifdef __BIG_ENDIAN__
#error "Big endian is not supported."
#endif
#if defined(AT_BUILD_ARM_VEC256_WITH_SLEEF)
#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code
#else
#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code
#endif
template<int index, bool mask_val>
struct BlendRegs {
static float32x4_t impl(
const float32x4_t& a, const float32x4_t& b, float32x4_t& res);
};
template<int index>
struct BlendRegs<index, true>{
static float32x4_t impl(
const float32x4_t& a, const float32x4_t& b, float32x4_t& res) {
return vsetq_lane_f32(vgetq_lane_f32(b, index), res, index);
}
};
template<int index>
struct BlendRegs<index, false>{
static float32x4_t impl(
const float32x4_t& a, const float32x4_t& b, float32x4_t& res) {
return vsetq_lane_f32(vgetq_lane_f32(a, index), res, index);
}
};
template <> class Vectorized<float> {
private:
float32x4x2_t values;
public:
using value_type = float;
using size_type = int;
static constexpr size_type size() {
return 8;
}
Vectorized() {}
Vectorized(float32x4x2_t v) : values(v) {}
Vectorized(float val) : values{vdupq_n_f32(val), vdupq_n_f32(val) } {}
Vectorized(float val0, float val1, float val2, float val3,
float val4, float val5, float val6, float val7) :
values{val0, val1, val2, val3, val4, val5, val6, val7} {}
Vectorized(float32x4_t val0, float32x4_t val1) : values{val0, val1} {}
operator float32x4x2_t() const {
return values;
}
template <int64_t mask>
static Vectorized<float> blend(const Vectorized<float>& a, const Vectorized<float>& b) {
Vectorized<float> vec;
// 0.
vec.values.val[0] =
BlendRegs<0, (mask & 0x01)!=0>::impl(
a.values.val[0], b.values.val[0], vec.values.val[0]);
vec.values.val[0] =
BlendRegs<1, (mask & 0x02)!=0>::impl(
a.values.val[0], b.values.val[0], vec.values.val[0]);
vec.values.val[0] =
BlendRegs<2, (mask & 0x04)!=0>::impl(
a.values.val[0], b.values.val[0], vec.values.val[0]);
vec.values.val[0] =
BlendRegs<3, (mask & 0x08)!=0>::impl(
a.values.val[0], b.values.val[0], vec.values.val[0]);
// 1.
vec.values.val[1] =
BlendRegs<0, (mask & 0x10)!=0>::impl(
a.values.val[1], b.values.val[1], vec.values.val[1]);
vec.values.val[1] =
BlendRegs<1, (mask & 0x20)!=0>::impl(
a.values.val[1], b.values.val[1], vec.values.val[1]);
vec.values.val[1] =
BlendRegs<2, (mask & 0x40)!=0>::impl(
a.values.val[1], b.values.val[1], vec.values.val[1]);
vec.values.val[1] =
BlendRegs<3, (mask & 0x80)!=0>::impl(
a.values.val[1], b.values.val[1], vec.values.val[1]);
return vec;
}
static Vectorized<float> blendv(const Vectorized<float>& a, const Vectorized<float>& b,
const Vectorized<float>& mask) {
// TODO
// NB: This requires that each value, i.e., each uint value,
// of the mask either all be zeros or all be 1s.
// We perhaps need some kind of an assert?
// But that will affect performance.
Vectorized<float> vec(mask.values);
vec.values.val[0] = vbslq_f32(
vreinterpretq_u32_f32(vec.values.val[0]),
b.values.val[0],
a.values.val[0]);
vec.values.val[1] = vbslq_f32(
vreinterpretq_u32_f32(vec.values.val[1]),
b.values.val[1],
a.values.val[1]);
return vec;
}
template<typename step_t>
static Vectorized<float> arange(float base = 0.f, step_t step = static_cast<step_t>(1)) {
const Vectorized<float> base_vec(base);
const Vectorized<float> step_vec(step);
const Vectorized<float> step_sizes(0, 1, 2, 3, 4, 5, 6, 7);
return fmadd(step_sizes, step_vec, base_vec);
}
static Vectorized<float> set(const Vectorized<float>& a, const Vectorized<float>& b,
int64_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
{
Vectorized<float> vec;
static uint32x4_t mask_low = {0xFFFFFFFF, 0x0, 0x0, 0x0};
vec.values.val[0] = vreinterpretq_f32_u32(mask_low);
vec.values.val[1] = a.values.val[1];
vec.values.val[0] = vbslq_f32(
vreinterpretq_u32_f32(vec.values.val[0]),
b.values.val[0],
a.values.val[0]);
return vec;
}
case 2:
{
Vectorized<float> vec;
static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0};
vec.values.val[0] = vreinterpretq_f32_u32(mask_low);
vec.values.val[1] = a.values.val[1];
vec.values.val[0] = vbslq_f32(
vreinterpretq_u32_f32(vec.values.val[0]),
b.values.val[0],
a.values.val[0]);
return vec;
}
case 3:
{
Vectorized<float> vec;
static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0};
vec.values.val[0] = vreinterpretq_f32_u32(mask_low);
vec.values.val[1] = a.values.val[1];
vec.values.val[0] = vbslq_f32(
vreinterpretq_u32_f32(vec.values.val[0]),
b.values.val[0],
a.values.val[0]);
return vec;
}
case 4:
return Vectorized<float>(b.values.val[0], a.values.val[1]);
case 5:
{
Vectorized<float> vec;
static uint32x4_t mask_high = {0xFFFFFFFF, 0x0, 0x0, 0x0};
vec.values.val[0] = b.values.val[0];
vec.values.val[1] = vreinterpretq_f32_u32(mask_high);
vec.values.val[1] = vbslq_f32(
vreinterpretq_u32_f32(vec.values.val[1]),
b.values.val[1],
a.values.val[1]);
return vec;
}
case 6:
{
Vectorized<float> vec;
static uint32x4_t mask_high = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0};
vec.values.val[0] = b.values.val[0];
vec.values.val[1] = vreinterpretq_f32_u32(mask_high);
vec.values.val[1] = vbslq_f32(
vreinterpretq_u32_f32(vec.values.val[1]),
b.values.val[1],
a.values.val[1]);
return vec;
}
case 7:
{
Vectorized<float> vec;
static uint32x4_t mask_high = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0};
vec.values.val[0] = b.values.val[0];
vec.values.val[1] = vreinterpretq_f32_u32(mask_high);
vec.values.val[1] = vbslq_f32(
vreinterpretq_u32_f32(vec.values.val[1]),
b.values.val[1],
a.values.val[1]);
return vec;
}
}
return b;
}
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
if (count == size()) {
return vld1q_f32_x2(reinterpret_cast<const float*>(ptr));
}
else if (count == (size() >> 1)) {
Vectorized<float> res;
res.values.val[0] = vld1q_f32(reinterpret_cast<const float*>(ptr));
res.values.val[1] = vdupq_n_f32(0.f);
return res;
}
else {
__at_align__ float tmp_values[size()];
for (const auto i : c10::irange(size())) {
tmp_values[i] = 0.0;
}
std::memcpy(
tmp_values,
reinterpret_cast<const float*>(ptr),
count * sizeof(float));
return vld1q_f32_x2(reinterpret_cast<const float*>(tmp_values));
}
}
void store(void* ptr, int64_t count = size()) const {
if (count == size()) {
vst1q_f32_x2(reinterpret_cast<float*>(ptr), values);
}
else if (count == (size() >> 1)) {
vst1q_f32(reinterpret_cast<float*>(ptr), values.val[0]);
}
else {
float tmp_values[size()];
vst1q_f32_x2(reinterpret_cast<float*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(float));
}
}
inline const float32x4_t& get_low() const {
return values.val[0];
}
inline float32x4_t& get_low() {
return values.val[0];
}
inline const float32x4_t& get_high() const {
return values.val[1];
}
inline float32x4_t& get_high() {
return values.val[1];
}
// Very slow implementation of indexing.
// Only required because vec256_qint refers to this.
// Once we specialize that implementation for ARM
// this should be removed. TODO (kimishpatel)
float operator[](int idx) const {
__at_align__ float tmp[size()];
store(tmp);
return tmp[idx];
}
float operator[](int idx) {
__at_align__ float tmp[size()];
store(tmp);
return tmp[idx];
}
// For boolean version where we want to if any 1/all zero
// etc. can be done faster in a different way.
int zero_mask() const {
__at_align__ float tmp[size()];
store(tmp);
int mask = 0;
for (int i = 0; i < size(); ++ i) {
if (tmp[i] == 0.f) {
mask |= (1 << i);
}
}
return mask;
}
Vectorized<float> isnan() const {
__at_align__ float tmp[size()];
__at_align__ float res[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
if (_isnan(tmp[i])) {
std::memset(static_cast<void*>(&res[i]), 0xFF, sizeof(float));
} else {
std::memset(static_cast<void*>(&res[i]), 0, sizeof(float));
}
}
return loadu(res);
};
bool has_inf_nan() const {
__at_align__ float tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
if(_isnan(tmp[i]) || _isinf(tmp[i])) {
return true;
}
}
return false;
}
Vectorized<float> map(float (*const f)(float)) const {
__at_align__ float tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
Vectorized<float> abs() const {
return Vectorized<float>(vabsq_f32(values.val[0]), vabsq_f32(values.val[1]));
}
Vectorized<float> angle() const {
auto zero = Vectorized<float>(0);
auto pi = Vectorized<float>(c10::pi<float>);
auto tmp = blendv(zero, pi, *this < zero);
return blendv(tmp, *this, isnan());
}
Vectorized<float> real() const {
return *this;
}
Vectorized<float> imag() const {
return Vectorized<float>(0.f);
}
Vectorized<float> conj() const {
return *this;
}
Vectorized<float> acos() const {
return USE_SLEEF(
Vectorized<float>(Sleef_acosf4_u10(values.val[0]), Sleef_acosf4_u10(values.val[1])),
map(std::acos)
);
}
Vectorized<float> acosh() const {
return USE_SLEEF(
Vectorized<float>(Sleef_acoshf4_u10(values.val[0]), Sleef_acoshf4_u10(values.val[1])),
map(std::acosh)
);
}
Vectorized<float> asin() const {
return USE_SLEEF(
Vectorized<float>(Sleef_asinf4_u10(values.val[0]), Sleef_asinf4_u10(values.val[1])),
map(std::asin)
);
}
Vectorized<float> atan() const {
return USE_SLEEF(
Vectorized<float>(Sleef_atanf4_u10(values.val[0]), Sleef_atanf4_u10(values.val[1])),
map(std::atan)
);
}
Vectorized<float> atanh() const {
return USE_SLEEF(
Vectorized<float>(Sleef_atanhf4_u10(values.val[0]), Sleef_atanhf4_u10(values.val[1])),
map(std::atanh)
);
}
Vectorized<float> atan2(const Vectorized<float> &exp) const {
USE_SLEEF(
{
return Vectorized<float>(Sleef_atan2f4_u10(values.val[0], exp.values.val[0]),
Sleef_atan2f4_u10(values.val[1], exp.values.val[1]));
},
{
__at_align__ float tmp[size()];
__at_align__ float tmp_exp[size()];
store(tmp);
exp.store(tmp_exp);
for (const auto i : c10::irange(size())) {
tmp[i] = std::atan2(tmp[i], tmp_exp[i]);
}
return loadu(tmp);
}
)
}
Vectorized<float> copysign(const Vectorized<float> &sign) const {
USE_SLEEF(
{
return Vectorized<float>(Sleef_copysignf4(values.val[0], sign.values.val[0]),
Sleef_copysignf4(values.val[1], sign.values.val[1]));
},
{
__at_align__ float tmp[size()];
__at_align__ float tmp_sign[size()];
store(tmp);
sign.store(tmp_sign);
for (size_type i = 0; i < size(); i++) {
tmp[i] = std::copysign(tmp[i], tmp_sign[i]);
}
return loadu(tmp);
}
)
}
Vectorized<float> erf() const;
Vectorized<float> erfc() const {
return USE_SLEEF(
Vectorized<float>(Sleef_erfcf4_u15(values.val[0]), Sleef_erfcf4_u15(values.val[1])),
map(std::erfc)
);
}
Vectorized<float> erfinv() const {
return map(calc_erfinv);
}
Vectorized<float> exp() const {
return USE_SLEEF(
Vectorized<float>(Sleef_expf4_u10(values.val[0]), Sleef_expf4_u10(values.val[1])),
map(std::exp)
);
}
Vectorized<float> exp2() const {
return USE_SLEEF(
Vectorized<float>(Sleef_exp2f4_u10(values.val[0]), Sleef_exp2f4_u10(values.val[1])),
map(std::exp2)
);
}
Vectorized<float> expm1() const {
return USE_SLEEF(
Vectorized<float>(Sleef_expm1f4_u10(values.val[0]), Sleef_expm1f4_u10(values.val[1])),
map(std::expm1)
);
}
Vectorized<float> exp_u20() const {
return exp();
}
Vectorized<float> fmod(const Vectorized<float>& q) const {
USE_SLEEF(
{
return Vectorized<float>(Sleef_fmodf4(values.val[0], q.values.val[0]),
Sleef_fmodf4(values.val[1], q.values.val[1]));
},
{
__at_align__ float tmp[size()];
__at_align__ float tmp_q[size()];
store(tmp);
q.store(tmp_q);
for (const auto i : c10::irange(size())) {
tmp[i] = std::fmod(tmp[i], tmp_q[i]);
}
return loadu(tmp);
}
)
}
Vectorized<float> hypot(const Vectorized<float> &b) const {
USE_SLEEF(
{
return Vectorized<float>(Sleef_hypotf4_u05(values.val[0], b.values.val[0]),
Sleef_hypotf4_u05(values.val[1], b.values.val[1]));
},
{
__at_align__ float tmp[size()];
__at_align__ float tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (const auto i : c10::irange(size())) {
tmp[i] = std::hypot(tmp[i], tmp_b[i]);
}
return loadu(tmp);
}
)
}
Vectorized<float> i0() const {
return map(calc_i0);
}
Vectorized<float> i0e() const {
return map(calc_i0e);
}
Vectorized<float> digamma() const {
return map(calc_digamma);
}
Vectorized<float> igamma(const Vectorized<float> &x) const {
__at_align__ float tmp[size()];
__at_align__ float tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (const auto i : c10::irange(size())) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<float> igammac(const Vectorized<float> &x) const {
__at_align__ float tmp[size()];
__at_align__ float tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (const auto i : c10::irange(size())) {
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<float> log() const {
return USE_SLEEF(
Vectorized<float>(Sleef_logf4_u10(values.val[0]), Sleef_logf4_u10(values.val[1])),
map(std::log)
);
}
Vectorized<float> log10() const {
return USE_SLEEF(
Vectorized<float>(Sleef_log10f4_u10(values.val[0]), Sleef_log10f4_u10(values.val[1])),
map(std::log10)
);
}
Vectorized<float> log1p() const {
return USE_SLEEF(
Vectorized<float>(Sleef_log1pf4_u10(values.val[0]), Sleef_log1pf4_u10(values.val[1])),
map(std::log1p)
);
}
Vectorized<float> log2() const {
return USE_SLEEF(
Vectorized<float>(Sleef_log2f4_u10(values.val[0]), Sleef_log2f4_u10(values.val[1])),
map(std::log2)
);
}
Vectorized<float> nextafter(const Vectorized<float> &b) const {
USE_SLEEF(
{
return Vectorized<float>(Sleef_nextafterf4(values.val[0], b.values.val[0]),
Sleef_nextafterf4(values.val[1], b.values.val[1]));
},
{
__at_align__ float tmp[size()];
__at_align__ float tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (const auto i : c10::irange(size())) {
tmp[i] = std::nextafter(tmp[i], tmp_b[i]);
}
return loadu(tmp);
}
)
}
Vectorized<float> frac() const;
Vectorized<float> sin() const {
return USE_SLEEF(
Vectorized<float>(Sleef_sinf4_u10(values.val[0]), Sleef_sinf4_u10(values.val[1])),
map(std::sin)
);
}
Vectorized<float> sinh() const {
return USE_SLEEF(
Vectorized<float>(Sleef_sinhf4_u10(values.val[0]), Sleef_sinhf4_u10(values.val[1])),
map(std::sinh)
);
}
Vectorized<float> cos() const {
return USE_SLEEF(
Vectorized<float>(Sleef_cosf4_u10(values.val[0]), Sleef_cosf4_u10(values.val[1])),
map(std::cos)
);
}
Vectorized<float> cosh() const {
return USE_SLEEF(
Vectorized<float>(Sleef_coshf4_u10(values.val[0]), Sleef_coshf4_u10(values.val[1])),
map(std::cosh)
);
}
Vectorized<float> ceil() const {
return map(at::native::ceil_impl);
}
Vectorized<float> floor() const {
return map(at::native::floor_impl);
}
Vectorized<float> neg() const {
return Vectorized<float>(
vnegq_f32(values.val[0]),
vnegq_f32(values.val[1]));
}
Vectorized<float> round() const {
// We do not use std::round because we would like to round midway numbers to the nearest even integer.
return map(at::native::round_impl);
}
Vectorized<float> tan() const {
return USE_SLEEF(
Vectorized<float>(Sleef_tanf4_u10(values.val[0]), Sleef_tanf4_u10(values.val[1])),
map(std::tan)
);
}
Vectorized<float> tanh() const {
return USE_SLEEF(
Vectorized<float>(Sleef_tanhf4_u10(values.val[0]), Sleef_tanhf4_u10(values.val[1])),
map(std::tanh)
);
}
Vectorized<float> trunc() const {
float32x4_t r0 = vrndq_f32(values.val[0]);
float32x4_t r1 = vrndq_f32(values.val[1]);
return Vectorized<float>(r0, r1);
}
Vectorized<float> lgamma() const {
return USE_SLEEF(
Vectorized<float>(Sleef_lgammaf4_u10(values.val[0]), Sleef_lgammaf4_u10(values.val[1])),
map(std::lgamma)
);
}
Vectorized<float> sqrt() const {
return Vectorized<float>(
vsqrtq_f32(values.val[0]),
vsqrtq_f32(values.val[1]));
}
Vectorized<float> reciprocal() const {
auto r0 = vdivq_f32(vdupq_n_f32(1.0f), values.val[0]);
auto r1 = vdivq_f32(vdupq_n_f32(1.0f), values.val[1]);
return Vectorized<float>(r0, r1);
}
Vectorized<float> rsqrt() const {
return this->sqrt().reciprocal();
}
Vectorized<float> pow(const Vectorized<float> &exp) const {
USE_SLEEF(
{
return Vectorized<float>(Sleef_powf4_u10(values.val[0], exp.values.val[0]),
Sleef_powf4_u10(values.val[1], exp.values.val[1]));
},
{
__at_align__ float tmp[size()];
__at_align__ float tmp_exp[size()];
store(tmp);
exp.store(tmp_exp);
for (const auto i : c10::irange(size())) {
tmp[i] = std::pow(tmp[i], tmp_exp[i]);
}
return loadu(tmp);
}
)
}
Vectorized<float> operator==(const Vectorized<float>& other) const {
float32x4_t r0 =
vreinterpretq_f32_u32(vceqq_f32(values.val[0], other.values.val[0]));
float32x4_t r1 =
vreinterpretq_f32_u32(vceqq_f32(values.val[1], other.values.val[1]));
return Vectorized<float>(r0, r1);
}
Vectorized<float> operator!=(const Vectorized<float>& other) const {
float32x4_t r0 = vreinterpretq_f32_u32(
vmvnq_u32(vceqq_f32(values.val[0], other.values.val[0])));
float32x4_t r1 = vreinterpretq_f32_u32(
vmvnq_u32(vceqq_f32(values.val[1], other.values.val[1])));
return Vectorized<float>(r0, r1);
}
Vectorized<float> operator<(const Vectorized<float>& other) const {
float32x4_t r0 =
vreinterpretq_f32_u32(vcltq_f32(values.val[0], other.values.val[0]));
float32x4_t r1 =
vreinterpretq_f32_u32(vcltq_f32(values.val[1], other.values.val[1]));
return Vectorized<float>(r0, r1);
}
Vectorized<float> operator<=(const Vectorized<float>& other) const {
float32x4_t r0 =
vreinterpretq_f32_u32(vcleq_f32(values.val[0], other.values.val[0]));
float32x4_t r1 =
vreinterpretq_f32_u32(vcleq_f32(values.val[1], other.values.val[1]));
return Vectorized<float>(r0, r1);
}
Vectorized<float> operator>(const Vectorized<float>& other) const {
float32x4_t r0 =
vreinterpretq_f32_u32(vcgtq_f32(values.val[0], other.values.val[0]));
float32x4_t r1 =
vreinterpretq_f32_u32(vcgtq_f32(values.val[1], other.values.val[1]));
return Vectorized<float>(r0, r1);
}
Vectorized<float> operator>=(const Vectorized<float>& other) const {
float32x4_t r0 =
vreinterpretq_f32_u32(vcgeq_f32(values.val[0], other.values.val[0]));
float32x4_t r1 =
vreinterpretq_f32_u32(vcgeq_f32(values.val[1], other.values.val[1]));
return Vectorized<float>(r0, r1);
}
Vectorized<float> eq(const Vectorized<float>& other) const;
Vectorized<float> ne(const Vectorized<float>& other) const;
Vectorized<float> gt(const Vectorized<float>& other) const;
Vectorized<float> ge(const Vectorized<float>& other) const;
Vectorized<float> lt(const Vectorized<float>& other) const;
Vectorized<float> le(const Vectorized<float>& other) const;
};
template <>
Vectorized<float> inline operator+(const Vectorized<float>& a, const Vectorized<float>& b) {
float32x4_t r0 = vaddq_f32(a.get_low(), b.get_low());
float32x4_t r1 = vaddq_f32(a.get_high(), b.get_high());
return Vectorized<float>(r0, r1);
}
template <>
Vectorized<float> inline operator-(const Vectorized<float>& a, const Vectorized<float>& b) {
float32x4_t r0 = vsubq_f32(a.get_low(), b.get_low());
float32x4_t r1 = vsubq_f32(a.get_high(), b.get_high());
return Vectorized<float>(r0, r1);
}
template <>
Vectorized<float> inline operator*(const Vectorized<float>& a, const Vectorized<float>& b) {
float32x4_t r0 = vmulq_f32(a.get_low(), b.get_low());
float32x4_t r1 = vmulq_f32(a.get_high(), b.get_high());
return Vectorized<float>(r0, r1);
}
template <>
Vectorized<float> inline operator/(const Vectorized<float>& a, const Vectorized<float>& b) {
float32x4_t r0 = vdivq_f32(a.get_low(), b.get_low());
float32x4_t r1 = vdivq_f32(a.get_high(), b.get_high());
return Vectorized<float>(r0, r1);
}
// frac. Implement this here so we can use subtraction
inline Vectorized<float> Vectorized<float>::frac() const {
return *this - this->trunc();
}
//Added sleef Implementation for Maximum
Vectorized<float> inline maximum(const Vectorized<float>& a, const Vectorized<float>& b) {
if(!a.has_inf_nan() && !b.has_inf_nan()){
return USE_SLEEF(
Vectorized<float>(Sleef_fmaxf4(a.get_low(), b.get_low()),Sleef_fmaxf4(a.get_high(), b.get_high())),
Vectorized<float>(vmaxq_f32(a.get_low(), b.get_low()),vmaxq_f32(a.get_high(), b.get_high())));
}
else{
return Vectorized<float>(vmaxq_f32(a.get_low(), b.get_low()),vmaxq_f32(a.get_high(), b.get_high()));
}
}
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<float> inline minimum(const Vectorized<float>& a, const Vectorized<float>& b) {
float32x4_t r0 = vminq_f32(a.get_low(), b.get_low());
float32x4_t r1 = vminq_f32(a.get_high(), b.get_high());
return Vectorized<float>(r0, r1);
}
template <>
Vectorized<float> inline clamp(const Vectorized<float>& a, const Vectorized<float>& min, const Vectorized<float>& max) {
return minimum(max, maximum(min, a));
}
template <>
Vectorized<float> inline clamp_max(const Vectorized<float>& a, const Vectorized<float>& max) {
return minimum(max, a);
}
template <>
Vectorized<float> inline clamp_min(const Vectorized<float>& a, const Vectorized<float>& min) {
return maximum(min, a);
}
template <>
Vectorized<float> inline operator&(const Vectorized<float>& a, const Vectorized<float>& b) {
float32x4_t r0 = vreinterpretq_f32_u32(vandq_u32(
vreinterpretq_u32_f32(a.get_low()),
vreinterpretq_u32_f32(b.get_low())));
float32x4_t r1 = vreinterpretq_f32_u32(vandq_u32(
vreinterpretq_u32_f32(a.get_high()),
vreinterpretq_u32_f32(b.get_high())));
return Vectorized<float>(r0, r1);
}
template <>
Vectorized<float> inline operator|(const Vectorized<float>& a, const Vectorized<float>& b) {
float32x4_t r0 = vreinterpretq_f32_u32(vorrq_u32(
vreinterpretq_u32_f32(a.get_low()),
vreinterpretq_u32_f32(b.get_low())));
float32x4_t r1 = vreinterpretq_f32_u32(vorrq_u32(
vreinterpretq_u32_f32(a.get_high()),
vreinterpretq_u32_f32(b.get_high())));
return Vectorized<float>(r0, r1);
}
template <>
Vectorized<float> inline operator^(const Vectorized<float>& a, const Vectorized<float>& b) {
float32x4_t r0 = vreinterpretq_f32_u32(veorq_u32(
vreinterpretq_u32_f32(a.get_low()),
vreinterpretq_u32_f32(b.get_low())));
float32x4_t r1 = vreinterpretq_f32_u32(veorq_u32(
vreinterpretq_u32_f32(a.get_high()),
vreinterpretq_u32_f32(b.get_high())));
return Vectorized<float>(r0, r1);
}
inline Vectorized<float> Vectorized<float>::eq(const Vectorized<float>& other) const {
return (*this == other) & Vectorized<float>(1.0f);
}
inline Vectorized<float> Vectorized<float>::ne(const Vectorized<float>& other) const {
return (*this != other) & Vectorized<float>(1.0f);
}
inline Vectorized<float> Vectorized<float>::gt(const Vectorized<float>& other) const {
return (*this > other) & Vectorized<float>(1.0f);
}
inline Vectorized<float> Vectorized<float>::ge(const Vectorized<float>& other) const {
return (*this >= other) & Vectorized<float>(1.0f);
}
inline Vectorized<float> Vectorized<float>::lt(const Vectorized<float>& other) const {
return (*this < other) & Vectorized<float>(1.0f);
}
inline Vectorized<float> Vectorized<float>::le(const Vectorized<float>& other) const {
return (*this <= other) & Vectorized<float>(1.0f);
}
template <>
inline void convert(const float* src, int32_t* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) {
vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i)));
vst1q_s32(dst + i + 4, vcvtq_s32_f32(vld1q_f32(src + i + 4)));
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = static_cast<int32_t>(src[i]);
}
}
template <>
inline void convert(const int32_t* src, float* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) {
vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i)));
vst1q_f32(dst + i + 4, vcvtq_f32_s32(vld1q_s32(src + i + 4)));
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = static_cast<float>(src[i]);
}
}
template <>
Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
float32x4_t r0 = vfmaq_f32(c.get_low(), a.get_low(), b.get_low());
float32x4_t r1 = vfmaq_f32(c.get_high(), a.get_high(), b.get_high());
return Vectorized<float>(r0, r1);
}
template <>
Vectorized<float> inline fmsub(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
float32x4_t r0 = vfmsq_f32(c.get_low(), a.get_low(), b.get_low());
float32x4_t r1 = vfmsq_f32(c.get_high(), a.get_high(), b.get_high());
return Vectorized<float>(r0, r1);
}
inline Vectorized<float> Vectorized<float>::erf() const{
// constants
const Vectorized<float> neg_zero_vec(-0.f);
const Vectorized<float> one_vec(1.0f);
const Vectorized<float> p(0.3275911f);
const Vectorized<float> p1(0.254829592f);
const Vectorized<float> p2(-0.284496736f);
const Vectorized<float> p3(1.421413741f);
const Vectorized<float> p4(-1.453152027f);
const Vectorized<float> p5(1.061405429f);
// sign(x)
auto sign_mask = neg_zero_vec & *this;
auto abs_vec = this->abs();
// t = 1 / (p * abs(x) + 1)
auto tmp0 = fmadd(p, abs_vec, one_vec);
auto t = one_vec / tmp0;
// r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1
auto tmp1 = fmadd(p5, t, p4);
auto tmp2 = fmadd(tmp1, t, p3);
auto tmp3 = fmadd(tmp2, t, p2);
auto r = fmadd(tmp3, t, p1);
// - exp(- x * x)
auto pow_2 = (*this) * (*this);
auto neg_pow_2 = pow_2 ^ neg_zero_vec;
auto tmp4 = neg_pow_2.map(std::exp); // This can be swapped for a faster implementation of exp.
auto tmp5 = tmp4 ^ neg_zero_vec;
// erf(x) = sign(x) * (1 - r * t * exp(- x * x))
auto tmp6 = t * tmp5;
auto tmp7 = fmadd(tmp6, r, one_vec);
return tmp7 ^ sign_mask;
}
#endif /* defined(aarch64) */
}} // namespace at::vec::CPU_CAPABILITY

View File

@ -0,0 +1,826 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vec256_float_neon.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/Half.h>
#include <c10/util/irange.h>
namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
// Right now contains only aarch64 implementation.
// Due to follow two reasons aarch32 is not currently supported.
// 1. Due to difference in ISA been aarch32 and aarch64, intrinsics
// that work for aarch64 dont work for aarch32.
// 2. Android NDK r21 has problems with compiling aarch32.
// Clang seg faults.
// https://github.com/android/ndk/issues/1248
// https://bugs.llvm.org/show_bug.cgi?id=45824
// Most likely we will do aarch32 support with inline asm.
#if !defined(C10_MOBILE) && defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
#ifdef __BIG_ENDIAN__
#error "Big endian is not supported."
#endif
template <int index, bool mask_val>
struct BlendHalfRegs {
static float16x8_t impl(
const float16x8_t& a,
const float16x8_t& b,
float16x8_t& res);
};
template <int index>
struct BlendHalfRegs<index, true> {
static float16x8_t impl(
const float16x8_t& a,
const float16x8_t& b,
float16x8_t& res) {
return vsetq_lane_f16(vgetq_lane_f16(b, index), res, index);
}
};
template <int index>
struct BlendHalfRegs<index, false> {
static float16x8_t impl(
const float16x8_t& a,
const float16x8_t& b,
float16x8_t& res) {
return vsetq_lane_f16(vgetq_lane_f16(a, index), res, index);
}
};
// On ARM, Half type supports float16_t->Half constructor and Half->float16_t
// conversion
template <>
class Vectorized<c10::Half> {
private:
float16x8x2_t values;
public:
// value_type should be c10::Half to fit interface with vec_base.h
using value_type = c10::Half;
using size_type = int;
static constexpr size_type size() {
static_assert(sizeof(float16x8x2_t) == 16 * sizeof(value_type));
return 16;
}
private:
// We use these private map functions to implement various methods
Vectorized<c10::Half> map2(
const Vectorized<c10::Half>& second,
c10::Half (*const f)(c10::Half, c10::Half)) const {
__at_align__ c10::Half tmp_first[size()];
__at_align__ c10::Half tmp_second[size()];
store(tmp_first); // store this to tmp_first
second.store(tmp_second);
for (const auto i : c10::irange(size())) {
tmp_first[i] = f(tmp_first[i], tmp_second[i]);
}
return loadu(tmp_first);
}
Vectorized<c10::Half> map_with_vec_float_method(
Vectorized<float> (Vectorized<float>::*m)() const) const {
// Convert low float16x8_t to 2 float32x4_t variables, apply m, and convert
// back
float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values.val[0]));
float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values.val[0]));
Vectorized<float> mv0 = (Vectorized<float>(v00, v01).*m)();
float16x4_t r00 = vcvt_f16_f32(mv0.get_low());
float16x4_t r01 = vcvt_f16_f32(mv0.get_high());
// Convert high float16x8_t to 2 float32x4_t variables, apply m, and convert
// back
float32x4_t v10 = vcvt_f32_f16(vget_low_f16(values.val[1]));
float32x4_t v11 = vcvt_f32_f16(vget_high_f16(values.val[1]));
Vectorized<float> mv1 = (Vectorized<float>(v10, v11).*m)();
float16x4_t r10 = vcvt_f16_f32(mv1.get_low());
float16x4_t r11 = vcvt_f16_f32(mv1.get_high());
// Pack result into Vectorized<c10::Half>
return Vectorized<c10::Half>(
vcombine_f16(r00, r01), vcombine_f16(r10, r11));
}
Vectorized<c10::Half> map2_with_vec_float_method(
const Vectorized<c10::Half>& second,
Vectorized<float> (Vectorized<float>::*m)(const Vectorized<float>&)
const) const {
// Convert low float16x8_t to 2 float32x4_t variables, apply m, and convert
// back
float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values.val[0]));
float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values.val[0]));
float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.get_low()));
float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.get_low()));
Vectorized<float> mv0 = (Vectorized<float>(v00, v01).*m)(
Vectorized<float>(second_v00, second_v01));
float16x4_t r00 = vcvt_f16_f32(mv0.get_low());
float16x4_t r01 = vcvt_f16_f32(mv0.get_high());
// Convert high float16x8_t to 2 float32x4_t variables, apply m, and convert
// back
float32x4_t v10 = vcvt_f32_f16(vget_low_f16(values.val[1]));
float32x4_t v11 = vcvt_f32_f16(vget_high_f16(values.val[1]));
float32x4_t second_v10 = vcvt_f32_f16(vget_low_f16(second.get_high()));
float32x4_t second_v11 = vcvt_f32_f16(vget_high_f16(second.get_high()));
Vectorized<float> mv1 = (Vectorized<float>(v10, v11).*m)(
Vectorized<float>(second_v10, second_v11));
float16x4_t r10 = vcvt_f16_f32(mv1.get_low());
float16x4_t r11 = vcvt_f16_f32(mv1.get_high());
// Pack result into Vectorized<c10::Half>
return Vectorized<c10::Half>(
vcombine_f16(r00, r01), vcombine_f16(r10, r11));
}
public:
// constructor
Vectorized() {}
Vectorized(float16x8x2_t v) : values(v) {}
// A ctor that accepts c10::Half is needed to fit interface with vec_base.h
// A second constructor that takes float16_t is also included
Vectorized(c10::Half val)
: values{vdupq_n_f16((float16_t)val), vdupq_n_f16((float16_t)val)} {
}
Vectorized(float16_t val) : values{vdupq_n_f16(val), vdupq_n_f16(val)} {}
Vectorized(
float16_t val0,
float16_t val1,
float16_t val2,
float16_t val3,
float16_t val4,
float16_t val5,
float16_t val6,
float16_t val7,
float16_t val8,
float16_t val9,
float16_t val10,
float16_t val11,
float16_t val12,
float16_t val13,
float16_t val14,
float16_t val15)
: values{
val0,
val1,
val2,
val3,
val4,
val5,
val6,
val7,
val8,
val9,
val10,
val11,
val12,
val13,
val14,
val15} {}
Vectorized(float16x8_t val0, float16x8_t val1) : values{val0, val1} {}
operator float16x8x2_t() const {
return values;
}
template <int64_t mask>
static Vectorized<c10::Half> blend(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& b) {
Vectorized<c10::Half> vec;
// 0.
vec.values.val[0] = BlendHalfRegs<0, (mask & 0x01) != 0>::impl(
a.values.val[0], b.values.val[0], vec.values.val[0]);
vec.values.val[0] = BlendHalfRegs<1, (mask & 0x02) != 0>::impl(
a.values.val[0], b.values.val[0], vec.values.val[0]);
vec.values.val[0] = BlendHalfRegs<2, (mask & 0x04) != 0>::impl(
a.values.val[0], b.values.val[0], vec.values.val[0]);
vec.values.val[0] = BlendHalfRegs<3, (mask & 0x08) != 0>::impl(
a.values.val[0], b.values.val[0], vec.values.val[0]);
vec.values.val[0] = BlendHalfRegs<4, (mask & 0x10) != 0>::impl(
a.values.val[0], b.values.val[0], vec.values.val[0]);
vec.values.val[0] = BlendHalfRegs<5, (mask & 0x20) != 0>::impl(
a.values.val[0], b.values.val[0], vec.values.val[0]);
vec.values.val[0] = BlendHalfRegs<6, (mask & 0x40) != 0>::impl(
a.values.val[0], b.values.val[0], vec.values.val[0]);
vec.values.val[0] = BlendHalfRegs<7, (mask & 0x80) != 0>::impl(
a.values.val[0], b.values.val[0], vec.values.val[0]);
// 1.
vec.values.val[1] = BlendHalfRegs<0, (mask & 0x10) != 0>::impl(
a.values.val[1], b.values.val[1], vec.values.val[1]);
vec.values.val[1] = BlendHalfRegs<1, (mask & 0x20) != 0>::impl(
a.values.val[1], b.values.val[1], vec.values.val[1]);
vec.values.val[1] = BlendHalfRegs<2, (mask & 0x40) != 0>::impl(
a.values.val[1], b.values.val[1], vec.values.val[1]);
vec.values.val[1] = BlendHalfRegs<3, (mask & 0x80) != 0>::impl(
a.values.val[1], b.values.val[1], vec.values.val[1]);
vec.values.val[1] = BlendHalfRegs<4, (mask & 0x10) != 0>::impl(
a.values.val[1], b.values.val[1], vec.values.val[1]);
vec.values.val[1] = BlendHalfRegs<5, (mask & 0x20) != 0>::impl(
a.values.val[1], b.values.val[1], vec.values.val[1]);
vec.values.val[1] = BlendHalfRegs<6, (mask & 0x40) != 0>::impl(
a.values.val[1], b.values.val[1], vec.values.val[1]);
vec.values.val[1] = BlendHalfRegs<7, (mask & 0x80) != 0>::impl(
a.values.val[1], b.values.val[1], vec.values.val[1]);
return vec;
}
static Vectorized<c10::Half> blendv(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& b,
const Vectorized<c10::Half>& mask) {
// Note: using blendv is very awkward because 0xFFFF is one of many NaN's in
// FP16 It's unfortunate that the mask has type Half (required from
// vec_base)
// TODO
// NB: This requires that each value, i.e., each uint value,
// of the mask either all be zeros or all be 1s.
// We perhaps need some kind of an assert?
// But that will affect performance.
Vectorized<c10::Half> vec(mask.values);
vec.values.val[0] = vbslq_f16(
vreinterpretq_u16_f16(vec.values.val[0]),
b.values.val[0],
a.values.val[0]);
vec.values.val[1] = vbslq_f16(
vreinterpretq_u16_f16(vec.values.val[1]),
b.values.val[1],
a.values.val[1]);
return vec;
}
template <typename step_t>
static Vectorized<c10::Half> arange(
c10::Half base = 0.0,
step_t step = static_cast<step_t>(1)) {
const Vectorized<c10::Half> base_vec(base);
const Vectorized<c10::Half> step_vec(step);
const Vectorized<c10::Half> step_sizes(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
return fmadd(step_sizes, step_vec, base_vec);
}
static Vectorized<c10::Half> set(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& b,
int64_t count = size()) {
uint16_t pre_mask[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
for (int i = 0; i < count; i++) {
pre_mask[i] = 0xFFFF;
}
uint16x8x2_t mask = vld1q_u16_x2(pre_mask);
// Using blendv is awkward because 0xFFFF is one of many NaN's in FP16
// so we directly use vbslq_f16 instead
Vectorized<c10::Half> vec(
vbslq_f16(
// Low bits
mask.val[0],
b.values.val[0],
a.values.val[0]),
// High bits
vbslq_f16(mask.val[1], b.values.val[1], a.values.val[1]));
return vec;
}
static Vectorized<c10::Half> loadu(const void* ptr, int64_t count = size()) {
if (count == size()) {
return vld1q_f16_x2(reinterpret_cast<const float16_t*>(ptr));
} else if (count == (size() >> 1)) {
Vectorized<c10::Half> res;
res.values.val[0] = vld1q_f16(reinterpret_cast<const float16_t*>(ptr));
std::memset(&res.values.val[1], 0, sizeof(res.values.val[1]));
return res;
}
__at_align__ float16_t tmp_values[size()];
for (const auto i : c10::irange(size())) {
tmp_values[i] = 0;
}
std::memcpy(
tmp_values,
reinterpret_cast<const float16_t*>(ptr),
count * sizeof(float16_t));
return vld1q_f16_x2(reinterpret_cast<const float16_t*>(tmp_values));
}
void store(void* ptr, int64_t count = size()) const {
if (count == size()) {
vst1q_f16_x2(reinterpret_cast<float16_t*>(ptr), values);
return;
} else if (count == (size() >> 1)) {
vst1q_f16(reinterpret_cast<float16_t*>(ptr), values.val[0]);
} else {
float16_t tmp_values[size()];
vst1q_f16_x2(reinterpret_cast<float16_t*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(float16_t));
}
}
inline const float16x8_t& get_low() const {
return values.val[0];
}
inline float16x8_t& get_low() {
return values.val[0];
}
inline const float16x8_t& get_high() const {
return values.val[1];
}
inline float16x8_t& get_high() {
return values.val[1];
}
// Very slow implementation of indexing.
// Only required because vec256_qint refers to this.
// Once we specialize that implementation for ARM
// this should be removed. TODO (kimishpatel)
c10::Half operator[](int idx) const {
__at_align__ c10::Half tmp[size()];
store(tmp);
return tmp[idx];
}
c10::Half operator[](int idx) {
__at_align__ c10::Half tmp[size()];
store(tmp);
return tmp[idx];
}
// For boolean version where we want to if any 1/all zero
// etc. can be done faster in a different way.
int zero_mask() const {
__at_align__ c10::Half tmp[size()];
store(tmp);
int mask = 0;
for (int i = 0; i < size(); ++i) {
if (tmp[i] == 0) {
mask |= (1 << i);
}
}
return mask;
}
Vectorized<c10::Half> isnan() const {
__at_align__ c10::Half tmp[size()];
__at_align__ c10::Half res[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
if (_isnan(tmp[i])) {
std::memset(static_cast<void*>(&res[i]), 0xFF, sizeof(c10::Half));
} else {
std::memset(static_cast<void*>(&res[i]), 0, sizeof(c10::Half));
}
}
return loadu(res);
};
bool has_inf_nan() const {
__at_align__ c10::Half tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
if (_isnan(tmp[i]) || _isinf(tmp[i])) {
return true;
}
}
return false;
}
Vectorized<c10::Half> map(c10::Half (*const f)(c10::Half)) const {
__at_align__ c10::Half tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
Vectorized<c10::Half> abs() const {
return Vectorized<c10::Half>(
vabsq_f16(values.val[0]), vabsq_f16(values.val[1]));
}
Vectorized<c10::Half> angle() const {
auto zero = Vectorized<c10::Half>(0);
auto pi = Vectorized<c10::Half>(c10::pi<c10::Half>);
auto tmp = blendv(zero, pi, *this < zero);
return blendv(tmp, *this, isnan());
}
Vectorized<c10::Half> real() const {
return *this;
}
Vectorized<c10::Half> imag() const {
return Vectorized<c10::Half>(0);
}
Vectorized<c10::Half> conj() const {
return *this;
}
// Sleef does not support FP16, so many math functions are applied by
// converting to FP32, applying the math function, and then converting back to
// FP16.
Vectorized<c10::Half> acos() const {
return map_with_vec_float_method(&Vectorized<float>::acos);
}
Vectorized<c10::Half> acosh() const {
return map_with_vec_float_method(&Vectorized<float>::acosh);
}
Vectorized<c10::Half> asin() const {
return map_with_vec_float_method(&Vectorized<float>::asin);
}
Vectorized<c10::Half> atan() const {
return map_with_vec_float_method(&Vectorized<float>::atan);
}
Vectorized<c10::Half> atanh() const {
return map_with_vec_float_method(&Vectorized<float>::atanh);
}
Vectorized<c10::Half> atan2(const Vectorized<c10::Half>& exp) const {
return map2_with_vec_float_method(exp, &Vectorized<float>::atan2);
}
Vectorized<c10::Half> copysign(const Vectorized<c10::Half>& sign) const {
return map2_with_vec_float_method(sign, &Vectorized<float>::copysign);
}
Vectorized<c10::Half> erf() const {
return map_with_vec_float_method(&Vectorized<float>::erf);
}
Vectorized<c10::Half> erfc() const {
return map_with_vec_float_method(&Vectorized<float>::erfc);
}
Vectorized<c10::Half> erfinv() const {
return map_with_vec_float_method(&Vectorized<float>::erfinv);
}
Vectorized<c10::Half> exp() const {
return map_with_vec_float_method(&Vectorized<float>::exp);
}
Vectorized<c10::Half> exp2() const {
return map_with_vec_float_method(&Vectorized<float>::exp2);
}
Vectorized<c10::Half> expm1() const {
return map_with_vec_float_method(&Vectorized<float>::expm1);
}
Vectorized<c10::Half> exp_u20() const {
return map_with_vec_float_method(&Vectorized<float>::exp_u20);
}
Vectorized<c10::Half> fmod(const Vectorized<c10::Half>& q) const {
// This function is questionable with a conversion, so we use map2
return map2(q, std::fmod);
}
Vectorized<c10::Half> hypot(const Vectorized<c10::Half>& b) const {
return map2_with_vec_float_method(b, &Vectorized<float>::hypot);
}
Vectorized<c10::Half> i0() const {
return map_with_vec_float_method(&Vectorized<float>::i0);
}
Vectorized<c10::Half> i0e() const {
return map_with_vec_float_method(&Vectorized<float>::i0e);
}
Vectorized<c10::Half> digamma() const {
return map_with_vec_float_method(&Vectorized<float>::digamma);
}
Vectorized<c10::Half> igamma(const Vectorized<c10::Half>& x) const {
return map2_with_vec_float_method(x, &Vectorized<float>::igamma);
}
Vectorized<c10::Half> igammac(const Vectorized<c10::Half>& x) const {
return map2_with_vec_float_method(x, &Vectorized<float>::igammac);
}
Vectorized<c10::Half> log() const {
return map_with_vec_float_method(&Vectorized<float>::log);
}
Vectorized<c10::Half> log10() const {
return map_with_vec_float_method(&Vectorized<float>::log10);
}
Vectorized<c10::Half> log1p() const {
return map_with_vec_float_method(&Vectorized<float>::log1p);
}
Vectorized<c10::Half> log2() const {
return map_with_vec_float_method(&Vectorized<float>::log2);
}
Vectorized<c10::Half> nextafter(const Vectorized<c10::Half>& b) const {
// This function does not make sense with conversion, so we use map2
return map2(b, std::nextafter);
}
Vectorized<c10::Half> frac() const;
Vectorized<c10::Half> sin() const {
return map_with_vec_float_method(&Vectorized<float>::sin);
}
Vectorized<c10::Half> sinh() const {
return map_with_vec_float_method(&Vectorized<float>::sinh);
}
Vectorized<c10::Half> cos() const {
return map_with_vec_float_method(&Vectorized<float>::cos);
}
Vectorized<c10::Half> cosh() const {
return map_with_vec_float_method(&Vectorized<float>::cosh);
}
Vectorized<c10::Half> ceil() const {
// This function is questionable with a conversion, so we use map
return map(at::native::ceil_impl);
}
Vectorized<c10::Half> floor() const {
// This function is questionable with a conversion, so we use map
return map(at::native::floor_impl);
}
Vectorized<c10::Half> neg() const {
return Vectorized<c10::Half>(
vnegq_f16(values.val[0]), vnegq_f16(values.val[1]));
}
inline Vectorized<c10::Half> round() const {
// This function is questionable with a conversion, so we use map
return map(at::native::round_impl);
}
inline Vectorized<c10::Half> tan() const {
return map_with_vec_float_method(&Vectorized<float>::tan);
}
inline Vectorized<c10::Half> tanh() const {
return map_with_vec_float_method(&Vectorized<float>::tanh);
}
Vectorized<c10::Half> trunc() const {
float16x8_t r0 = vrndq_f16(values.val[0]);
float16x8_t r1 = vrndq_f16(values.val[1]);
return Vectorized<c10::Half>(r0, r1);
}
Vectorized<c10::Half> lgamma() const {
return map_with_vec_float_method(&Vectorized<float>::lgamma);
}
Vectorized<c10::Half> sqrt() const {
return Vectorized<c10::Half>(
vsqrtq_f16(values.val[0]), vsqrtq_f16(values.val[1]));
}
Vectorized<c10::Half> reciprocal() const {
auto ones = vdupq_n_f16(1.0f);
auto r0 = vdivq_f16(ones, values.val[0]);
auto r1 = vdivq_f16(ones, values.val[1]);
return Vectorized<c10::Half>(r0, r1);
}
Vectorized<c10::Half> rsqrt() const {
return this->sqrt().reciprocal();
}
Vectorized<c10::Half> pow(const Vectorized<c10::Half>& exp) const {
return map2_with_vec_float_method(exp, &Vectorized<float>::pow);
}
Vectorized<c10::Half> operator==(const Vectorized<c10::Half>& other) const {
float16x8_t r0 =
vreinterpretq_f16_u16(vceqq_f16(values.val[0], other.values.val[0]));
float16x8_t r1 =
vreinterpretq_f16_u16(vceqq_f16(values.val[1], other.values.val[1]));
return Vectorized<c10::Half>(r0, r1);
}
Vectorized<c10::Half> operator!=(const Vectorized<c10::Half>& other) const {
float16x8_t r0 = vreinterpretq_f16_u16(
vmvnq_u16(vceqq_f16(values.val[0], other.values.val[0])));
float16x8_t r1 = vreinterpretq_f16_u16(
vmvnq_u16(vceqq_f16(values.val[1], other.values.val[1])));
return Vectorized<c10::Half>(r0, r1);
}
Vectorized<c10::Half> operator<(const Vectorized<c10::Half>& other) const {
float16x8_t r0 =
vreinterpretq_f16_u16(vcltq_f16(values.val[0], other.values.val[0]));
float16x8_t r1 =
vreinterpretq_f16_u16(vcltq_f16(values.val[1], other.values.val[1]));
return Vectorized<c10::Half>(r0, r1);
}
Vectorized<c10::Half> operator<=(const Vectorized<c10::Half>& other) const {
float16x8_t r0 =
vreinterpretq_f16_u16(vcleq_f16(values.val[0], other.values.val[0]));
float16x8_t r1 =
vreinterpretq_f16_u16(vcleq_f16(values.val[1], other.values.val[1]));
return Vectorized<c10::Half>(r0, r1);
}
Vectorized<c10::Half> operator>(const Vectorized<c10::Half>& other) const {
float16x8_t r0 =
vreinterpretq_f16_u16(vcgtq_f16(values.val[0], other.values.val[0]));
float16x8_t r1 =
vreinterpretq_f16_u16(vcgtq_f16(values.val[1], other.values.val[1]));
return Vectorized<c10::Half>(r0, r1);
}
Vectorized<c10::Half> operator>=(const Vectorized<c10::Half>& other) const {
float16x8_t r0 =
vreinterpretq_f16_u16(vcgeq_f16(values.val[0], other.values.val[0]));
float16x8_t r1 =
vreinterpretq_f16_u16(vcgeq_f16(values.val[1], other.values.val[1]));
return Vectorized<c10::Half>(r0, r1);
}
Vectorized<c10::Half> eq(const Vectorized<c10::Half>& other) const;
Vectorized<c10::Half> ne(const Vectorized<c10::Half>& other) const;
Vectorized<c10::Half> gt(const Vectorized<c10::Half>& other) const;
Vectorized<c10::Half> ge(const Vectorized<c10::Half>& other) const;
Vectorized<c10::Half> lt(const Vectorized<c10::Half>& other) const;
Vectorized<c10::Half> le(const Vectorized<c10::Half>& other) const;
}; // Vectorized<Half>
template <>
Vectorized<c10::Half> inline operator+(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& b) {
float16x8_t r0 = vaddq_f16(a.get_low(), b.get_low());
float16x8_t r1 = vaddq_f16(a.get_high(), b.get_high());
return Vectorized<c10::Half>(r0, r1);
}
template <>
Vectorized<c10::Half> inline operator-(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& b) {
float16x8_t r0 = vsubq_f16(a.get_low(), b.get_low());
float16x8_t r1 = vsubq_f16(a.get_high(), b.get_high());
return Vectorized<c10::Half>(r0, r1);
}
template <>
Vectorized<c10::Half> inline operator*(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& b) {
float16x8_t r0 = vmulq_f16(a.get_low(), b.get_low());
float16x8_t r1 = vmulq_f16(a.get_high(), b.get_high());
return Vectorized<c10::Half>(r0, r1);
}
template <>
Vectorized<c10::Half> inline operator/(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& b) {
float16x8_t r0 = vdivq_f16(a.get_low(), b.get_low());
float16x8_t r1 = vdivq_f16(a.get_high(), b.get_high());
return Vectorized<c10::Half>(r0, r1);
}
// frac. Implement this here so we can use subtraction
inline Vectorized<c10::Half> Vectorized<c10::Half>::frac() const {
return *this - this->trunc();
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<c10::Half> inline maximum(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& b) {
float16x8_t r0 = vmaxq_f16(a.get_low(), b.get_low());
float16x8_t r1 = vmaxq_f16(a.get_high(), b.get_high());
return Vectorized<c10::Half>(r0, r1);
}
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<c10::Half> inline minimum(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& b) {
float16x8_t r0 = vminq_f16(a.get_low(), b.get_low());
float16x8_t r1 = vminq_f16(a.get_high(), b.get_high());
return Vectorized<c10::Half>(r0, r1);
}
template <>
Vectorized<c10::Half> inline clamp(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& min,
const Vectorized<c10::Half>& max) {
return minimum(max, maximum(min, a));
}
template <>
Vectorized<c10::Half> inline clamp_max(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& max) {
return minimum(max, a);
}
template <>
Vectorized<c10::Half> inline clamp_min(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& min) {
return maximum(min, a);
}
template <>
Vectorized<c10::Half> inline operator&(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& b) {
float16x8_t r0 = vreinterpretq_f16_u16(vandq_u16(
vreinterpretq_u16_f16(a.get_low()), vreinterpretq_u16_f16(b.get_low())));
float16x8_t r1 = vreinterpretq_f16_u16(vandq_u16(
vreinterpretq_u16_f16(a.get_high()),
vreinterpretq_u16_f16(b.get_high())));
return Vectorized<c10::Half>(r0, r1);
}
template <>
Vectorized<c10::Half> inline operator|(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& b) {
float16x8_t r0 = vreinterpretq_f16_u16(vorrq_u16(
vreinterpretq_u16_f16(a.get_low()), vreinterpretq_u16_f16(b.get_low())));
float16x8_t r1 = vreinterpretq_f16_u16(vorrq_u16(
vreinterpretq_u16_f16(a.get_high()),
vreinterpretq_u16_f16(b.get_high())));
return Vectorized<c10::Half>(r0, r1);
}
template <>
Vectorized<c10::Half> inline operator^(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& b) {
float16x8_t r0 = vreinterpretq_f16_u16(veorq_u16(
vreinterpretq_u16_f16(a.get_low()), vreinterpretq_u16_f16(b.get_low())));
float16x8_t r1 = vreinterpretq_f16_u16(veorq_u16(
vreinterpretq_u16_f16(a.get_high()),
vreinterpretq_u16_f16(b.get_high())));
return Vectorized<c10::Half>(r0, r1);
}
inline Vectorized<c10::Half> Vectorized<c10::Half>::eq(
const Vectorized<c10::Half>& other) const {
return (*this == other) & Vectorized<c10::Half>(1);
}
inline Vectorized<c10::Half> Vectorized<c10::Half>::ne(
const Vectorized<c10::Half>& other) const {
return (*this != other) & Vectorized<c10::Half>(1);
}
inline Vectorized<c10::Half> Vectorized<c10::Half>::gt(
const Vectorized<c10::Half>& other) const {
return (*this > other) & Vectorized<c10::Half>(1);
}
inline Vectorized<c10::Half> Vectorized<c10::Half>::ge(
const Vectorized<c10::Half>& other) const {
return (*this >= other) & Vectorized<c10::Half>(1);
}
inline Vectorized<c10::Half> Vectorized<c10::Half>::lt(
const Vectorized<c10::Half>& other) const {
return (*this < other) & Vectorized<c10::Half>(1);
}
inline Vectorized<c10::Half> Vectorized<c10::Half>::le(
const Vectorized<c10::Half>& other) const {
return (*this <= other) & Vectorized<c10::Half>(1);
}
template <>
inline void convert(const float16_t* src, int16_t* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<c10::Half>::size());
i += Vectorized<c10::Half>::size()) {
vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i)));
vst1q_s16(dst + i + 8, vcvtq_s16_f16(vld1q_f16(src + i + 8)));
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = static_cast<int16_t>(src[i]);
}
}
template <>
inline void convert(const int16_t* src, float16_t* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<c10::Half>::size());
i += Vectorized<c10::Half>::size()) {
vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i)));
vst1q_f16(dst + i + 8, vcvtq_f16_s16(vld1q_s16(src + i + 8)));
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = static_cast<float16_t>(src[i]);
}
}
template <>
Vectorized<c10::Half> inline fmadd(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& b,
const Vectorized<c10::Half>& c) {
float16x8_t r0 = vfmaq_f16(c.get_low(), a.get_low(), b.get_low());
float16x8_t r1 = vfmaq_f16(c.get_high(), a.get_high(), b.get_high());
return Vectorized<c10::Half>(r0, r1);
}
template <>
Vectorized<c10::Half> inline fmsub(
const Vectorized<c10::Half>& a,
const Vectorized<c10::Half>& b,
const Vectorized<c10::Half>& c) {
float16x8_t r0 = vfmsq_f16(c.get_low(), a.get_low(), b.get_low());
float16x8_t r1 = vfmsq_f16(c.get_high(), a.get_high(), b.get_high());
return Vectorized<c10::Half>(r0, r1);
}
#endif /* defined(aarch64) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(C10_MOBILE) */
} // namespace CPU_CAPABILITY
} // namespace at::vec

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,298 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec_mask.h>
namespace at::vec {
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
template <typename T, int dst_n, typename mask_t, int mask_n>
struct VecMaskLoad<
T,
dst_n,
mask_t,
mask_n,
typename std::enable_if_t<
(mask_n == dst_n * 2 && dst_n >= 1) &&
(std::is_same_v<T, float> || std::is_same_v<T, int32_t>),
void>> {
static inline VectorizedN<T, dst_n> apply(
const T* ptr,
const VecMask<mask_t, mask_n>& vec_mask) {
VectorizedN<mask_t, 2> tmp_vec;
VectorizedN<T, dst_n> result;
for (int i = 0; i < dst_n; i++) {
tmp_vec[0] = vec_mask[2 * i];
tmp_vec[1] = vec_mask[2 * i + 1];
auto int64_mask = VecMask<mask_t, 2>(tmp_vec).template cast<int64_t, 2>();
auto int_mask = int64_mask.template cast<int, 1>()[0];
if constexpr (std::is_same_v<T, float>) {
result[i] = Vectorized<T>(
_mm256_maskload_ps(ptr + i * Vectorized<T>::size(), int_mask));
} else {
result[i] = Vectorized<T>(
_mm256_maskload_epi32(ptr + i * Vectorized<T>::size(), int_mask));
}
}
return result;
}
};
template <typename T, int dst_n, typename mask_t>
struct VecMaskLoad<
T,
dst_n,
mask_t,
dst_n,
typename std::enable_if_t<
std::is_same_v<T, float> || std::is_same_v<T, int32_t>,
void>> {
static inline VectorizedN<T, dst_n> apply(
const T* ptr,
const VecMask<mask_t, dst_n>& vec_mask) {
VectorizedN<T, dst_n> result;
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < dst_n; i++) {
auto tmp_mask = VecMask<mask_t, 1>(vec_mask[i]);
auto int_mask = tmp_mask.template cast<int, 1>()[0];
if constexpr (std::is_same_v<T, float>) {
result[i] = Vectorized<T>(
_mm256_maskload_ps(ptr + i * Vectorized<T>::size(), int_mask));
} else {
result[i] = Vectorized<T>(
_mm256_maskload_epi32(ptr + i * Vectorized<T>::size(), int_mask));
}
}
return result;
}
};
template <typename T, typename mask_t>
struct VecMaskLoad<
T,
2,
mask_t,
1,
typename std::enable_if_t<
std::is_same_v<T, int64_t> || std::is_same_v<T, double>>> {
static inline VectorizedN<T, 2> apply(
const T* ptr,
const VecMask<mask_t, 1>& vec_mask) {
auto int64_mask = vec_mask.template cast<int64_t, 2>();
auto result = at::vec::VectorizedN<T, 2>();
if constexpr (std::is_same_v<T, double>) {
result[0] = _mm256_maskload_pd(ptr, int64_mask[0]);
result[1] = _mm256_maskload_pd(
ptr + at::vec::Vectorized<T>::size(), int64_mask[1]);
} else {
result[0] = _mm256_maskload_epi64(
reinterpret_cast<const long long*>(ptr), int64_mask[0]);
result[1] = _mm256_maskload_epi64(
reinterpret_cast<const long long*>(
ptr + at::vec::Vectorized<T>::size()),
int64_mask[1]);
}
return result;
}
};
// TODO: add specialization of VecMaskLoad for bfloat16/half and int8/uint8
template <int N>
struct VecMaskCast<float, N, int, N> {
static inline VecMask<float, N> apply(const VecMask<int, N>& vec_mask) {
VectorizedN<float, N> result;
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
result[i] = _mm256_castsi256_ps(vec_mask[i]);
}
return result;
}
};
template <int N>
struct VecMaskCast<int, N, float, N> {
static inline VecMask<int, N> apply(const VecMask<float, N>& vec_mask) {
VectorizedN<int, N> result;
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
result[i] = _mm256_castps_si256(vec_mask[i]);
}
return result;
}
};
template <int N>
struct VecMaskCast<int64_t, N, double, N> {
static inline VecMask<int64_t, N> apply(const VecMask<double, N>& vec_mask) {
VectorizedN<int64_t, N> result;
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
result[i] = _mm256_castpd_si256(vec_mask[i]);
}
return result;
}
};
template <int N>
struct VecMaskCast<double, N, int64_t, N> {
static inline VecMask<double, N> apply(const VecMask<int64_t, N>& vec_mask) {
VectorizedN<double, N> result;
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
result[i] = _mm256_castsi256_pd(vec_mask[i]);
}
return result;
}
};
template <int dst_n, typename mask_t, int mask_n>
struct VecMaskCast<
int64_t,
dst_n,
mask_t,
mask_n,
typename std::enable_if_t<
(dst_n == 2 * mask_n) &&
(std::is_same_v<mask_t, float> || std::is_same_v<mask_t, int>),
void>> {
static inline VecMask<int64_t, dst_n> apply(
const VecMask<mask_t, mask_n>& vec_mask) {
VectorizedN<int64_t, dst_n> result;
auto int_mask = vec_mask.template cast<int, mask_n>();
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < mask_n; ++i) {
auto int64_vec =
convert<int64_t, 2, int, 1>(VectorizedN<int, 1>(int_mask[i]));
result[2 * i] = int64_vec[0];
result[2 * i + 1] = int64_vec[1];
}
return VecMask<int64_t, dst_n>(result);
}
};
template <typename dst_t, int dst_n, int mask_n>
struct VecMaskCast<
dst_t,
dst_n,
int64_t,
mask_n,
typename std::enable_if_t<
(mask_n == 2 * dst_n) &&
(std::is_same_v<dst_t, float> || std::is_same_v<dst_t, int>),
void>> {
static inline VecMask<dst_t, dst_n> apply(
const VecMask<int64_t, mask_n>& vec_mask) {
VectorizedN<int, dst_n> result;
VectorizedN<int64_t, 2> int64_vec;
for (int i = 0; i < dst_n; ++i) {
int64_vec[0] = vec_mask[2 * i];
int64_vec[1] = vec_mask[2 * i + 1];
result[i] = convert<int, 1, int64_t, 2>(int64_vec);
}
return VecMask<int, dst_n>(result).template cast<dst_t, dst_n>();
}
};
template <>
struct VecMaskCast<double, 2, float, 1> {
static inline VecMask<double, 2> apply(const VecMask<float, 1>& vec_mask) {
auto int64_mask = VecMaskCast<int64_t, 2, float, 1>::apply(vec_mask);
return VecMaskCast<double, 2, int64_t, 2>::apply(int64_mask);
}
};
template <>
struct VecMaskCast<float, 1, double, 2> {
static inline VecMask<float, 1> apply(const VecMask<double, 2>& vec_mask) {
auto int64_mask = VecMaskCast<int64_t, 2, double, 2>::apply(vec_mask);
return VecMaskCast<float, 1, int64_t, 2>::apply(int64_mask);
}
};
template <>
inline bool VecMask<int, 1>::all_zero() const {
return _mm256_testz_si256(mask_[0], mask_[0]);
}
template <>
inline bool VecMask<int, 1>::is_masked(int i) const {
return _mm256_movemask_ps(_mm256_castsi256_ps(mask_[0])) & (1 << i);
}
template <>
inline bool VecMask<int, 1>::all_masked() const {
int mask = _mm256_movemask_ps(_mm256_castsi256_ps(mask_[0]));
return mask == 0xff;
}
template <int N>
struct VecMaskCheck<int64_t, N> {
static inline bool all_zero(const VectorizedN<int64_t, N>& vec_mask) {
bool all_zero = true;
for (int i = 0; i < N; ++i) {
all_zero = all_zero && (_mm256_testz_si256(vec_mask[i], vec_mask[i]) > 0);
if (!all_zero) {
return all_zero;
}
}
return all_zero;
}
static inline bool is_masked(const VectorizedN<int64_t, N>& vec_mask, int i) {
for (int j = 0; j < N; ++j) {
if (i < (j + 1) * 4) {
return _mm256_movemask_pd(_mm256_castsi256_pd(vec_mask[j])) &
(1 << (i - j * 4));
}
}
return false;
}
static inline bool all_masked(const VectorizedN<int64_t, N>& vec_mask) {
bool all_masked = true;
for (int i = 0; i < N; ++i) {
all_masked = all_masked &&
(_mm256_movemask_pd(_mm256_castsi256_pd(vec_mask[i])) == 0x0f);
if (!all_masked) {
return all_masked;
}
}
return all_masked;
}
};
#define VEC_MASK_METHOD_WITH_CAST_TO_INT( \
T, N, return_type, method, args_def, args) \
template <> \
inline return_type VecMask<T, N>::method args_def const { \
return cast<int, 1>().method args; \
}
VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_zero, (), ())
VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_zero, (), ())
VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, is_masked, (int i), (i))
VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, is_masked, (int i), (i))
VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_masked, (), ())
VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_masked, (), ())
#undef VEC_MASK_DEFINE_METHOD_WITH_CAST_TO_INT
#endif
} // namespace CPU_CAPABILITY
} // namespace at::vec

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,73 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/irange.h>
namespace at {
namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(
const Vectorized<BFloat16>& a) {
constexpr int64_t K = Vectorized<BFloat16>::size();
__at_align__ float arr[K];
__at_align__ BFloat16 arr2[K];
a.store(arr2);
convert(arr2, arr, K);
return std::make_tuple(
Vectorized<float>::loadu(arr),
Vectorized<float>::loadu(arr + Vectorized<float>::size()));
}
inline Vectorized<BFloat16> convert_float_bfloat16(
const Vectorized<float>& a,
const Vectorized<float>& b) {
constexpr int64_t K = Vectorized<BFloat16>::size();
__at_align__ float arr[K];
__at_align__ BFloat16 arr2[K];
a.store(arr);
b.store(arr + Vectorized<float>::size());
convert(arr, arr2, K);
return Vectorized<BFloat16>::loadu(arr2);
}
inline void load_fp32_from_bf16(const c10::BFloat16* data, Vectorized<float>& out) {
__at_align__ float values[Vectorized<float>::size()];
for (const auto k : c10::irange(Vectorized<float>::size())) {
values[k] = data[k];
}
out = Vectorized<float>::loadu(values);
}
inline void load_fp32_from_bf16(
const c10::BFloat16* data,
Vectorized<float>& out1,
Vectorized<float>& out2) {
load_fp32_from_bf16(data, out1);
data += Vectorized<float>::size();
load_fp32_from_bf16(data, out2);
}
inline void load_fp32_from_fp16(const c10::Half* data, Vectorized<float>& out) {
__at_align__ float values[Vectorized<float>::size()];
for (const auto k : c10::irange(Vectorized<float>::size())) {
values[k] = data[k];
}
out = Vectorized<float>::loadu(values);
}
inline void load_fp32_from_fp16(
const c10::Half* data,
Vectorized<float>& out1,
Vectorized<float>& out2) {
load_fp32_from_fp16(data, out1);
data += Vectorized<float>::size();
load_fp32_from_fp16(data, out2);
}
} // namespace
} // namespace vec
} // namespace at

View File

@ -0,0 +1,246 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
// Note: header order is important here
#include <ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h>
#include <ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h>
namespace at {
namespace vec {
inline namespace CPU_CAPABILITY {
DEFINE_CLAMP_FUNCS(c10::quint8)
DEFINE_CLAMP_FUNCS(c10::qint8)
DEFINE_CLAMP_FUNCS(c10::qint32)
DEFINE_CLAMP_FUNCS(int16_t)
DEFINE_CLAMP_FUNCS(int32_t)
DEFINE_CLAMP_FUNCS(int64_t)
DEFINE_CLAMP_FUNCS(float)
DEFINE_CLAMP_FUNCS(double)
template <>
Vectorized<double> C10_ALWAYS_INLINE fmadd(
const Vectorized<double>& a,
const Vectorized<double>& b,
const Vectorized<double>& c) {
return Vectorized<double>{
vec_madd(a.vec0(), b.vec0(), c.vec0()),
vec_madd(a.vec1(), b.vec1(), c.vec1())};
}
template <>
Vectorized<int64_t> C10_ALWAYS_INLINE fmadd(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b,
const Vectorized<int64_t>& c) {
return Vectorized<int64_t>{
a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
}
template <>
Vectorized<int32_t> C10_ALWAYS_INLINE fmadd(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b,
const Vectorized<int32_t>& c) {
return Vectorized<int32_t>{
a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
}
template <>
Vectorized<int16_t> C10_ALWAYS_INLINE fmadd(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b,
const Vectorized<int16_t>& c) {
return Vectorized<int16_t>{
a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
}
DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(float)
DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double)
DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int64_t)
DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int32_t)
DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int16_t)
template <>
Vectorized<int64_t> C10_ALWAYS_INLINE
convert_to_int_of_same_size<double>(const Vectorized<double>& src) {
return Vectorized<int64_t>{vec_signed(src.vec0()), vec_signed(src.vec1())};
}
template <>
Vectorized<int32_t> C10_ALWAYS_INLINE
convert_to_int_of_same_size<float>(
const Vectorized<float>& src) {
return Vectorized<int32_t>{vec_signed(src.vec0()), vec_signed(src.vec1())};
}
template <>
inline void convert(const int32_t* src, float* dst, int64_t n) {
// int32_t and float have same size
int64_t i;
for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) {
const int32_t* src_a = src + i;
float* dst_a = dst + i;
vint32 input_vec0 = vec_vsx_ld(offset0, reinterpret_cast<const vint32*>(src_a));
vint32 input_vec1 =
vec_vsx_ld(offset16, reinterpret_cast<const vint32*>(src_a));
vfloat32 c0 = vec_float(input_vec0);
vfloat32 c1 = vec_float(input_vec1);
vec_vsx_st(c0, offset0, dst_a);
vec_vsx_st(c1, offset16, dst_a);
}
for (; i < n; i++) {
dst[i] = static_cast<float>(src[i]);
}
}
template <>
inline void convert(const int64_t* src, double* dst, int64_t n) {
int64_t i;
for (i = 0; i <= (n - Vectorized<double>::size()); i += Vectorized<double>::size()) {
const int64_t* src_a = src + i;
double* dst_a = dst + i;
vint64 input_vec0 =
vec_vsx_ld(offset0, reinterpret_cast<const vint64*>(src_a));
vint64 input_vec1 =
vec_vsx_ld(offset16, reinterpret_cast<const vint64*>(src_a));
vfloat64 c0 = vec_double(input_vec0);
vfloat64 c1 = vec_double(input_vec1);
vec_vsx_st(c0, offset0, reinterpret_cast<double*>(dst_a));
vec_vsx_st(c1, offset16, reinterpret_cast<double*>(dst_a));
}
for (; i < n; i++) {
dst[i] = static_cast<double>(src[i]);
}
}
//Generic implementation to fix compiler error
//TO-DO : Add optimized version for ppc64
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_half_float(
const Vectorized<Half>& a) {
constexpr int64_t K = Vectorized<Half>::size();
__at_align__ float arr[K];
__at_align__ Half arr2[K];
a.store(arr2);
convert(arr2, arr, K);
return std::make_tuple(
Vectorized<float>::loadu(arr),
Vectorized<float>::loadu(arr + Vectorized<float>::size()));
}
inline Vectorized<Half> convert_float_half(
const Vectorized<float>& a, const Vectorized<float>& b) {
constexpr int64_t K = Vectorized<Half>::size();
__at_align__ float arr[K];
__at_align__ Half arr2[K];
a.store(arr);
b.store(arr + Vectorized<float>::size());
convert(arr, arr2, K);
return Vectorized<Half>::loadu(arr2);
};
template <>
std::pair<Vectorized<double>, Vectorized<double>> inline interleave2<double>(
const Vectorized<double>& a,
const Vectorized<double>& b) {
// inputs:
// a = {a0, a1, a2, a3}
// b = {b0, b1, b2, b3}
vfloat64 ab00 = vec_xxpermdi(a.vec0(), b.vec0(), 0);
vfloat64 ab11 = vec_xxpermdi(a.vec0(), b.vec0(), 3);
vfloat64 ab2_00 = vec_xxpermdi(a.vec1(), b.vec1(), 0);
vfloat64 ab2_11 = vec_xxpermdi(a.vec1(), b.vec1(), 3);
// return {a0, b0, a1, b1}
// {a2, b2, a3, b3}
return std::make_pair(
Vectorized<double>{ab00, ab11}, Vectorized<double>{ab2_00, ab2_11});
}
template <>
std::pair<Vectorized<double>, Vectorized<double>> inline deinterleave2<double>(
const Vectorized<double>& a,
const Vectorized<double>& b) {
// inputs:
// a = {a0, b0, a1, b1}
// b = {a2, b2, a3, b3}
vfloat64 aa01 = vec_xxpermdi(a.vec0(), a.vec1(), 0);
vfloat64 aa23 = vec_xxpermdi(b.vec0(), b.vec1(), 0);
vfloat64 bb_01 = vec_xxpermdi(a.vec0(), a.vec1(), 3);
vfloat64 bb_23 = vec_xxpermdi(b.vec0(), b.vec1(), 3);
// swap lanes:
// return {a0, a1, a2, a3}
// {b0, b1, b2, b3}
return std::make_pair(
Vectorized<double>{aa01, aa23}, Vectorized<double>{bb_01, bb_23});
}
template <>
std::pair<Vectorized<float>, Vectorized<float>> inline interleave2<float>(
const Vectorized<float>& a,
const Vectorized<float>& b) {
// inputs:
// a = {a0, a1, a2, a3,, a4, a5, a6, a7}
// b = {b0, b1, b2, b3,, b4, b5, b6, b7}
vfloat32 ab0011 = vec_mergeh(a.vec0(), b.vec0());
vfloat32 ab2233 = vec_mergel(a.vec0(), b.vec0());
vfloat32 ab2_0011 = vec_mergeh(a.vec1(), b.vec1());
vfloat32 ab2_2233 = vec_mergel(a.vec1(), b.vec1());
// group cols crossing lanes:
// return {a0, b0, a1, b1,, a2, b2, a3, b3}
// {a4, b4, a5, b5,, a6, b6, a7, b7}
return std::make_pair(
Vectorized<float>{ab0011, ab2233}, Vectorized<float>{ab2_0011, ab2_2233});
}
template <>
std::pair<Vectorized<float>, Vectorized<float>> inline deinterleave2<float>(
const Vectorized<float>& a,
const Vectorized<float>& b) {
// inputs:
// a = {a0, b0, a1, b1,, a2, b2, a3, b3}
// b = {a4, b4, a5, b5,, a6, b6, a7, b7}
// {a0,a2,b0,b2} {a1,a3,b1,b3}
vfloat32 a0a2b0b2 = vec_mergeh(a.vec0(), a.vec1());
vfloat32 a1a3b1b3 = vec_mergel(a.vec0(), a.vec1());
vfloat32 aa0123 = vec_mergeh(a0a2b0b2, a1a3b1b3);
vfloat32 bb0123 = vec_mergel(a0a2b0b2, a1a3b1b3);
vfloat32 a0a2b0b2_2 = vec_mergeh(b.vec0(), b.vec1());
vfloat32 a1a3b1b3_2 = vec_mergel(b.vec0(), b.vec1());
vfloat32 aa0123_2 = vec_mergeh(a0a2b0b2_2, a1a3b1b3_2);
vfloat32 bb0123_2 = vec_mergel(a0a2b0b2_2, a1a3b1b3_2);
// it could be done with vec_perm ,too
// swap lanes:
// return {a0, a1, a2, a3,, a4, a5, a6, a7}
// {b0, b1, b2, b3,, b4, b5, b6, b7}
return std::make_pair(
Vectorized<float>{aa0123, aa0123_2}, Vectorized<float>{bb0123, bb0123_2});
}
} // namespace
} // namespace vec
} // namespace at

View File

@ -0,0 +1,584 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
#include <c10/util/complex.h>
#include <c10/util/irange.h>
namespace at {
namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
using ComplexDbl = c10::complex<double>;
template <>
class Vectorized<ComplexDbl> {
union {
struct {
vfloat64 _vec0;
vfloat64 _vec1;
};
struct {
vbool64 _vecb0;
vbool64 _vecb1;
};
} __attribute__((__may_alias__));
public:
using value_type = ComplexDbl;
using vec_internal_type = vfloat64;
using vec_internal_mask_type = vbool64;
using size_type = int;
static constexpr size_type size() {
return 2;
}
Vectorized() {}
C10_ALWAYS_INLINE Vectorized(vfloat64 v) : _vec0{v}, _vec1{v} {}
C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
C10_ALWAYS_INLINE Vectorized(vfloat64 v1, vfloat64 v2) : _vec0{v1}, _vec1{v2} {}
C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) : _vecb0{v1}, _vecb1{v2} {}
Vectorized(ComplexDbl val) {
double real_value = val.real();
double imag_value = val.imag();
_vec0 = vfloat64{real_value, imag_value};
_vec1 = vfloat64{real_value, imag_value};
}
Vectorized(ComplexDbl val1, ComplexDbl val2) {
_vec0 = vfloat64{val1.real(), val1.imag()};
_vec1 = vfloat64{val2.real(), val2.imag()};
}
C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
return _vec0;
}
C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
return _vec1;
}
template <int64_t mask>
static std::enable_if_t<blendChoiceComplexDbl(mask) == 0, Vectorized<ComplexDbl>>
C10_ALWAYS_INLINE
blend(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
return a;
}
template <int64_t mask>
static std::enable_if_t<blendChoiceComplexDbl(mask) == 1, Vectorized<ComplexDbl>>
C10_ALWAYS_INLINE
blend(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
return b;
}
template <int64_t mask>
static std::enable_if_t<blendChoiceComplexDbl(mask) == 2, Vectorized<ComplexDbl>>
C10_ALWAYS_INLINE
blend(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
return {b._vec0, a._vec1};
}
template <int64_t mask>
static std::enable_if_t<blendChoiceComplexDbl(mask) == 3, Vectorized<ComplexDbl>>
C10_ALWAYS_INLINE
blend(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
return {a._vec0, b._vec1};
}
template <int64_t mask>
static Vectorized<ComplexDbl> C10_ALWAYS_INLINE
el_blend(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
const vbool64 mask_1st = VsxDblMask1(mask);
const vbool64 mask_2nd = VsxDblMask2(mask);
return {
(vfloat64)vec_sel(a._vec0, b._vec0, mask_1st),
(vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)};
}
static Vectorized<ComplexDbl> blendv(
const Vectorized<ComplexDbl>& a,
const Vectorized<ComplexDbl>& b,
const Vectorized<ComplexDbl>& mask) {
// convert std::complex<V> index mask to V index mask: xy -> xxyy
auto mask_complex =
Vectorized<ComplexDbl>(vec_splat(mask._vec0, 0), vec_splat(mask._vec1, 0));
return {
vec_sel(a._vec0, b._vec0, mask_complex._vecb0),
vec_sel(a._vec1, b._vec1, mask_complex._vecb1)};
}
static Vectorized<ComplexDbl> C10_ALWAYS_INLINE elwise_blendv(
const Vectorized<ComplexDbl>& a,
const Vectorized<ComplexDbl>& b,
const Vectorized<ComplexDbl>& mask) {
return {
vec_sel(a._vec0, b._vec0, mask._vecb0),
vec_sel(a._vec1, b._vec1, mask._vecb1)};
}
template <typename step_t>
static Vectorized<ComplexDbl> arange(
ComplexDbl base = 0.,
step_t step = static_cast<step_t>(1)) {
return Vectorized<ComplexDbl>(base, base + step);
}
static Vectorized<ComplexDbl> set(
const Vectorized<ComplexDbl>& a,
const Vectorized<ComplexDbl>& b,
int64_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
}
return b;
}
static Vectorized<value_type> C10_ALWAYS_INLINE
loadu(const void* ptr, int count = size()) {
if (count == size()) {
return {
vec_vsx_ld(offset0, reinterpret_cast<const double*>(ptr)),
vec_vsx_ld(offset16, reinterpret_cast<const double*>(ptr))};
}
__at_align__ value_type tmp_values[size()] = {};
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return {
vec_vsx_ld(offset0, reinterpret_cast<const double*>(tmp_values)),
vec_vsx_ld(offset16, reinterpret_cast<const double*>(tmp_values))};
}
void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
if (count == size()) {
vec_vsx_st(_vec0, offset0, reinterpret_cast<double*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<double*>(ptr));
} else if (count > 0) {
__at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, reinterpret_cast<double*>(tmp_values));
vec_vsx_st(_vec1, offset16, reinterpret_cast<double*>(tmp_values));
std::memcpy(
ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
}
}
const ComplexDbl& operator[](int idx) const = delete;
ComplexDbl& operator[](int idx) = delete;
Vectorized<ComplexDbl> map(ComplexDbl (*const f)(ComplexDbl)) const {
__at_align__ ComplexDbl tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
Vectorized<ComplexDbl> map(ComplexDbl (*const f)(const ComplexDbl&)) const {
__at_align__ ComplexDbl tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
Vectorized<ComplexDbl> el_swapped() const {
vfloat64 v0 = vec_xxpermdi(_vec0, _vec0, 2);
vfloat64 v1 = vec_xxpermdi(_vec1, _vec1, 2);
return {v0, v1};
}
Vectorized<ComplexDbl> el_madd(
const Vectorized<ComplexDbl>& multiplier,
const Vectorized<ComplexDbl>& val) const {
return {
vec_madd(_vec0, multiplier._vec0, val._vec0),
vec_madd(_vec1, multiplier._vec1, val._vec1)};
}
Vectorized<ComplexDbl> el_mergeo() const {
vfloat64 v0 = vec_splat(_vec0, 1);
vfloat64 v1 = vec_splat(_vec1, 1);
return {v0, v1};
}
Vectorized<ComplexDbl> el_mergee() const {
vfloat64 v0 = vec_splat(_vec0, 0);
vfloat64 v1 = vec_splat(_vec1, 0);
return {v0, v1};
}
static Vectorized<ComplexDbl> el_mergee(
Vectorized<ComplexDbl>& first,
Vectorized<ComplexDbl>& second) {
return {
vec_mergeh(first._vec0, second._vec0),
vec_mergeh(first._vec1, second._vec1)};
}
static Vectorized<ComplexDbl> el_mergeo(
Vectorized<ComplexDbl>& first,
Vectorized<ComplexDbl>& second) {
return {
vec_mergel(first._vec0, second._vec0),
vec_mergel(first._vec1, second._vec1)};
}
Vectorized<ComplexDbl> abs_2_() const {
auto a = (*this).elwise_mult(*this);
auto permuted = a.el_swapped();
a = a + permuted;
return a;
}
Vectorized<ComplexDbl> abs_() const {
auto vi = el_mergeo();
auto vr = el_mergee();
return {Sleef_hypotd2_u05vsx(vr._vec0, vi._vec0), Sleef_hypotd2_u05vsx(vr._vec1, vi._vec1)};
}
Vectorized<ComplexDbl> abs() const {
return abs_() & vd_real_mask;
}
Vectorized<ComplexDbl> angle_() const {
// angle = atan2(b/a)
// auto b_a = _mm256_permute_pd(values, 0x05); // b a
// return Sleef_atan2d4_u10(values, b_a); // 90-angle angle
Vectorized<ComplexDbl> ret;
ret._vec0[0] = std::atan2(_vec0[1], _vec0[0]);
ret._vec1[0] = std::atan2(_vec1[1], _vec1[0]);
return ret;
}
Vectorized<ComplexDbl> angle() const {
return angle_() & vd_real_mask;
}
Vectorized<ComplexDbl> real_() const {
return *this & vd_real_mask;
}
Vectorized<ComplexDbl> real() const {
return *this & vd_real_mask;
}
Vectorized<ComplexDbl> imag_() const {
return *this & vd_imag_mask;
}
Vectorized<ComplexDbl> imag() const {
return imag_().el_swapped();
}
Vectorized<ComplexDbl> conj_() const {
return *this ^ vd_isign_mask;
}
Vectorized<ComplexDbl> conj() const {
return *this ^ vd_isign_mask;
}
Vectorized<ComplexDbl> log() const {
// Most trigonomic ops use the log() op to improve complex number
// performance.
return map(std::log);
}
Vectorized<ComplexDbl> log2() const {
// log2eB_inv
auto ret = log();
return ret.elwise_mult(vd_log2e_inv);
}
Vectorized<ComplexDbl> log10() const {
auto ret = log();
return ret.elwise_mult(vd_log10e_inv);
}
Vectorized<ComplexDbl> log1p() const {
return map(std::log1p);
}
Vectorized<ComplexDbl> asin() const {
// asin(x)
// = -i*ln(iz + sqrt(1 -z^2))
// = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
// = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
auto conj = conj_();
auto b_a = conj.el_swapped();
auto ab = conj.elwise_mult(b_a);
auto im = ab + ab;
auto val_2 = (*this).elwise_mult(*this);
auto val_2_swapped = val_2.el_swapped();
auto re = horizontal_sub(val_2, val_2_swapped);
re = Vectorized<ComplexDbl>(vd_one) - re;
auto root = el_blend<0x0A>(re, im).sqrt();
auto ln = (b_a + root).log();
return ln.el_swapped().conj();
}
Vectorized<ComplexDbl> acos() const {
// acos(x) = pi/2 - asin(x)
return Vectorized(vd_pi_2) - asin();
}
Vectorized<ComplexDbl> atan() const {
// atan(x) = i/2 * ln((i + z)/(i - z))
auto ione = Vectorized(vd_imag_one);
auto sum = ione + *this;
auto sub = ione - *this;
auto ln = (sum / sub).log(); // ln((i + z)/(i - z))
return ln * vd_imag_half; // i/2*ln()
}
Vectorized<ComplexDbl> atanh() const {
return map(std::atanh);
}
Vectorized<ComplexDbl> sin() const {
return map(std::sin);
}
Vectorized<ComplexDbl> sinh() const {
return map(std::sinh);
}
Vectorized<ComplexDbl> cos() const {
return map(std::cos);
}
Vectorized<ComplexDbl> cosh() const {
return map(std::cosh);
}
Vectorized<ComplexDbl> tan() const {
return map(std::tan);
}
Vectorized<ComplexDbl> tanh() const {
return map(std::tanh);
}
Vectorized<ComplexDbl> ceil() const {
return {vec_ceil(_vec0), vec_ceil(_vec1)};
}
Vectorized<ComplexDbl> floor() const {
return {vec_floor(_vec0), vec_floor(_vec1)};
}
Vectorized<ComplexDbl> neg() const {
auto z = Vectorized<ComplexDbl>(vd_zero);
return z - *this;
}
Vectorized<ComplexDbl> round() const {
return {vec_rint(_vec0), vec_rint(_vec1)};
}
Vectorized<ComplexDbl> trunc() const {
return {vec_trunc(_vec0), vec_trunc(_vec1)};
}
Vectorized<ComplexDbl> elwise_sqrt() const {
return {vec_sqrt(_vec0), vec_sqrt(_vec1)};
}
Vectorized<ComplexDbl> sqrt() const {
return map(std::sqrt);
}
Vectorized<ComplexDbl> reciprocal() const {
// re + im*i = (a + bi) / (c + di)
// re = (ac + bd)/abs_2() = c/abs_2()
// im = (bc - ad)/abs_2() = d/abs_2()
auto c_d = *this ^ vd_isign_mask; // c -d
auto abs = abs_2_();
return c_d.elwise_div(abs);
}
Vectorized<ComplexDbl> rsqrt() const {
return sqrt().reciprocal();
}
static Vectorized<ComplexDbl> horizontal_add(
Vectorized<ComplexDbl>& first,
Vectorized<ComplexDbl>& second) {
// Operates on individual floats, see _mm_hadd_ps
// {f0+f1, s0+s1, f2+f3, s2+s3, ...}
// i.e. it sums the re and im of each value and interleaves first and second:
// {f_re0 + f_im0, s_re0 + s_im0, f_re1 + f_im1, s_re1 + s_im1, ...}
return el_mergee(first, second) + el_mergeo(first, second);
}
static Vectorized<ComplexDbl> horizontal_sub(
Vectorized<ComplexDbl>& first,
Vectorized<ComplexDbl>& second) {
// we will simulate it differently with 6 instructions total
// lets permute second so that we can add it getting horizontal sums
auto first_perm = first.el_swapped(); // 2perm
auto second_perm = second.el_swapped(); // 2perm
// summ
auto first_ret = first - first_perm; // 2sub
auto second_ret = second - second_perm; // 2 sub
// now lets choose evens
return el_mergee(first_ret, second_ret); // 2 mergee's
}
Vectorized<ComplexDbl> inline operator*(const Vectorized<ComplexDbl>& b) const {
//(a + bi) * (c + di) = (ac - bd) + (ad + bc)i
#if 1
// this is more vsx friendly than simulating horizontal from x86
auto vi = b.el_mergeo();
auto vr = b.el_mergee();
vi = vi ^ vd_rsign_mask;
auto ret = elwise_mult(vr);
auto vx_swapped = el_swapped();
ret = vx_swapped.el_madd(vi, ret);
#else
auto ac_bd = elwise_mult(b);
auto d_c = b.el_swapped();
d_c = d_c ^ vd_isign_mask;
auto ad_bc = elwise_mult(d_c);
auto ret = horizontal_sub(ac_bd, ad_bc);
#endif
return ret;
}
Vectorized<ComplexDbl> inline operator/(const Vectorized<ComplexDbl>& b) const {
// re + im*i = (a + bi) / (c + di)
// re = (ac + bd)/abs_2()
// im = (bc - ad)/abs_2()
auto fabs_cd = Vectorized{
vec_andc(b._vec0, vd_sign_mask),
vec_andc(b._vec1, vd_sign_mask)}; // |c| |d|
auto fabs_dc = fabs_cd.el_swapped(); // |d| |c|
auto scale = fabs_cd.elwise_max(fabs_dc); // sc = max(|c|, |d|)
auto a2 = elwise_div(scale); // a/sc b/sc
auto b2 = b.elwise_div(scale); // c/sc d/sc
auto acbd2 = a2.elwise_mult(b2); // ac/sc^2 bd/sc^2
auto dc2 = b2.el_swapped(); // d/sc c/sc
dc2 = dc2 ^ vd_rsign_mask; // -d/sc c/sc
auto adbc2 = a2.elwise_mult(dc2); // -ad/sc^2 bc/sc^2
auto ret = horizontal_add(acbd2, adbc2); // (ac+bd)/sc^2 (bc-ad)/sc^2
auto denom2 = b2.abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2
ret = ret.elwise_div(denom2);
return ret;
}
Vectorized<ComplexDbl> exp() const {
return map(std::exp);
}
Vectorized<ComplexDbl> exp2() const {
return map(exp2_impl);
}
Vectorized<ComplexDbl> expm1() const {
return map(std::expm1);
}
Vectorized<ComplexDbl> pow(const Vectorized<ComplexDbl>& exp) const {
__at_align__ ComplexDbl x_tmp[size()];
__at_align__ ComplexDbl y_tmp[size()];
store(x_tmp);
exp.store(y_tmp);
for (const auto i : c10::irange(size())) {
x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
}
return loadu(x_tmp);
}
Vectorized<ComplexDbl> sgn() const {
return map(at::native::sgn_impl);
}
Vectorized<ComplexDbl> operator<(const Vectorized<ComplexDbl>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<ComplexDbl> operator<=(const Vectorized<ComplexDbl>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<ComplexDbl> operator>(const Vectorized<ComplexDbl>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<ComplexDbl> operator>=(const Vectorized<ComplexDbl>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<ComplexDbl> eq(const Vectorized<ComplexDbl>& other) const {
auto eq = (*this == other); // compares real and imag individually
// If both real numbers and imag numbers are equal, then the complex numbers are equal
return (eq.real() & eq.imag()) & vd_one;
}
Vectorized<ComplexDbl> ne(const Vectorized<ComplexDbl>& other) const {
auto ne = (*this != other); // compares real and imag individually
// If either real numbers or imag numbers are not equal, then the complex numbers are not equal
return (ne.real() | ne.imag()) & vd_one;
}
DEFINE_MEMBER_OP(operator==, ComplexDbl, vec_cmpeq)
DEFINE_MEMBER_OP(operator!=, ComplexDbl, vec_cmpne)
DEFINE_MEMBER_OP(operator+, ComplexDbl, vec_add)
DEFINE_MEMBER_OP(operator-, ComplexDbl, vec_sub)
DEFINE_MEMBER_OP(operator&, ComplexDbl, vec_and)
DEFINE_MEMBER_OP(operator|, ComplexDbl, vec_or)
DEFINE_MEMBER_OP(operator^, ComplexDbl, vec_xor)
// elementwise helpers
DEFINE_MEMBER_OP(elwise_mult, ComplexDbl, vec_mul)
DEFINE_MEMBER_OP(elwise_div, ComplexDbl, vec_div)
DEFINE_MEMBER_OP(elwise_gt, ComplexDbl, vec_cmpgt)
DEFINE_MEMBER_OP(elwise_ge, ComplexDbl, vec_cmpge)
DEFINE_MEMBER_OP(elwise_lt, ComplexDbl, vec_cmplt)
DEFINE_MEMBER_OP(elwise_le, ComplexDbl, vec_cmple)
DEFINE_MEMBER_OP(elwise_max, ComplexDbl, vec_max)
};
template <>
Vectorized<ComplexDbl> inline maximum(
const Vectorized<ComplexDbl>& a,
const Vectorized<ComplexDbl>& b) {
auto abs_a = a.abs_2_();
auto abs_b = b.abs_2_();
// auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ);
// auto max = _mm256_blendv_ps(a, b, mask);
auto mask = abs_a.elwise_lt(abs_b);
auto max = Vectorized<ComplexDbl>::elwise_blendv(a, b, mask);
return max;
// Exploit the fact that all-ones is a NaN.
// auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
// return _mm256_or_ps(max, isnan);
}
template <>
Vectorized<ComplexDbl> inline minimum(
const Vectorized<ComplexDbl>& a,
const Vectorized<ComplexDbl>& b) {
auto abs_a = a.abs_2_();
auto abs_b = b.abs_2_();
// auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ);
// auto min = _mm256_blendv_ps(a, b, mask);
auto mask = abs_a.elwise_gt(abs_b);
auto min = Vectorized<ComplexDbl>::elwise_blendv(a, b, mask);
return min;
// Exploit the fact that all-ones is a NaN.
// auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
// return _mm256_or_ps(min, isnan);
}
template <>
Vectorized<ComplexDbl> C10_ALWAYS_INLINE operator+(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
return Vectorized<ComplexDbl>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
}
template <>
Vectorized<ComplexDbl> C10_ALWAYS_INLINE operator-(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
return Vectorized<ComplexDbl>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
}
template <>
Vectorized<ComplexDbl> C10_ALWAYS_INLINE operator&(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
return Vectorized<ComplexDbl>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
}
template <>
Vectorized<ComplexDbl> C10_ALWAYS_INLINE operator|(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
return Vectorized<ComplexDbl>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
}
template <>
Vectorized<ComplexDbl> C10_ALWAYS_INLINE operator^(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
return Vectorized<ComplexDbl>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
}
} // namespace
} // namespace vec
} // namespace at

View File

@ -0,0 +1,660 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
#include <c10/util/complex.h>
#include <c10/util/irange.h>
namespace at {
namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
using ComplexFlt = c10::complex<float>;
template <>
class Vectorized<ComplexFlt> {
private:
union {
struct {
vfloat32 _vec0;
vfloat32 _vec1;
};
struct {
vbool32 _vecb0;
vbool32 _vecb1;
};
} __attribute__((__may_alias__));
public:
using value_type = ComplexFlt;
using vec_internal_type = vfloat32;
using vec_internal_mask_type = vbool32;
using size_type = int;
static constexpr size_type size() {
return 4;
}
Vectorized() {}
C10_ALWAYS_INLINE Vectorized(vfloat32 v) : _vec0{v}, _vec1{v} {}
C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
C10_ALWAYS_INLINE Vectorized(vfloat32 v1, vfloat32 v2) : _vec0{v1}, _vec1{v2} {}
C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {}
Vectorized(ComplexFlt val) {
float real_value = val.real();
float imag_value = val.imag();
_vec0 = vfloat32{real_value, imag_value, real_value, imag_value};
_vec1 = vfloat32{real_value, imag_value, real_value, imag_value};
}
Vectorized(ComplexFlt val1, ComplexFlt val2, ComplexFlt val3, ComplexFlt val4) {
_vec0 = vfloat32{val1.real(), val1.imag(), val2.real(), val2.imag()};
_vec1 = vfloat32{val3.real(), val3.imag(), val4.real(), val4.imag()};
}
C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
return _vec0;
}
C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
return _vec1;
}
template <uint64_t mask>
static std::enable_if_t<blendChoiceComplex(mask) == 0, Vectorized<ComplexFlt>>
C10_ALWAYS_INLINE
blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
return a;
}
template <uint64_t mask>
static std::enable_if_t<blendChoiceComplex(mask) == 1, Vectorized<ComplexFlt>>
C10_ALWAYS_INLINE
blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
return b;
}
template <uint64_t mask>
static std::enable_if_t<blendChoiceComplex(mask) == 2, Vectorized<ComplexFlt>>
C10_ALWAYS_INLINE
blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
return {b._vec0, a._vec1};
}
template <uint64_t mask>
static std::enable_if_t<blendChoiceComplex(mask) == 3, Vectorized<ComplexFlt>>
C10_ALWAYS_INLINE
blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
return {a._vec0, b._vec1};
}
template <uint64_t mask>
static std::enable_if_t<blendChoiceComplex(mask) == 4, Vectorized<ComplexFlt>>
C10_ALWAYS_INLINE
blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
const vbool32 mask_1st = VsxComplexMask1(mask);
return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1};
}
template <uint64_t mask>
static std::enable_if_t<blendChoiceComplex(mask) == 5, Vectorized<ComplexFlt>>
C10_ALWAYS_INLINE
blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
const vbool32 mask_1st = VsxComplexMask1(mask);
return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1};
}
template <uint64_t mask>
static std::enable_if_t<blendChoiceComplex(mask) == 6, Vectorized<ComplexFlt>>
C10_ALWAYS_INLINE
blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
const vbool32 mask_2nd = VsxComplexMask2(mask);
// generated masks
return {a._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
}
template <uint64_t mask>
static std::enable_if_t<blendChoiceComplex(mask) == 7, Vectorized<ComplexFlt>>
C10_ALWAYS_INLINE
blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
const vbool32 mask_2nd = VsxComplexMask2(mask);
// generated masks
return {b._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
}
template <uint64_t mask>
static std::enable_if_t<blendChoiceComplex(mask) == 8, Vectorized<ComplexFlt>>
C10_ALWAYS_INLINE
blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
const vbool32 mask_1st = VsxComplexMask1(mask);
const vbool32 mask_2nd = VsxComplexMask2(mask);
return {
(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st),
(vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
}
template <int64_t mask>
static Vectorized<ComplexFlt> C10_ALWAYS_INLINE
el_blend(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
const vbool32 mask_1st = VsxMask1(mask);
const vbool32 mask_2nd = VsxMask2(mask);
return {
(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st),
(vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
}
static Vectorized<ComplexFlt> blendv(
const Vectorized<ComplexFlt>& a,
const Vectorized<ComplexFlt>& b,
const Vectorized<ComplexFlt>& mask) {
// convert std::complex<V> index mask to V index mask: xy -> xxyy
auto mask_complex = Vectorized<ComplexFlt>(
vec_mergeh(mask._vec0, mask._vec0), vec_mergeh(mask._vec1, mask._vec1));
return {
vec_sel(a._vec0, b._vec0, reinterpret_cast<vbool32>(mask_complex._vec0)),
vec_sel(a._vec1, b._vec1, reinterpret_cast<vbool32>(mask_complex._vec1)),
};
}
static Vectorized<ComplexFlt> elwise_blendv(
const Vectorized<ComplexFlt>& a,
const Vectorized<ComplexFlt>& b,
const Vectorized<ComplexFlt>& mask) {
return {
vec_sel(a._vec0, b._vec0, reinterpret_cast<vbool32>(mask._vec0)),
vec_sel(a._vec1, b._vec1, reinterpret_cast<vbool32>(mask._vec1)),
};
}
template <typename step_t>
static Vectorized<ComplexFlt> arange(
ComplexFlt base = 0.,
step_t step = static_cast<step_t>(1)) {
return Vectorized<ComplexFlt>(
base,
base + step,
base + ComplexFlt(2) * step,
base + ComplexFlt(3) * step);
}
static Vectorized<ComplexFlt> set(
const Vectorized<ComplexFlt>& a,
const Vectorized<ComplexFlt>& b,
int64_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
}
return b;
}
static Vectorized<value_type> C10_ALWAYS_INLINE
loadu(const void* ptr, int count = size()) {
if (count == size()) {
return {
vec_vsx_ld(offset0, reinterpret_cast<const float*>(ptr)),
vec_vsx_ld(offset16, reinterpret_cast<const float*>(ptr))};
}
__at_align__ value_type tmp_values[size()] = {};
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return {
vec_vsx_ld(offset0, reinterpret_cast<const float*>(tmp_values)),
vec_vsx_ld(offset16, reinterpret_cast<const float*>(tmp_values))};
}
void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
if (count == size()) {
vec_vsx_st(_vec0, offset0, reinterpret_cast<float*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<float*>(ptr));
} else if (count > 0) {
__at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, reinterpret_cast<float*>(tmp_values));
vec_vsx_st(_vec1, offset16, reinterpret_cast<float*>(tmp_values));
std::memcpy(
ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
}
}
const ComplexFlt& operator[](int idx) const = delete;
ComplexFlt& operator[](int idx) = delete;
Vectorized<ComplexFlt> map(ComplexFlt (*const f)(ComplexFlt)) const {
__at_align__ ComplexFlt tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
Vectorized<ComplexFlt> map(ComplexFlt (*const f)(const ComplexFlt&)) const {
__at_align__ ComplexFlt tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
static Vectorized<ComplexFlt> horizontal_add(
Vectorized<ComplexFlt>& first,
Vectorized<ComplexFlt>& second) {
// Operates on individual floats, see _mm_hadd_ps
// {f0+f1, s0+s1, f2+f3, s2+s3, ...}
// i.e. it sums the re and im of each value and interleaves first and second:
// {f_re0 + f_im0, s_re0 + s_im0, f_re1 + f_im1, s_re1 + s_im1, ...}
return el_mergee(first, second) + el_mergeo(first, second);
}
static Vectorized<ComplexFlt> horizontal_sub_permD8(
Vectorized<ComplexFlt>& first,
Vectorized<ComplexFlt>& second) {
// we will simulate it differently with 6 instructions total
// lets permute second so that we can add it getting horizontal sums
auto first_perm = first.el_swapped(); // 2perm
auto second_perm = second.el_swapped(); // 2perm
// sum
auto first_ret = first - first_perm; // 2sub
auto second_ret = second - second_perm; // 2 sub
// now lets choose evens
return el_mergee(first_ret, second_ret); // 2 mergee's
}
Vectorized<ComplexFlt> abs_2_() const {
auto a = (*this).elwise_mult(*this);
auto permuted = a.el_swapped();
a = a + permuted;
return a.el_mergee();
}
Vectorized<ComplexFlt> abs_() const {
auto vi = el_mergeo();
auto vr = el_mergee();
return {Sleef_hypotf4_u05vsx(vr._vec0, vi._vec0), Sleef_hypotf4_u05vsx(vr._vec1, vi._vec1)};
}
Vectorized<ComplexFlt> abs() const {
return abs_() & real_mask;
}
Vectorized<ComplexFlt> real_() const {
return *this & real_mask;
}
Vectorized<ComplexFlt> real() const {
return *this & real_mask;
}
Vectorized<ComplexFlt> imag_() const {
return *this & imag_mask;
}
Vectorized<ComplexFlt> imag() const {
// we can use swap_mask or sldwi
auto ret = imag_();
return {
vec_sldw(ret._vec0, ret._vec0, 3), vec_sldw(ret._vec1, ret._vec1, 3)};
}
Vectorized<ComplexFlt> conj_() const {
return *this ^ isign_mask;
}
Vectorized<ComplexFlt> conj() const {
return *this ^ isign_mask;
}
Vectorized<ComplexFlt> log() const {
// Most trigonomic ops use the log() op to improve complex number
// performance.
return map(std::log);
}
Vectorized<ComplexFlt> log2() const {
// log2eB_inv
auto ret = log();
return ret.elwise_mult(log2e_inv);
}
Vectorized<ComplexFlt> log10() const {
auto ret = log();
return ret.elwise_mult(log10e_inv);
}
Vectorized<ComplexFlt> log1p() const {
return map(std::log1p);
}
Vectorized<ComplexFlt> el_swapped() const {
vfloat32 v0 = vec_perm(_vec0, _vec0, swap_mask);
vfloat32 v1 = vec_perm(_vec1, _vec1, swap_mask);
return {v0, v1};
}
Vectorized<ComplexFlt> el_mergee() const {
// as mergee phased in , we can use vec_perm with mask
return {vec_mergee(_vecb0, _vecb0), vec_mergee(_vecb1, _vecb1)};
}
Vectorized<ComplexFlt> el_mergeo() const {
// as mergeo phased in , we can use vec_perm with mask
return {vec_mergeo(_vecb0, _vecb0), vec_mergeo(_vecb1, _vecb1)};
}
Vectorized<ComplexFlt> el_madd(
const Vectorized<ComplexFlt>& multiplier,
const Vectorized<ComplexFlt>& val) const {
return {
vec_madd(_vec0, multiplier._vec0, val._vec0),
vec_madd(_vec1, multiplier._vec1, val._vec1)};
}
static Vectorized<ComplexFlt> el_mergee(
Vectorized<ComplexFlt>& first,
Vectorized<ComplexFlt>& second) {
return {
vec_mergee(first._vecb0, second._vecb0),
vec_mergee(first._vecb1, second._vecb1)};
}
static Vectorized<ComplexFlt> el_mergeo(
Vectorized<ComplexFlt>& first,
Vectorized<ComplexFlt>& second) {
return {
vec_mergeo(first._vecb0, second._vecb0),
vec_mergeo(first._vecb1, second._vecb1)};
}
Vectorized<ComplexFlt> angle_() const {
// angle = atan2(b/a)
// auto b_a = _mm256_permute_ps(values, 0xB1); // b a
// return Sleef_atan2f8_u10(values, b_a); // 90-angle angle
Vectorized<ComplexFlt> ret;
for (int i = 0; i < 4; i += 2) {
ret._vec0[i] = std::atan2(_vec0[i + 1], _vec0[i]);
ret._vec1[i] = std::atan2(_vec1[i + 1], _vec1[i]);
}
return ret;
}
Vectorized<ComplexFlt> angle() const {
return angle_() & real_mask;
}
Vectorized<ComplexFlt> sin() const {
return map(std::sin);
}
Vectorized<ComplexFlt> sinh() const {
return map(std::sinh);
}
Vectorized<ComplexFlt> cos() const {
return map(std::cos);
}
Vectorized<ComplexFlt> cosh() const {
return map(std::cosh);
}
Vectorized<ComplexFlt> ceil() const {
return {vec_ceil(_vec0), vec_ceil(_vec1)};
}
Vectorized<ComplexFlt> floor() const {
return {vec_floor(_vec0), vec_floor(_vec1)};
}
Vectorized<ComplexFlt> neg() const {
auto z = Vectorized<ComplexFlt>(zero);
return z - *this;
}
Vectorized<ComplexFlt> round() const {
return {vec_round(_vec0), vec_round(_vec1)};
}
Vectorized<ComplexFlt> tan() const {
return map(std::tan);
}
Vectorized<ComplexFlt> tanh() const {
return map(std::tanh);
}
Vectorized<ComplexFlt> trunc() const {
return {vec_trunc(_vec0), vec_trunc(_vec1)};
}
Vectorized<ComplexFlt> elwise_sqrt() const {
return {vec_sqrt(_vec0), vec_sqrt(_vec1)};
}
Vectorized<ComplexFlt> sqrt() const {
return map(std::sqrt);
}
Vectorized<ComplexFlt> reciprocal() const {
// re + im*i = (a + bi) / (c + di)
// re = (ac + bd)/abs_2() = c/abs_2()
// im = (bc - ad)/abs_2() = d/abs_2()
auto c_d = *this ^ isign_mask; // c -d
auto abs = abs_2_();
return c_d.elwise_div(abs);
}
Vectorized<ComplexFlt> rsqrt() const {
return sqrt().reciprocal();
}
Vectorized<ComplexFlt> pow(const Vectorized<ComplexFlt>& exp) const {
__at_align__ ComplexFlt x_tmp[size()];
__at_align__ ComplexFlt y_tmp[size()];
store(x_tmp);
exp.store(y_tmp);
for (const auto i : c10::irange(size())) {
x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
}
return loadu(x_tmp);
}
Vectorized<ComplexFlt> atan() const {
// atan(x) = i/2 * ln((i + z)/(i - z))
auto ione = Vectorized(imag_one);
auto sum = ione + *this;
auto sub = ione - *this;
auto ln = (sum / sub).log(); // ln((i + z)/(i - z))
return ln * imag_half; // i/2*ln()
}
Vectorized<ComplexFlt> atanh() const {
return map(std::atanh);
}
Vectorized<ComplexFlt> acos() const {
// acos(x) = pi/2 - asin(x)
return Vectorized(pi_2) - asin();
}
Vectorized<ComplexFlt> inline operator*(const Vectorized<ComplexFlt>& b) const {
//(a + bi) * (c + di) = (ac - bd) + (ad + bc)i
#if 1
// this is more vsx friendly than simulating horizontal from x86
auto vi = b.el_mergeo();
auto vr = b.el_mergee();
vi = vi ^ rsign_mask;
auto ret = elwise_mult(vr);
auto vx_swapped = el_swapped();
ret = vx_swapped.el_madd(vi, ret);
return ret;
#else
auto ac_bd = elwise_mult(b);
auto d_c = b.el_swapped();
d_c = d_c ^ isign_mask;
auto ad_bc = elwise_mult(d_c);
auto ret = horizontal_sub_permD8(ac_bd, ad_bc);
return ret;
#endif
}
Vectorized<ComplexFlt> inline operator/(const Vectorized<ComplexFlt>& b) const {
// re + im*i = (a + bi) / (c + di)
// re = (ac + bd)/abs_2()
// im = (bc - ad)/abs_2()
auto fabs_cd = Vectorized{
vec_andc(b._vec0, sign_mask),
vec_andc(b._vec1, sign_mask)}; // |c| |d|
auto fabs_dc = fabs_cd.el_swapped(); // |d| |c|
auto scale = fabs_cd.elwise_max(fabs_dc); // sc = max(|c|, |d|)
auto a2 = elwise_div(scale); // a/sc b/sc
auto b2 = b.elwise_div(scale); // c/sc d/sc
auto acbd2 = a2.elwise_mult(b2); // ac/sc^2 bd/sc^2
auto dc2 = b2.el_swapped(); // d/sc c/sc
dc2 = dc2 ^ rsign_mask; // -d/sc c/sc
auto adbc2 = a2.elwise_mult(dc2); // -ad/sc^2 bc/sc^2
auto ret = horizontal_add(acbd2, adbc2); // (ac+bd)/sc^2 (bc-ad)/sc^2
auto denom2 = b2.abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2
ret = ret.elwise_div(denom2);
return ret;
}
Vectorized<ComplexFlt> asin() const {
// asin(x)
// = -i*ln(iz + sqrt(1 -z^2))
// = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
// = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
#if 1
auto conj = conj_();
auto b_a = conj.el_swapped();
auto ab = conj.elwise_mult(b_a);
auto im = ab + ab;
auto val_2 = (*this).elwise_mult(*this);
auto val_2_swapped = val_2.el_swapped();
auto re = horizontal_sub_permD8(val_2, val_2_swapped);
re = Vectorized<ComplexFlt>(one) - re;
auto root = el_blend<0xAA>(re, im).sqrt();
auto ln = (b_a + root).log();
return ln.el_swapped().conj();
#else
return map(std::asin);
#endif
}
Vectorized<ComplexFlt> exp() const {
return map(std::exp);
}
Vectorized<ComplexFlt> exp2() const {
return map(exp2_impl);
}
Vectorized<ComplexFlt> expm1() const {
return map(std::expm1);
}
Vectorized<ComplexFlt> eq(const Vectorized<ComplexFlt>& other) const {
auto eq = (*this == other); // compares real and imag individually
// If both real numbers and imag numbers are equal, then the complex numbers are equal
return (eq.real() & eq.imag()) & one;
}
Vectorized<ComplexFlt> ne(const Vectorized<ComplexFlt>& other) const {
auto ne = (*this != other); // compares real and imag individually
// If either real numbers or imag numbers are not equal, then the complex numbers are not equal
return (ne.real() | ne.imag()) & one;
}
Vectorized<ComplexFlt> sgn() const {
return map(at::native::sgn_impl);
}
Vectorized<ComplexFlt> operator<(const Vectorized<ComplexFlt>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<ComplexFlt> operator<=(const Vectorized<ComplexFlt>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<ComplexFlt> operator>(const Vectorized<ComplexFlt>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<ComplexFlt> operator>=(const Vectorized<ComplexFlt>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
DEFINE_MEMBER_OP(operator==, ComplexFlt, vec_cmpeq)
DEFINE_MEMBER_OP(operator!=, ComplexFlt, vec_cmpne)
DEFINE_MEMBER_OP(operator+, ComplexFlt, vec_add)
DEFINE_MEMBER_OP(operator-, ComplexFlt, vec_sub)
DEFINE_MEMBER_OP(operator&, ComplexFlt, vec_and)
DEFINE_MEMBER_OP(operator|, ComplexFlt, vec_or)
DEFINE_MEMBER_OP(operator^, ComplexFlt, vec_xor)
// elementwise helpers
DEFINE_MEMBER_OP(elwise_mult, ComplexFlt, vec_mul)
DEFINE_MEMBER_OP(elwise_div, ComplexFlt, vec_div)
DEFINE_MEMBER_OP(elwise_gt, ComplexFlt, vec_cmpgt)
DEFINE_MEMBER_OP(elwise_ge, ComplexFlt, vec_cmpge)
DEFINE_MEMBER_OP(elwise_lt, ComplexFlt, vec_cmplt)
DEFINE_MEMBER_OP(elwise_le, ComplexFlt, vec_cmple)
DEFINE_MEMBER_OP(elwise_max, ComplexFlt, vec_max)
};
template <>
Vectorized<ComplexFlt> inline maximum(
const Vectorized<ComplexFlt>& a,
const Vectorized<ComplexFlt>& b) {
auto abs_a = a.abs_2_();
auto abs_b = b.abs_2_();
// auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ);
// auto max = _mm256_blendv_ps(a, b, mask);
auto mask = abs_a.elwise_lt(abs_b);
auto max = Vectorized<ComplexFlt>::elwise_blendv(a, b, mask);
return max;
// Exploit the fact that all-ones is a NaN.
// auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
// return _mm256_or_ps(max, isnan);
}
template <>
Vectorized<ComplexFlt> inline minimum(
const Vectorized<ComplexFlt>& a,
const Vectorized<ComplexFlt>& b) {
auto abs_a = a.abs_2_();
auto abs_b = b.abs_2_();
// auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ);
// auto min = _mm256_blendv_ps(a, b, mask);
auto mask = abs_a.elwise_gt(abs_b);
auto min = Vectorized<ComplexFlt>::elwise_blendv(a, b, mask);
return min;
// Exploit the fact that all-ones is a NaN.
// auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
// return _mm256_or_ps(min, isnan);
}
template <>
Vectorized<ComplexFlt> C10_ALWAYS_INLINE operator+(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
return Vectorized<ComplexFlt>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
}
template <>
Vectorized<ComplexFlt> C10_ALWAYS_INLINE operator-(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
return Vectorized<ComplexFlt>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
}
template <>
Vectorized<ComplexFlt> C10_ALWAYS_INLINE operator&(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
return Vectorized<ComplexFlt>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
}
template <>
Vectorized<ComplexFlt> C10_ALWAYS_INLINE operator|(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
return Vectorized<ComplexFlt>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
}
template <>
Vectorized<ComplexFlt> C10_ALWAYS_INLINE operator^(const Vectorized<ComplexFlt>& a, const Vectorized<ComplexFlt>& b) {
return Vectorized<ComplexFlt>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
}
} // namespace
} // namespace vec
} // namespace at

View File

@ -0,0 +1,477 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
#include <c10/util/irange.h>
#include <sleef.h>
namespace at {
namespace vec {
inline namespace CPU_CAPABILITY {
template <>
class Vectorized<double> {
private:
union {
struct {
vfloat64 _vec0;
vfloat64 _vec1;
};
struct {
vbool64 _vecb0;
vbool64 _vecb1;
};
} __attribute__((__may_alias__));
public:
using value_type = double;
using vec_internal_type = vfloat64;
using vec_internal_mask_type = vbool64;
using size_type = int;
static constexpr size_type size() {
return 4;
}
Vectorized() {}
C10_ALWAYS_INLINE Vectorized(vfloat64 v) : _vec0{v}, _vec1{v} {}
C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
C10_ALWAYS_INLINE Vectorized(vfloat64 v1, vfloat64 v2) : _vec0{v1}, _vec1{v2} {}
C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) : _vecb0{v1}, _vecb1{v2} {}
C10_ALWAYS_INLINE Vectorized(double scalar)
: _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {}
C10_ALWAYS_INLINE Vectorized(
double scalar1,
double scalar2,
double scalar3,
double scalar4)
: _vec0{vfloat64{scalar1, scalar2}}, _vec1{vfloat64{scalar3, scalar4}} {}
C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
return _vec0;
}
C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
return _vec1;
}
int zero_mask() const {
auto cmp = (*this == vd_zero);
return (cmp._vecb0[0] & 1) | (cmp._vecb0[1] & 2) | (cmp._vecb1[0] & 4) |
(cmp._vecb1[1] & 8);
}
template <int64_t mask>
static std::enable_if_t<blendChoiceDbl(mask) == 0, Vectorized<double>> C10_ALWAYS_INLINE
blend(const Vectorized<double>& a, const Vectorized<double>& b) {
return a;
}
template <int64_t mask>
static std::enable_if_t<blendChoiceDbl(mask) == 1, Vectorized<double>> C10_ALWAYS_INLINE
blend(const Vectorized<double>& a, const Vectorized<double>& b) {
return b;
}
template <int64_t mask>
static std::enable_if_t<blendChoiceDbl(mask) == 2, Vectorized<double>> C10_ALWAYS_INLINE
blend(const Vectorized<double>& a, const Vectorized<double>& b) {
return { b._vec0, a._vec1 };
}
template <int64_t mask>
static std::enable_if_t<blendChoiceDbl(mask) == 3, Vectorized<double>> C10_ALWAYS_INLINE
blend(const Vectorized<double>& a, const Vectorized<double>& b) {
return { a._vec0, b._vec1 };
}
template <int64_t mask>
static std::enable_if_t<blendChoiceDbl(mask) == 4, Vectorized<double>> C10_ALWAYS_INLINE
blend(const Vectorized<double>& a, const Vectorized<double>& b) {
const vbool64 mask_1st = VsxDblMask1(mask);
return { (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1 };
}
template <int64_t mask>
static std::enable_if_t<blendChoiceDbl(mask) == 5, Vectorized<double>> C10_ALWAYS_INLINE
blend(const Vectorized<double>& a, const Vectorized<double>& b) {
const vbool64 mask_1st = VsxDblMask1(mask);
return { (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1 };
}
template <int64_t mask>
static std::enable_if_t<blendChoiceDbl(mask) == 6,
Vectorized<double>>
C10_ALWAYS_INLINE blend(const Vectorized<double>& a, const Vectorized<double>& b) {
const vbool64 mask_2nd = VsxDblMask2(mask);
// generated masks
return { a._vec0,
(vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd) };
}
template <int64_t mask>
static std::enable_if_t<blendChoiceDbl(mask) == 7,
Vectorized<double>>
C10_ALWAYS_INLINE blend(const Vectorized<double>& a, const Vectorized<double>& b) {
const vbool64 mask_2nd = VsxDblMask2(mask);
// generated masks
return { b._vec0,
(vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd) };
}
template <int64_t mask>
static std::enable_if_t<blendChoiceDbl(mask) == 8, Vectorized<double>>
C10_ALWAYS_INLINE blend(const Vectorized<double>& a, const Vectorized<double>& b) {
const vbool64 mask_1st = VsxDblMask1(mask);
const vbool64 mask_2nd = VsxDblMask2(mask);
return {
(vfloat64)vec_sel(a._vec0, b._vec0, mask_1st),
(vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd) };
}
static Vectorized<double> C10_ALWAYS_INLINE blendv(
const Vectorized<double>& a,
const Vectorized<double>& b,
const Vectorized<double>& mask) {
// the mask used here returned by comparision of vec256
return {
vec_sel(a._vec0, b._vec0, mask._vecb0),
vec_sel(a._vec1, b._vec1, mask._vecb1)};
}
template <typename step_t>
static Vectorized<double> arange(double base = 0., step_t step = static_cast<step_t>(1)) {
return Vectorized<double>(base, base + step, base + 2 * step, base + 3 * step);
}
static Vectorized<double> C10_ALWAYS_INLINE
set(const Vectorized<double>& a, const Vectorized<double>& b, size_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
}
return b;
}
static Vectorized<value_type> C10_ALWAYS_INLINE
loadu(const void* ptr, int count = size()) {
if (count == size()) {
return {
vec_vsx_ld(offset0, reinterpret_cast<const value_type*>(ptr)),
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
}
__at_align__ value_type tmp_values[size()] = {};
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
}
void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
if (count == size()) {
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
} else if (count > 0) {
__at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, tmp_values);
vec_vsx_st(_vec1, offset16, tmp_values);
std::memcpy(
ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
}
}
const double& operator[](int idx) const = delete;
double& operator[](int idx) = delete;
Vectorized<double> map(double (*const f)(double)) const {
Vectorized<double> ret;
for (const auto i : c10::irange(size()/2)) {
ret._vec0[i] = f(_vec0[i]);
}
for (const auto i : c10::irange(size()/2)) {
ret._vec1[i] = f(_vec1[i]);
}
return ret;
}
Vectorized<double> mapbi(double (*const f)(double, double), const Vectorized<double>& other)
const {
Vectorized<double> ret;
for (const auto i : c10::irange(size()/2)) {
ret._vec0[i] = f(_vec0[i], other._vec0[i]);
}
for (const auto i : c10::irange(size()/2)) {
ret._vec1[i] = f(_vec1[i], other._vec1[i]);
}
return ret;
}
Vectorized<double> C10_ALWAYS_INLINE abs() const {
return {vec_abs(_vec0), vec_abs(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE acos() const {
return {Sleef_acosd2_u10(_vec0), Sleef_acosd2_u10(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE acosh() const {
return {Sleef_acoshd2_u10(_vec0), Sleef_acoshd2_u10(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE asin() const {
return {Sleef_asind2_u10(_vec0), Sleef_asind2_u10(_vec1)};
}
Vectorized<double> atan() const {
return {Sleef_atand2_u10(_vec0), Sleef_atand2_u10(_vec1)};
}
Vectorized<double> atanh() const {
return {Sleef_atanhd2_u10(_vec0), Sleef_atanhd2_u10(_vec1)};
}
Vectorized<double> atan2(const Vectorized<double>& b) const {
return {Sleef_atan2d2_u10(_vec0, b._vec0), Sleef_atan2d2_u10(_vec1, b._vec1)};
}
Vectorized<double> copysign(const Vectorized<double> &sign) const {
return {Sleef_copysignd2(_vec0, sign._vec0), Sleef_copysignd2(_vec1, sign._vec1)};
}
Vectorized<double> erf() const {
return {Sleef_erfd2_u10(_vec0), Sleef_erfd2_u10(_vec1)};
}
Vectorized<double> erfc() const {
return {Sleef_erfcd2_u15(_vec0), Sleef_erfcd2_u15(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE exp() const {
return {Sleef_expd2_u10(_vec0), Sleef_expd2_u10(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE exp2() const {
return {Sleef_exp2d2_u10(_vec0), Sleef_exp2d2_u10(_vec1)};
}
Vectorized<double> expm1() const {
return {Sleef_expm1d2_u10(_vec0), Sleef_expm1d2_u10(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE exp_u20() const {
return exp();
}
Vectorized<double> lgamma() const __ubsan_ignore_undefined__ {
return {Sleef_lgammad2_u10(_vec0), Sleef_lgammad2_u10(_vec1)};
}
Vectorized<double> erfinv() const {
return map(calc_erfinv);
}
Vectorized<double> angle() const {
auto tmp = blendv(
Vectorized<double>(0), Vectorized<double>(c10::pi<double>), *this < Vectorized<double>(0));
return blendv(tmp, *this, isnan());
}
Vectorized<double> real() const {
return *this;
}
Vectorized<double> imag() const {
return Vectorized<double>{0};
}
Vectorized<double> conj() const {
return *this;
}
Vectorized<double> C10_ALWAYS_INLINE log() const {
return {Sleef_logd2_u10(_vec0), Sleef_logd2_u10(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE log10() const {
return {Sleef_log10d2_u10(_vec0), Sleef_log10d2_u10(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE log1p() const {
return {Sleef_log1pd2_u10(_vec0), Sleef_log1pd2_u10(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE log2() const {
return {Sleef_log2d2_u10(_vec0), Sleef_log2d2_u10(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE ceil() const {
return {vec_ceil(_vec0), vec_ceil(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE cos() const {
return {Sleef_cosd2_u10(_vec0), Sleef_cosd2_u10(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE cosh() const {
return {Sleef_coshd2_u10(_vec0), Sleef_coshd2_u10(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE floor() const {
return {vec_floor(_vec0), vec_floor(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE neg() const {
return {vec_neg(_vec0), vec_neg(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE round() const {
return {vec_rint(_vec0), vec_rint(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE sin() const {
return {Sleef_sind2_u10(_vec0), Sleef_sind2_u10(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE sinh() const {
return {Sleef_sinhd2_u10(_vec0), Sleef_sinhd2_u10(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE tan() const {
return {Sleef_tand2_u10(_vec0), Sleef_tand2_u10(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE tanh() const {
return {Sleef_tanhd2_u10(_vec0), Sleef_tanhd2_u10(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE trunc() const {
return {vec_trunc(_vec0), vec_trunc(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE frac() const {
return *this - trunc();
}
Vectorized<double> C10_ALWAYS_INLINE sqrt() const {
return {vec_sqrt(_vec0), vec_sqrt(_vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE reciprocal() const {
return {
vec_div(vd_one, _vec0), // vec_re(_vec0) is estimated one.
vec_div(vd_one, _vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE rsqrt() const {
return sqrt().reciprocal();
}
Vectorized<double> C10_ALWAYS_INLINE pow(const Vectorized<double>& b) const {
return {Sleef_powd2_u10(_vec0, b._vec0), Sleef_powd2_u10(_vec1, b._vec1)};
}
Vectorized<double> C10_ALWAYS_INLINE fmod(const Vectorized<double>& b) const {
return {Sleef_fmodd2(_vec0, b._vec0),Sleef_fmodd2(_vec1, b._vec1)};
}
Vectorized<double> hypot(const Vectorized<double>& b) const {
return {Sleef_hypotd2_u05(_vec0, b._vec0), Sleef_hypotd2_u05(_vec1, b._vec1)};
}
Vectorized<double> nextafter(const Vectorized<double>& b) const {
return {Sleef_nextafterd2(_vec0, b._vec0), Sleef_nextafterd2(_vec1, b._vec1)};
}
Vectorized<double> igamma(const Vectorized<double>& x) const {
return mapbi(calc_igamma, x);
}
Vectorized<double> igammac(const Vectorized<double>& x) const {
return mapbi(calc_igammac, x);
}
Vectorized<double> i0() const {
return map(calc_i0);
}
Vectorized<double> i0e() const {
return map(calc_i0e);
}
Vectorized<double> digamma() const {
return map(calc_digamma);
}
Vectorized<double> _nor() const {
return {vec_nor(_vec0, _vec0), vec_nor(_vec1, _vec1)};
}
Vectorized<double> isnan() const {
auto x = *this;
auto ret = (x == x);
return ret._nor();
}
bool has_inf_nan() const {
for (const auto i : c10::irange(size()/2)) {
if(_isnan(_vec0[i]) || _isinf(_vec0[i])) {
return true;
}
}
for (const auto i : c10::irange(size()/2)) {
if(_isnan(_vec1[i]) || _isinf(_vec1[i])) {
return true;
}
}
return false;
}
DEFINE_MEMBER_OP(operator==, double, vec_cmpeq)
DEFINE_MEMBER_OP(operator!=, double, vec_cmpne)
DEFINE_MEMBER_OP(operator<, double, vec_cmplt)
DEFINE_MEMBER_OP(operator<=, double, vec_cmple)
DEFINE_MEMBER_OP(operator>, double, vec_cmpgt)
DEFINE_MEMBER_OP(operator>=, double, vec_cmpge)
DEFINE_MEMBER_OP_AND_ONE(eq, double, vec_cmpeq)
DEFINE_MEMBER_OP_AND_ONE(ne, double, vec_cmpne)
DEFINE_MEMBER_OP_AND_ONE(lt, double, vec_cmplt)
DEFINE_MEMBER_OP_AND_ONE(le, double, vec_cmple)
DEFINE_MEMBER_OP_AND_ONE(gt, double, vec_cmpgt)
DEFINE_MEMBER_OP_AND_ONE(ge, double, vec_cmpge)
DEFINE_MEMBER_OP(operator+, double, vec_add)
DEFINE_MEMBER_OP(operator-, double, vec_sub)
DEFINE_MEMBER_OP(operator*, double, vec_mul)
DEFINE_MEMBER_OP(operator/, double, vec_div)
DEFINE_MEMBER_OP(maximum, double, vec_max_nan2)
DEFINE_MEMBER_OP(minimum, double, vec_min_nan2)
DEFINE_MEMBER_OP(operator&, double, vec_and)
DEFINE_MEMBER_OP(operator|, double, vec_or)
DEFINE_MEMBER_OP(operator^, double, vec_xor)
DEFINE_MEMBER_TERNARY_OP(madd, double, vec_madd)
};
template <>
Vectorized<double> inline maximum(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return a.maximum(b);
}
template <>
Vectorized<double> inline minimum(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return a.minimum(b);
}
template <>
Vectorized<double> C10_ALWAYS_INLINE operator+(const Vectorized<double>& a, const Vectorized<double>& b) {
return Vectorized<double>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
}
template <>
Vectorized<double> C10_ALWAYS_INLINE operator-(const Vectorized<double>& a, const Vectorized<double>& b) {
return Vectorized<double>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
}
template <>
Vectorized<double> C10_ALWAYS_INLINE operator*(const Vectorized<double>& a, const Vectorized<double>& b) {
return Vectorized<double>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())};
}
template <>
Vectorized<double> C10_ALWAYS_INLINE operator/(const Vectorized<double>& a, const Vectorized<double>& b) {
return Vectorized<double>{vec_div(a.vec0(), b.vec0()), vec_div(a.vec1(), b.vec1())};
}
template <>
Vectorized<double> C10_ALWAYS_INLINE operator&(const Vectorized<double>& a, const Vectorized<double>& b) {
return Vectorized<double>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
}
template <>
Vectorized<double> C10_ALWAYS_INLINE operator|(const Vectorized<double>& a, const Vectorized<double>& b) {
return Vectorized<double>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
}
template <>
Vectorized<double> C10_ALWAYS_INLINE operator^(const Vectorized<double>& a, const Vectorized<double>& b) {
return Vectorized<double>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
}
} // namespace
} // namespace vec
} // namespace at

View File

@ -0,0 +1,499 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
#include <sleef.h>
namespace at {
namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
template <>
class Vectorized<float> {
private:
union {
struct {
vfloat32 _vec0;
vfloat32 _vec1;
};
struct {
vbool32 _vecb0;
vbool32 _vecb1;
};
} __attribute__((__may_alias__));
public:
using value_type = float;
using vec_internal_type = vfloat32;
using vec_internal_mask_type = vbool32;
using size_type = int;
static constexpr size_type size() {
return 8;
}
Vectorized() {}
C10_ALWAYS_INLINE Vectorized(vfloat32 v) : _vec0{v}, _vec1{v} {}
C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
C10_ALWAYS_INLINE Vectorized(vfloat32 v1, vfloat32 v2) : _vec0{v1}, _vec1{v2} {}
C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {}
C10_ALWAYS_INLINE Vectorized(float scalar)
: _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {}
C10_ALWAYS_INLINE Vectorized(
float scalar1,
float scalar2,
float scalar3,
float scalar4,
float scalar5,
float scalar6,
float scalar7,
float scalar8)
: _vec0{vfloat32{scalar1, scalar2, scalar3, scalar4}},
_vec1{vfloat32{scalar5, scalar6, scalar7, scalar8}} {}
C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
return _vec0;
}
C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
return _vec1;
}
template <int64_t mask>
static std::enable_if_t<blendChoice(mask) == 0, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float>& a, const Vectorized<float>& b) {
return a;
}
template <int64_t mask>
static std::enable_if_t<blendChoice(mask) == 1, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float>& a, const Vectorized<float>& b) {
return b;
}
template <int64_t mask>
static std::enable_if_t<blendChoice(mask) == 2, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float>& a, const Vectorized<float>& b) {
return {b._vec0, a._vec1};
}
template <int64_t mask>
static std::enable_if_t<blendChoice(mask) == 3, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float>& a, const Vectorized<float>& b) {
return {a._vec0, b._vec1};
}
template <int64_t mask>
static std::enable_if_t<blendChoice(mask) == 4, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float>& a, const Vectorized<float>& b) {
const vbool32 mask_1st = VsxMask1(mask);
return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1};
}
template <int64_t mask>
static std::enable_if_t<blendChoice(mask) == 5, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float>& a, const Vectorized<float>& b) {
const vbool32 mask_1st = VsxMask1(mask);
return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1};
}
template <int64_t mask>
static std::enable_if_t<blendChoice(mask) == 6, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float>& a, const Vectorized<float>& b) {
const vbool32 mask_2nd = VsxMask2(mask);
// generated masks
return {a._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
}
template <int64_t mask>
static std::enable_if_t<blendChoice(mask) == 7, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float>& a, const Vectorized<float>& b) {
const vbool32 mask_2nd = VsxMask2(mask);
// generated masks
return {b._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
}
template <int64_t mask>
static std::enable_if_t<blendChoice(mask) == 8, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float>& a, const Vectorized<float>& b) {
const vbool32 mask_1st = VsxMask1(mask);
const vbool32 mask_2nd = VsxMask2(mask);
return {
(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st),
(vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
}
static Vectorized<float> C10_ALWAYS_INLINE blendv(
const Vectorized<float>& a,
const Vectorized<float>& b,
const Vectorized<float>& mask) {
// the mask used here returned by comparision of vec256
// assuming this we can use the same mask directly with vec_sel
return {
vec_sel(a._vec0, b._vec0, mask._vecb0),
vec_sel(a._vec1, b._vec1, mask._vecb1)};
}
template <typename step_t>
static Vectorized<float> arange(float base = 0.f, step_t step = static_cast<step_t>(1)) {
return Vectorized<float>(
base,
base + step,
base + 2 * step,
base + 3 * step,
base + 4 * step,
base + 5 * step,
base + 6 * step,
base + 7 * step);
}
static Vectorized<float> set(
const Vectorized<float>& a,
const Vectorized<float>& b,
size_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
case 4:
return blend<15>(a, b);
case 5:
return blend<31>(a, b);
case 6:
return blend<63>(a, b);
case 7:
return blend<127>(a, b);
}
return b;
}
static Vectorized<value_type> C10_ALWAYS_INLINE
loadu(const void* ptr, int count = size()) {
if (count == size()) {
return {
vec_vsx_ld(offset0, reinterpret_cast<const value_type*>(ptr)),
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
}
__at_align__ value_type tmp_values[size()] = {};
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
}
void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
if (count == size()) {
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
} else if (count > 0) {
__at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, tmp_values);
vec_vsx_st(_vec1, offset16, tmp_values);
std::memcpy(
ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
}
}
const float& operator[](int idx) const = delete;
float& operator[](int idx) = delete;
Vectorized<float> map(float (*const f)(float)) const {
Vectorized<float> ret;
for (int i = 0; i < size() / 2; i++) {
ret._vec0[i] = f(_vec0[i]);
}
for (int i = 0; i < size() / 2; i++) {
ret._vec1[i] = f(_vec1[i]);
}
return ret;
}
Vectorized<float> mapbi(float (*const f)(float, float), const Vectorized<float>& other)
const {
Vectorized<float> ret;
for (int i = 0; i < size() / 2; i++) {
ret._vec0[i] = f(_vec0[i], other._vec0[i]);
}
for (int i = 0; i < size() / 2; i++) {
ret._vec1[i] = f(_vec1[i], other._vec1[i]);
}
return ret;
}
Vectorized<float> _nor() const {
return {vec_nor(_vec0, _vec0), vec_nor(_vec1, _vec1)};
}
Vectorized<float> isnan() const {
auto x = *this;
auto ret = (x == x);
return ret._nor();
}
bool has_inf_nan() const {
for (const auto i : c10::irange(size()/2)) {
if(_isnan(_vec0[i]) || _isinf(_vec0[i])) {
return true;
}
}
for (const auto i : c10::irange(size()/2)) {
if(_isnan(_vec1[i]) || _isinf(_vec1[i])) {
return true;
}
}
return false;
}
int zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit
// and others are translated to 0-bit
//__m256 cmp = _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_EQ_OQ);
auto cmp = (*this == zero);
// return _mm256_movemask_ps(cmp);
// possible simulation //mask= lvsl ( 0 ) vbpermq( vec, mask <<5)
vuint64 result0 = vec_vbpermq((vuint8)cmp._vecb0, mask_zero_bits);
vuint64 result1 = vec_vbpermq((vuint8)cmp._vecb1, mask_zero_bits);
return (result0[1] >> 12 | (result1[1] >> 8));
}
Vectorized<float> C10_ALWAYS_INLINE abs() const {
return {vec_abs(_vec0), vec_abs(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE acos() const {
return {Sleef_acosf4_u10(_vec0), Sleef_acosf4_u10(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE acosh() const {
return {Sleef_acoshf4_u10(_vec0), Sleef_acoshf4_u10(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE asin() const {
return {Sleef_asinf4_u10(_vec0), Sleef_asinf4_u10(_vec1)};
}
Vectorized<float> atan() const {
return {Sleef_atanf4_u10(_vec0), Sleef_atanf4_u10(_vec1)};
}
Vectorized<float> atanh() const {
return {Sleef_atanhf4_u10(_vec0), Sleef_atanhf4_u10(_vec1)};
}
Vectorized<float> atan2(const Vectorized<float>& b) const {
return {Sleef_atan2f4_u10(_vec0, b._vec0), Sleef_atan2f4_u10(_vec1, b._vec1)};
}
Vectorized<float> copysign(const Vectorized<float> &sign) const {
return {Sleef_copysignf4(_vec0, sign._vec0), Sleef_copysignf4(_vec1, sign._vec1)};
}
Vectorized<float> lgamma() const {
return {Sleef_lgammaf4_u10(_vec0), Sleef_lgammaf4_u10(_vec1)};
}
Vectorized<float> erf() const {
return {Sleef_erff4_u10(_vec0), Sleef_erff4_u10(_vec1)};
}
Vectorized<float> erfc() const {
return {Sleef_erfcf4_u15(_vec0), Sleef_erfcf4_u15(_vec1)};
}
Vectorized<float> erfinv() const {
return map(calc_erfinv);
}
Vectorized<float> angle() const {
auto tmp = blendv(
Vectorized<float>(0), Vectorized<float>(c10::pi<float>), *this < Vectorized<float>(0));
return blendv(tmp, *this, isnan());
}
Vectorized<float> real() const {
return *this;
}
Vectorized<float> imag() const {
return Vectorized<float>{0};
}
Vectorized<float> conj() const {
return *this;
}
Vectorized<float> C10_ALWAYS_INLINE exp() const {
return {Sleef_expf4_u10(_vec0), Sleef_expf4_u10(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE exp2() const {
return {Sleef_exp2f4_u10(_vec0), Sleef_exp2f4_u10(_vec1)};
}
Vectorized<float> expm1() const {
return {Sleef_expm1f4_u10(_vec0), Sleef_expm1f4_u10(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE exp_u20() const {
return exp();
}
Vectorized<float> C10_ALWAYS_INLINE log() const {
return {Sleef_logf4_u10(_vec0), Sleef_logf4_u10(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE log10() const {
return {Sleef_log10f4_u10(_vec0), Sleef_log10f4_u10(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE log1p() const {
return {Sleef_log1pf4_u10(_vec0), Sleef_log1pf4_u10(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE log2() const {
return {Sleef_log2f4_u10(_vec0), Sleef_log2f4_u10(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE ceil() const {
return {vec_ceil(_vec0), vec_ceil(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE cos() const {
return {Sleef_cosf4_u10(_vec0), Sleef_cosf4_u10(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE cosh() const {
return {Sleef_coshf4_u10(_vec0), Sleef_coshf4_u10(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE floor() const {
return {vec_floor(_vec0), vec_floor(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE neg() const {
return {vec_neg(_vec0), vec_neg(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE round() const {
return {vec_round(_vec0), vec_round(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE sin() const {
return {Sleef_sinf4_u10(_vec0), Sleef_sinf4_u10(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE sinh() const {
return {Sleef_sinhf4_u10(_vec0), Sleef_sinhf4_u10(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE tan() const {
return {Sleef_tanf4_u10(_vec0), Sleef_tanf4_u10(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE tanh() const {
return {Sleef_tanhf4_u10(_vec0), Sleef_tanhf4_u10(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE trunc() const {
return {vec_trunc(_vec0), vec_trunc(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE frac() const {
return *this - trunc();
}
Vectorized<float> C10_ALWAYS_INLINE sqrt() const {
return {vec_sqrt(_vec0), vec_sqrt(_vec1)};
}
Vectorized<float> C10_ALWAYS_INLINE reciprocal() const {
return Vectorized<float>(one) / (*this);
}
Vectorized<float> C10_ALWAYS_INLINE rsqrt() const {
return sqrt().reciprocal();
}
Vectorized<float> C10_ALWAYS_INLINE pow(const Vectorized<float>& exp) const {
return {Sleef_powf4_u10(_vec0, exp._vec0), Sleef_powf4_u10(_vec1, exp._vec1)};
}
Vectorized<float> fmod(const Vectorized<float>& b) const {
return {Sleef_fmodf4(_vec0, b._vec0),Sleef_fmodf4(_vec1, b._vec1)};
}
Vectorized<float> hypot(const Vectorized<float>& b) const {
return {Sleef_hypotf4_u05(_vec0, b._vec0), Sleef_hypotf4_u05(_vec1, b._vec1)};
}
Vectorized<float> nextafter(const Vectorized<float>& b) const {
return {Sleef_nextafterf4(_vec0, b._vec0), Sleef_nextafterf4(_vec1, b._vec1)};
}
Vectorized<float> igamma(const Vectorized<float>& x) const {
return mapbi(calc_igamma, x);
}
Vectorized<float> igammac(const Vectorized<float>& x) const {
return mapbi(calc_igammac, x);
}
Vectorized<float> i0() const {
return map(calc_i0);
}
Vectorized<float> i0e() const {
return map(calc_i0e);
}
Vectorized<float> digamma() const {
return map(calc_digamma);
}
DEFINE_MEMBER_OP(operator==, float, vec_cmpeq)
DEFINE_MEMBER_OP(operator!=, float, vec_cmpne)
DEFINE_MEMBER_OP(operator<, float, vec_cmplt)
DEFINE_MEMBER_OP(operator<=, float, vec_cmple)
DEFINE_MEMBER_OP(operator>, float, vec_cmpgt)
DEFINE_MEMBER_OP(operator>=, float, vec_cmpge)
DEFINE_MEMBER_OP_AND_ONE(eq, float, vec_cmpeq)
DEFINE_MEMBER_OP_AND_ONE(ne, float, vec_cmpne)
DEFINE_MEMBER_OP_AND_ONE(lt, float, vec_cmplt)
DEFINE_MEMBER_OP_AND_ONE(le, float, vec_cmple)
DEFINE_MEMBER_OP_AND_ONE(gt, float, vec_cmpgt)
DEFINE_MEMBER_OP_AND_ONE(ge, float, vec_cmpge)
DEFINE_MEMBER_OP(operator+, float, vec_add)
DEFINE_MEMBER_OP(operator-, float, vec_sub)
DEFINE_MEMBER_OP(operator*, float, vec_mul)
DEFINE_MEMBER_OP(operator/, float, vec_div)
DEFINE_MEMBER_OP(maximum, float, vec_max_nan2)
DEFINE_MEMBER_OP(minimum, float, vec_min_nan2)
DEFINE_MEMBER_OP(operator&, float, vec_and)
DEFINE_MEMBER_OP(operator|, float, vec_or)
DEFINE_MEMBER_OP(operator^, float, vec_xor)
DEFINE_MEMBER_TERNARY_OP(madd, float, vec_madd)
};
template <>
Vectorized<float> inline maximum(const Vectorized<float>& a, const Vectorized<float>& b) {
return a.maximum(b);
}
template <>
Vectorized<float> inline minimum(const Vectorized<float>& a, const Vectorized<float>& b) {
return a.minimum(b);
}
template <>
Vectorized<float> C10_ALWAYS_INLINE operator+(const Vectorized<float>& a, const Vectorized<float>& b) {
return Vectorized<float>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
}
template <>
Vectorized<float> C10_ALWAYS_INLINE operator-(const Vectorized<float>& a, const Vectorized<float>& b) {
return Vectorized<float>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
}
template <>
Vectorized<float> C10_ALWAYS_INLINE operator*(const Vectorized<float>& a, const Vectorized<float>& b) {
return Vectorized<float>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())};
}
template <>
Vectorized<float> C10_ALWAYS_INLINE operator/(const Vectorized<float>& a, const Vectorized<float>& b) {
return Vectorized<float>{vec_div(a.vec0(), b.vec0()), vec_div(a.vec1(), b.vec1())};
}
template <>
Vectorized<float> C10_ALWAYS_INLINE operator&(const Vectorized<float>& a, const Vectorized<float>& b) {
return Vectorized<float>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
}
template <>
Vectorized<float> C10_ALWAYS_INLINE operator|(const Vectorized<float>& a, const Vectorized<float>& b) {
return Vectorized<float>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
}
template <>
Vectorized<float> C10_ALWAYS_INLINE operator^(const Vectorized<float>& a, const Vectorized<float>& b) {
return Vectorized<float>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
}
} // namespace
} // namespace vec
} // namespace at

View File

@ -0,0 +1,402 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
namespace at {
namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
template <>
class Vectorized<int16_t> {
private:
union {
struct {
vint16 _vec0;
vint16 _vec1;
};
struct {
vbool16 _vecb0;
vbool16 _vecb1;
};
} __attribute__((__may_alias__));
public:
using value_type = int16_t;
using vec_internal_type = vint16;
using vec_internal_mask_type = vbool16;
using size_type = int;
static constexpr size_type size() {
return 16;
}
Vectorized() {}
C10_ALWAYS_INLINE Vectorized(vint16 v) : _vec0{v}, _vec1{v} {}
C10_ALWAYS_INLINE Vectorized(vbool16 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
C10_ALWAYS_INLINE Vectorized(vint16 v1, vint16 v2) : _vec0{v1}, _vec1{v2} {}
C10_ALWAYS_INLINE Vectorized(vbool16 v1, vbool16 v2) : _vecb0{v1}, _vecb1{v2} {}
C10_ALWAYS_INLINE Vectorized(int16_t scalar)
: _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {}
C10_ALWAYS_INLINE Vectorized(
int16_t scalar1,
int16_t scalar2,
int16_t scalar3,
int16_t scalar4,
int16_t scalar5,
int16_t scalar6,
int16_t scalar7,
int16_t scalar8,
int16_t scalar9,
int16_t scalar10,
int16_t scalar11,
int16_t scalar12,
int16_t scalar13,
int16_t scalar14,
int16_t scalar15,
int16_t scalar16)
: _vec0{vint16{
scalar1,
scalar2,
scalar3,
scalar4,
scalar5,
scalar6,
scalar7,
scalar8}},
_vec1{vint16{
scalar9,
scalar10,
scalar11,
scalar12,
scalar13,
scalar14,
scalar15,
scalar16}} {}
C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
return _vec0;
}
C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
return _vec1;
}
template <uint64_t mask>
static std::enable_if_t<mask == 0, Vectorized<int16_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
return a;
}
template <uint64_t mask>
static std::enable_if_t<(mask & 65535) == 65535, Vectorized<int16_t>>
C10_ALWAYS_INLINE blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
return b;
}
template <uint64_t mask>
static std::enable_if_t<mask == 255, Vectorized<int16_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
return {b._vec0, a._vec1};
}
template <uint64_t mask>
static std::enable_if_t<(mask > 0 && mask < 255), Vectorized<int16_t>>
C10_ALWAYS_INLINE blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
constexpr int16_t g0 = (mask & 1) * 0xffff;
constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff;
constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff;
constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff;
constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff;
constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff;
constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff;
constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff;
const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7};
return {(vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st), a._vec1};
}
template <uint64_t mask>
static std::enable_if_t<
(mask > 255 && (mask & 65535) != 65535 && ((mask & 255) == 255)),
Vectorized<int16_t>>
C10_ALWAYS_INLINE blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
constexpr int16_t g0_2 = (mask & 1) * 0xffff;
constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff;
constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff;
constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff;
constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff;
constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff;
constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff;
constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff;
const vint16 mask_2nd =
vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2};
// generated masks
return {b._vec0, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)};
}
template <uint64_t mask>
static std::enable_if_t<
(mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) == 0)),
Vectorized<int16_t>>
C10_ALWAYS_INLINE blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
constexpr int16_t mask2 = (mask & 65535) >> 16;
constexpr int16_t g0_2 = (mask & 1) * 0xffff;
constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff;
constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff;
constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff;
constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff;
constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff;
constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff;
constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff;
const vint16 mask_2nd =
vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2};
// generated masks
return {a, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)};
}
template <uint64_t mask>
static std::enable_if_t<
(mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) != 0) &&
((mask & 255) != 255)),
Vectorized<int16_t>>
C10_ALWAYS_INLINE blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
constexpr int16_t g0 = (mask & 1) * 0xffff;
constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff;
constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff;
constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff;
constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff;
constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff;
constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff;
constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff;
constexpr int16_t mask2 = (mask & 65535) >> 16;
constexpr int16_t g0_2 = (mask & 1) * 0xffff;
constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff;
constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff;
constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff;
constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff;
constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff;
constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff;
constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff;
const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7};
const vint16 mask_2nd =
vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2};
// generated masks
return {
(vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st),
(vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)};
}
static Vectorized<int16_t> C10_ALWAYS_INLINE blendv(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b,
const Vectorized<int16_t>& mask) {
// the mask used here returned by comparision of vec256
// assuming this we can use the same mask directly with vec_sel
// warning intel style mask will not work properly
return {
vec_sel(a._vec0, b._vec0, mask._vecb0),
vec_sel(a._vec1, b._vec1, mask._vecb1)};
}
template <typename step_t>
static Vectorized<int16_t> arange(int16_t base = 0, step_t step = static_cast<step_t>(1)) {
return Vectorized<int16_t>(
base,
base + step,
base + 2 * step,
base + 3 * step,
base + 4 * step,
base + 5 * step,
base + 6 * step,
base + 7 * step,
base + 8 * step,
base + 9 * step,
base + 10 * step,
base + 11 * step,
base + 12 * step,
base + 13 * step,
base + 14 * step,
base + 15 * step);
}
static Vectorized<int16_t> set(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b,
size_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
case 4:
return blend<15>(a, b);
case 5:
return blend<31>(a, b);
case 6:
return blend<63>(a, b);
case 7:
return blend<127>(a, b);
case 8:
return blend<255>(a, b);
case 9:
return blend<511>(a, b);
case 10:
return blend<1023>(a, b);
case 11:
return blend<2047>(a, b);
case 12:
return blend<4095>(a, b);
case 13:
return blend<8191>(a, b);
case 14:
return blend<16383>(a, b);
case 15:
return blend<32767>(a, b);
}
return b;
}
static Vectorized<value_type> C10_ALWAYS_INLINE
loadu(const void* ptr, int count = size()) {
if (count == size()) {
return {
vec_vsx_ld(offset0, reinterpret_cast<const value_type*>(ptr)),
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
}
__at_align__ value_type tmp_values[size()] = {};
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
}
void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
if (count == size()) {
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
} else if (count > 0) {
__at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, tmp_values);
vec_vsx_st(_vec1, offset16, tmp_values);
std::memcpy(ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
}
}
const int16_t& operator[](int idx) const = delete;
int16_t& operator[](int idx) = delete;
Vectorized<int16_t> angle() const {
return blendv(
Vectorized<int16_t>(0), Vectorized<int16_t>(c10::pi<int16_t>), *this < Vectorized<int16_t>(0));
}
Vectorized<int16_t> real() const {
return *this;
}
Vectorized<int16_t> imag() const {
return Vectorized<int16_t>{0};
}
Vectorized<int16_t> conj() const {
return *this;
}
Vectorized<int16_t> C10_ALWAYS_INLINE abs() const {
return {vec_abs(_vec0), vec_abs(_vec1)};
}
Vectorized<int16_t> C10_ALWAYS_INLINE neg() const {
return {vec_neg(_vec0), vec_neg(_vec1)};
}
DEFINE_MEMBER_UNARY_OP(operator~, int16_t, vec_not)
DEFINE_MEMBER_OP(operator==, int16_t, vec_cmpeq)
DEFINE_MEMBER_OP(operator!=, int16_t, vec_cmpne)
DEFINE_MEMBER_OP(operator<, int16_t, vec_cmplt)
DEFINE_MEMBER_OP(operator<=, int16_t, vec_cmple)
DEFINE_MEMBER_OP(operator>, int16_t, vec_cmpgt)
DEFINE_MEMBER_OP(operator>=, int16_t, vec_cmpge)
DEFINE_MEMBER_OP_AND_ONE(eq, int16_t, vec_cmpeq)
DEFINE_MEMBER_OP_AND_ONE(ne, int16_t, vec_cmpne)
DEFINE_MEMBER_OP_AND_ONE(lt, int16_t, vec_cmplt)
DEFINE_MEMBER_OP_AND_ONE(le, int16_t, vec_cmple)
DEFINE_MEMBER_OP_AND_ONE(gt, int16_t, vec_cmpgt)
DEFINE_MEMBER_OP_AND_ONE(ge, int16_t, vec_cmpge)
DEFINE_MEMBER_OP(operator+, int16_t, vec_add)
DEFINE_MEMBER_OP(operator-, int16_t, vec_sub)
DEFINE_MEMBER_OP(operator*, int16_t, vec_mul)
DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, int16_t, /)
DEFINE_MEMBER_OP(maximum, int16_t, vec_max)
DEFINE_MEMBER_OP(minimum, int16_t, vec_min)
DEFINE_MEMBER_OP(operator&, int16_t, vec_and)
DEFINE_MEMBER_OP(operator|, int16_t, vec_or)
DEFINE_MEMBER_OP(operator^, int16_t, vec_xor)
};
template <>
Vectorized<int16_t> inline operator<<(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
vuint16 shift_vec0 = reinterpret_cast<vuint16>(b.vec0());
vuint16 shift_vec1 = reinterpret_cast<vuint16>(b.vec1());
return Vectorized<int16_t>{vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)};
}
template <>
Vectorized<int16_t> inline operator>>(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
vuint16 shift_vec0 = reinterpret_cast<vuint16>(b.vec0());
vuint16 shift_vec1 = reinterpret_cast<vuint16>(b.vec1()) ;
return Vectorized<int16_t>{vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)};
}
template <>
Vectorized<int16_t> inline maximum(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b) {
return a.maximum(b);
}
template <>
Vectorized<int16_t> inline minimum(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b) {
return a.minimum(b);
}
template <>
Vectorized<int16_t> C10_ALWAYS_INLINE operator+(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
return Vectorized<int16_t>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
}
template <>
Vectorized<int16_t> C10_ALWAYS_INLINE operator-(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
return Vectorized<int16_t>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
}
template <>
Vectorized<int16_t> C10_ALWAYS_INLINE operator*(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
return Vectorized<int16_t>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())};
}
template <>
Vectorized<int16_t> C10_ALWAYS_INLINE operator/(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
return Vectorized<int16_t>{a.vec0()/b.vec0(), a.vec1()/b.vec1()};
}
template <>
Vectorized<int16_t> C10_ALWAYS_INLINE operator&(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
return Vectorized<int16_t>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
}
template <>
Vectorized<int16_t> C10_ALWAYS_INLINE operator|(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
return Vectorized<int16_t>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
}
template <>
Vectorized<int16_t> C10_ALWAYS_INLINE operator^(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
return Vectorized<int16_t>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
}
} // namespace
} // namespace vec
} // namespace at

View File

@ -0,0 +1,333 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
namespace at {
namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
template <>
class Vectorized<int32_t> {
private:
union {
struct {
vint32 _vec0;
vint32 _vec1;
};
struct {
vbool32 _vecb0;
vbool32 _vecb1;
};
} __attribute__((__may_alias__));
public:
using value_type = int32_t;
using vec_internal_type = vint32;
using vec_internal_mask_type = vbool32;
using size_type = int;
static constexpr size_type size() {
return 8;
}
Vectorized() {}
C10_ALWAYS_INLINE Vectorized(vint32 v) : _vec0{v}, _vec1{v} {}
C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
C10_ALWAYS_INLINE Vectorized(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {}
C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {}
C10_ALWAYS_INLINE Vectorized(int32_t scalar)
: _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {}
C10_ALWAYS_INLINE Vectorized(
int32_t scalar1,
int32_t scalar2,
int32_t scalar3,
int32_t scalar4,
int32_t scalar5,
int32_t scalar6,
int32_t scalar7,
int32_t scalar8)
: _vec0{vint32{scalar1, scalar2, scalar3, scalar4}},
_vec1{vint32{scalar5, scalar6, scalar7, scalar8}} {}
C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
return _vec0;
}
C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
return _vec1;
}
template <uint64_t mask>
static std::enable_if_t<mask == 0, Vectorized<int32_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
return a;
}
template <uint64_t mask>
static std::enable_if_t<(mask & 255) == 255, Vectorized<int32_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
return b;
}
template <uint64_t mask>
static std::enable_if_t<mask == 15, Vectorized<int32_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
return {b._vec0, a._vec1};
}
template <uint64_t mask>
static std::enable_if_t<(mask > 0 && mask < 15), Vectorized<int32_t>>
C10_ALWAYS_INLINE blend(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
constexpr uint32_t g0 = (mask & 1) * 0xffffffff;
constexpr uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff;
constexpr uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff;
constexpr uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff;
const vbool32 mask_1st = (vbool32){g0, g1, g2, g3};
return {(vint32)vec_sel(a._vec0, b._vec0, (vbool32)mask_1st), a._vec1};
}
template <uint64_t mask>
static std::enable_if_t<
(mask > 15 && (mask & 255) != 255 && ((mask & 15) == 15)),
Vectorized<int32_t>>
C10_ALWAYS_INLINE blend(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
constexpr uint32_t mask2 = (mask & 255) >> 4;
constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff;
constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff;
constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff;
constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff;
const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2};
// generated masks
return {b._vec0, (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)};
}
template <uint64_t mask>
static std::enable_if_t<
(mask > 15 && ((mask & 255) != 255) && ((mask & 15) == 0)),
Vectorized<int32_t>>
C10_ALWAYS_INLINE blend(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
constexpr uint32_t mask2 = (mask & 255) >> 4;
constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff;
constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff;
constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff;
constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff;
const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2};
// generated masks
return {a, (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)};
}
template <uint64_t mask>
static std::enable_if_t<
(mask > 15 && ((mask & 255) != 255) && ((mask & 15) != 0) &&
((mask & 15) != 15)),
Vectorized<int32_t>>
C10_ALWAYS_INLINE blend(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
constexpr uint32_t g0 = (mask & 1) * 0xffffffff;
constexpr uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff;
constexpr uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff;
constexpr uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff;
constexpr uint32_t mask2 = (mask & 255) >> 4;
constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff;
constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff;
constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff;
constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff;
const vbool32 mask_1st = (vbool32){g0, g1, g2, g3};
const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2};
// generated masks
return {
(vint32)vec_sel(a._vec0, b._vec0, (vbool32)mask_1st),
(vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)};
}
static Vectorized<int32_t> C10_ALWAYS_INLINE blendv(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b,
const Vectorized<int32_t>& mask) {
// the mask used here returned by comparision of vec256
// assuming this we can use the same mask directly with vec_sel
// warning intel style mask will not work properly
return {
vec_sel(a._vec0, b._vec0, mask._vecb0),
vec_sel(a._vec1, b._vec1, mask._vecb1)};
}
template <typename step_t>
static Vectorized<int32_t> arange(int32_t base = 0.f, step_t step = static_cast<step_t>(1)) {
return Vectorized<int32_t>(
base,
base + step,
base + 2 * step,
base + 3 * step,
base + 4 * step,
base + 5 * step,
base + 6 * step,
base + 7 * step);
}
static Vectorized<int32_t> set(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b,
size_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
case 4:
return blend<15>(a, b);
case 5:
return blend<31>(a, b);
case 6:
return blend<63>(a, b);
case 7:
return blend<127>(a, b);
}
return b;
}
static Vectorized<value_type> C10_ALWAYS_INLINE
loadu(const void* ptr, int count = size()) {
if (count == size()) {
return {
vec_vsx_ld(offset0, reinterpret_cast<const value_type*>(ptr)),
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
}
__at_align__ value_type tmp_values[size()] = {};
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
}
void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
if (count == size()) {
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
} else if (count > 0) {
__at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, tmp_values);
vec_vsx_st(_vec1, offset16, tmp_values);
std::memcpy(
ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
}
}
const int32_t& operator[](int idx) const = delete;
int32_t& operator[](int idx) = delete;
Vectorized<int32_t> angle() const {
return blendv(
Vectorized<int32_t>(0), Vectorized<int32_t>(c10::pi<int32_t>), *this < Vectorized<int32_t>(0));
}
Vectorized<int32_t> real() const {
return *this;
}
Vectorized<int32_t> imag() const {
return Vectorized<int32_t>{0};
}
Vectorized<int32_t> conj() const {
return *this;
}
Vectorized<int32_t> C10_ALWAYS_INLINE abs() const {
return {vec_abs(_vec0), vec_abs(_vec1)};
}
Vectorized<int32_t> C10_ALWAYS_INLINE neg() const {
return {vec_neg(_vec0), vec_neg(_vec1)};
}
DEFINE_MEMBER_UNARY_OP(operator~, int32_t, vec_not)
DEFINE_MEMBER_OP(operator==, int32_t, vec_cmpeq)
DEFINE_MEMBER_OP(operator!=, int32_t, vec_cmpne)
DEFINE_MEMBER_OP(operator<, int32_t, vec_cmplt)
DEFINE_MEMBER_OP(operator<=, int32_t, vec_cmple)
DEFINE_MEMBER_OP(operator>, int32_t, vec_cmpgt)
DEFINE_MEMBER_OP(operator>=, int32_t, vec_cmpge)
DEFINE_MEMBER_OP_AND_ONE(eq, int32_t, vec_cmpeq)
DEFINE_MEMBER_OP_AND_ONE(ne, int32_t, vec_cmpne)
DEFINE_MEMBER_OP_AND_ONE(lt, int32_t, vec_cmplt)
DEFINE_MEMBER_OP_AND_ONE(le, int32_t, vec_cmple)
DEFINE_MEMBER_OP_AND_ONE(gt, int32_t, vec_cmpgt)
DEFINE_MEMBER_OP_AND_ONE(ge, int32_t, vec_cmpge)
DEFINE_MEMBER_OP(operator+, int32_t, vec_add)
DEFINE_MEMBER_OP(operator-, int32_t, vec_sub)
DEFINE_MEMBER_OP(operator*, int32_t, vec_mul)
DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, int32_t, /)
DEFINE_MEMBER_OP(maximum, int32_t, vec_max)
DEFINE_MEMBER_OP(minimum, int32_t, vec_min)
DEFINE_MEMBER_OP(operator&, int32_t, vec_and)
DEFINE_MEMBER_OP(operator|, int32_t, vec_or)
DEFINE_MEMBER_OP(operator^, int32_t, vec_xor)
};
template <>
Vectorized<int32_t> inline operator<<(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
vuint32 shift_vec0 = reinterpret_cast<vuint32>(b.vec0());
vuint32 shift_vec1 = reinterpret_cast<vuint32>(b.vec1()) ;
return Vectorized<int32_t>{vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)};
}
template <>
Vectorized<int32_t> inline operator>>(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
vuint32 shift_vec0 = reinterpret_cast<vuint32>(b.vec0());
vuint32 shift_vec1 = reinterpret_cast<vuint32>(b.vec1()) ;
return Vectorized<int32_t>{vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)};
}
template <>
Vectorized<int32_t> inline maximum(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b) {
return a.maximum(b);
}
template <>
Vectorized<int32_t> inline minimum(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b) {
return a.minimum(b);
}
template <>
Vectorized<int32_t> C10_ALWAYS_INLINE operator+(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
return Vectorized<int32_t>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
}
template <>
Vectorized<int32_t> C10_ALWAYS_INLINE operator-(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
return Vectorized<int32_t>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
}
template <>
Vectorized<int32_t> C10_ALWAYS_INLINE operator*(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
return Vectorized<int32_t>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())};
}
template <>
Vectorized<int32_t> C10_ALWAYS_INLINE operator/(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
return Vectorized<int32_t>{a.vec0()/b.vec0(), a.vec1()/b.vec1()};
}
template <>
Vectorized<int32_t> C10_ALWAYS_INLINE operator&(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
return Vectorized<int32_t>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
}
template <>
Vectorized<int32_t> C10_ALWAYS_INLINE operator|(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
return Vectorized<int32_t>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
}
template <>
Vectorized<int32_t> C10_ALWAYS_INLINE operator^(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
return Vectorized<int32_t>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
}
} // namespace
} // namespace vec
} // namespace at

View File

@ -0,0 +1,286 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
namespace at {
namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
template <>
class Vectorized<int64_t> {
private:
union {
struct {
vint64 _vec0;
vint64 _vec1;
};
struct {
vbool64 _vecb0;
vbool64 _vecb1;
};
} __attribute__((__may_alias__));
public:
using value_type = int64_t;
using vec_internal_type = vint64;
using vec_internal_mask_type = vbool64;
using size_type = int;
using ElementType = signed long long;
static constexpr size_type size() {
return 4;
}
Vectorized() {}
C10_ALWAYS_INLINE Vectorized(vint64 v) : _vec0{v}, _vec1{v} {}
C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
C10_ALWAYS_INLINE Vectorized(vint64 v1, vint64 v2) : _vec0{v1}, _vec1{v2} {}
C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) : _vecb0{v1}, _vecb1{v2} {}
C10_ALWAYS_INLINE Vectorized(int64_t scalar)
: _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {}
C10_ALWAYS_INLINE Vectorized(
int64_t scalar1,
int64_t scalar2,
int64_t scalar3,
int64_t scalar4)
: _vec0{vint64{scalar1, scalar2}}, _vec1{vint64{scalar3, scalar4}} {}
C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
return _vec0;
}
C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
return _vec1;
}
template <uint64_t mask>
static std::enable_if_t<mask == 0, Vectorized<int64_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
return a;
}
template <uint64_t mask>
static std::enable_if_t<mask == 3, Vectorized<int64_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
return {b._vec0, a._vec1};
}
template <uint64_t mask>
static std::enable_if_t<(mask & 15) == 15, Vectorized<int64_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
return b;
}
template <uint64_t mask>
static std::enable_if_t<(mask > 0 && mask < 3), Vectorized<int64_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
constexpr uint64_t g0 = (mask & 1) * 0xffffffffffffffff;
constexpr uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff;
const vbool64 mask_1st = (vbool64){g0, g1};
return {(vint64)vec_sel(a._vec0, b._vec0, (vbool64)mask_1st), a._vec1};
}
template <uint64_t mask>
static std::enable_if_t<(mask > 3) && (mask & 3) == 0, Vectorized<int64_t>>
C10_ALWAYS_INLINE blend(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
constexpr uint64_t g0_2 = ((mask & 4) >> 2) * 0xffffffffffffffff;
constexpr uint64_t g1_2 = ((mask & 8) >> 3) * 0xffffffffffffffff;
const vbool64 mask_2nd = (vbool64){g0_2, g1_2};
return {a._vec0, (vint64)vec_sel(a._vec1, b._vec1, (vbool64)mask_2nd)};
}
template <uint64_t mask>
static std::enable_if_t<
(mask > 3) && (mask & 3) != 0 && (mask & 15) != 15,
Vectorized<int64_t>>
C10_ALWAYS_INLINE blend(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
constexpr uint64_t g0 = (mask & 1) * 0xffffffffffffffff;
constexpr uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff;
constexpr uint64_t g0_2 = ((mask & 4) >> 2) * 0xffffffffffffffff;
constexpr uint64_t g1_2 = ((mask & 8) >> 3) * 0xffffffffffffffff;
const vbool64 mask_1st = (vbool64){g0, g1};
const vbool64 mask_2nd = (vbool64){g0_2, g1_2};
return {
(vint64)vec_sel(a._vec0, b._vec0, (vbool64)mask_1st),
(vint64)vec_sel(a._vec1, b._vec1, (vbool64)mask_2nd)};
}
static Vectorized<int64_t> C10_ALWAYS_INLINE blendv(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b,
const Vectorized<int64_t>& mask) {
// the mask used here returned by comparision of vec256
return {
vec_sel(a._vec0, b._vec0, mask._vecb0),
vec_sel(a._vec1, b._vec1, mask._vecb1)};
}
template <typename step_t>
static Vectorized<int64_t> arange(int64_t base = 0., step_t step = static_cast<step_t>(1)) {
return Vectorized<int64_t>(base, base + step, base + 2 * step, base + 3 * step);
}
static Vectorized<int64_t> C10_ALWAYS_INLINE
set(const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b,
size_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
}
return b;
}
static Vectorized<value_type> C10_ALWAYS_INLINE
loadu(const void* ptr, int count = size()) {
if (count == size()) {
static_assert(sizeof(double) == sizeof(value_type));
const double* dptr = reinterpret_cast<const double*>(ptr);
return {// treat it as double load
(vint64)vec_vsx_ld(offset0, dptr),
(vint64)vec_vsx_ld(offset16, dptr)};
}
__at_align__ double tmp_values[size()] = {};
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return {
(vint64)vec_vsx_ld(offset0, tmp_values),
(vint64)vec_vsx_ld(offset16, tmp_values)};
}
void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
if (count == size()) {
double* dptr = reinterpret_cast<double*>(ptr);
vec_vsx_st((vfloat64)_vec0, offset0, dptr);
vec_vsx_st((vfloat64)_vec1, offset16, dptr);
} else if (count > 0) {
__at_align__ double tmp_values[size()];
vec_vsx_st((vfloat64)_vec0, offset0, tmp_values);
vec_vsx_st((vfloat64)_vec1, offset16, tmp_values);
std::memcpy(
ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
}
}
const int64_t& operator[](int idx) const = delete;
int64_t& operator[](int idx) = delete;
Vectorized<int64_t> angle() const {
return blendv(
Vectorized<int64_t>(0), Vectorized<int64_t>(c10::pi<int64_t>), *this < Vectorized<int64_t>(0));
}
Vectorized<int64_t> real() const {
return *this;
}
Vectorized<int64_t> imag() const {
return Vectorized<int64_t>{0};
}
Vectorized<int64_t> conj() const {
return *this;
}
Vectorized<int64_t> C10_ALWAYS_INLINE abs() const {
return {vec_abs(_vec0), vec_abs(_vec1)};
}
Vectorized<int64_t> C10_ALWAYS_INLINE neg() const {
return {vec_neg(_vec0), vec_neg(_vec1)};
}
DEFINE_MEMBER_UNARY_OP(operator~, int64_t, vec_not)
DEFINE_MEMBER_OP(operator==, int64_t, vec_cmpeq)
DEFINE_MEMBER_OP(operator!=, int64_t, vec_cmpne)
DEFINE_MEMBER_OP(operator<, int64_t, vec_cmplt)
DEFINE_MEMBER_OP(operator<=, int64_t, vec_cmple)
DEFINE_MEMBER_OP(operator>, int64_t, vec_cmpgt)
DEFINE_MEMBER_OP(operator>=, int64_t, vec_cmpge)
DEFINE_MEMBER_OP_AND_ONE(eq, int64_t, vec_cmpeq)
DEFINE_MEMBER_OP_AND_ONE(ne, int64_t, vec_cmpne)
DEFINE_MEMBER_OP_AND_ONE(lt, int64_t, vec_cmplt)
DEFINE_MEMBER_OP_AND_ONE(le, int64_t, vec_cmple)
DEFINE_MEMBER_OP_AND_ONE(gt, int64_t, vec_cmpgt)
DEFINE_MEMBER_OP_AND_ONE(ge, int64_t, vec_cmpge)
DEFINE_MEMBER_OP(operator+, int64_t, vec_add)
DEFINE_MEMBER_OP(operator-, int64_t, vec_sub)
DEFINE_MEMBER_OP(operator*, int64_t, vec_mul)
DEFINE_MEMBER_OP(operator/, int64_t, vec_div)
DEFINE_MEMBER_OP(maximum, int64_t, vec_max)
DEFINE_MEMBER_OP(minimum, int64_t, vec_min)
DEFINE_MEMBER_OP(operator&, int64_t, vec_and)
DEFINE_MEMBER_OP(operator|, int64_t, vec_or)
DEFINE_MEMBER_OP(operator^, int64_t, vec_xor)
};
template <>
Vectorized<int64_t> inline operator<<(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
vuint64 shift_vec0 = reinterpret_cast<vuint64>(b.vec0());
vuint64 shift_vec1 = reinterpret_cast<vuint64>(b.vec1()) ;
return Vectorized<int64_t>{vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)};
}
template <>
Vectorized<int64_t> inline operator>>(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
vuint64 shift_vec0 = reinterpret_cast<vuint64>(b.vec0());
vuint64 shift_vec1 = reinterpret_cast<vuint64>(b.vec1()) ;
return Vectorized<int64_t>{vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)};
}
template <>
Vectorized<int64_t> inline maximum(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b) {
return a.maximum(b);
}
template <>
Vectorized<int64_t> inline minimum(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b) {
return a.minimum(b);
}
template <>
Vectorized<int64_t> C10_ALWAYS_INLINE operator+(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
return Vectorized<int64_t>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
}
template <>
Vectorized<int64_t> C10_ALWAYS_INLINE operator-(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
return Vectorized<int64_t>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
}
template <>
Vectorized<int64_t> C10_ALWAYS_INLINE operator*(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
return Vectorized<int64_t>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())};
}
template <>
Vectorized<int64_t> C10_ALWAYS_INLINE operator/(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
return Vectorized<int64_t>{vec_div(a.vec0(), b.vec0()), vec_div(a.vec1(), b.vec1())};
}
template <>
Vectorized<int64_t> C10_ALWAYS_INLINE operator&(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
return Vectorized<int64_t>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
}
template <>
Vectorized<int64_t> C10_ALWAYS_INLINE operator|(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
return Vectorized<int64_t>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
}
template <>
Vectorized<int64_t> C10_ALWAYS_INLINE operator^(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
return Vectorized<int64_t>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
}
} // namespace
} // namespace vec
} // namespace at

View File

@ -0,0 +1,281 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
#include <c10/util/qint32.h>
#include <array>
// This file defines Vectorized<> for the quantized types.
//
//
// Currently, we simply use these classes as efficient converters between
// the quantized types and Vectorized<float>, usually in bandwidth-bound cases
// where doing the arithmetic in full-precision is acceptable (e.g.
// elementwise operators).
//
//
// Conversions are as follows:
// Vectorized<qint32> -> 1x Vectorized<float>
//
// The size of the returned float vector is specified by the special
// constexpr function float_num_vecs. The type of the value returned
// from dequantize (and expected as an argument to quantize) is
// specified by float_vec_return_type.
//
// When writing kernels with these vectors, it is expected that floating-
// point operations will be carried out in a loop over Vectorized<T>::float_num_vecs
// iterations.
namespace at {
namespace vec {
inline namespace CPU_CAPABILITY {
template <>
struct Vectorized<c10::qint32> {
private:
union {
struct {
vint32 _vec0;
vint32 _vec1;
};
struct {
vbool32 _vecb0;
vbool32 _vecb1;
};
} __attribute__((__may_alias__));
public:
Vectorized() {}
using size_type = int;
static constexpr size_type size() {
return 8;
}
static constexpr size_t float_num_vecs() {
return 1;
}
static constexpr int int_num_vecs() {
return 1;
}
using float_vec_return_type = std::array<Vectorized<float>, 1>;
using int_vec_return_type = std::array<Vectorized<c10::qint32>, 1>;
using value_type = c10::qint32::underlying;
using vec_internal_type = vint32;
using vec_internal_mask_type = vbool32;
C10_ALWAYS_INLINE Vectorized(vint32 v) : _vec0{v}, _vec1{v} {}
C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
C10_ALWAYS_INLINE Vectorized(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {}
C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {}
Vectorized(const c10::qint32& val)
: _vec0(vec_splats(val.val_)), _vec1(vec_splats(val.val_)) {}
static Vectorized<c10::qint32> C10_ALWAYS_INLINE
loadu(const void* ptr, int count = size()) {
if (count == size()) {
return {
vec_vsx_ld(offset0, reinterpret_cast<const value_type*>(ptr)),
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
}
__at_align__ value_type tmp_values[size()] = {};
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
}
void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
if (count == size()) {
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
} else if (count > 0) {
__at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, tmp_values);
vec_vsx_st(_vec1, offset16, tmp_values);
std::memcpy(
ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
}
}
C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
return _vec0;
}
C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
return _vec1;
}
float_vec_return_type dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point,
Vectorized<float> scale_zp_premul) const {
vfloat32 float_vals0 = vec_float(_vec0);
vfloat32 float_vals1 = vec_float(_vec1);
vfloat32 scale_vec0 = scale.vec0();
vfloat32 scale_vec1 = scale.vec1();
vfloat32 scale_zp_premul0 = scale_zp_premul.vec0();
vfloat32 scale_zp_premul1 = scale_zp_premul.vec1();
return {Vectorized<float>{
vec_madd(scale_vec0, float_vals0, scale_zp_premul0),
vec_madd(scale_vec1, float_vals1, scale_zp_premul1)}};
}
float_vec_return_type dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point) const {
vfloat32 float_vals0 = vec_float(_vec0);
vfloat32 float_vals1 = vec_float(_vec1);
vfloat32 scale_vec0 = scale.vec0();
vfloat32 scale_vec1 = scale.vec1();
vfloat32 zero_point0 = zero_point.vec0();
vfloat32 zero_point1 = zero_point.vec1();
return {Vectorized<float>{
(float_vals0 - zero_point0) * scale_vec0,
(float_vals1 - zero_point1) * scale_vec1}};
}
static Vectorized<c10::qint32> quantize(
const float_vec_return_type& rhs,
float scale,
int32_t zero_point,
float inverse_scale) {
Vectorized<c10::qint32> retval;
const vint32 vmin = vec_splats(std::numeric_limits<value_type>::min());
const vint32 vmax = vec_splats(std::numeric_limits<value_type>::max());
vfloat32 inverse_scale_v = vec_splats(inverse_scale);
vfloat32 vec_zero_point = vec_splats((float)(zero_point));
Vectorized<float> vf0 = rhs[0];
vfloat32 vecf0 = vf0.vec0();
vfloat32 vecf1 = vf0.vec1();
vecf0 = vec_mul(vecf0, inverse_scale_v);
vecf1 = vec_mul(vecf1, inverse_scale_v);
vecf0 = vec_add(vec_rint(vecf0), vec_zero_point);
vecf1 = vec_add(vec_rint(vecf1), vec_zero_point);
vint32 veci0 = vec_signed(vecf0);
vint32 veci1 = vec_signed(vecf1);
veci0 = vec_max(veci0, vmin);
veci1 = vec_max(veci1, vmin);
veci0 = vec_min(veci0, vmax);
veci1 = vec_min(veci1, vmax);
return {veci0, veci1};
}
Vectorized<c10::qint32> relu(Vectorized<c10::qint32> zero_point) const {
return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)};
}
Vectorized<c10::qint32> relu6(
Vectorized<c10::qint32> zero_point,
Vectorized<c10::qint32> q_six) const {
vint32 max0 = vec_max(_vec0, zero_point._vec0);
vint32 max1 = vec_max(_vec1, zero_point._vec1);
return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)};
}
int_vec_return_type widening_subtract(Vectorized<c10::qint32> b) const {
return {*this - b};
}
static Vectorized<c10::qint32> requantize_from_int(
const int_vec_return_type& inp,
float multiplier,
int32_t zero_point) {
const vint32 vmin = vec_splats(std::numeric_limits<value_type>::min());
const vint32 vmax = vec_splats(std::numeric_limits<value_type>::max());
vfloat32 vec_mult = vec_splats(multiplier);
vint32 vec_zero_point = vec_splats(zero_point);
Vectorized<c10::qint32> vi = inp[0];
vfloat32 vecf0 = vec_float(vi.vec0());
vfloat32 vecf1 = vec_float(vi.vec1());
vecf0 = vec_mul(vecf0, vec_mult);
vecf1 = vec_mul(vecf1, vec_mult);
vecf0 = vec_rint(vecf0);
vecf1 = vec_rint(vecf1);
vint32 veci0 = vec_add(vec_signed(vecf0),vec_zero_point);
vint32 veci1 = vec_add(vec_signed(vecf1),vec_zero_point);
veci0 = vec_max(veci0, vmin);
veci1 = vec_max(veci1, vmin);
veci0 = vec_min(veci0, vmax);
veci1 = vec_min(veci1, vmax);
return {veci0, veci1};
}
DEFINE_MEMBER_OP(operator==, c10::qint32, vec_cmpeq)
DEFINE_MEMBER_OP(operator!=, c10::qint32, vec_cmpne)
DEFINE_MEMBER_OP(operator<, c10::qint32, vec_cmplt)
DEFINE_MEMBER_OP(operator<=, c10::qint32, vec_cmple)
DEFINE_MEMBER_OP(operator>, c10::qint32, vec_cmpgt)
DEFINE_MEMBER_OP(operator>=, c10::qint32, vec_cmpge)
DEFINE_MEMBER_OP(operator+, c10::qint32, vec_add)
DEFINE_MEMBER_OP(operator-, c10::qint32, vec_sub)
DEFINE_MEMBER_OP(operator*, c10::qint32, vec_mul)
DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint32, /)
DEFINE_MEMBER_OP(maximum, c10::qint32, vec_max)
DEFINE_MEMBER_OP(minimum, c10::qint32, vec_min)
DEFINE_MEMBER_OP(operator&, c10::qint32, vec_and)
DEFINE_MEMBER_OP(operator|, c10::qint32, vec_or)
DEFINE_MEMBER_OP(operator^, c10::qint32, vec_xor)
};
template <>
Vectorized<c10::qint32> inline maximum(
const Vectorized<c10::qint32>& a,
const Vectorized<c10::qint32>& b) {
return a.maximum(b);
}
template <>
Vectorized<c10::qint32> inline minimum(
const Vectorized<c10::qint32>& a,
const Vectorized<c10::qint32>& b) {
return a.minimum(b);
}
template <>
Vectorized<c10::qint32> C10_ALWAYS_INLINE operator+(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
return Vectorized<c10::qint32>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
}
template <>
Vectorized<c10::qint32> C10_ALWAYS_INLINE operator-(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
return Vectorized<c10::qint32>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
}
template <>
Vectorized<c10::qint32> C10_ALWAYS_INLINE operator*(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
return Vectorized<c10::qint32>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())};
}
template <>
Vectorized<c10::qint32> C10_ALWAYS_INLINE operator/(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
return Vectorized<c10::qint32>{a.vec0()/b.vec0(), a.vec1()/b.vec1()};
}
template <>
Vectorized<c10::qint32> C10_ALWAYS_INLINE operator&(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
return Vectorized<c10::qint32>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
}
template <>
Vectorized<c10::qint32> C10_ALWAYS_INLINE operator|(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
return Vectorized<c10::qint32>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
}
template <>
Vectorized<c10::qint32> C10_ALWAYS_INLINE operator^(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
return Vectorized<c10::qint32>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
}
} // namespace
} // namespace vec
} // namespace at

View File

@ -0,0 +1,483 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
#include <c10/util/qint8.h>
#include <array>
// This file defines Vectorized<> for the quantized types.
//
//
// Currently, we simply use these classes as efficient converters between
// the quantized types and Vectorized<float>, usually in bandwidth-bound cases
// where doing the arithmetic in full-precision is acceptable (e.g.
// elementwise operators).
//
//
// Conversions are as follows:
// Vectorized<qint8> -> 4x Vectorized<float>
//
// The size of the returned float vector is specified by the special
// constexpr function float_num_vecs. The type of the value returned
// from dequantize (and expected as an argument to quantize) is
// specified by float_vec_return_type.
//
// When writing kernels with these vectors, it is expected that floating-
// point operations will be carried out in a loop over Vectorized<T>::float_num_vecs
// iterations.
namespace at {
namespace vec {
inline namespace CPU_CAPABILITY {
template <>
struct Vectorized<c10::qint8> {
private:
union {
struct {
vint8 _vec0;
vint8 _vec1;
};
struct {
vbool8 _vecb0;
vbool8 _vecb1;
};
} __attribute__((__may_alias__));
public:
Vectorized() {}
using size_type = int;
static constexpr size_type size() {
return 32;
}
static constexpr size_t float_num_vecs() {
return 4;
}
static constexpr int int_num_vecs() {
return 4;
}
using float_vec_return_type = std::array<Vectorized<float>, 4>;
using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
using value_type = typename c10::qint8::underlying;
using vec_internal_type = vint8;
using vec_internal_mask_type = vbool8;
// Broadcast constructor
C10_ALWAYS_INLINE Vectorized(const c10::qint8& val)
: _vec0{vec_splats(val.val_)}, _vec1{vec_splats(val.val_)} {}
C10_ALWAYS_INLINE Vectorized(const Vectorized<c10::qint8>& other)
: _vec0{other._vec0}, _vec1(other._vec1) {}
C10_ALWAYS_INLINE Vectorized(vint8 v) : _vec0{v}, _vec1{v} {}
C10_ALWAYS_INLINE Vectorized(vbool8 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
C10_ALWAYS_INLINE Vectorized(vint8 v1, vint8 v2) : _vec0{v1}, _vec1{v2} {}
C10_ALWAYS_INLINE Vectorized(vbool8 v1, vbool8 v2) : _vecb0{v1}, _vecb1{v2} {}
C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
return _vec0;
}
C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
return _vec1;
}
static C10_ALWAYS_INLINE Vectorized<c10::qint8> loadu(
const void* ptr,
int count = size()) {
if (count == size()) {
return {
vec_vsx_ld(offset0, reinterpret_cast<const vint8*>(ptr)),
vec_vsx_ld(offset16, reinterpret_cast<const vint8*>(ptr))};
}
__at_align__ value_type tmp_values[size()] = {};
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
}
void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
if (count == size()) {
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
} else if (count > 0) {
__at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, tmp_values);
vec_vsx_st(_vec1, offset16, tmp_values);
std::memcpy(
ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
}
}
public:
float_vec_return_type C10_ALWAYS_INLINE dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point,
Vectorized<float> scale_zp_premul) const {
vint16 vecshi0 = vec_unpackh(_vec0);
vint16 vecshi1 = vec_unpackl(_vec0);
vint16 vecshi2 = vec_unpackh(_vec1);
vint16 vecshi3 = vec_unpackl(_vec1);
vint32 veci0 = vec_unpackh(vecshi0);
vint32 veci1 = vec_unpackl(vecshi0);
vint32 veci2 = vec_unpackh(vecshi1);
vint32 veci3 = vec_unpackl(vecshi1);
vint32 veci4 = vec_unpackh(vecshi2);
vint32 veci5 = vec_unpackl(vecshi2);
vint32 veci6 = vec_unpackh(vecshi3);
vint32 veci7 = vec_unpackl(vecshi3);
vfloat32 vecf0_0 = vec_float(veci0);
vfloat32 vecf1_0 = vec_float(veci1);
vfloat32 vecf0_1 = vec_float(veci2);
vfloat32 vecf1_1 = vec_float(veci3);
vfloat32 vecf0_2 = vec_float(veci4);
vfloat32 vecf1_2 = vec_float(veci5);
vfloat32 vecf0_3 = vec_float(veci6);
vfloat32 vecf1_3 = vec_float(veci7);
vfloat32 scale_vec0 = scale.vec0();
vfloat32 scale_vec1 = scale.vec1();
vfloat32 scale_zp_premul0 = scale_zp_premul.vec0();
vfloat32 scale_zp_premul1 = scale_zp_premul.vec1();
return {
Vectorized<float>{
vec_madd(scale_vec0, vecf0_0, scale_zp_premul0),
vec_madd(scale_vec1, vecf1_0, scale_zp_premul1)},
Vectorized<float>{
vec_madd(scale_vec0, vecf0_1, scale_zp_premul0),
vec_madd(scale_vec1, vecf1_1, scale_zp_premul1)},
Vectorized<float>{
vec_madd(scale_vec0, vecf0_2, scale_zp_premul0),
vec_madd(scale_vec1, vecf1_2, scale_zp_premul1)},
Vectorized<float>{
vec_madd(scale_vec0, vecf0_3, scale_zp_premul0),
vec_madd(scale_vec1, vecf1_3, scale_zp_premul1)}};
}
float_vec_return_type C10_ALWAYS_INLINE dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point) const {
vint16 vecshi0 = vec_unpackh(_vec0);
vint16 vecshi1 = vec_unpackl(_vec0);
vint16 vecshi2 = vec_unpackh(_vec1);
vint16 vecshi3 = vec_unpackl(_vec1);
vint32 veci0 = vec_unpackh(vecshi0);
vint32 veci1 = vec_unpackl(vecshi0);
vint32 veci2 = vec_unpackh(vecshi1);
vint32 veci3 = vec_unpackl(vecshi1);
vint32 veci4 = vec_unpackh(vecshi2);
vint32 veci5 = vec_unpackl(vecshi2);
vint32 veci6 = vec_unpackh(vecshi3);
vint32 veci7 = vec_unpackl(vecshi3);
vfloat32 vecf0_0 = vec_float(veci0);
vfloat32 vecf1_0 = vec_float(veci1);
vfloat32 vecf0_1 = vec_float(veci2);
vfloat32 vecf1_1 = vec_float(veci3);
vfloat32 vecf0_2 = vec_float(veci4);
vfloat32 vecf1_2 = vec_float(veci5);
vfloat32 vecf0_3 = vec_float(veci6);
vfloat32 vecf1_3 = vec_float(veci7);
vfloat32 scale_vec0 = scale.vec0();
vfloat32 scale_vec1 = scale.vec1();
vfloat32 zero_point0 = zero_point.vec0();
vfloat32 zero_point1 = zero_point.vec1();
return {
Vectorized<float>{
(vecf0_0 - zero_point0) * scale_vec0,
(vecf1_0 - zero_point1) * scale_vec1},
Vectorized<float>{
(vecf0_1 - zero_point0) * scale_vec0,
(vecf1_1 - zero_point1) * scale_vec1},
Vectorized<float>{
(vecf0_2 - zero_point0) * scale_vec0,
(vecf1_2 - zero_point1) * scale_vec1},
Vectorized<float>{
(vecf0_3 - zero_point0) * scale_vec0,
(vecf1_3 - zero_point1) * scale_vec1}};
}
static Vectorized<c10::qint8> quantize(
const float_vec_return_type& rhs,
float scale,
int32_t zero_point,
float inverse_scale) {
// constexpr int32_t min_val = std::numeric_limits<value_type>::min();
// constexpr int32_t max_val = std::numeric_limits<value_type>::max();
vfloat32 inverse_scale_v = vec_splats(inverse_scale);
vfloat32 vec_zero_point = vec_splats((float)zero_point);
// vint32 vmin = vec_splats(min_val);
// vint32 vmax = vec_splats(max_val);
Vectorized<float> vf0 = rhs[0];
Vectorized<float> vf1 = rhs[1];
Vectorized<float> vf2 = rhs[2];
Vectorized<float> vf3 = rhs[3];
vfloat32 vecf0 = vf0.vec0();
vfloat32 vecf1 = vf0.vec1();
vfloat32 vecf2 = vf1.vec0();
vfloat32 vecf3 = vf1.vec1();
vfloat32 vecf4 = vf2.vec0();
vfloat32 vecf5 = vf2.vec1();
vfloat32 vecf6 = vf3.vec0();
vfloat32 vecf7 = vf3.vec1();
vecf0 = vec_mul(vecf0, inverse_scale_v);
vecf1 = vec_mul(vecf1, inverse_scale_v);
vecf2 = vec_mul(vecf2, inverse_scale_v);
vecf3 = vec_mul(vecf3, inverse_scale_v);
vecf4 = vec_mul(vecf4, inverse_scale_v);
vecf5 = vec_mul(vecf5, inverse_scale_v);
vecf6 = vec_mul(vecf6, inverse_scale_v);
vecf7 = vec_mul(vecf7, inverse_scale_v);
vecf0 = vec_add(vec_rint(vecf0), vec_zero_point);
vecf1 = vec_add(vec_rint(vecf1), vec_zero_point);
vecf2 = vec_add(vec_rint(vecf2), vec_zero_point);
vecf3 = vec_add(vec_rint(vecf3), vec_zero_point);
vecf4 = vec_add(vec_rint(vecf4), vec_zero_point);
vecf5 = vec_add(vec_rint(vecf5), vec_zero_point);
vecf6 = vec_add(vec_rint(vecf6), vec_zero_point);
vecf7 = vec_add(vec_rint(vecf7), vec_zero_point);
vint32 veci0 = vec_signed(vecf0);
vint32 veci1 = vec_signed(vecf1);
vint32 veci2 = vec_signed(vecf2);
vint32 veci3 = vec_signed(vecf3);
vint32 veci4 = vec_signed(vecf4);
vint32 veci5 = vec_signed(vecf5);
vint32 veci6 = vec_signed(vecf6);
vint32 veci7 = vec_signed(vecf7);
// veci0 = vec_min(vmax, vec_max( vmin, vecf0)) ;
// veci1 = vec_min(vmax, vec_max( vmin, vecf1)) ;
// veci2 = vec_min(vmax, vec_max( vmin, vecf2)) ;
// veci3 = vec_min(vmax, vec_max( vmin, vecf3)) ;
// veci4 = vec_min(vmax, vec_max( vmin, vecf4)) ;
// veci5 = vec_min(vmax, vec_max( vmin, vecf5)) ;
// veci6 = vec_min(vmax, vec_max( vmin, vecf6)) ;
// veci7 = vec_min(vmax, vec_max( vmin, vecf7)) ;
// vec_packs CLAMP already
vint16 vecshi0 = vec_packs(veci0, veci1);
vint16 vecshi1 = vec_packs(veci2, veci3);
vint16 vecshi2 = vec_packs(veci4, veci5);
vint16 vecshi3 = vec_packs(veci6, veci7);
vint8 vec0 = vec_packs(vecshi0, vecshi1);
vint8 vec1 = vec_packs(vecshi2, vecshi3);
return {vec0, vec1};
}
Vectorized<c10::qint8> C10_ALWAYS_INLINE relu(Vectorized<c10::qint8> zero_point) const {
return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)};
}
Vectorized<c10::qint8> C10_ALWAYS_INLINE
relu6(Vectorized<c10::qint8> zero_point, Vectorized<c10::qint8> q_six) const {
vint8 max0 = vec_max(_vec0, zero_point._vec0);
vint8 max1 = vec_max(_vec1, zero_point._vec1);
return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)};
}
int_vec_return_type widening_subtract(Vectorized<c10::qint8> b) const {
vint16 vecshi0 = vec_unpackh(_vec0);
vint16 vecBshi0 = vec_unpackh(b._vec0);
vint16 vecshi1 = vec_unpackl(_vec0);
vint16 vecBshi1 = vec_unpackl(b._vec0);
vint16 vecshi2 = vec_unpackh(_vec1);
vint16 vecBshi2 = vec_unpackh(b._vec1);
vint16 vecshi3 = vec_unpackl(_vec1);
vint16 vecBshi3 = vec_unpackl(b._vec1);
vint32 veci0 = vec_unpackh(vecshi0);
vint32 vecBi0 = vec_unpackh(vecBshi0);
vint32 veci1 = vec_unpackl(vecshi0);
vint32 vecBi1 = vec_unpackl(vecBshi0);
vint32 veci2 = vec_unpackh(vecshi1);
vint32 vecBi2 = vec_unpackh(vecBshi1);
vint32 veci3 = vec_unpackl(vecshi1);
vint32 vecBi3 = vec_unpackl(vecBshi1);
vint32 veci4 = vec_unpackh(vecshi2);
vint32 vecBi4 = vec_unpackh(vecBshi2);
vint32 veci5 = vec_unpackl(vecshi2);
vint32 vecBi5 = vec_unpackl(vecBshi2);
vint32 veci6 = vec_unpackh(vecshi3);
vint32 vecBi6 = vec_unpackh(vecBshi3);
vint32 veci7 = vec_unpackl(vecshi3);
vint32 vecBi7 = vec_unpackl(vecBshi3);
return {
Vectorized<c10::qint32>(veci0 - vecBi0, veci1 - vecBi1),
Vectorized<c10::qint32>(veci2 - vecBi2, veci3 - vecBi3),
Vectorized<c10::qint32>(veci4 - vecBi4, veci5 - vecBi5),
Vectorized<c10::qint32>(veci6 - vecBi6, veci7 - vecBi7)};
}
static Vectorized<c10::qint8> requantize_from_int(
const int_vec_return_type& inp,
float multiplier,
int32_t zero_point) {
vfloat32 vec_multiplier = vec_splats(multiplier);
vint32 vec_zero_point = vec_splats(zero_point);
Vectorized<c10::qint32> vi0 = inp[0];
Vectorized<c10::qint32> vi1 = inp[1];
Vectorized<c10::qint32> vi2 = inp[2];
Vectorized<c10::qint32> vi3 = inp[3];
vfloat32 vecf0 = vec_float(vi0.vec0());
vfloat32 vecf1 = vec_float(vi0.vec1());
vfloat32 vecf2 = vec_float(vi1.vec0());
vfloat32 vecf3 = vec_float(vi1.vec1());
vfloat32 vecf4 = vec_float(vi2.vec0());
vfloat32 vecf5 = vec_float(vi2.vec1());
vfloat32 vecf6 = vec_float(vi3.vec0());
vfloat32 vecf7 = vec_float(vi3.vec1());
vecf0 = vec_mul(vecf0, vec_multiplier);
vecf1 = vec_mul(vecf1, vec_multiplier);
vecf2 = vec_mul(vecf2, vec_multiplier);
vecf3 = vec_mul(vecf3, vec_multiplier);
vecf4 = vec_mul(vecf4, vec_multiplier);
vecf5 = vec_mul(vecf5, vec_multiplier);
vecf6 = vec_mul(vecf6, vec_multiplier);
vecf7 = vec_mul(vecf7, vec_multiplier);
vecf0 = vec_rint(vecf0);
vecf1 = vec_rint(vecf1);
vecf2 = vec_rint(vecf2);
vecf3 = vec_rint(vecf3);
vecf4 = vec_rint(vecf4);
vecf5 = vec_rint(vecf5);
vecf6 = vec_rint(vecf6);
vecf7 = vec_rint(vecf7);
vint32 veci0 = vec_signed(vecf0);
vint32 veci1 = vec_signed(vecf1);
vint32 veci2 = vec_signed(vecf2);
vint32 veci3 = vec_signed(vecf3);
vint32 veci4 = vec_signed(vecf4);
vint32 veci5 = vec_signed(vecf5);
vint32 veci6 = vec_signed(vecf6);
vint32 veci7 = vec_signed(vecf7);
veci0 = vec_add(veci0, vec_zero_point);
veci1 = vec_add(veci1, vec_zero_point);
veci2 = vec_add(veci2, vec_zero_point);
veci3 = vec_add(veci3, vec_zero_point);
veci4 = vec_add(veci4, vec_zero_point);
veci5 = vec_add(veci5, vec_zero_point);
veci6 = vec_add(veci6, vec_zero_point);
veci7 = vec_add(veci7, vec_zero_point);
vint16 vecshi0 = vec_packs(veci0, veci1);
vint16 vecshi1 = vec_packs(veci2, veci3);
vint16 vecshi2 = vec_packs(veci4, veci5);
vint16 vecshi3 = vec_packs(veci6, veci7);
vint8 vec0 = vec_packs(vecshi0, vecshi1);
vint8 vec1 = vec_packs(vecshi2, vecshi3);
return {vec0, vec1};
}
DEFINE_MEMBER_OP(operator==, c10::qint8, vec_cmpeq)
DEFINE_MEMBER_OP(operator!=, c10::qint8, vec_cmpne)
DEFINE_MEMBER_OP(operator<, c10::qint8, vec_cmplt)
DEFINE_MEMBER_OP(operator<=, c10::qint8, vec_cmple)
DEFINE_MEMBER_OP(operator>, c10::qint8, vec_cmpgt)
DEFINE_MEMBER_OP(operator>=, c10::qint8, vec_cmpge)
DEFINE_MEMBER_OP(operator+, c10::qint8, vec_add)
DEFINE_MEMBER_OP(operator-, c10::qint8, vec_sub)
DEFINE_MEMBER_OP(operator*, c10::qint8, vec_mul)
DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint8, /)
DEFINE_MEMBER_OP(maximum, c10::qint8, vec_max)
DEFINE_MEMBER_OP(minimum, c10::qint8, vec_min)
DEFINE_MEMBER_OP(operator&, c10::qint8, vec_and)
DEFINE_MEMBER_OP(operator|, c10::qint8, vec_or)
DEFINE_MEMBER_OP(operator^, c10::qint8, vec_xor)
};
template <>
Vectorized<c10::qint8> inline maximum(
const Vectorized<c10::qint8>& a,
const Vectorized<c10::qint8>& b) {
return a.maximum(b);
}
template <>
Vectorized<c10::qint8> inline minimum(
const Vectorized<c10::qint8>& a,
const Vectorized<c10::qint8>& b) {
return a.minimum(b);
}
template <>
Vectorized<c10::qint8> C10_ALWAYS_INLINE operator+(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
return Vectorized<c10::qint8>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
}
template <>
Vectorized<c10::qint8> C10_ALWAYS_INLINE operator-(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
return Vectorized<c10::qint8>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
}
template <>
Vectorized<c10::qint8> C10_ALWAYS_INLINE operator*(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
return Vectorized<c10::qint8>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())};
}
template <>
Vectorized<c10::qint8> C10_ALWAYS_INLINE operator/(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
return Vectorized<c10::qint8>{a.vec0()/b.vec0(), a.vec1()/b.vec1()};
}
template <>
Vectorized<c10::qint8> C10_ALWAYS_INLINE operator&(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
return Vectorized<c10::qint8>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
}
template <>
Vectorized<c10::qint8> C10_ALWAYS_INLINE operator|(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
return Vectorized<c10::qint8>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
}
template <>
Vectorized<c10::qint8> C10_ALWAYS_INLINE operator^(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
return Vectorized<c10::qint8>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
}
} // namespace
} // namespace vec
} // namespace at

View File

@ -0,0 +1,501 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
#include <c10/util/irange.h>
#include <c10/util/quint8.h>
#include <array>
// This file defines Vectorized<> for the quantized types.
//
//
// Currently, we simply use these classes as efficient converters between
// the quantized types and Vectorized<float>, usually in bandwidth-bound cases
// where doing the arithmetic in full-precision is acceptable (e.g.
// elementwise operators).
//
//
// Conversions are as follows:
// Vectorized<quint8> -> 4x Vectorized<float>
//
// The size of the returned float vector is specified by the special
// constexpr function float_num_vecs. The type of the value returned
// from dequantize (and expected as an argument to quantize) is
// specified by float_vec_return_type.
//
// When writing kernels with these vectors, it is expected that floating-
// point operations will be carried out in a loop over Vectorized<T>::float_num_vecs
// iterations.
namespace at {
namespace vec {
inline namespace CPU_CAPABILITY {
const vint16 mask_unsigned = vec_splats((short int)0xFF);
template <>
struct Vectorized<c10::quint8> {
private:
union {
struct {
vuint8 _vec0;
vuint8 _vec1;
};
struct {
vbool8 _vecb0;
vbool8 _vecb1;
};
} __attribute__((__may_alias__));
public:
Vectorized() {}
using size_type = int;
static constexpr size_type size() {
return 32;
}
static constexpr size_t float_num_vecs() {
return 4;
}
static constexpr int int_num_vecs() {
return 4;
}
using float_vec_return_type = std::array<Vectorized<float>, 4>;
using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
using value_type = typename c10::quint8::underlying;
using vec_internal_type = vuint8;
using vec_internal_mask_type = vbool8;
// Broadcast constructor
C10_ALWAYS_INLINE Vectorized(const c10::quint8& val)
: _vec0(vec_splats(val.val_)), _vec1(vec_splats(val.val_)) {}
C10_ALWAYS_INLINE Vectorized(const Vectorized<c10::quint8>& other)
: _vec0{other._vec0}, _vec1(other._vec1) {}
C10_ALWAYS_INLINE Vectorized(vuint8 v) : _vec0{v}, _vec1{v} {}
C10_ALWAYS_INLINE Vectorized(vbool8 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
C10_ALWAYS_INLINE Vectorized(vuint8 v1, vuint8 v2) : _vec0{v1}, _vec1{v2} {}
C10_ALWAYS_INLINE Vectorized(vbool8 v1, vbool8 v2) : _vecb0{v1}, _vecb1{v2} {}
C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
return _vec0;
}
C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
return _vec1;
}
static C10_ALWAYS_INLINE Vectorized<c10::quint8> loadu(
const void* ptr,
int count = size()) {
if (count == size()) {
return {
vec_vsx_ld(offset0, reinterpret_cast<const value_type*>(ptr)),
vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
}
__at_align__ value_type tmp_values[size()] = {};
std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
}
void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
if (count == size()) {
vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
} else if (count > 0) {
__at_align__ value_type tmp_values[size()];
vec_vsx_st(_vec0, offset0, tmp_values);
vec_vsx_st(_vec1, offset16, tmp_values);
std::memcpy(
ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
}
}
public:
float_vec_return_type C10_ALWAYS_INLINE dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point,
Vectorized<float> scale_zp_premul) const {
// unpacking unsigned as signed
vint16 vecshi0 = vec_unpackh((vint8)_vec0);
vint16 vecshi1 = vec_unpackl((vint8)_vec0);
vint16 vecshi2 = vec_unpackh((vint8)_vec1);
vint16 vecshi3 = vec_unpackl((vint8)_vec1);
// signed -> unsigned
vecshi0 = vec_and(vecshi0, mask_unsigned);
vecshi1 = vec_and(vecshi1, mask_unsigned);
vecshi2 = vec_and(vecshi2, mask_unsigned);
vecshi3 = vec_and(vecshi3, mask_unsigned);
vint32 veci0 = vec_unpackh(vecshi0);
vint32 veci1 = vec_unpackl(vecshi0);
vint32 veci2 = vec_unpackh(vecshi1);
vint32 veci3 = vec_unpackl(vecshi1);
vint32 veci4 = vec_unpackh(vecshi2);
vint32 veci5 = vec_unpackl(vecshi2);
vint32 veci6 = vec_unpackh(vecshi3);
vint32 veci7 = vec_unpackl(vecshi3);
vfloat32 vecf0_0 = vec_float(veci0);
vfloat32 vecf1_0 = vec_float(veci1);
vfloat32 vecf0_1 = vec_float(veci2);
vfloat32 vecf1_1 = vec_float(veci3);
vfloat32 vecf0_2 = vec_float(veci4);
vfloat32 vecf1_2 = vec_float(veci5);
vfloat32 vecf0_3 = vec_float(veci6);
vfloat32 vecf1_3 = vec_float(veci7);
vfloat32 scale_vec0 = scale.vec0();
vfloat32 scale_vec1 = scale.vec1();
vfloat32 scale_zp_premul0 = scale_zp_premul.vec0();
vfloat32 scale_zp_premul1 = scale_zp_premul.vec1();
return {
Vectorized<float>{
vec_madd(scale_vec0, vecf0_0, scale_zp_premul0),
vec_madd(scale_vec1, vecf1_0, scale_zp_premul1)},
Vectorized<float>{
vec_madd(scale_vec0, vecf0_1, scale_zp_premul0),
vec_madd(scale_vec1, vecf1_1, scale_zp_premul1)},
Vectorized<float>{
vec_madd(scale_vec0, vecf0_2, scale_zp_premul0),
vec_madd(scale_vec1, vecf1_2, scale_zp_premul1)},
Vectorized<float>{
vec_madd(scale_vec0, vecf0_3, scale_zp_premul0),
vec_madd(scale_vec1, vecf1_3, scale_zp_premul1)}};
}
float_vec_return_type C10_ALWAYS_INLINE dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point) const {
// unpacking unsigned as signed
vint16 vecshi0 = vec_unpackh((vint8)_vec0);
vint16 vecshi1 = vec_unpackl((vint8)_vec0);
vint16 vecshi2 = vec_unpackh((vint8)_vec1);
vint16 vecshi3 = vec_unpackl((vint8)_vec1);
// signed -> unsigned
vecshi0 = vec_and(vecshi0, mask_unsigned);
vecshi1 = vec_and(vecshi1, mask_unsigned);
vecshi2 = vec_and(vecshi2, mask_unsigned);
vecshi3 = vec_and(vecshi3, mask_unsigned);
vint32 veci0 = vec_unpackh(vecshi0);
vint32 veci1 = vec_unpackl(vecshi0);
vint32 veci2 = vec_unpackh(vecshi1);
vint32 veci3 = vec_unpackl(vecshi1);
vint32 veci4 = vec_unpackh(vecshi2);
vint32 veci5 = vec_unpackl(vecshi2);
vint32 veci6 = vec_unpackh(vecshi3);
vint32 veci7 = vec_unpackl(vecshi3);
vfloat32 vecf0_0 = vec_float(veci0);
vfloat32 vecf1_0 = vec_float(veci1);
vfloat32 vecf0_1 = vec_float(veci2);
vfloat32 vecf1_1 = vec_float(veci3);
vfloat32 vecf0_2 = vec_float(veci4);
vfloat32 vecf1_2 = vec_float(veci5);
vfloat32 vecf0_3 = vec_float(veci6);
vfloat32 vecf1_3 = vec_float(veci7);
vfloat32 scale_vec0 = scale.vec0();
vfloat32 scale_vec1 = scale.vec1();
vfloat32 zero_point0 = zero_point.vec0();
vfloat32 zero_point1 = zero_point.vec1();
return {
Vectorized<float>{
(vecf0_0 - zero_point0) * scale_vec0,
(vecf1_0 - zero_point1) * scale_vec1},
Vectorized<float>{
(vecf0_1 - zero_point0) * scale_vec0,
(vecf1_1 - zero_point1) * scale_vec1},
Vectorized<float>{
(vecf0_2 - zero_point0) * scale_vec0,
(vecf1_2 - zero_point1) * scale_vec1},
Vectorized<float>{
(vecf0_3 - zero_point0) * scale_vec0,
(vecf1_3 - zero_point1) * scale_vec1}};
}
static Vectorized<c10::quint8> quantize(
const float_vec_return_type& rhs,
float scale,
int32_t zero_point,
float inverse_scale) {
// constexpr int32_t min_val = std::numeric_limits<value_type>::min();
// constexpr int32_t max_val = std::numeric_limits<value_type>::max();
vfloat32 vec_inverse = vec_splats(inverse_scale);
vfloat32 vec_zero_point = vec_splats((float)zero_point);
// vuint32 vmin = vec_splats(min_val);
// vuint32 vmax = vec_splats(max_val);
Vectorized<float> vf0 = rhs[0];
Vectorized<float> vf1 = rhs[1];
Vectorized<float> vf2 = rhs[2];
Vectorized<float> vf3 = rhs[3];
vfloat32 vecf0 = vf0.vec0();
vfloat32 vecf1 = vf0.vec1();
vfloat32 vecf2 = vf1.vec0();
vfloat32 vecf3 = vf1.vec1();
vfloat32 vecf4 = vf2.vec0();
vfloat32 vecf5 = vf2.vec1();
vfloat32 vecf6 = vf3.vec0();
vfloat32 vecf7 = vf3.vec1();
vecf0 = vec_mul(vecf0, vec_inverse);
vecf1 = vec_mul(vecf1, vec_inverse);
vecf2 = vec_mul(vecf2, vec_inverse);
vecf3 = vec_mul(vecf3, vec_inverse);
vecf4 = vec_mul(vecf4, vec_inverse);
vecf5 = vec_mul(vecf5, vec_inverse);
vecf6 = vec_mul(vecf6, vec_inverse);
vecf7 = vec_mul(vecf7, vec_inverse);
vecf0 = vec_add(vec_rint(vecf0), vec_zero_point);
vecf1 = vec_add(vec_rint(vecf1), vec_zero_point);
vecf2 = vec_add(vec_rint(vecf2), vec_zero_point);
vecf3 = vec_add(vec_rint(vecf3), vec_zero_point);
vecf4 = vec_add(vec_rint(vecf4), vec_zero_point);
vecf5 = vec_add(vec_rint(vecf5), vec_zero_point);
vecf6 = vec_add(vec_rint(vecf6), vec_zero_point);
vecf7 = vec_add(vec_rint(vecf7), vec_zero_point);
vint32 veci0 = vec_signed(vecf0);
vint32 veci1 = vec_signed(vecf1);
vint32 veci2 = vec_signed(vecf2);
vint32 veci3 = vec_signed(vecf3);
vint32 veci4 = vec_signed(vecf4);
vint32 veci5 = vec_signed(vecf5);
vint32 veci6 = vec_signed(vecf6);
vint32 veci7 = vec_signed(vecf7);
vint16 vecshi0 = vec_packs(veci0, veci1);
vint16 vecshi1 = vec_packs(veci2, veci3);
vint16 vecshi2 = vec_packs(veci4, veci5);
vint16 vecshi3 = vec_packs(veci6, veci7);
vuint8 vec0 = vec_packsu(vecshi0, vecshi1);
vuint8 vec1 = vec_packsu(vecshi2, vecshi3);
return {vec0, vec1};
}
Vectorized<c10::quint8> C10_ALWAYS_INLINE relu(Vectorized<c10::quint8> zero_point) const {
return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)};
}
Vectorized<c10::quint8> C10_ALWAYS_INLINE
relu6(Vectorized<c10::quint8> zero_point, Vectorized<c10::quint8> q_six) const {
vuint8 max0 = vec_max(_vec0, zero_point._vec0);
vuint8 max1 = vec_max(_vec1, zero_point._vec1);
return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)};
}
int_vec_return_type widening_subtract(Vectorized<c10::quint8> b) const {
vint16 vecshi0 = vec_unpackh((vint8)_vec0);
vint16 vecBshi0 = vec_unpackh((vint8)b._vec0);
vint16 vecshi1 = vec_unpackl((vint8)_vec0);
vint16 vecBshi1 = vec_unpackl((vint8)b._vec0);
vint16 vecshi2 = vec_unpackh((vint8)_vec1);
vint16 vecBshi2 = vec_unpackh((vint8)b._vec1);
vint16 vecshi3 = vec_unpackl((vint8)_vec1);
vint16 vecBshi3 = vec_unpackl((vint8)b._vec1);
vecshi0 = vec_and(vecshi0, mask_unsigned);
vecBshi0 = vec_and(vecBshi0, mask_unsigned);
vecshi1 = vec_and(vecshi1, mask_unsigned);
vecBshi1 = vec_and(vecBshi1, mask_unsigned);
vecshi2 = vec_and(vecshi2, mask_unsigned);
vecBshi2 = vec_and(vecBshi2, mask_unsigned);
vecshi3 = vec_and(vecshi3, mask_unsigned);
vecBshi3 = vec_and(vecBshi3, mask_unsigned);
vint32 veci0 = vec_unpackh(vecshi0);
vint32 vecBi0 = vec_unpackh(vecBshi0);
vint32 veci1 = vec_unpackl(vecshi0);
vint32 vecBi1 = vec_unpackl(vecBshi0);
vint32 veci2 = vec_unpackh(vecshi1);
vint32 vecBi2 = vec_unpackh(vecBshi1);
vint32 veci3 = vec_unpackl(vecshi1);
vint32 vecBi3 = vec_unpackl(vecBshi1);
vint32 veci4 = vec_unpackh(vecshi2);
vint32 vecBi4 = vec_unpackh(vecBshi2);
vint32 veci5 = vec_unpackl(vecshi2);
vint32 vecBi5 = vec_unpackl(vecBshi2);
vint32 veci6 = vec_unpackh(vecshi3);
vint32 vecBi6 = vec_unpackh(vecBshi3);
vint32 veci7 = vec_unpackl(vecshi3);
vint32 vecBi7 = vec_unpackl(vecBshi3);
return {
Vectorized<c10::qint32>(veci0 - vecBi0, veci1 - vecBi1),
Vectorized<c10::qint32>(veci2 - vecBi2, veci3 - vecBi3),
Vectorized<c10::qint32>(veci4 - vecBi4, veci5 - vecBi5),
Vectorized<c10::qint32>(veci6 - vecBi6, veci7 - vecBi7)};
}
static Vectorized<c10::quint8> requantize_from_int(
const int_vec_return_type& inp,
float multiplier,
int32_t zero_point) {
vfloat32 vec_multiplier = vec_splats(multiplier);
vint32 vec_zero_point = vec_splats(zero_point);
Vectorized<c10::qint32> vi0 = inp[0];
Vectorized<c10::qint32> vi1 = inp[1];
Vectorized<c10::qint32> vi2 = inp[2];
Vectorized<c10::qint32> vi3 = inp[3];
vfloat32 vecf0 = vec_float(vi0.vec0());
vfloat32 vecf1 = vec_float(vi0.vec1());
vfloat32 vecf2 = vec_float(vi1.vec0());
vfloat32 vecf3 = vec_float(vi1.vec1());
vfloat32 vecf4 = vec_float(vi2.vec0());
vfloat32 vecf5 = vec_float(vi2.vec1());
vfloat32 vecf6 = vec_float(vi3.vec0());
vfloat32 vecf7 = vec_float(vi3.vec1());
vecf0 = vec_mul(vecf0, vec_multiplier);
vecf1 = vec_mul(vecf1, vec_multiplier);
vecf2 = vec_mul(vecf2, vec_multiplier);
vecf3 = vec_mul(vecf3, vec_multiplier);
vecf4 = vec_mul(vecf4, vec_multiplier);
vecf5 = vec_mul(vecf5, vec_multiplier);
vecf6 = vec_mul(vecf6, vec_multiplier);
vecf7 = vec_mul(vecf7, vec_multiplier);
vecf0 = vec_rint(vecf0);
vecf1 = vec_rint(vecf1);
vecf2 = vec_rint(vecf2);
vecf3 = vec_rint(vecf3);
vecf4 = vec_rint(vecf4);
vecf5 = vec_rint(vecf5);
vecf6 = vec_rint(vecf6);
vecf7 = vec_rint(vecf7);
vint32 veci0 = vec_signed(vecf0);
vint32 veci1 = vec_signed(vecf1);
vint32 veci2 = vec_signed(vecf2);
vint32 veci3 = vec_signed(vecf3);
vint32 veci4 = vec_signed(vecf4);
vint32 veci5 = vec_signed(vecf5);
vint32 veci6 = vec_signed(vecf6);
vint32 veci7 = vec_signed(vecf7);
veci0 = vec_add(veci0, vec_zero_point);
veci1 = vec_add(veci1, vec_zero_point);
veci2 = vec_add(veci2, vec_zero_point);
veci3 = vec_add(veci3, vec_zero_point);
veci4 = vec_add(veci4, vec_zero_point);
veci5 = vec_add(veci5, vec_zero_point);
veci6 = vec_add(veci6, vec_zero_point);
veci7 = vec_add(veci7, vec_zero_point);
vint16 vecshi0 = vec_packs(veci0, veci1);
vint16 vecshi1 = vec_packs(veci2, veci3);
vint16 vecshi2 = vec_packs(veci4, veci5);
vint16 vecshi3 = vec_packs(veci6, veci7);
vuint8 vec0 = vec_packsu(vecshi0, vecshi1);
vuint8 vec1 = vec_packsu(vecshi2, vecshi3);
return {vec0, vec1};
}
DEFINE_MEMBER_OP(operator==, c10::quint8, vec_cmpeq)
DEFINE_MEMBER_OP(operator!=, c10::quint8, vec_cmpne)
DEFINE_MEMBER_OP(operator<, c10::quint8, vec_cmplt)
DEFINE_MEMBER_OP(operator<=, c10::quint8, vec_cmple)
DEFINE_MEMBER_OP(operator>, c10::quint8, vec_cmpgt)
DEFINE_MEMBER_OP(operator>=, c10::quint8, vec_cmpge)
DEFINE_MEMBER_OP(operator+, c10::quint8, vec_add)
DEFINE_MEMBER_OP(operator-, c10::quint8, vec_sub)
DEFINE_MEMBER_OP(operator*, c10::quint8, vec_mul)
DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::quint8, /)
DEFINE_MEMBER_OP(maximum, c10::quint8, vec_max)
DEFINE_MEMBER_OP(minimum, c10::quint8, vec_min)
DEFINE_MEMBER_OP(operator&, c10::quint8, vec_and)
DEFINE_MEMBER_OP(operator|, c10::quint8, vec_or)
DEFINE_MEMBER_OP(operator^, c10::quint8, vec_xor)
};
template <>
Vectorized<c10::quint8> inline maximum(
const Vectorized<c10::quint8>& a,
const Vectorized<c10::quint8>& b) {
return a.maximum(b);
}
template <>
Vectorized<c10::quint8> inline minimum(
const Vectorized<c10::quint8>& a,
const Vectorized<c10::quint8>& b) {
return a.minimum(b);
}
template <>
Vectorized<c10::quint8> C10_ALWAYS_INLINE operator+(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
return Vectorized<c10::quint8>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
}
template <>
Vectorized<c10::quint8> C10_ALWAYS_INLINE operator-(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
return Vectorized<c10::quint8>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
}
template <>
Vectorized<c10::quint8> C10_ALWAYS_INLINE operator*(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
return Vectorized<c10::quint8>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())};
}
template <>
Vectorized<c10::quint8> C10_ALWAYS_INLINE operator/(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
return Vectorized<c10::quint8>{a.vec0()/b.vec0(), a.vec1()/b.vec1()};
}
template <>
Vectorized<c10::quint8> C10_ALWAYS_INLINE operator&(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
return Vectorized<c10::quint8>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
}
template <>
Vectorized<c10::quint8> C10_ALWAYS_INLINE operator|(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
return Vectorized<c10::quint8>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
}
template <>
Vectorized<c10::quint8> C10_ALWAYS_INLINE operator^(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
return Vectorized<c10::quint8>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
}
} // namespace
} // namespace vec
} // namespace at

View File

@ -0,0 +1,474 @@
#pragma once
#include <cstdint>
#include <c10/macros/Macros.h>
#include <ATen/cpu/vec/intrinsics.h>
#if defined(__clang__)
typedef __vector __bool char vbool8;
typedef __vector __bool short vbool16;
typedef __vector __bool int vbool32;
typedef __vector __bool long long vbool64;
using vint8 = __attribute__((vector_size(16))) signed char;
using vint16 = __attribute__((vector_size(16))) signed short;
using vint32 = __attribute__((vector_size(16))) signed int;
using vint64 = __attribute__((vector_size(16))) signed long long;
using vuint8 = __attribute__((vector_size(16))) unsigned char;
using vuint16 = __attribute__((vector_size(16))) unsigned short;
using vuint32 = __attribute__((vector_size(16))) unsigned int;
using vuint64 = __attribute__((vector_size(16))) unsigned long long;
using vfloat32 = __attribute__((vector_size(16))) float;
using vfloat64 = __attribute__((vector_size(16))) double;
#else
using vbool8 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) char;
using vbool16 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) short;
using vbool32 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) int;
using vbool64 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) long long;
using vint8 = __attribute__((altivec(vector__))) signed char;
using vint16 = __attribute__((altivec(vector__))) signed short;
using vint32 = __attribute__((altivec(vector__))) signed int;
using vint64 = __attribute__((altivec(vector__))) signed long long;
using vuint8 = __attribute__((altivec(vector__))) unsigned char;
using vuint16 = __attribute__((altivec(vector__))) unsigned short;
using vuint32 = __attribute__((altivec(vector__))) unsigned int;
using vuint64 = __attribute__((altivec(vector__))) unsigned long long;
using vfloat32 = __attribute__((altivec(vector__))) float;
using vfloat64 = __attribute__((altivec(vector__))) double;
#endif
#if !defined(vec_float)
C10_ALWAYS_INLINE vfloat32 vec_float(const vint32& vec_in) {
vfloat32 vec_out;
__asm__("xvcvsxwsp %x0,%x1" : "=wf"(vec_out) : "wa"(vec_in));
return vec_out;
}
#endif
#if !defined(vec_signed)
C10_ALWAYS_INLINE vint32 vec_signed(const vfloat32& vec_in) {
vint32 vec_out;
__asm__("xvcvspsxws %x0,%x1" : "=wa"(vec_out) : "wf"(vec_in));
return vec_out;
}
C10_ALWAYS_INLINE vint64 vec_signed(const vfloat64& vec_in) {
vint64 vec_out;
__asm__("xvcvdpsxds %x0,%x1" : "=wa"(vec_out) : "wd"(vec_in));
return vec_out;
}
#endif
#if !defined(vec_neg)
C10_ALWAYS_INLINE vfloat32 vec_neg(const vfloat32& vec_in) {
vfloat32 vec_out;
__asm__("xvnegsp %x0,%x1" : "=wf"(vec_out) : "wf"(vec_in));
return vec_out;
}
C10_ALWAYS_INLINE vfloat64 vec_neg(const vfloat64& vec_in) {
vfloat64 vec_out;
__asm__("xvnegdp %x0,%x1" : "=wd"(vec_out) : "wd"(vec_in));
return vec_out;
}
C10_ALWAYS_INLINE vint16 vec_neg(const vint16& vec_in) {
vint16 vint0 = {0, 0, 0, 0 ,0, 0, 0, 0};
return vec_vsubuhm(vint0, vec_in);
}
C10_ALWAYS_INLINE vint32 vec_neg(const vint32& vec_in) {
vint32 vint0 = {0, 0, 0, 0};
return vec_vsubuwm(vint0, vec_in);
}
C10_ALWAYS_INLINE vint64 vec_neg(const vint64& vec_in) {
return -vec_in;
}
#endif
#if !defined(vec_sldw)
template <unsigned int C>
C10_ALWAYS_INLINE vfloat32
vec_sldw_aux(const vfloat32& vec_in0, const vfloat32& vec_in1) {
vfloat32 vec_out;
__asm("xxsldwi %x0, %x1, %x2, %3 "
: "=wa"(vec_out)
: "wa"(vec_in0), "wa"(vec_in1), "I"(C));
return vec_out;
}
#define vec_sldw(a, b, c) vec_sldw_aux<c>(a, b)
#endif
#define vec_not(a) vec_nor(a, a)
#if defined(__clang__) && !defined(vec_splats)
C10_ALWAYS_INLINE vint64 vec_splats(const int64_t& a) {
return vec_splats(a);
}
#endif
// Vectorized min/max which return a if any operand is nan
template <class T>
C10_ALWAYS_INLINE T vec_min_nan(const T& a, const T& b) {
return vec_min(a, b);
}
template <class T>
C10_ALWAYS_INLINE T vec_max_nan(const T& a, const T& b) {
return vec_max(a, b);
}
// Specializations for float/double taken from Eigen
template<>
C10_ALWAYS_INLINE vfloat32 vec_min_nan<vfloat32>(const vfloat32& a, const vfloat32& b)
{
// NOTE: about 10% slower than vec_min, but consistent with std::min and SSE regarding NaN
vfloat32 ret;
__asm__ ("xvcmpgesp %x0,%x1,%x2\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b));
return ret;
}
// Specializations for float/double taken from Eigen
template<>
C10_ALWAYS_INLINE vfloat32 vec_max_nan<vfloat32>(const vfloat32& a, const vfloat32& b)
{
// NOTE: about 10% slower than vec_max, but consistent with std::min and SSE regarding NaN
vfloat32 ret;
__asm__ ("xvcmpgtsp %x0,%x2,%x1\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b));
return ret;
}
template<>
C10_ALWAYS_INLINE vfloat64 vec_min_nan<vfloat64>(const vfloat64& a, const vfloat64& b)
{
// NOTE: about 10% slower than vec_min, but consistent with std::min and SSE regarding NaN
vfloat64 ret;
__asm__ ("xvcmpgedp %x0,%x1,%x2\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b));
return ret;
}
template<>
C10_ALWAYS_INLINE vfloat64 vec_max_nan<vfloat64>(const vfloat64& a, const vfloat64& b)
{
// NOTE: about 10% slower than vec_max, but consistent with std::max and SSE regarding NaN
vfloat64 ret;
__asm__ ("xvcmpgtdp %x0,%x2,%x1\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b));
return ret;
}
// Vectorizes min/max function which returns nan if any side is nan
#define C10_VSX_VEC_NAN_PROPAG(name, type, btype, func) \
C10_ALWAYS_INLINE type name(const type& a, const type& b) { \
type tmp = func(a, b); \
btype nan_a = vec_cmpne(a, a); \
btype nan_b = vec_cmpne(b, b); \
tmp = vec_sel(tmp, a, nan_a); \
return vec_sel(tmp, b, nan_b); \
}
C10_VSX_VEC_NAN_PROPAG(vec_min_nan2, vfloat32, vbool32, vec_min)
C10_VSX_VEC_NAN_PROPAG(vec_max_nan2, vfloat32, vbool32, vec_max)
C10_VSX_VEC_NAN_PROPAG(vec_min_nan2, vfloat64, vbool64, vec_min)
C10_VSX_VEC_NAN_PROPAG(vec_max_nan2, vfloat64, vbool64, vec_max)
#undef C10_VSX_VEC_NAN_PROPAG
#define DEFINE_MEMBER_UNARY_OP(op, op_type, func) \
Vectorized<op_type> C10_ALWAYS_INLINE op() const { \
return Vectorized<op_type>{func(_vec0), func(_vec1)}; \
}
#define DEFINE_MEMBER_OP(op, op_type, func) \
Vectorized<op_type> C10_ALWAYS_INLINE op(const Vectorized<op_type>& other) const { \
return Vectorized<op_type>{ \
func(_vec0, other._vec0), func(_vec1, other._vec1)}; \
}
#define DEFINE_MEMBER_BITWISE_OP(op, op_type, func) \
Vectorized<op_type> C10_ALWAYS_INLINE op(const Vectorized<op_type>& other) const { \
return Vectorized<op_type>{ \
func(_vecb0, other._vecb0), func(_vecb1, other._vecb1)}; \
}
#define DEFINE_MEMBER_TERNARY_OP(op, op_type, func) \
Vectorized<op_type> C10_ALWAYS_INLINE op( \
const Vectorized<op_type>& b, const Vectorized<op_type>& c) const { \
return Vectorized<op_type>{ \
func(_vec0, b._vec0, c._vec0), func(_vec1, b._vec1, c._vec1)}; \
}
#define DEFINE_MEMBER_EMULATE_BINARY_OP(op, op_type, binary_op) \
Vectorized<op_type> C10_ALWAYS_INLINE op(const Vectorized<op_type>& b) const { \
Vectorized<op_type>::vec_internal_type ret_0; \
Vectorized<op_type>::vec_internal_type ret_1; \
for (int i = 0; i < Vectorized<op_type>::size() / 2; i++) { \
ret_0[i] = _vec0[i] binary_op b._vec0[i]; \
ret_1[i] = _vec1[i] binary_op b._vec1[i]; \
} \
return Vectorized<op_type>{ret_0, ret_1}; \
}
#define DEFINE_MEMBER_OP_AND_ONE(op, op_type, func) \
Vectorized<op_type> C10_ALWAYS_INLINE op(const Vectorized<op_type>& other) const { \
using vvtype = Vectorized<op_type>::vec_internal_type; \
const vvtype v_one = vec_splats(static_cast<op_type>(1.0)); \
vvtype ret0 = (vvtype)func(_vec0, other._vec0); \
vvtype ret1 = (vvtype)func(_vec1, other._vec1); \
return Vectorized<op_type>{vec_and(ret0, v_one), vec_and(ret1, v_one)}; \
}
#define DEFINE_CLAMP_FUNCS(operand_type) \
template <> \
Vectorized<operand_type> C10_ALWAYS_INLINE clamp( \
const Vectorized<operand_type>& a, \
const Vectorized<operand_type>& min, \
const Vectorized<operand_type>& max) { \
return Vectorized<operand_type>{ \
vec_min_nan(vec_max_nan(a.vec0(), min.vec0()), max.vec0()), \
vec_min_nan(vec_max_nan(a.vec1(), min.vec1()), max.vec1())}; \
} \
template <> \
Vectorized<operand_type> C10_ALWAYS_INLINE clamp_min( \
const Vectorized<operand_type>& a, const Vectorized<operand_type>& min) { \
return Vectorized<operand_type>{ \
vec_max_nan(a.vec0(), min.vec0()), \
vec_max_nan(a.vec1(), min.vec1())}; \
} \
template <> \
Vectorized<operand_type> C10_ALWAYS_INLINE clamp_max( \
const Vectorized<operand_type>& a, const Vectorized<operand_type>& max) { \
return Vectorized<operand_type>{ \
vec_min_nan(a.vec0(), max.vec0()), \
vec_min_nan(a.vec1(), max.vec1())}; \
}
#define DEFINE_REINTERPRET_CAST_FUNCS( \
first_type, cast_type, cast_inner_vector_type) \
template <> \
C10_ALWAYS_INLINE Vectorized<cast_type> cast<cast_type, first_type>( \
const Vectorized<first_type>& src) { \
return Vectorized<cast_type>{(cast_inner_vector_type)src.vec0(), \
(cast_inner_vector_type)src.vec1()}; \
}
#define DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(first_type) \
DEFINE_REINTERPRET_CAST_FUNCS(first_type, double, vfloat64) \
DEFINE_REINTERPRET_CAST_FUNCS(first_type, float, vfloat32) \
DEFINE_REINTERPRET_CAST_FUNCS(first_type, int64_t, vint64) \
DEFINE_REINTERPRET_CAST_FUNCS(first_type, int32_t, vint32) \
DEFINE_REINTERPRET_CAST_FUNCS(first_type, int16_t, vint16)
// it can be used to emulate blend faster
constexpr int blendChoice(uint32_t mask, uint32_t half1 = 0xF, uint32_t half2 = 0xF0) {
uint32_t none = 0;
uint32_t both = half1 | half2;
// clamp it between 0 and both
mask = mask & both;
// return (a._vec0, a._vec1)
if (mask == none) return 0;
// return (b._vec0,b._vec1)
else if (mask == both)
return 1;
// return (b._vec0,a._vec1)
else if (mask == half1)
return 2;
// return (a._vec0,b._vec1)
else if (mask == half2)
return 3;
// return (*_vec0,a._vec1)
else if (mask > 0 && mask < half1)
return 4;
// return (*_vec0,b._vec1)
else if ((mask & half2) == half2)
return 5;
// return (a._vec0,*_vec1)
else if ((mask & half1) == 0 && mask > half1)
return 6;
// return (b._vec0,*_vec1)
else if ((mask & half1) == half1 && mask > half1)
return 7;
// return (*_vec0,*_vec1)
return 8;
}
// it can be used to emulate blend faster
constexpr int blendChoiceDbl(uint32_t mask) {
// clamp it 0 and 0xF
return blendChoice(mask, 0x3, 0xC);
}
constexpr vbool32 VsxMask1(uint32_t mask) {
uint32_t g0 = (mask & 1) * 0xffffffff;
uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff;
uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff;
uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff;
return (vbool32){g0, g1, g2, g3};
}
constexpr vbool32 VsxMask2(uint32_t mask) {
uint32_t mask2 = (mask & 0xFF) >> 4;
return VsxMask1(mask2);
}
constexpr vbool64 VsxDblMask1(uint32_t mask) {
uint64_t g0 = (mask & 1) * 0xffffffffffffffff;
uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff;
return (vbool64){g0, g1};
}
constexpr vbool64 VsxDblMask2(uint32_t mask) {
uint32_t mask2 = (mask & 0xF) >> 2;
return VsxDblMask1(mask2);
}
constexpr int maskForComplex(uint32_t mask) {
mask = mask & 0xF;
int complex_mask = 0;
if (mask & 1) complex_mask |= 3;
if (mask & 2) complex_mask |= (3 << 2);
if (mask & 4) complex_mask |= (3 << 4);
if (mask & 8) complex_mask |= (3 << 6);
return complex_mask;
}
constexpr int maskForComplexDbl(uint32_t mask) {
mask = mask & 0x3;
int complex_mask = 0;
if (mask & 1) complex_mask |= 3;
if (mask & 2) complex_mask |= (3 << 2);
return complex_mask;
}
constexpr int blendChoiceComplex(uint32_t mask) {
return blendChoice(maskForComplex(mask));
}
constexpr int blendChoiceComplexDbl(uint32_t mask) {
return blendChoiceDbl(maskForComplexDbl(mask));
}
constexpr vbool32 VsxComplexMask1(uint32_t mask) {
return VsxMask1(maskForComplex(mask));
}
constexpr vbool32 VsxComplexMask2(uint32_t mask) {
uint32_t mask2 = (mask & 0xF) >> 2;
return VsxMask1(maskForComplex(mask2));
}
constexpr vbool64 VsxComplexDblMask1(uint32_t mask) { return VsxDblMask1(mask); }
constexpr vbool64 VsxComplexDblMask2(uint32_t mask) {
uint32_t mask2 = (mask & 0xF) >> 2;
return VsxDblMask1(mask2);
}
// constants
namespace at {
namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
//
constexpr int offset0 = 0;
constexpr int offset16 = 16;
// #Constants
const vuint8 mask_zero_bits = vuint8{128, 128, 128, 128, 128, 128, 128, 128,
128, 128, 128, 128, 96, 64, 32, 0};
const vuint8 swap_mask =
vuint8{4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11};
const vint32 v0x7f = vec_splats(0x7f);
const vint32 vi_0 = vec_splats((int)(0));
const vint32 vi_1 = vec_splats((int)1);
const vint32 vi_2 = vec_splats((int)2);
const vint32 vi_4 = vec_splats((int)4);
const vint32 vi_inv1 = vec_splats((int)~1);
const vuint32 vu_29 = vec_splats(29u);
const vuint32 vu_23 = vec_splats(23u);
const vbool32 inv_mant_mask = (vbool32)vec_splats((unsigned int)~0xff800000);
const vbool32 sign_mask = (vbool32)vec_splats((int)0x80000000);
const vbool32 real_mask = vbool32{0xFFFFFFFF, 0x0, 0xFFFFFFFF, 0x0};
const vbool32 imag_mask = vbool32{0x0, 0xFFFFFFFF, 0x0, 0xFFFFFFFF};
const vbool32 isign_mask = vbool32{0x0, 0x80000000, 0x0, 0x80000000};
const vbool32 rsign_mask = vbool32{0x80000000, 0x0, 0x80000000, 0x0};
const vbool64 vd_sign_mask = vbool64{0x8000000000000000, 0x8000000000000000};
const vbool64 vd_imag_mask = vbool64{0x0, 0xFFFFFFFFFFFFFFFF};
const vbool64 vd_real_mask = vbool64{0xFFFFFFFFFFFFFFFF, 0x0};
const vbool64 vd_isign_mask = vbool64{0x0, 0x8000000000000000};
const vbool64 vd_rsign_mask = vbool64{0x8000000000000000, 0x0};
const vfloat32 zero = vec_splats(0.f);
const vfloat32 half = vec_splats(0.5f);
const vfloat32 one = vec_splats(1.f);
const vfloat32 two = vec_splats(2.0f);
const vfloat32 _4div_pi = vec_splats(1.27323954473516f);
const vfloat32 v_inf = (vfloat32)vec_splats(0x7f800000u);
const vfloat32 v_minus_inf = vfloat32{ 0xff800000u, 0xff800000u, 0xff800000u, 0xff800000u };
const vfloat32 v_nan = (vfloat32)vec_splats(0x7fffffff);
const vfloat32 log10e_inv = vec_splats(0.43429448190325176f);
const vfloat32 log2e_inv = vec_splats(1.4426950408889634f);
const vfloat32 log2eB_inv = vec_splats(1.442695036924675f);
const vfloat32 cephes_SQRTHF = vec_splats(0.707106781186547524f);
const vfloat32 coscof_p0 = vec_splats(2.443315711809948E-005f);
const vfloat32 coscof_p1 = vec_splats(-1.388731625493765E-003f);
const vfloat32 coscof_p2 = vec_splats(4.166664568298827E-002f);
const vfloat32 exp_hi = vec_splats(104.f);
const vfloat32 exp_lo = vec_splats(-104.f);
const vfloat32 exp_p0 = vec_splats(0.000198527617612853646278381f);
const vfloat32 exp_p1 = vec_splats((0.00139304355252534151077271f));
const vfloat32 exp_p2 = vec_splats(0.00833336077630519866943359f);
const vfloat32 exp_p3 = vec_splats(0.0416664853692054748535156f);
const vfloat32 exp_p4 = vec_splats(0.166666671633720397949219f);
const vfloat32 exp_p5 = vec_splats(0.5f);
const vfloat32 log_p0 = vec_splats(7.0376836292E-2f);
const vfloat32 log_p1 = vec_splats(-1.1514610310E-1f);
const vfloat32 log_p2 = vec_splats(1.1676998740E-1f);
const vfloat32 log_p3 = vec_splats(-1.2420140846E-1f);
const vfloat32 log_p4 = vec_splats(+1.4249322787E-1f);
const vfloat32 log_p5 = vec_splats(-1.6668057665E-1f);
const vfloat32 log_p6 = vec_splats(+2.0000714765E-1f);
const vfloat32 log_p7 = vec_splats(-2.4999993993E-1f);
const vfloat32 log_p8 = vec_splats(+3.3333331174E-1f);
const vfloat32 log_q1 = vec_splats(-2.12194440e-4f);
const vfloat32 log_q2 = vec_splats(0.693359375f);
const vfloat32 max_logf = vec_splats(88.02969187150841f);
const vfloat32 max_numf = vec_splats(1.7014117331926442990585209174225846272e38f);
const vfloat32 min_inf = (vfloat32)vec_splats(0xff800000u);
const vfloat32 min_norm_pos = (vfloat32)vec_splats(0x0800000u);
const vfloat32 minus_cephes_dp1 = vec_splats(-0.78515625f);
const vfloat32 minus_cephes_dp2 = vec_splats(-2.4187564849853515625e-4f);
const vfloat32 minus_cephes_dp3 = vec_splats(-3.77489497744594108e-8f);
const vfloat32 negln2f_hi = vec_splats(-0.693145751953125f);
const vfloat32 negln2f_lo = vec_splats(-1.428606765330187045e-06f);
const vfloat32 p0 = vec_splats(2.03721912945E-4f);
const vfloat32 p1 = vec_splats(8.33028376239E-3f);
const vfloat32 p2 = vec_splats(1.66667160211E-1f);
const vfloat32 sincof_p0 = vec_splats(-1.9515295891E-4f);
const vfloat32 sincof_p1 = vec_splats(8.3321608736E-3f);
const vfloat32 sincof_p2 = vec_splats(-1.6666654611E-1f);
const vfloat32 tanh_0p625 = vec_splats(0.625f);
const vfloat32 tanh_half_max = vec_splats(44.014845935754205f);
const vfloat32 tanh_p0 = vec_splats(-5.70498872745E-3f);
const vfloat32 tanh_p1 = vec_splats(2.06390887954E-2f);
const vfloat32 tanh_p2 = vec_splats(-5.37397155531E-2f);
const vfloat32 tanh_p3 = vec_splats(1.33314422036E-1f);
const vfloat32 tanh_p4 = vec_splats(-3.33332819422E-1f);
const vfloat32 vcheck = vec_splats((float)(1LL << 24));
const vfloat32 imag_one = vfloat32{0.f, 1.f, 0.f, 1.f};
const vfloat32 imag_half = vfloat32{0.f, 0.5f, 0.f, 0.5f};
const vfloat32 sqrt2_2 = vfloat32{0.70710676908493042f, 0.70710676908493042,
0.70710676908493042, 0.70710676908493042};
const vfloat32 pi_2 = vfloat32{M_PI / 2, 0.0, M_PI / 2, 0.0};
const vfloat32 vf_89 = vfloat32{89.f, 89.f, 89.f, 89.f};
const vfloat64 vd_one = vec_splats(1.0);
const vfloat64 vd_zero = vec_splats(0.0);
const vfloat64 vd_log10e_inv = vec_splats(0.43429448190325176);
const vfloat64 vd_log2e_inv = vec_splats(1.4426950408889634);
const vfloat64 vd_imag_one = vfloat64{0.0, 1.0};
const vfloat64 vd_imag_half = vfloat64{0.0, 0.5};
const vfloat64 vd_sqrt2_2 = vfloat64{0.70710678118654757, 0.70710678118654757};
const vfloat64 vd_pi_2 = vfloat64{M_PI / 2.0, 0.0};
} // namespace
} // namespace vec
} // namespace at

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,291 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec512/vec512_float.h>
#include <ATen/cpu/vec/vec512/vec512_bfloat16.h>
#include <ATen/cpu/vec/vec512/vec512_double.h>
#include <ATen/cpu/vec/vec512/vec512_int.h>
#include <ATen/cpu/vec/vec512/vec512_qint.h>
#include <ATen/cpu/vec/vec512/vec512_complex_float.h>
#include <ATen/cpu/vec/vec512/vec512_complex_double.h>
#include <ATen/cpu/vec/vec512/vec512_convert.h>
#include <ATen/cpu/vec/vec512/vec512_mask.h>
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <ostream>
namespace at {
namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) {
stream << val.val_;
return stream;
}
inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) {
stream << static_cast<int>(val.val_);
return stream;
}
inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) {
stream << static_cast<unsigned int>(val.val_);
return stream;
}
template <typename T>
std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
T buf[Vectorized<T>::size()];
vec.store(buf);
stream << "vec[";
for (int i = 0; i != Vectorized<T>::size(); i++) {
if (i != 0) {
stream << ", ";
}
stream << buf[i];
}
stream << "]";
return stream;
}
#if defined(CPU_CAPABILITY_AVX512)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX512) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<>
inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) {
return _mm512_castpd_ps(src);
}
template<>
inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) {
return _mm512_castps_pd(src);
}
template<>
inline Vectorized<float> cast<float, int32_t>(const Vectorized<int32_t>& src) {
return _mm512_castsi512_ps(src);
}
template<>
inline Vectorized<double> cast<double, int64_t>(const Vectorized<int64_t>& src) {
return _mm512_castsi512_pd(src);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#ifndef _MSC_VER
// MSVC is not working well on complex function overload.
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline gather(const double* base_addr, const Vectorized<int64_t>& vindex) {
return _mm512_i64gather_pd(vindex, base_addr, scale);
}
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
inline gather(const float* base_addr, const Vectorized<int32_t>& vindex) {
return _mm512_i32gather_ps(vindex, base_addr, scale);
}
#endif
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#ifndef _MSC_VER
// MSVC is not working well on complex function overload.
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline mask_gather(const Vectorized<double>& src, const double* base_addr,
const Vectorized<int64_t>& vindex, Vectorized<double>& mask) {
auto all_ones = _mm512_castsi512_pd(_mm512_set1_epi64(0xFFFFFFFFFFFFFFFF));
auto mask_ = _mm512_cmp_pd_mask(all_ones, mask.values, _CMP_EQ_OQ);
return _mm512_mask_i64gather_pd(src, mask_, vindex, base_addr, scale);
}
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
inline mask_gather(const Vectorized<float>& src, const float* base_addr,
const Vectorized<int32_t>& vindex, Vectorized<float>& mask) {
auto all_ones = _mm512_castsi512_ps(_mm512_set1_epi32(0xFFFFFFFF));
auto mask_ = _mm512_cmp_ps_mask(all_ones, mask.values, _CMP_EQ_OQ);
return _mm512_mask_i32gather_ps(src, mask_, vindex, base_addr, scale);
}
#endif
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<>
Vectorized<int64_t>
inline convert_to_int_of_same_size<double>(const Vectorized<double> &src) {
return _mm512_cvtpd_epi64(src);
}
template<>
Vectorized<int32_t>
inline convert_to_int_of_same_size<float>(const Vectorized<float> &src) {
return _mm512_cvttps_epi32(src);
}
template<>
Vectorized<double>
inline convert_to_fp_of_same_size<double>(const Vectorized<int64_t> &src) {
return _mm512_cvtepi64_pd(src);
}
template<>
Vectorized<float>
inline convert_to_fp_of_same_size<float>(const Vectorized<int32_t> &src) {
return _mm512_cvtepi32_ps(src);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <>
std::pair<Vectorized<double>, Vectorized<double>>
inline interleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
// inputs:
// a = {a0, a1, a3, a3, a4, a5, a6, a7}
// b = {b0, b1, b2, b3, b4, b5, b6, b7}
// group cols crossing lanes:
// return {a0, b0, a1, b1, a2, b2, a3, b3}
// {a4, b4, a5, b5, a6, b6, a7, b7}
__m512i idx1 = _mm512_set_epi64(11, 3, 10, 2, 9, 1, 8, 0);
__m512i idx2 = _mm512_set_epi64(15, 7, 14, 6, 13, 5, 12, 4);
return std::make_pair(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
_mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
}
template <>
std::pair<Vectorized<float>, Vectorized<float>>
inline interleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
// inputs:
// a = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
// b = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
//
// return:
// {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
// {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
__m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4,
19, 3, 18, 2, 17, 1, 16, 0);
__m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12,
27, 11, 26, 10, 25, 9, 24, 8);
return std::make_pair(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b),
_mm512_mask_permutex2var_ps(a, 0xffff, idx2, b));
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <>
std::pair<Vectorized<double>, Vectorized<double>>
inline deinterleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
// inputs:
// a = {a0, b0, a1, b1, a2, b2, a3, b3}
// b = {a4, b4, a5, b5, a6, b6, a7, b7}
// output:
// return {a0, a1, a2, a3, a4, a5, a6, a7}
// {b0, b1, b2, b3, b4, b5, b6, b7}
// The members of indices have been written in binary format for better understandability
__m512i idx1 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0);
__m512i idx2 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1);
return std::make_pair(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
_mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
}
template <>
std::pair<Vectorized<float>, Vectorized<float>>
inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
// inputs:
// a = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
// b = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
// output:
// return {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
// {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
__m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16,
14, 12, 10, 8, 6, 4, 2, 0);
__m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17,
15, 13, 11, 9, 7, 5, 3, 1);
return std::make_pair(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b),
_mm512_mask_permutex2var_ps(a, 0xffff, idx2, b));
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<>
inline Vectorized<float> flip(const Vectorized<float> & v) {
const __m512i mask = _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15);
return _mm512_permutexvar_ps(mask, v);
}
template<>
inline Vectorized<double> flip(const Vectorized<double> & v) {
const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7);
return _mm512_permutexvar_pd(mask, v);
}
template<>
inline Vectorized<int64_t> flip(const Vectorized<int64_t> & v) {
const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7);
return _mm512_permutexvar_epi64(mask, v);
}
template<>
inline Vectorized<int32_t> flip(const Vectorized<int32_t> & v) {
const __m512i mask = _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15);
return _mm512_permutexvar_epi32(mask, v);
}
template<>
inline Vectorized<int16_t> flip(const Vectorized<int16_t> & v) {
const __m512i mask = _mm512_set_epi16(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
);
return _mm512_permutexvar_epi16(mask, v);
}
inline __m512i flip8(const __m512i & v) {
const __m512i mask1 = _mm512_set_epi8(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
);
const __m512i mask2 = _mm512_set_epi64(1, 0, 3, 2, 5, 4, 7, 6);
auto reversed_vec = _mm512_shuffle_epi8(v, mask1);
return _mm512_permutexvar_epi64(mask2, reversed_vec);
}
template<>
inline Vectorized<int8_t> flip(const Vectorized<int8_t> & v) {
return flip8(v);
}
template<>
inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
return flip8(v);
}
inline Vectorized<bool> operator&&(
const Vectorized<bool>& self,
const Vectorized<bool>& other) {
const __m512i* self_ = reinterpret_cast<const __m512i*>(self.as_bytes());
const __m512i* other_ = reinterpret_cast<const __m512i*>(other.as_bytes());
__m512i out = _mm512_and_si512(*self_, *other_);
Vectorized<bool> ret;
// We do not have a constructer that takes __m512i, so we need to memcpy
std::memcpy(ret, &out, ret.size() * sizeof(bool));
return ret;
}
#endif // defined(CPU_CAPABILITY_AVX512)
}}}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,513 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <c10/util/complex.h>
#include <c10/util/irange.h>
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#if defined(CPU_CAPABILITY_AVX512)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
namespace at {
namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX512)
template <> class Vectorized<c10::complex<double>> {
private:
__m512d values;
static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0};
public:
using value_type = c10::complex<double>;
using size_type = int;
static constexpr size_type size() {
return 4;
}
Vectorized() {}
Vectorized(__m512d v) : values(v) {}
Vectorized(c10::complex<double> val) {
double real_value = val.real();
double imag_value = val.imag();
values = _mm512_setr_pd(real_value, imag_value, real_value, imag_value,
real_value, imag_value, real_value, imag_value);
}
Vectorized(c10::complex<double> val1, c10::complex<double> val2,
c10::complex<double> val3, c10::complex<double> val4) {
values = _mm512_setr_pd(val1.real(), val1.imag(),
val2.real(), val2.imag(),
val3.real(), val3.imag(),
val4.real(), val4.imag());
}
operator __m512d() const {
return values;
}
template <int64_t mask>
static Vectorized<c10::complex<double>> blend(const Vectorized<c10::complex<double>>& a,
const Vectorized<c10::complex<double>>& b) {
// convert c10::complex<V> index mask to V index mask: xy -> xxyy
// NOLINTNEXTLINE(clang-diagnostic-warning)
switch (mask) {
case 0:
return a;
case 1:
return _mm512_mask_blend_pd(0x03, a.values, b.values); //b0000 0001 = b0000 0011
case 2:
return _mm512_mask_blend_pd(0x0C, a.values, b.values); //b0000 0010 = b0000 1100
case 3:
return _mm512_mask_blend_pd(0x0F, a.values, b.values); //b0000 0011 = b0000 1111
case 4:
return _mm512_mask_blend_pd(0x30, a.values, b.values); //b0000 0100 = b0011 0000
case 5:
return _mm512_mask_blend_pd(0x33, a.values, b.values); //b0000 0101 = b0011 0011
case 6:
return _mm512_mask_blend_pd(0x3C, a.values, b.values); //b0000 0110 = b0011 1100
case 7:
return _mm512_mask_blend_pd(0x3F, a.values, b.values); //b0000 0111 = b0011 1111
case 8:
return _mm512_mask_blend_pd(0xC0, a.values, b.values); //b0000 1000 = b1100 0000
case 9:
return _mm512_mask_blend_pd(0xC3, a.values, b.values); //b0000 1001 = b1100 0011
case 10:
return _mm512_mask_blend_pd(0xCC, a.values, b.values); //b0000 1010 = b1100 1100
case 11:
return _mm512_mask_blend_pd(0xCF, a.values, b.values); //b0000 1011 = b1100 1111
case 12:
return _mm512_mask_blend_pd(0xF0, a.values, b.values); //b0000 1100 = b1111 0000
case 13:
return _mm512_mask_blend_pd(0xF3, a.values, b.values); //b0000 1101 = b1111 0011
case 14:
return _mm512_mask_blend_pd(0xFC, a.values, b.values); //b0000 1110 = b1111 1100
case 15:
return _mm512_mask_blend_pd(0xFF, a.values, b.values); //b0000 1111 = b1111 1111
}
return b;
}
static Vectorized<c10::complex<double>> blendv(const Vectorized<c10::complex<double>>& a,
const Vectorized<c10::complex<double>>& b,
const Vectorized<c10::complex<double>>& mask) {
// convert c10::complex<V> index mask to V index mask: xy -> xxyy
auto mask_ = _mm512_unpacklo_pd(mask.values, mask.values);
auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF);
auto mmask = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask_), all_ones, _MM_CMPINT_EQ);
return _mm512_mask_blend_pd(mmask, a.values, b.values);
}
template<typename step_t>
static Vectorized<c10::complex<double>> arange(c10::complex<double> base = 0.,
step_t step = static_cast<step_t>(1)) {
return Vectorized<c10::complex<double>>(base,
base + c10::complex<double>(1)*step,
base + c10::complex<double>(2)*step,
base + c10::complex<double>(3)*step);
}
static Vectorized<c10::complex<double>> set(const Vectorized<c10::complex<double>>& a,
const Vectorized<c10::complex<double>>& b,
int64_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
}
return b;
}
static Vectorized<c10::complex<double>> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return _mm512_loadu_pd(reinterpret_cast<const double*>(ptr));
__at_align__ double tmp_values[2*size()];
// Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
// for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
// instructions while a loop would be compiled to one instruction.
for (const auto i : c10::irange(2*size())) {
tmp_values[i] = 0.0;
}
std::memcpy(
tmp_values,
reinterpret_cast<const double*>(ptr),
count * sizeof(c10::complex<double>));
return _mm512_load_pd(tmp_values);
}
void store(void* ptr, int count = size()) const {
if (count == size()) {
_mm512_storeu_pd(reinterpret_cast<double*>(ptr), values);
} else if (count > 0) {
double tmp_values[2*size()];
_mm512_storeu_pd(reinterpret_cast<double*>(tmp_values), values);
std::memcpy(ptr, tmp_values, count * sizeof(c10::complex<double>));
}
}
const c10::complex<double>& operator[](int idx) const = delete;
c10::complex<double>& operator[](int idx) = delete;
Vectorized<c10::complex<double>> map(c10::complex<double> (*const f)(const c10::complex<double> &)) const {
__at_align__ c10::complex<double> tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
// AVX512 doesn't have horizontal add & horizontal sub instructions.
// TODO: hadd_pd() & hsub_pd() may have scope for improvement.
static inline __m512d hadd_pd(__m512d a, __m512d b) {
__m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0);
__m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1);
return _mm512_add_pd(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
_mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
}
static inline __m512d hsub_pd(__m512d a, __m512d b) {
__m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0);
__m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1);
return _mm512_sub_pd(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
_mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
}
__m512d abs_2_() const {
auto val_2 = _mm512_mul_pd(values, values); // a*a b*b
return hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b
}
__m512d abs_() const {
auto real = _mm512_movedup_pd(values); // real real
// movehdup_pd does not exist...
auto imag = _mm512_permute_pd(values, 0xff); // imag imag
return Sleef_hypotd8_u05(real, imag); // abs abs
}
Vectorized<c10::complex<double>> abs() const {
const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
return _mm512_and_pd(abs_(), real_mask); // abs 0
}
__m512d angle_() const {
//angle = atan2(b/a)
auto b_a = _mm512_permute_pd(values, 0x55); // b a
return Sleef_atan2d8_u10(values, b_a); // 90-angle angle
}
Vectorized<c10::complex<double>> angle() const {
const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
auto angle = _mm512_permute_pd(angle_(), 0x55); // angle 90-angle
return _mm512_and_pd(angle, real_mask); // angle 0
}
Vectorized<c10::complex<double>> sgn() const {
auto abs = abs_();
auto zero = _mm512_setzero_pd();
auto mask = _mm512_cmp_pd_mask(abs, zero, _CMP_EQ_OQ);
auto div = _mm512_div_pd(values, abs);
return _mm512_mask_blend_pd(mask, div, zero);
}
__m512d real_() const {
const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
return _mm512_and_pd(values, real_mask);
}
Vectorized<c10::complex<double>> real() const {
return real_();
}
__m512d imag_() const {
const __m512d imag_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0x0000000000000000, 0xFFFFFFFFFFFFFFFF,
0x0000000000000000, 0xFFFFFFFFFFFFFFFF,
0x0000000000000000, 0xFFFFFFFFFFFFFFFF,
0x0000000000000000, 0xFFFFFFFFFFFFFFFF));
return _mm512_and_pd(values, imag_mask);
}
Vectorized<c10::complex<double>> imag() const {
return _mm512_permute_pd(imag_(), 0x55); //b a
}
__m512d conj_() const {
const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
return _mm512_xor_pd(values, sign_mask); // a -b
}
Vectorized<c10::complex<double>> conj() const {
return conj_();
}
Vectorized<c10::complex<double>> log() const {
// Most trigonomic ops use the log() op to improve complex number performance.
return map(std::log);
}
Vectorized<c10::complex<double>> log2() const {
const __m512d log2_ = _mm512_set1_pd(std::log(2));
return _mm512_div_pd(log(), log2_);
}
Vectorized<c10::complex<double>> log10() const {
const __m512d log10_ = _mm512_set1_pd(std::log(10));
return _mm512_div_pd(log(), log10_);
}
Vectorized<c10::complex<double>> log1p() const {
return map(std::log1p);
}
Vectorized<c10::complex<double>> asin() const {
// asin(x)
// = -i*ln(iz + sqrt(1 -z^2))
// = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
// = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
const __m512d one = _mm512_set1_pd(1);
auto conj = conj_();
auto b_a = _mm512_permute_pd(conj, 0x55); //-b a
auto ab = _mm512_mul_pd(conj, b_a); //-ab -ab
auto im = _mm512_add_pd(ab, ab); //-2ab -2ab
auto val_2 = _mm512_mul_pd(values, values); // a*a b*b
auto re = hsub_pd(val_2, _mm512_permute_pd(val_2, 0x55)); // a*a-b*b b*b-a*a
re = _mm512_sub_pd(one, re);
auto root = Vectorized(_mm512_mask_blend_pd(0xAA, re, im)).sqrt(); //sqrt(re + i*im)
auto ln = Vectorized(_mm512_add_pd(b_a, root)).log(); //ln(iz + sqrt())
return Vectorized(_mm512_permute_pd(ln.values, 0x55)).conj(); //-i*ln()
}
Vectorized<c10::complex<double>> acos() const {
// acos(x) = pi/2 - asin(x)
constexpr auto pi_2d = c10::pi<double> / 2;
const __m512d pi_2 = _mm512_setr_pd(pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0);
return _mm512_sub_pd(pi_2, asin());
}
Vectorized<c10::complex<double>> atan() const;
Vectorized<c10::complex<double>> atanh() const {
return map(std::atanh);
}
Vectorized<c10::complex<double>> exp() const {
//exp(a + bi)
// = exp(a)*(cos(b) + sin(b)i)
auto exp = Sleef_expd8_u10(values); //exp(a) exp(b)
exp = _mm512_mask_blend_pd(0xAA, exp, _mm512_permute_pd(exp, 0x55)); //exp(a) exp(a)
auto sin_cos = Sleef_sincosd8_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)]
auto cos_sin = _mm512_mask_blend_pd(0xAA, _mm512_permute_pd(sin_cos.y, 0x55),
sin_cos.x); //cos(b) sin(b)
return _mm512_mul_pd(exp, cos_sin);
}
Vectorized<c10::complex<double>> exp2() const {
// Use identity 2**x = exp(log(2) * x)
const __m512d ln_2 = _mm512_set1_pd(c10::ln_2<double>);
Vectorized<c10::complex<double>> scaled_values = _mm512_mul_pd(values, ln_2);
return scaled_values.exp();
}
Vectorized<c10::complex<double>> expm1() const {
return map(std::expm1);
}
Vectorized<c10::complex<double>> sin() const {
return map(std::sin);
}
Vectorized<c10::complex<double>> sinh() const {
return map(std::sinh);
}
Vectorized<c10::complex<double>> cos() const {
return map(std::cos);
}
Vectorized<c10::complex<double>> cosh() const {
return map(std::cosh);
}
Vectorized<c10::complex<double>> ceil() const {
return _mm512_ceil_pd(values);
}
Vectorized<c10::complex<double>> floor() const {
return _mm512_floor_pd(values);
}
Vectorized<c10::complex<double>> neg() const {
auto zero = _mm512_setzero_pd();
return _mm512_sub_pd(zero, values);
}
Vectorized<c10::complex<double>> round() const {
return _mm512_roundscale_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
Vectorized<c10::complex<double>> tan() const {
return map(std::tan);
}
Vectorized<c10::complex<double>> tanh() const {
return map(std::tanh);
}
Vectorized<c10::complex<double>> trunc() const {
return _mm512_roundscale_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
}
Vectorized<c10::complex<double>> sqrt() const {
return map(std::sqrt);
}
Vectorized<c10::complex<double>> reciprocal() const;
Vectorized<c10::complex<double>> rsqrt() const {
return sqrt().reciprocal();
}
Vectorized<c10::complex<double>> pow(const Vectorized<c10::complex<double>> &exp) const {
__at_align__ c10::complex<double> x_tmp[size()];
__at_align__ c10::complex<double> y_tmp[size()];
store(x_tmp);
exp.store(y_tmp);
for (const auto i : c10::irange(size())) {
x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
}
return loadu(x_tmp);
}
// Comparison using the _CMP_**_OQ predicate.
// `O`: get false if an operand is NaN
// `Q`: do not raise if an operand is NaN
Vectorized<c10::complex<double>> operator==(const Vectorized<c10::complex<double>>& other) const {
auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, mask,
0xFFFFFFFFFFFFFFFF));
}
Vectorized<c10::complex<double>> operator!=(const Vectorized<c10::complex<double>>& other) const {
auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_UQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, mask,
0xFFFFFFFFFFFFFFFF));
}
Vectorized<c10::complex<double>> operator<(const Vectorized<c10::complex<double>>& other [[maybe_unused]]) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> operator<=(const Vectorized<c10::complex<double>>& other [[maybe_unused]]) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> operator>(const Vectorized<c10::complex<double>>& other [[maybe_unused]]) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> operator>=(const Vectorized<c10::complex<double>>& other [[maybe_unused]]) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Vectorized<c10::complex<double>> eq(const Vectorized<c10::complex<double>>& other) const;
Vectorized<c10::complex<double>> ne(const Vectorized<c10::complex<double>>& other) const;
};
template <> Vectorized<c10::complex<double>> inline operator+(const Vectorized<c10::complex<double>> &a,
const Vectorized<c10::complex<double>> &b) {
return _mm512_add_pd(a, b);
}
template <> Vectorized<c10::complex<double>> inline operator-(const Vectorized<c10::complex<double>> &a,
const Vectorized<c10::complex<double>> &b) {
return _mm512_sub_pd(a, b);
}
template <> Vectorized<c10::complex<double>> inline operator*(const Vectorized<c10::complex<double>> &a,
const Vectorized<c10::complex<double>> &b) {
//(a + bi) * (c + di) = (ac - bd) + (ad + bc)i
const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
auto ac_bd = _mm512_mul_pd(a, b); //ac bd
auto d_c = _mm512_permute_pd(b, 0x55); //d c
d_c = _mm512_xor_pd(sign_mask, d_c); //d -c
auto ad_bc = _mm512_mul_pd(a, d_c); //ad -bc
auto ret = Vectorized<c10::complex<double>>::hsub_pd(ac_bd, ad_bc); //ac - bd ad + bc
return ret;
}
template <> Vectorized<c10::complex<double>> inline operator/(const Vectorized<c10::complex<double>> &a,
const Vectorized<c10::complex<double>> &b) {
//re + im*i = (a + bi) / (c + di)
auto mask = _mm512_set1_pd(-0.f);
auto fabs_cd = _mm512_andnot_pd(mask, b); // |c| |d|
auto fabs_dc = _mm512_permute_pd(fabs_cd, 0x55); // |d| |c|
auto scale = _mm512_rcp14_pd(_mm512_max_pd(fabs_cd, fabs_dc)); // 1/sc 1/sc
auto a2 = _mm512_mul_pd(a, scale); // a/sc b/sc
auto b2 = _mm512_mul_pd(b, scale); // c/sc d/sc
auto acbd2 = _mm512_mul_pd(a2, b2);
const __m512d sign_mask = _mm512_setr_pd(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
auto dc2 = _mm512_permute_pd(b2, 0x55); // d/sc c/sc
dc2 = _mm512_xor_pd(sign_mask, dc2); // -d/|c,d| c/sc
auto adbc2 = _mm512_mul_pd(a2, dc2); //-ad/sc^2 bc/sc^2
auto res2 = Vectorized<c10::complex<double>>::hadd_pd(acbd2, adbc2); //(ac+bd)/sc^2 (bc-ad)/sc^2
// get the denominator
auto denom2 = Vectorized<c10::complex<double>>(b2).abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2
res2 = _mm512_div_pd(res2, denom2);
return res2;
}
// reciprocal. Implement this here so we can use multiplication.
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::reciprocal() const{
//re + im*i = (a + bi) / (c + di)
//re = (ac + bd)/abs_2() = c/abs_2()
//im = (bc - ad)/abs_2() = d/abs_2()
const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
auto c_d = _mm512_xor_pd(sign_mask, values); //c -d
return _mm512_div_pd(c_d, abs_2_());
}
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::atan() const {
// atan(x) = i/2 * ln((i + z)/(i - z))
const __m512d i = _mm512_setr_pd(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0);
const Vectorized i_half = _mm512_setr_pd(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5);
auto sum = Vectorized(_mm512_add_pd(i, values)); // a 1+b
auto sub = Vectorized(_mm512_sub_pd(i, values)); // -a 1-b
auto ln = (sum/sub).log(); // ln((i + z)/(i - z))
return i_half*ln; // i/2*ln()
}
template <>
Vectorized<c10::complex<double>> inline maximum(const Vectorized<c10::complex<double>>& a,
const Vectorized<c10::complex<double>>& b) {
auto zero_vec = _mm512_set1_epi64(0);
auto abs_a = a.abs_2_();
auto abs_b = b.abs_2_();
auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_LT_OQ);
auto max = _mm512_mask_blend_pd(mask, a, b);
// Exploit the fact that all-ones is a NaN.
auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q);
auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask,
0xFFFFFFFFFFFFFFFF);
return _mm512_or_pd(max, _mm512_castsi512_pd(isnan));
}
template <>
Vectorized<c10::complex<double>> inline minimum(const Vectorized<c10::complex<double>>& a,
const Vectorized<c10::complex<double>>& b) {
auto zero_vec = _mm512_set1_epi64(0);
auto abs_a = a.abs_2_();
auto abs_b = b.abs_2_();
auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_GT_OQ);
auto min = _mm512_mask_blend_pd(mask, a, b);
// Exploit the fact that all-ones is a NaN.
auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q);
auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask,
0xFFFFFFFFFFFFFFFF);
return _mm512_or_pd(min, _mm512_castsi512_pd(isnan));
}
template <>
Vectorized<c10::complex<double>> inline operator&(const Vectorized<c10::complex<double>>& a,
const Vectorized<c10::complex<double>>& b) {
return _mm512_and_pd(a, b);
}
template <>
Vectorized<c10::complex<double>> inline operator|(const Vectorized<c10::complex<double>>& a,
const Vectorized<c10::complex<double>>& b) {
return _mm512_or_pd(a, b);
}
template <>
Vectorized<c10::complex<double>> inline operator^(const Vectorized<c10::complex<double>>& a,
const Vectorized<c10::complex<double>>& b) {
return _mm512_xor_pd(a, b);
}
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::eq(const Vectorized<c10::complex<double>>& other) const {
auto eq = (*this == other); // compares real and imag individually
// If both real numbers and imag numbers are equal, then the complex numbers are equal
return (eq.real() & eq.imag()) & Vectorized<c10::complex<double>>(_mm512_set1_pd(1.0));
}
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::ne(const Vectorized<c10::complex<double>>& other) const {
auto ne = (*this != other); // compares real and imag individually
// If either real numbers or imag numbers are not equal, then the complex numbers are not equal
return (ne.real() | ne.imag()) & Vectorized<c10::complex<double>>(_mm512_set1_pd(1.0));
}
#endif
}}}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,262 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec512/vec512_bfloat16.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec_convert.h>
namespace at::vec {
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
template <>
struct VecConvert<float, 1, BFloat16, 1> {
static inline VectorizedN<float, 1> apply(
const VectorizedN<BFloat16, 1>& src) {
VectorizedN<float, 1> result;
__m512 value;
cvtbf16_fp32(_mm512_castsi512_si256(src[0]), value);
result[0] = value;
return result;
}
};
template <>
struct VecConvert<float, 1, Half, 1> {
static inline VectorizedN<float, 1> apply(const VectorizedN<Half, 1>& src) {
VectorizedN<float, 1> result;
__m512 value;
cvtfp16_fp32(_mm512_castsi512_si256(src[0]), value);
result[0] = value;
return result;
}
};
template <>
struct VecConvert<BFloat16, 1, float, 1> {
static inline VectorizedN<BFloat16, 1> apply(
const VectorizedN<float, 1>& src) {
VectorizedN<BFloat16, 1> result;
result[0] = _mm512_castsi256_si512(cvtfp32_bf16(src[0]));
return result;
}
};
template <>
struct VecConvert<BFloat16, 1, float, 2> {
static inline VectorizedN<BFloat16, 1> apply(
const VectorizedN<float, 2>& src) {
VectorizedN<BFloat16, 1> result;
result[0] = convert_float_bfloat16(src[0], src[1]);
return result;
}
};
template <>
struct VecConvert<float, 2, BFloat16, 1> {
static inline VectorizedN<float, 2> apply(
const VectorizedN<BFloat16, 1>& src) {
VectorizedN<float, 2> result;
std::tie(result[0], result[1]) = convert_bfloat16_float(src[0]);
return result;
}
};
template <>
struct VecConvert<Half, 1, float, 1> {
static inline VectorizedN<Half, 1> apply(const VectorizedN<float, 1>& src) {
VectorizedN<Half, 1> result;
result[0] = _mm512_castsi256_si512(cvtfp32_fp16(src[0]));
return result;
}
};
template <>
struct VecConvert<Half, 1, float, 2> {
static inline VectorizedN<Half, 1> apply(const VectorizedN<float, 2>& src) {
VectorizedN<Half, 1> result;
result[0] = convert_float_half(src[0], src[1]);
return result;
}
};
template <>
struct VecConvert<float, 2, Half, 1> {
static inline VectorizedN<float, 2> apply(const VectorizedN<Half, 1>& src) {
VectorizedN<float, 2> result;
std::tie(result[0], result[1]) = convert_half_float(src[0]);
return result;
}
};
template <>
struct VecConvert<float, 1, int64_t, 2> {
static inline VectorizedN<float, 1> apply(
const VectorizedN<int64_t, 2>& src) {
auto low = _mm512_cvtepi64_ps(src[0]);
auto high = _mm512_cvtepi64_ps(src[1]);
return Vectorized<float>(
_mm512_insertf32x8(_mm512_castps256_ps512(low), high, 1));
}
};
template <>
struct VecConvert<int64_t, 2, float, 1> {
static inline VectorizedN<int64_t, 2> apply(
const VectorizedN<float, 1>& src) {
at::vec::VectorizedN<int64_t, 2> result;
result[0] = _mm512_cvt_roundps_epi64(
_mm512_castps512_ps256(src[0]), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC);
result[1] = _mm512_cvt_roundps_epi64(
_mm512_extractf32x8_ps(src[0], 1),
_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC);
return result;
}
};
template <>
struct VecConvert<int32_t, 1, int64_t, 2> {
static inline VectorizedN<int32_t, 1> apply(
const VectorizedN<int64_t, 2>& src) {
auto low = _mm512_cvtepi64_epi32(src[0]);
auto high = _mm512_cvtepi64_epi32(src[1]);
return Vectorized<int32_t>(
_mm512_inserti32x8(_mm512_castsi256_si512(low), high, 1));
}
};
template <>
struct VecConvert<int64_t, 2, int32_t, 1> {
static inline VectorizedN<int64_t, 2> apply(
const VectorizedN<int32_t, 1>& src) {
at::vec::VectorizedN<int64_t, 2> result;
result[0] = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(src[0]));
result[1] = _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(src[0], 1));
return result;
}
};
template <>
struct VecConvert<int32_t, 1, int8_t, 1> {
static inline VectorizedN<int32_t, 1> apply(
const VectorizedN<int8_t, 1>& src) {
auto src128 = _mm512_castsi512_si128(src[0]);
return Vectorized<int32_t>(_mm512_cvtepi8_epi32(src128));
}
};
template <>
struct VecConvert<int32_t, 1, uint8_t, 1> {
static inline VectorizedN<int32_t, 1> apply(
const VectorizedN<uint8_t, 1>& src) {
auto src128 = _mm512_castsi512_si128(src[0]);
return Vectorized<int32_t>(_mm512_cvtepu8_epi32(src128));
}
};
template <>
struct VecConvert<int32_t, 1, float, 1> {
static inline VectorizedN<int32_t, 1> apply(
const VectorizedN<float, 1>& src) {
return Vectorized<int32_t>(_mm512_cvttps_epi32(src[0]));
}
};
template <>
struct VecConvert<float, 1, int32_t, 1> {
static inline VectorizedN<float, 1> apply(
const VectorizedN<int32_t, 1>& src) {
return Vectorized<float>(_mm512_cvtepi32_ps(src[0]));
}
};
template <>
struct VecConvert<int16_t, 1, uint8_t, 1> {
static inline VectorizedN<int16_t, 1> apply(
const VectorizedN<uint8_t, 1>& src) {
auto src256 = _mm512_castsi512_si256(src[0]);
return Vectorized<int16_t>(_mm512_cvtepu8_epi16(src256));
}
};
template <>
struct VecConvert<int8_t, 1, int32_t, 1> {
static inline VectorizedN<int8_t, 1> apply(
const VectorizedN<int32_t, 1>& src) {
auto src128 = _mm512_cvtepi32_epi8(src[0]);
return Vectorized<int8_t>(_mm512_castsi128_si512(src128));
}
};
template <>
struct VecConvert<int8_t, 1, int16_t, 1> {
static inline VectorizedN<int8_t, 1> apply(
const VectorizedN<int16_t, 1>& src) {
auto src256 = _mm512_cvtepi16_epi8(src[0]);
return Vectorized<int8_t>(_mm512_castsi256_si512(src256));
}
};
template <typename dst_t, typename src_t>
struct VecConvert<
dst_t,
1,
src_t,
1,
typename std::enable_if_t<
(is_reduced_floating_point_v<dst_t> && is_8bit_integer_v<src_t>) ||
(is_reduced_floating_point_v<src_t> && is_8bit_integer_v<dst_t>),
void>> {
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<src_t, 1>& src) {
VectorizedN<float, 1> tmp_fp32 = VecConvert<float, 1, src_t, 1>::apply(src);
return VecConvert<dst_t, 1, float, 1>::apply(tmp_fp32);
}
};
template <typename dst_t>
struct VecConvert<
dst_t,
1,
float,
1,
typename std::enable_if_t<is_8bit_integer_v<dst_t>,
void>> {
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<float, 1>& src) {
return convert_float_to_int8<dst_t>(src[0]);
}
};
template <typename src_t>
struct VecConvert<
float,
1,
src_t,
1,
typename std::enable_if_t<is_8bit_integer_v<src_t>,
void>> {
static inline VectorizedN<float, 1> apply(const VectorizedN<src_t, 1>& src) {
return convert_int8_to_float<src_t>(src[0]);
}
};
template <typename dst_t>
struct VecConvert<
dst_t,
1,
int64_t,
2,
typename std::enable_if<
std::is_same_v<dst_t, int8_t> ||
std::is_same_v<dst_t, uint8_t>>::type> {
static inline VectorizedN<dst_t, 1> apply(
const VectorizedN<int64_t, 2>& src) {
return VecConvert<dst_t, 1, int32_t, 1>::apply(
VecConvert<int32_t, 1, int64_t, 2>::apply(src));
}
};
#endif
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -0,0 +1,472 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/irange.h>
#if (defined(CPU_CAPABILITY_AVX512))
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
namespace at {
namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX512)
template <> class Vectorized<double> {
private:
static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0};
public:
// values needs to be public for compilation with clang
// as vec512.h uses it
__m512d values;
using value_type = double;
using size_type = int;
static constexpr size_type size() {
return 8;
}
Vectorized() {}
Vectorized(__m512d v) : values(v) {}
Vectorized(double val) {
values = _mm512_set1_pd(val);
}
Vectorized(double val1, double val2, double val3, double val4,
double val5, double val6, double val7, double val8) {
values = _mm512_setr_pd(val1, val2, val3, val4, val5, val6, val7, val8);
}
operator __m512d() const {
return values;
}
template <int64_t mask>
static Vectorized<double> blend(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm512_mask_blend_pd(mask, a.values, b.values);
}
static Vectorized<double> blendv(const Vectorized<double>& a, const Vectorized<double>& b,
const Vectorized<double>& mask) {
auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF);
auto mmask = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask.values), all_ones, _MM_CMPINT_EQ);
return _mm512_mask_blend_pd(mmask, a.values, b.values);
}
template<typename step_t>
static Vectorized<double> arange(double base = 0., step_t step = static_cast<step_t>(1)) {
return Vectorized<double>(base, base + step, base + 2 * step, base + 3 * step,
base + 4 * step, base + 5 * step, base + 6 * step,
base + 7 * step);
}
static Vectorized<double> set(const Vectorized<double>& a, const Vectorized<double>& b,
int64_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
case 4:
return blend<15>(a, b);
case 5:
return blend<31>(a, b);
case 6:
return blend<63>(a, b);
case 7:
return blend<127>(a, b);
}
return b;
}
static Vectorized<double> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return _mm512_loadu_pd(reinterpret_cast<const double*>(ptr));
__mmask8 mask = (1ULL << count) - 1;
return _mm512_maskz_loadu_pd(mask, ptr);
}
void store(void* ptr, int count = size()) const {
if (count == size()) {
_mm512_storeu_pd(reinterpret_cast<double*>(ptr), values);
} else if (count > 0) {
__mmask8 mask = (1ULL << count) - 1;
_mm512_mask_storeu_pd(reinterpret_cast<double*>(ptr), mask, values);
}
}
const double& operator[](int idx) const = delete;
double& operator[](int idx) = delete;
int zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
__mmask8 cmp = _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_EQ_OQ);
return static_cast<int32_t>(cmp);
}
Vectorized<double> isnan() const {
auto cmp_mask = _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_UNORD_Q);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
0xFFFFFFFFFFFFFFFF));
}
bool has_inf_nan() const {
__m512d self_sub = _mm512_sub_pd(values, values);
return (_mm512_movepi8_mask(_mm512_castpd_si512(self_sub)) & 0x7777777777777777) != 0;
}
Vectorized<double> map(double (*const f)(double)) const {
__at_align__ double tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
Vectorized<double> abs() const {
auto mask = _mm512_set1_pd(-0.f);
return _mm512_andnot_pd(mask, values);
}
Vectorized<double> angle() const {
const auto zero_vec = _mm512_castsi512_pd(zero_vector);
const auto nan_vec = _mm512_set1_pd(NAN);
const auto not_nan_mask = _mm512_cmp_pd_mask(values, values, _CMP_EQ_OQ);
const auto not_nan = _mm512_mask_set1_epi64(zero_vector, not_nan_mask,
0xFFFFFFFFFFFFFFFF);
const auto nan_mask = _mm512_cmp_pd_mask(_mm512_castsi512_pd(not_nan),
zero_vec, _CMP_EQ_OQ);
const auto pi = _mm512_set1_pd(c10::pi<double>);
const auto neg_mask = _mm512_cmp_pd_mask(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm512_mask_blend_pd(neg_mask, zero_vec, pi);
angle = _mm512_mask_blend_pd(nan_mask, angle, nan_vec);
return angle;
}
Vectorized<double> real() const {
return *this;
}
Vectorized<double> imag() const {
return _mm512_set1_pd(0);
}
Vectorized<double> conj() const {
return *this;
}
Vectorized<double> acos() const {
return Vectorized<double>(Sleef_acosd8_u10(values));
}
Vectorized<double> acosh() const {
return Vectorized<double>(Sleef_acoshd8_u10(values));
}
Vectorized<double> asin() const {
return Vectorized<double>(Sleef_asind8_u10(values));
}
Vectorized<double> atan() const {
return Vectorized<double>(Sleef_atand8_u10(values));
}
Vectorized<double> atanh() const {
return Vectorized<double>(Sleef_atanhd8_u10(values));
}
Vectorized<double> atan2(const Vectorized<double> &b) const {
return Vectorized<double>(Sleef_atan2d8_u10(values, b));
}
Vectorized<double> copysign(const Vectorized<double> &sign) const {
return Vectorized<double>(Sleef_copysignd8(values, sign));
}
Vectorized<double> erf() const {
return Vectorized<double>(Sleef_erfd8_u10(values));
}
Vectorized<double> erfc() const {
return Vectorized<double>(Sleef_erfcd8_u15(values));
}
Vectorized<double> erfinv() const {
return map(calc_erfinv);
}
Vectorized<double> exp() const {
return Vectorized<double>(Sleef_expd8_u10(values));
}
Vectorized<double> exp2() const {
return Vectorized<double>(Sleef_exp2d8_u10(values));
}
Vectorized<double> expm1() const {
return Vectorized<double>(Sleef_expm1d8_u10(values));
}
Vectorized<double> exp_u20() const {
return exp();
}
Vectorized<double> fmod(const Vectorized<double>& q) const {
return Vectorized<double>(Sleef_fmodd8(values, q));
}
Vectorized<double> hypot(const Vectorized<double> &b) const {
return Vectorized<double>(Sleef_hypotd8_u05(values, b));
}
Vectorized<double> i0() const {
return map(calc_i0);
}
Vectorized<double> i0e() const {
return map(calc_i0e);
}
Vectorized<double> digamma() const {
return map(calc_digamma);
}
Vectorized<double> igamma(const Vectorized<double> &x) const {
__at_align__ double tmp[size()];
__at_align__ double tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (const auto i : c10::irange(size())) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<double> igammac(const Vectorized<double> &x) const {
__at_align__ double tmp[size()];
__at_align__ double tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (const auto i : c10::irange(size())) {
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<double> log() const {
return Vectorized<double>(Sleef_logd8_u10(values));
}
Vectorized<double> log2() const {
return Vectorized<double>(Sleef_log2d8_u10(values));
}
Vectorized<double> log10() const {
return Vectorized<double>(Sleef_log10d8_u10(values));
}
Vectorized<double> log1p() const {
return Vectorized<double>(Sleef_log1pd8_u10(values));
}
Vectorized<double> sin() const {
return Vectorized<double>(Sleef_sind8_u10(values));
}
Vectorized<double> sinh() const {
return Vectorized<double>(Sleef_sinhd8_u10(values));
}
Vectorized<double> cos() const {
return Vectorized<double>(Sleef_cosd8_u10(values));
}
Vectorized<double> cosh() const {
return Vectorized<double>(Sleef_coshd8_u10(values));
}
Vectorized<double> ceil() const {
return _mm512_ceil_pd(values);
}
Vectorized<double> floor() const {
return _mm512_floor_pd(values);
}
Vectorized<double> frac() const;
Vectorized<double> neg() const {
return _mm512_xor_pd(_mm512_set1_pd(-0.), values);
}
Vectorized<double> nextafter(const Vectorized<double> &b) const {
return Vectorized<double>(Sleef_nextafterd8(values, b));
}
Vectorized<double> round() const {
return _mm512_roundscale_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
Vectorized<double> tan() const {
return Vectorized<double>(Sleef_tand8_u10(values));
}
Vectorized<double> tanh() const {
return Vectorized<double>(Sleef_tanhd8_u10(values));
}
Vectorized<double> trunc() const {
return _mm512_roundscale_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
}
Vectorized<double> lgamma() const {
return Vectorized<double>(Sleef_lgammad8_u10(values));
}
Vectorized<double> sqrt() const {
return _mm512_sqrt_pd(values);
}
Vectorized<double> reciprocal() const {
return _mm512_div_pd(_mm512_set1_pd(1), values);
}
Vectorized<double> rsqrt() const {
return _mm512_div_pd(_mm512_set1_pd(1), _mm512_sqrt_pd(values));
}
Vectorized<double> pow(const Vectorized<double> &b) const {
return Vectorized<double>(Sleef_powd8_u10(values, b));
}
// Comparison using the _CMP_**_OQ predicate.
// `O`: get false if an operand is NaN
// `Q`: do not raise if an operand is NaN
Vectorized<double> operator==(const Vectorized<double>& other) const {
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
0xFFFFFFFFFFFFFFFF));
}
Vectorized<double> operator!=(const Vectorized<double>& other) const {
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_UQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
0xFFFFFFFFFFFFFFFF));
}
Vectorized<double> operator<(const Vectorized<double>& other) const {
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LT_OQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
0xFFFFFFFFFFFFFFFF));
}
Vectorized<double> operator<=(const Vectorized<double>& other) const {
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LE_OQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
0xFFFFFFFFFFFFFFFF));
}
Vectorized<double> operator>(const Vectorized<double>& other) const {
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GT_OQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
0xFFFFFFFFFFFFFFFF));
}
Vectorized<double> operator>=(const Vectorized<double>& other) const {
auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GE_OQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
0xFFFFFFFFFFFFFFFF));
}
Vectorized<double> eq(const Vectorized<double>& other) const;
Vectorized<double> ne(const Vectorized<double>& other) const;
Vectorized<double> lt(const Vectorized<double>& other) const;
Vectorized<double> le(const Vectorized<double>& other) const;
Vectorized<double> gt(const Vectorized<double>& other) const;
Vectorized<double> ge(const Vectorized<double>& other) const;
};
template <>
Vectorized<double> inline operator+(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm512_add_pd(a, b);
}
template <>
Vectorized<double> inline operator-(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm512_sub_pd(a, b);
}
template <>
Vectorized<double> inline operator*(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm512_mul_pd(a, b);
}
template <>
Vectorized<double> inline operator/(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm512_div_pd(a, b);
}
// frac. Implement this here so we can use subtraction.
inline Vectorized<double> Vectorized<double>::frac() const {
return *this - this->trunc();
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<double> inline maximum(const Vectorized<double>& a, const Vectorized<double>& b) {
auto zero_vec = _mm512_set1_epi64(0);
Vectorized<double> max = _mm512_max_pd(a, b);
auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q);
auto isnan = _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vec, isnan_mask,
0xFFFFFFFFFFFFFFFF));
// Exploit the fact that all-ones is a NaN.
return _mm512_or_pd(max, isnan);
}
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<double> inline minimum(const Vectorized<double>& a, const Vectorized<double>& b) {
auto zero_vec = _mm512_set1_epi64(0);
Vectorized<double> min = _mm512_min_pd(a, b);
auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q);
auto isnan = _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vec, isnan_mask,
0xFFFFFFFFFFFFFFFF));
// Exploit the fact that all-ones is a NaN.
return _mm512_or_pd(min, isnan);
}
template <>
Vectorized<double> inline clamp(const Vectorized<double>& a, const Vectorized<double>& min, const Vectorized<double>& max) {
return _mm512_min_pd(max, _mm512_max_pd(min, a));
}
template <>
Vectorized<double> inline clamp_min(const Vectorized<double>& a, const Vectorized<double>& min) {
return _mm512_max_pd(min, a);
}
template <>
Vectorized<double> inline clamp_max(const Vectorized<double>& a, const Vectorized<double>& max) {
return _mm512_min_pd(max, a);
}
template <>
Vectorized<double> inline operator&(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm512_and_pd(a, b);
}
template <>
Vectorized<double> inline operator|(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm512_or_pd(a, b);
}
template <>
Vectorized<double> inline operator^(const Vectorized<double>& a, const Vectorized<double>& b) {
return _mm512_xor_pd(a, b);
}
inline Vectorized<double> Vectorized<double>::eq(const Vectorized<double>& other) const {
return (*this == other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::ne(const Vectorized<double>& other) const {
return (*this != other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::gt(const Vectorized<double>& other) const {
return (*this > other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::ge(const Vectorized<double>& other) const {
return (*this >= other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::lt(const Vectorized<double>& other) const {
return (*this < other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::le(const Vectorized<double>& other) const {
return (*this <= other) & Vectorized<double>(1.0);
}
template <>
inline void convert(const double* src, double* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<double>::size()); i += Vectorized<double>::size()) {
_mm512_storeu_pd(dst + i, _mm512_loadu_pd(src + i));
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = src[i];
}
}
template <>
Vectorized<double> inline fmadd(const Vectorized<double>& a, const Vectorized<double>& b, const Vectorized<double>& c) {
return _mm512_fmadd_pd(a, b, c);
}
template <>
Vectorized<double> inline fmsub(const Vectorized<double>& a, const Vectorized<double>& b, const Vectorized<double>& c) {
return _mm512_fmsub_pd(a, b, c);
}
#endif
}}}

View File

@ -0,0 +1,708 @@
#pragma once
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with AVX]
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/irange.h>
#if defined(CPU_CAPABILITY_AVX512)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif
namespace at {
namespace vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX512)
template <> class Vectorized<float> {
private:
static constexpr __m512i zero_vec {0, 0, 0, 0, 0, 0, 0, 0};
public:
__m512 values;
using value_type = float;
using size_type = int;
static constexpr size_type size() {
return 16;
}
Vectorized() {}
Vectorized(__m512 v) : values(v) {}
Vectorized(float val) {
values = _mm512_set1_ps(val);
}
Vectorized(float val1, float val2, float val3, float val4,
float val5, float val6, float val7, float val8,
float val9, float val10, float val11, float val12,
float val13, float val14, float val15, float val16) {
values = _mm512_setr_ps(val1, val2, val3, val4, val5, val6, val7, val8,
val9, val10, val11, val12, val13, val14, val15, val16);
}
operator __m512() const {
return values;
}
template <int64_t mask>
static Vectorized<float> blend(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm512_mask_blend_ps(mask, a.values, b.values);
}
static Vectorized<float> blendv(const Vectorized<float>& a, const Vectorized<float>& b,
const Vectorized<float>& mask) {
auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask.values), all_ones, _MM_CMPINT_EQ);
return _mm512_mask_blend_ps(mmask, a.values, b.values);
}
template<typename step_t>
static Vectorized<float> arange(float base = 0.f, step_t step = static_cast<step_t>(1)) {
return Vectorized<float>(
base, base + step, base + 2 * step, base + 3 * step,
base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step,
base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step,
base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step);
}
static Vectorized<float> set(const Vectorized<float>& a, const Vectorized<float>& b,
int64_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
case 4:
return blend<15>(a, b);
case 5:
return blend<31>(a, b);
case 6:
return blend<63>(a, b);
case 7:
return blend<127>(a, b);
case 8:
return blend<255>(a, b);
case 9:
return blend<511>(a, b);
case 10:
return blend<1023>(a, b);
case 11:
return blend<2047>(a, b);
case 12:
return blend<4095>(a, b);
case 13:
return blend<8191>(a, b);
case 14:
return blend<16383>(a, b);
case 15:
return blend<32767>(a, b);
}
return b;
}
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
if (count == size())
return _mm512_loadu_ps(reinterpret_cast<const float*>(ptr));
__mmask16 mask = (1ULL << count) - 1;
return _mm512_maskz_loadu_ps(mask, ptr);
}
void store(void* ptr, int64_t count = size()) const {
if (count == size()) {
_mm512_storeu_ps(reinterpret_cast<float*>(ptr), values);
} else if (count > 0) {
__mmask16 mask = (1ULL << count) - 1;
_mm512_mask_storeu_ps(reinterpret_cast<float*>(ptr), mask, values);
}
}
const float& operator[](int idx) const = delete;
float& operator[](int idx) = delete;
int zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
__mmask16 cmp = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_EQ_OQ);
return static_cast<int32_t>(cmp);
}
Vectorized<float> isnan() const {
auto mask = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_UNORD_Q);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
0xFFFFFFFF));
}
bool has_inf_nan() const {
__m512 self_sub = _mm512_sub_ps(values, values);
return (_mm512_movepi8_mask(_mm512_castps_si512(self_sub)) & 0x7777777777777777) != 0;
}
Vectorized<float> map(float (*const f)(float)) const {
__at_align__ float tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = f(tmp[i]);
}
return loadu(tmp);
}
Vectorized<float> abs() const {
auto mask = _mm512_set1_ps(-0.f);
return _mm512_andnot_ps(mask, values);
}
Vectorized<float> angle() const {
__m512 zero_vec = _mm512_set1_ps(0.f);
const auto nan_vec = _mm512_set1_ps(NAN);
const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ);
const auto not_nan_vec = _mm512_mask_set1_epi32(_mm512_castps_si512(zero_vec),
not_nan_mask, 0xFFFFFFFF);
const auto nan_mask = _mm512_cmp_ps_mask(_mm512_castsi512_ps(not_nan_vec),
zero_vec, _CMP_EQ_OQ);
const auto pi = _mm512_set1_ps(c10::pi<double>);
const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi);
angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec);
return angle;
}
Vectorized<float> real() const {
return *this;
}
Vectorized<float> imag() const {
return _mm512_set1_ps(0);
}
Vectorized<float> conj() const {
return *this;
}
Vectorized<float> acos() const {
return Vectorized<float>(Sleef_acosf16_u10(values));
}
Vectorized<float> acosh() const {
return Vectorized<float>(Sleef_acoshf16_u10(values));
}
Vectorized<float> asin() const {
return Vectorized<float>(Sleef_asinf16_u10(values));
}
Vectorized<float> atan() const {
return Vectorized<float>(Sleef_atanf16_u10(values));
}
Vectorized<float> atanh() const {
return Vectorized<float>(Sleef_atanhf16_u10(values));
}
Vectorized<float> atan2(const Vectorized<float> &b) const {
return Vectorized<float>(Sleef_atan2f16_u10(values, b));
}
Vectorized<float> copysign(const Vectorized<float> &sign) const {
return Vectorized<float>(Sleef_copysignf16(values, sign));
}
Vectorized<float> erf() const {
// constants
const auto neg_zero_vec = _mm512_set1_ps(-0.f);
const auto one_vec = _mm512_set1_ps(1.0f);
const auto p = _mm512_set1_ps(0.3275911f);
const auto p1 = _mm512_set1_ps(0.254829592f);
const auto p2 = _mm512_set1_ps(-0.284496736f);
const auto p3 = _mm512_set1_ps(1.421413741f);
const auto p4 = _mm512_set1_ps(-1.453152027f);
const auto p5 = _mm512_set1_ps(1.061405429f);
// sign(x)
auto sign_mask = _mm512_and_ps(neg_zero_vec, values);
auto abs_vec = _mm512_abs_ps(values);
// t = 1 / (p * abs(x) + 1)
auto tmp0 = _mm512_fmadd_ps(p, abs_vec, one_vec);
auto t = _mm512_div_ps(one_vec, tmp0);
// r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1
auto tmp1 = _mm512_fmadd_ps(p5, t, p4);
auto tmp2 = _mm512_fmadd_ps(tmp1, t, p3);
auto tmp3 = _mm512_fmadd_ps(tmp2, t, p2);
auto r = _mm512_fmadd_ps(tmp3, t, p1);
// - exp(- x * x)
auto pow_2 = _mm512_mul_ps(values, values);
auto neg_pow_2 = _mm512_xor_ps(neg_zero_vec, pow_2);
// auto tmp4 = exp(neg_pow_2);
auto tmp4 = Vectorized<float>(Sleef_expf16_u10(neg_pow_2));
auto tmp5 = _mm512_xor_ps(neg_zero_vec, tmp4);
// erf(x) = sign(x) * (1 - r * t * exp(- x * x))
auto tmp6 = _mm512_mul_ps(tmp5, t);
auto tmp7 = _mm512_fmadd_ps(tmp6, r, one_vec);
return _mm512_xor_ps(sign_mask, tmp7);
}
Vectorized<float> erfc() const {
return Vectorized<float>(Sleef_erfcf16_u15(values));
}
Vectorized<float> erfinv() const {
return map(calc_erfinv);
}
Vectorized<float> exp() const {
return Vectorized<float>(Sleef_expf16_u10(values));
}
Vectorized<float> exp2() const {
return Vectorized<float>(Sleef_exp2f16_u10(values));
}
Vectorized<float> expm1() const {
return Vectorized<float>(Sleef_expm1f16_u10(values));
}
Vectorized<float> exp_u20() const {
// A faster version of exp with ULP=20
static __m512 vec_factorial_1 =
_mm512_set1_ps(0.999999701f); // 1/factorial(1)
static __m512 vec_factorial_2 =
_mm512_set1_ps(0.499991506f); // 1/factorial(2)
static __m512 vec_factorial_3 =
_mm512_set1_ps(0.166676521f); // 1/factorial(3)
static __m512 vec_factorial_4 =
_mm512_set1_ps(0.0418978221f); // 1/factorial(4)
static __m512 vec_factorial_5 =
_mm512_set1_ps(0.00828929059f); // 1/factorial(5)
static __m512 vec_exp_log2ef =
_mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e)
static __m512 vec_half = _mm512_set1_ps(0.5f);
static __m512 vec_one = _mm512_set1_ps(1.f);
static __m512 vec_zero = _mm512_set1_ps(0.f);
static __m512 vec_two = _mm512_set1_ps(2.f);
static __m512 vec_ln2f = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); // ln(2)
static __m512 vec_ln_flt_min = _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50));
static __m512 vec_ln_flt_max = _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218));
static __m512i vec_127 = _mm512_set1_epi32(0x0000007f);
static int n_mantissa_bits = 23;
// exp(x) =
// = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem
// = 2^n * exp(r) // simplify the exp(n*ln(2)) expression
auto less_ln_flt_min_mask =
_mm512_cmp_ps_mask(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/);
auto vec_src = _mm512_min_ps(values, vec_ln_flt_max);
vec_src = _mm512_max_ps(vec_src, vec_ln_flt_min);
// fx = floorf(x * log2ef + 0.5)
auto vec_fx = _mm512_fmadd_ps(vec_src, vec_exp_log2ef, vec_half);
auto vec_fx_i = _mm512_cvt_roundps_epi32(
vec_fx, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC);
vec_fx = _mm512_cvtepi32_ps(vec_fx_i);
// x = x - fx * ln2
auto vec_exp_poly = _mm512_fnmadd_ps(vec_fx, vec_ln2f, vec_src);
// compute polynomial
auto vec_res =
_mm512_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4);
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3);
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2);
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1);
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_one);
// compute 2^(n-1)
auto vec_exp_number = _mm512_sub_ps(vec_fx, vec_one);
auto vec_exp_number_i = _mm512_cvtps_epi32(vec_exp_number);
auto vec_two_pow_n_i = _mm512_add_epi32(vec_exp_number_i, vec_127);
vec_two_pow_n_i = _mm512_slli_epi32(vec_two_pow_n_i, n_mantissa_bits);
auto vec_two_pow_n = _mm512_castsi512_ps(vec_two_pow_n_i);
vec_two_pow_n =
_mm512_mask_blend_ps(less_ln_flt_min_mask, vec_two_pow_n, vec_zero);
// y = y * 2^n
vec_res = _mm512_mul_ps(vec_res, vec_two_pow_n);
vec_res = _mm512_mul_ps(vec_res, vec_two);
return vec_res;
}
Vectorized<float> fmod(const Vectorized<float>& q) const {
return Vectorized<float>(Sleef_fmodf16(values, q));
}
Vectorized<float> log() const {
return Vectorized<float>(Sleef_logf16_u10(values));
}
Vectorized<float> log2() const {
return Vectorized<float>(Sleef_log2f16_u10(values));
}
Vectorized<float> log10() const {
return Vectorized<float>(Sleef_log10f16_u10(values));
}
Vectorized<float> log1p() const {
return Vectorized<float>(Sleef_log1pf16_u10(values));
}
Vectorized<float> frac() const;
Vectorized<float> sin() const {
return Vectorized<float>(Sleef_sinf16_u35(values));
}
Vectorized<float> sinh() const {
return Vectorized<float>(Sleef_sinhf16_u10(values));
}
Vectorized<float> cos() const {
return Vectorized<float>(Sleef_cosf16_u35(values));
}
Vectorized<float> cosh() const {
return Vectorized<float>(Sleef_coshf16_u10(values));
}
Vectorized<float> ceil() const {
return _mm512_ceil_ps(values);
}
Vectorized<float> floor() const {
return _mm512_floor_ps(values);
}
Vectorized<float> hypot(const Vectorized<float> &b) const {
return Vectorized<float>(Sleef_hypotf16_u05(values, b));
}
Vectorized<float> i0() const {
return map(calc_i0);
}
Vectorized<float> i0e() const {
return map(calc_i0e);
}
Vectorized<float> digamma() const {
return map(calc_digamma);
}
Vectorized<float> igamma(const Vectorized<float> &x) const {
__at_align__ float tmp[size()];
__at_align__ float tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (const auto i : c10::irange(size())) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<float> igammac(const Vectorized<float> &x) const {
__at_align__ float tmp[size()];
__at_align__ float tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (const auto i : c10::irange(size())) {
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<float> neg() const {
return _mm512_xor_ps(_mm512_set1_ps(-0.f), values);
}
Vectorized<float> nextafter(const Vectorized<float> &b) const {
return Vectorized<float>(Sleef_nextafterf16(values, b));
}
Vectorized<float> round() const {
return _mm512_roundscale_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
Vectorized<float> tan() const {
return Vectorized<float>(Sleef_tanf16_u10(values));
}
Vectorized<float> tanh() const {
return Vectorized<float>(Sleef_tanhf16_u10(values));
}
Vectorized<float> trunc() const {
return _mm512_roundscale_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
}
Vectorized<float> lgamma() const {
return Vectorized<float>(Sleef_lgammaf16_u10(values));
}
Vectorized<float> sqrt() const {
return _mm512_sqrt_ps(values);
}
Vectorized<float> reciprocal() const {
return _mm512_div_ps(_mm512_set1_ps(1), values);
}
Vectorized<float> rsqrt() const {
return _mm512_div_ps(_mm512_set1_ps(1), _mm512_sqrt_ps(values));
}
Vectorized<float> pow(const Vectorized<float> &b) const {
return Vectorized<float>(Sleef_powf16_u10(values, b));
}
// Comparison using the _CMP_**_OQ predicate.
// `O`: get false if an operand is NaN
// `Q`: do not raise if an operand is NaN
Vectorized<float> operator==(const Vectorized<float>& other) const {
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_EQ_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
0xFFFFFFFF));
}
Vectorized<float> operator!=(const Vectorized<float>& other) const {
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_UQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
0xFFFFFFFF));
}
Vectorized<float> operator<(const Vectorized<float>& other) const {
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LT_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
0xFFFFFFFF));
}
Vectorized<float> operator<=(const Vectorized<float>& other) const {
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LE_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
0xFFFFFFFF));
}
Vectorized<float> operator>(const Vectorized<float>& other) const {
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GT_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
0xFFFFFFFF));
}
Vectorized<float> operator>=(const Vectorized<float>& other) const {
auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GE_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
0xFFFFFFFF));
}
Vectorized<float> eq(const Vectorized<float>& other) const;
Vectorized<float> ne(const Vectorized<float>& other) const;
Vectorized<float> gt(const Vectorized<float>& other) const;
Vectorized<float> ge(const Vectorized<float>& other) const;
Vectorized<float> lt(const Vectorized<float>& other) const;
Vectorized<float> le(const Vectorized<float>& other) const;
};
template <>
Vectorized<float> inline operator+(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm512_add_ps(a, b);
}
template <>
Vectorized<float> inline operator-(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm512_sub_ps(a, b);
}
template <>
Vectorized<float> inline operator*(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm512_mul_ps(a, b);
}
template <>
Vectorized<float> inline operator/(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm512_div_ps(a, b);
}
// frac. Implement this here so we can use subtraction
inline Vectorized<float> Vectorized<float>::frac() const {
return *this - this->trunc();
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<float> inline maximum(const Vectorized<float>& a, const Vectorized<float>& b) {
auto zero_vec = _mm512_set1_epi32(0);
auto max = _mm512_max_ps(a, b);
auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q);
auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask,
0xFFFFFFFF));
// Exploit the fact that all-ones is a NaN.
return _mm512_or_ps(max, isnan);
}
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<float> inline minimum(const Vectorized<float>& a, const Vectorized<float>& b) {
auto zero_vec = _mm512_set1_epi32(0);
auto min = _mm512_min_ps(a, b);
auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q);
auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask,
0xFFFFFFFF));
// Exploit the fact that all-ones is a NaN.
return _mm512_or_ps(min, isnan);
}
template <>
Vectorized<float> inline clamp(const Vectorized<float>& a, const Vectorized<float>& min, const Vectorized<float>& max) {
return _mm512_min_ps(max, _mm512_max_ps(min, a));
}
template <>
Vectorized<float> inline clamp_max(const Vectorized<float>& a, const Vectorized<float>& max) {
return _mm512_min_ps(max, a);
}
template <>
Vectorized<float> inline clamp_min(const Vectorized<float>& a, const Vectorized<float>& min) {
return _mm512_max_ps(min, a);
}
template <>
Vectorized<float> inline operator&(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm512_and_ps(a, b);
}
template <>
Vectorized<float> inline operator|(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm512_or_ps(a, b);
}
template <>
Vectorized<float> inline operator^(const Vectorized<float>& a, const Vectorized<float>& b) {
return _mm512_xor_ps(a, b);
}
inline Vectorized<float> Vectorized<float>::eq(const Vectorized<float>& other) const {
return (*this == other) & Vectorized<float>(1.0f);
}
inline Vectorized<float> Vectorized<float>::ne(const Vectorized<float>& other) const {
return (*this != other) & Vectorized<float>(1.0f);
}
inline Vectorized<float> Vectorized<float>::gt(const Vectorized<float>& other) const {
return (*this > other) & Vectorized<float>(1.0f);
}
inline Vectorized<float> Vectorized<float>::ge(const Vectorized<float>& other) const {
return (*this >= other) & Vectorized<float>(1.0f);
}
inline Vectorized<float> Vectorized<float>::lt(const Vectorized<float>& other) const {
return (*this < other) & Vectorized<float>(1.0f);
}
inline Vectorized<float> Vectorized<float>::le(const Vectorized<float>& other) const {
return (*this <= other) & Vectorized<float>(1.0f);
}
template <>
inline void convert(const float* src, float* dst, int64_t n) {
int64_t i;
#ifndef __msvc_cl__
#pragma unroll
#endif
for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) {
_mm512_storeu_ps(dst + i, _mm512_loadu_ps(src + i));
}
#ifndef __msvc_cl__
#pragma unroll
#endif
for (; i < n; i++) {
dst[i] = src[i];
}
}
template <>
Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
return _mm512_fmadd_ps(a, b, c);
}
template <>
Vectorized<float> inline fmsub(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
return _mm512_fmsub_ps(a, b, c);
}
// TODO(jgong5): rewrite with ATEN vectorized (need to add unpack and shuffle)
// Used by Inductor CPP codegen
// Code referred to FBGEMM:
// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L230-L304
// kernel for transposing mxn where m, n <= 16
// M + (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + (M + 7) / 8 * 8 + 2 * N instructions
inline void transpose_mxn_16x16(const float* src, int64_t ld_src, float* dst, int64_t ld_dst, int M, int N) {
TORCH_CHECK(M <= 16 && N <= 16, "transpose_mxn<float> expects M, N <= 16.");
// load from src to registers
__m512 input[16];
int i;
if (N == 16) {
for (i = 0; i < M; ++i) {
input[i] = _mm512_loadu_ps(&src[i * ld_src]);
}
} else {
__mmask16 src_mask = (1 << N) - 1;
for (i = 0; i < M; ++i) {
input[i] = _mm512_maskz_loadu_ps(src_mask, &src[i * ld_src]);
}
}
for (; i < 16; ++i) {
// Not really needed but to avoid uninitialized variable warning.
// Shouldn't be much overhead because xor can be executed in parallel with
// other instructions.
input[i] = _mm512_setzero_ps();
}
// unpacking and interleaving 32-bit elements
__m512 temp[16];
for (i = 0; i < (M + 1) / 2; ++i) {
temp[2 * i] = _mm512_unpacklo_ps(input[2 * i], input[2 * i + 1]);
temp[2 * i + 1] = _mm512_unpackhi_ps(input[2 * i], input[2 * i + 1]);
}
for (i = i * 2; i < 16; ++i) {
temp[i] = _mm512_setzero_ps();
}
// unpacking and interleaving 64-bit elements
for (i = 0; i < (M + 3) / 4; ++i) {
input[4 * i] = _mm512_castpd_ps(_mm512_unpacklo_pd(
_mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2])));
input[4 * i + 1] = _mm512_castpd_ps(_mm512_unpackhi_pd(
_mm512_castps_pd(temp[4 * i]), _mm512_castps_pd(temp[4 * i + 2])));
input[4 * i + 2] = _mm512_castpd_ps(_mm512_unpacklo_pd(
_mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3])));
input[4 * i + 3] = _mm512_castpd_ps(_mm512_unpackhi_pd(
_mm512_castps_pd(temp[4 * i + 1]), _mm512_castps_pd(temp[4 * i + 3])));
}
// shuffle 128-bits (composed of 4 32-bit elements)
for (i = 0; i < (M + 7) / 8; ++i) {
temp[8 * i] = _mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0x88);
temp[8 * i + 1] =
_mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0x88);
temp[8 * i + 2] =
_mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0x88);
temp[8 * i + 3] =
_mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0x88);
temp[8 * i + 4] =
_mm512_shuffle_f32x4(input[8 * i], input[8 * i + 4], 0xdd);
temp[8 * i + 5] =
_mm512_shuffle_f32x4(input[8 * i + 1], input[8 * i + 5], 0xdd);
temp[8 * i + 6] =
_mm512_shuffle_f32x4(input[8 * i + 2], input[8 * i + 6], 0xdd);
temp[8 * i + 7] =
_mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0xdd);
}
for (i = 0; i < N; ++i) {
if (i < 8) {
input[i] = _mm512_shuffle_f32x4(temp[i], temp[8 + i], 0x88);
} else {
input[i] = _mm512_shuffle_f32x4(temp[i - 8], temp[i], 0xdd);
}
}
// store from registers to dst
if (M == 16) {
for (i = 0; i < N; ++i) {
_mm512_storeu_ps(&dst[i * ld_dst], input[i]);
}
} else {
__mmask16 dst_mask = (1 << M) - 1;
for (i = 0; i < N; ++i) {
_mm512_mask_storeu_ps(&dst[i * ld_dst], dst_mask, input[i]);
}
}
}
template<>
inline void transpose_mxn<float>(const float* src, int64_t ld_src, float* dst, int64_t ld_dst, int M, int N) {
int64_t i = 0;
for (; i < M / 16 * 16; i += 16) {
int64_t j = 0;
for (; j < N / 16 * 16; j += 16) {
transpose_mxn_16x16(
src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, 16);
}
// handle remainder j
int nrem = N - j;
if (nrem > 0) {
transpose_mxn_16x16(
src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, nrem);
}
}
// handle remainder i
int mrem = M - i;
if (mrem > 0) {
int j = 0;
for (; j < N / 16 * 16; j += 16) {
transpose_mxn_16x16(
src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, 16);
}
// handle remainder j
int nrem = N - j;
transpose_mxn_16x16(
src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, nrem);
}
}
template <typename T, int M, int N,
typename std::enable_if_t<std::is_same<T, float>::value, int> = 0>
inline void transpose_mxn(const float* src, int64_t ld_src, float* dst, int64_t ld_dst) {
transpose_mxn<float>(src, ld_src, dst, ld_dst, M, N);
}
#endif
}}}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,393 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec_mask.h>
namespace at::vec {
inline namespace CPU_CAPABILITY {
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
template <typename T, int dst_n, typename mask_t, int mask_n>
struct VecMaskLoad<
T,
dst_n,
mask_t,
mask_n,
typename std::enable_if_t<
(mask_n == dst_n * 2 && dst_n >= 1) &&
(std::is_same_v<T, float> || std::is_same_v<T, int32_t>),
void>> {
static inline VectorizedN<T, dst_n> apply(
const T* ptr,
const VecMask<mask_t, mask_n>& vec_mask) {
at::vec::Vectorized<T> zero_vec(0);
auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
VectorizedN<mask_t, 2> tmp_vec;
VectorizedN<T, dst_n> result;
for (int i = 0; i < dst_n; i++) {
tmp_vec[0] = vec_mask[2 * i];
tmp_vec[1] = vec_mask[2 * i + 1];
auto int64_mask = VecMask<mask_t, 2>(tmp_vec).template cast<int64_t, 2>();
auto int_mask = int64_mask.template cast<int, 1>()[0];
auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ);
if constexpr (std::is_same_v<T, float>) {
result[i] = Vectorized<T>(_mm512_mask_loadu_ps(
zero_vec, mmask, ptr + i * Vectorized<T>::size()));
} else {
result[i] = Vectorized<T>(_mm512_mask_loadu_epi32(
zero_vec, mmask, ptr + i * Vectorized<T>::size()));
}
}
return result;
}
};
template <typename T, int dst_n, typename mask_t>
struct VecMaskLoad<
T,
dst_n,
mask_t,
dst_n,
typename std::enable_if_t<
std::is_same_v<T, float> || std::is_same_v<T, int32_t>,
void>> {
static inline VectorizedN<T, dst_n> apply(
const T* ptr,
const VecMask<mask_t, dst_n>& vec_mask) {
at::vec::Vectorized<T> zero_vec(0);
auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
VectorizedN<T, dst_n> result;
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < dst_n; i++) {
auto tmp_mask = VecMask<mask_t, 1>(vec_mask[i]);
auto int_mask = tmp_mask.template cast<int, 1>()[0];
auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ);
if constexpr (std::is_same_v<T, float>) {
result[i] = Vectorized<T>(_mm512_mask_loadu_ps(
zero_vec, mmask, ptr + i * Vectorized<T>::size()));
} else {
result[i] = Vectorized<T>(_mm512_mask_loadu_epi32(
zero_vec, mmask, ptr + i * Vectorized<T>::size()));
}
}
return result;
}
};
template <typename data_t, int dst_n, typename mask_t>
struct VecMaskLoad<
data_t,
dst_n,
mask_t,
dst_n,
typename std::enable_if<
std::is_same_v<data_t, BFloat16> ||
std::is_same_v<data_t, Half>>::type> {
static inline VectorizedN<data_t, dst_n> apply(
const data_t* ptr,
const VecMask<mask_t, dst_n>& vec_mask) {
auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
VectorizedN<data_t, dst_n> result;
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < dst_n; i++) {
auto tmp_mask = VecMask<mask_t, 1>(vec_mask[i]);
auto int_mask = tmp_mask.template cast<int, 2>();
auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ);
auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ);
auto zero = _mm256_set1_epi16(0);
auto temp0 = _mm256_mask_loadu_epi16(
zero, mmask0, ptr + (2 * i) * Vectorized<int>::size());
auto temp1 = _mm256_mask_loadu_epi16(
zero, mmask1, ptr + (2 * i + 1) * Vectorized<int>::size());
result[i] = Vectorized<data_t>(
_mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1));
}
return result;
}
};
template <typename data_t, int dst_n, typename mask_t, int mask_n>
struct VecMaskLoad<
data_t,
dst_n,
mask_t,
mask_n,
typename std::enable_if_t<
(mask_n == 2 * dst_n && dst_n >= 1) &&
(std::is_same_v<data_t, BFloat16> || std::is_same_v<data_t, Half>)>> {
static inline VectorizedN<data_t, dst_n> apply(
const data_t* ptr,
const VecMask<mask_t, mask_n>& vec_mask) {
auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
VectorizedN<data_t, dst_n> result;
VectorizedN<mask_t, 2> tmp_vec;
for (int i = 0; i < dst_n; i++) {
tmp_vec[0] = vec_mask[2 * i];
tmp_vec[1] = vec_mask[2 * i + 1];
auto int_mask = VecMask<mask_t, 2>(tmp_vec).template cast<int, 2>();
auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ);
auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ);
auto zero = _mm256_set1_epi16(0);
auto temp0 = _mm256_mask_loadu_epi16(
zero, mmask0, ptr + (2 * i) * Vectorized<int>::size());
auto temp1 = _mm256_mask_loadu_epi16(
zero, mmask1, ptr + (2 * i + 1) * Vectorized<int>::size());
result[i] = Vectorized<data_t>(
_mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1));
}
return result;
}
};
template <typename data_t, typename mask_t>
struct VecMaskLoad<
data_t,
1,
mask_t,
1,
typename std::enable_if<
std::is_same_v<data_t, int8_t> ||
std::is_same_v<data_t, uint8_t>>::type> {
static inline VectorizedN<data_t, 1> apply(
const data_t* ptr,
const VecMask<mask_t, 1>& vec_mask) {
auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
auto int_mask = vec_mask.template cast<int, 1>()[0];
auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ);
auto zero = _mm_set1_epi8(0);
auto temp = _mm_mask_loadu_epi8(zero, mmask, ptr);
return Vectorized<data_t>(
_mm512_inserti64x2(_mm512_set1_epi32(0), temp, 0));
}
};
template <typename data_t, typename mask_t>
struct VecMaskLoad<
data_t,
2,
mask_t,
1,
typename std::enable_if<
std::is_same_v<data_t, int64_t> ||
std::is_same_v<data_t, double>>::type> {
static inline VectorizedN<data_t, 2> apply(
const data_t* ptr,
const VecMask<mask_t, 1>& vec_mask) {
auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
at::vec::Vectorized<data_t> zero_vec(0);
auto int_mask = vec_mask.template cast<int, 1>()[0];
auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ);
at::vec::VectorizedN<data_t, 2> result;
if constexpr (std::is_same_v<data_t, double>) {
result[0] = _mm512_mask_loadu_pd(zero_vec, (__mmask8)mmask, ptr);
result[1] =
_mm512_mask_loadu_pd(zero_vec, (__mmask8)(mmask >> 8), ptr + 8);
} else {
result[0] = _mm512_mask_loadu_epi64(zero_vec, (__mmask8)mmask, ptr);
result[1] =
_mm512_mask_loadu_epi64(zero_vec, (__mmask8)(mmask >> 8), ptr + 8);
}
return result;
}
};
template <int N>
struct VecMaskCast<float, N, int, N> {
static inline VecMask<float, N> apply(const VecMask<int, N>& vec_mask) {
VectorizedN<float, N> result;
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
result[i] = _mm512_castsi512_ps(vec_mask[i]);
}
return result;
}
};
template <int N>
struct VecMaskCast<int, N, float, N> {
static inline VecMask<int, N> apply(const VecMask<float, N>& vec_mask) {
VectorizedN<int, N> result;
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
result[i] = _mm512_castps_si512(vec_mask[i]);
}
return result;
}
};
template <int N>
struct VecMaskCast<int64_t, N, double, N> {
static inline VecMask<int64_t, N> apply(const VecMask<double, N>& vec_mask) {
VectorizedN<int64_t, N> result;
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
result[i] = _mm512_castpd_si512(vec_mask[i]);
}
return result;
}
};
template <int N>
struct VecMaskCast<double, N, int64_t, N> {
static inline VecMask<double, N> apply(const VecMask<int64_t, N>& vec_mask) {
VectorizedN<double, N> result;
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
result[i] = _mm512_castsi512_pd(vec_mask[i]);
}
return result;
}
};
template <int dst_n, typename mask_t, int mask_n>
struct VecMaskCast<
int64_t,
dst_n,
mask_t,
mask_n,
typename std::enable_if_t<
(dst_n == 2 * mask_n) &&
(std::is_same_v<mask_t, float> || std::is_same_v<mask_t, int>),
void>> {
static inline VecMask<int64_t, dst_n> apply(
const VecMask<mask_t, mask_n>& vec_mask) {
VectorizedN<int64_t, dst_n> result;
auto int_mask = vec_mask.template cast<int, mask_n>();
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < mask_n; ++i) {
auto int64_vec =
convert<int64_t, 2, int, 1>(VectorizedN<int, 1>(int_mask[i]));
result[2 * i] = int64_vec[0];
result[2 * i + 1] = int64_vec[1];
}
return VecMask<int64_t, dst_n>(result);
}
};
template <typename dst_t, int dst_n, int mask_n>
struct VecMaskCast<
dst_t,
dst_n,
int64_t,
mask_n,
typename std::enable_if_t<
(mask_n == 2 * dst_n) &&
(std::is_same_v<dst_t, float> || std::is_same_v<dst_t, int>),
void>> {
static inline VecMask<dst_t, dst_n> apply(
const VecMask<int64_t, mask_n>& vec_mask) {
VectorizedN<int, dst_n> result;
VectorizedN<int64_t, 2> int64_vec;
for (int i = 0; i < dst_n; ++i) {
int64_vec[0] = vec_mask[2 * i];
int64_vec[1] = vec_mask[2 * i + 1];
result[i] = convert<int, 1, int64_t, 2>(int64_vec);
}
return VecMask<int, dst_n>(result).template cast<dst_t, dst_n>();
}
};
template <>
struct VecMaskCast<double, 2, float, 1> {
static inline VecMask<double, 2> apply(const VecMask<float, 1>& vec_mask) {
auto int64_mask = VecMaskCast<int64_t, 2, float, 1>::apply(vec_mask);
return VecMaskCast<double, 2, int64_t, 2>::apply(int64_mask);
}
};
template <>
struct VecMaskCast<float, 1, double, 2> {
static inline VecMask<float, 1> apply(const VecMask<double, 2>& vec_mask) {
auto int64_mask = VecMaskCast<int64_t, 2, double, 2>::apply(vec_mask);
return VecMaskCast<float, 1, int64_t, 2>::apply(int64_mask);
}
};
template <>
inline bool VecMask<int, 1>::all_zero() const {
__mmask16 mask = _mm512_test_epi32_mask(mask_[0], mask_[0]);
return mask == 0;
}
template <>
inline bool VecMask<int, 1>::is_masked(int i) const {
return _mm512_movepi32_mask(mask_[0]) & (1 << i);
}
template <>
inline bool VecMask<int, 1>::all_masked() const {
__mmask16 mask = _mm512_movepi32_mask(mask_[0]);
return mask == 0xffff;
}
template <int N>
struct VecMaskCheck<int64_t, N> {
static inline bool all_zero(const VectorizedN<int64_t, N>& vec_mask) {
bool all_zero = true;
for (int i = 0; i < N; ++i) {
all_zero =
all_zero && (_mm512_test_epi64_mask(vec_mask[i], vec_mask[i]) == 0);
if (!all_zero) {
return all_zero;
}
}
return all_zero;
}
static inline bool is_masked(const VectorizedN<int64_t, N>& vec_mask, int i) {
for (int j = 0; j < N; ++j) {
if (i < (j + 1) * 8) {
return _mm512_movepi64_mask(vec_mask[j]) & (1 << (i - j * 8));
}
}
return false;
}
static inline bool all_masked(const VectorizedN<int64_t, N>& vec_mask) {
bool all_masked = true;
for (int i = 0; i < N; ++i) {
all_masked = all_masked && (_mm512_movepi64_mask(vec_mask[i]) == 0xff);
if (!all_masked) {
return all_masked;
}
}
return all_masked;
}
};
#define VEC_MASK_METHOD_WITH_CAST_TO_INT( \
T, N, return_type, method, args_def, args) \
template <> \
inline return_type VecMask<T, N>::method args_def const { \
return cast<int, 1>().method args; \
}
VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_zero, (), ())
VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_zero, (), ())
VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, is_masked, (int i), (i))
VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, is_masked, (int i), (i))
VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_masked, (), ())
VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_masked, (), ())
#undef VEC_MASK_DEFINE_METHOD_WITH_CAST_TO_INT
#endif
} // namespace CPU_CAPABILITY
} // namespace at::vec

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,65 @@
#pragma once
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec_n.h>
namespace at::vec {
inline namespace CPU_CAPABILITY {
template <
typename dst_t,
int dst_n,
typename src_t,
int src_n,
typename Enabled = void>
struct VecConvert {
static inline VectorizedN<dst_t, dst_n> apply(
const VectorizedN<src_t, src_n>& src) {
constexpr int count = std::min(
VectorizedN<src_t, src_n>::size(), VectorizedN<dst_t, dst_n>::size());
__at_align__ src_t src_buf[VectorizedN<src_t, src_n>::size()];
src.store(src_buf);
__at_align__ dst_t dst_buf[VectorizedN<dst_t, dst_n>::size()];
for (int i = 0; i < count; i++) {
dst_buf[i] = static_cast<dst_t>(src_buf[i]);
}
return VectorizedN<dst_t, dst_n>::loadu(dst_buf, count);
}
};
template <typename dst_t, typename src_t>
inline std::enable_if_t<std::is_same_v<dst_t, src_t>, Vectorized<src_t>>
convert(const Vectorized<src_t>& src) {
return src;
}
template <typename dst_t, typename src_t>
inline std::enable_if_t<!std::is_same_v<dst_t, src_t>, Vectorized<dst_t>>
convert(const Vectorized<src_t>& src) {
return VecConvert<dst_t, 1, src_t, 1>::apply(src);
}
template <
typename dst_t,
int dst_n,
typename src_t,
int src_n,
std::enable_if_t<dst_n != 1, int> = 0>
inline VectorizedN<dst_t, dst_n> convert(const VectorizedN<src_t, src_n>& src) {
return VecConvert<dst_t, dst_n, src_t, src_n>::apply(src);
}
template <
typename dst_t,
int dst_n,
typename src_t,
int src_n,
bool keep = false,
std::enable_if_t<dst_n == 1, int> = 0>
inline std::conditional_t<keep, VectorizedN<dst_t, 1>, Vectorized<dst_t>>
convert(const VectorizedN<src_t, src_n>& src) {
return VecConvert<dst_t, dst_n, src_t, src_n>::apply(src);
}
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -0,0 +1,50 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
!defined(__APPLE__)
static inline uint16_t float2half_scalar(float val) {
#if defined(CPU_CAPABILITY_AVX2)
#if defined(_MSC_VER)
__m256 v = _mm256_set1_ps(val);
__m128i o =
_mm256_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
return static_cast<std::uint16_t>(_mm_cvtsi128_si32(o));
#else
return _cvtss_sh(val, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
#endif
#elif defined(CPU_CAPABILITY_AVX512)
__m512 v = _mm512_set1_ps(val);
__m256i o =
_mm512_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
return static_cast<std::uint16_t>(
_mm_cvtsi128_si32(_mm256_castsi256_si128(o)));
#endif
}
static inline float half2float_scalar(uint16_t val) {
#if defined(CPU_CAPABILITY_AVX2)
#if defined(_MSC_VER)
__m128i v = _mm_cvtsi32_si128(val);
__m256 o = _mm256_cvtph_ps(v);
return _mm256_cvtss_f32(o);
#else
return _cvtsh_ss(val);
#endif
#elif defined(CPU_CAPABILITY_AVX512)
__m256i v =
_mm256_setr_epi16(val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
__m512 o = _mm512_cvtph_ps(v);
return _mm512_cvtss_f32(o);
#endif
}
#endif
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -0,0 +1,294 @@
#pragma once
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/vec_n.h>
namespace at::vec {
inline namespace CPU_CAPABILITY {
/**
* The `VecMask` class provides a convenient interface for working with
* vectorized masks in SIMD operations. It encapsulates a `Vectorized<T, N>`
* mask that can be directly usable in masked vectorized operations. It provides
* various methods for manipulating and accessing the mask elements:
* 1. `from` and `to`: Conversion between a vector of boolean values and a
* vectorized mask.
* 2. `cast`: Casts the mask to a different base type.
* 3. `all_zero`: Checks if all mask elements are zero.
* 4. `is_masked`: Checks if a specific element is masked.
* 5. `loadu`: Loads data from memory using the mask.
* 6. `all_masked`: Checks if all mask elements are masked.
*
* Some helper template classes are provided to simplify the specialization of
* the `VecMask` for the specific CPU arch:
* 1. `VecMaskLoad`: Loads data from memory using the mask.
* 2. `VecMaskTo`: Converts the mask to boolean.
* 3. `VecMaskCast`: Casts the mask to a different base type.
*
*/
template <typename T, int N>
class VecMask;
template <
typename data_t,
int data_n,
typename mask_t,
int mask_n,
typename Enabled = void>
struct VecMaskLoad {
static inline VectorizedN<data_t, data_n> apply(
const data_t* ptr,
const VecMask<mask_t, mask_n>& vec_mask) {
constexpr typename VecMask<mask_t, mask_n>::size_type size =
VecMask<mask_t, mask_n>::size();
static_assert(VectorizedN<data_t, data_n>::size() >= size);
__at_align__ data_t data[size];
__at_align__ mask_t mask[size];
auto mask_ = VectorizedN<mask_t, mask_n>(vec_mask);
mask_.store(mask);
for (int i = 0; i < size; i++) {
data[i] = mask[i] ? ptr[i] : static_cast<data_t>(0);
}
return VectorizedN<data_t, data_n>::loadu(data, size);
}
};
template <
typename dst_t,
int dst_n,
typename src_t,
int src_n,
typename Enabled = void>
struct VecMaskTo {
static inline VecMask<dst_t, dst_n> apply(
const VecMask<src_t, src_n>& vec_mask) {
auto zeros = VectorizedN<dst_t, dst_n>(static_cast<dst_t>(0));
auto ones = VectorizedN<dst_t, dst_n>(static_cast<dst_t>(1));
return VectorizedN<dst_t, dst_n>::blendv(
zeros, ones, vec_mask.template cast<dst_t, dst_n>());
}
};
template <typename dst_t, int dst_n, typename src_t, int src_n, typename Enabled = void>
struct VecMaskCast {
static inline VecMask<dst_t, dst_n> apply(
const VecMask<src_t, src_n>& vec_mask) {
return VecMask<dst_t, dst_n>::from(VectorizedN<src_t, src_n>(vec_mask));
}
};
template <typename T, int N>
struct VecMaskCast<T, N, T, N> {
static inline VecMask<T, N> apply(const VecMask<T, N>& vec_mask) {
return vec_mask;
}
};
template <typename T, int N>
struct VecMaskCheck {
static inline bool all_zero(const VectorizedN<T, N>& vec_mask) {
__at_align__ T mask[VectorizedN<T, N>::size()];
vec_mask.store(mask);
return std::all_of(
mask, mask + VectorizedN<T, N>::size(), [](T m) { return m == static_cast<T>(0); });
}
static inline bool all_masked(const VectorizedN<T, N>& vec_mask) {
__at_align__ T mask[VectorizedN<T, N>::size()];
vec_mask.store(mask);
return std::all_of(
mask, mask + VectorizedN<T, N>::size(), [](T m) { return m != static_cast<T>(0); });
}
static inline bool is_masked(const VectorizedN<T, N>& vec_mask, int i) {
__at_align__ T mask[VectorizedN<T, N>::size()];
vec_mask.store(mask);
return mask[i] != static_cast<T>(0);
}
};
template <typename T, int N>
class VecMask {
public:
using size_type = int;
static constexpr size_type size() {
return VectorizedN<T, N>::size();
}
private:
VectorizedN<T, N> mask_;
public:
VecMask() : mask_(static_cast<T>(0)) {}
VecMask(const VectorizedN<T, N>& mask) : mask_(mask) {}
template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
VecMask(const Vectorized<T>& mask) : mask_(mask) {}
template <typename U, int L>
static VecMask<T, N> from(const VectorizedN<U, L>& b_vec) {
__at_align__ U b_buf[size()];
if constexpr (size() >= VectorizedN<U, L>::size()) {
b_vec.store(b_buf);
for (int i = VectorizedN<U, L>::size(); i < size(); i++) {
b_buf[i] = static_cast<U>(0);
}
} else {
b_vec.store(b_buf, size());
}
return from(b_buf);
}
template <typename U>
static VecMask<T, N> from(U b) {
using int_t = int_same_size_t<T>;
T mask = b ? c10::bit_cast<T>((int_t)(~(int_t)0)) : (T)0;
return VectorizedN<T, N>(mask);
}
template <typename U>
static VecMask<T, N> from(U* b) {
using int_t = int_same_size_t<T>;
__at_align__ T mask[size()];
#ifndef __msvc_cl__
#pragma unroll
#endif
for (int i = 0; i < size(); i++) {
*(int_t*)(mask + i) = b[i] ? ~(int_t)0 : (int_t)0;
}
return VectorizedN<T, N>(VectorizedN<T, N>::loadu(mask));
}
static VecMask<T, N> blendv(
const VecMask<T, N>& c,
const VecMask<T, N>& b,
const VecMask<T, N>& a) {
VectorizedN<T, N> result = VectorizedN<T, N>::blendv(
VectorizedN<T, N>(c),
VectorizedN<T, N>(b),
VectorizedN<T, N>(a));
return result;
}
static VecMask<T, N> set(
const VecMask<T, N>& a,
const VecMask<T, N>& b,
int64_t count = size()) {
VectorizedN<T, N> result = VectorizedN<T, N>::set(
VectorizedN<T, N>(a),
VectorizedN<T, N>(b),
count);
return result;
}
void store(bool* b, int count = size()) {
constexpr int L = (VectorizedN<T, N>::size() + Vectorized<bool>::size() - 1)/ Vectorized<bool>::size();
auto res = this->to<bool, L>();
res.store(b, count);
return;
}
template <typename U, int L, std::enable_if_t<L >= 2, int> = 0>
inline VectorizedN<U, L> to() const {
return VecMaskTo<U, L, T, N>::apply(*this);
}
template <typename U, int L, std::enable_if_t<L == 1, int> = 0>
inline Vectorized<U> to() const {
return VecMaskTo<U, L, T, N>::apply(*this);
}
template <typename U, int L>
inline VecMask<U, L> cast() const {
return VecMaskCast<U, L, T, N>::apply(*this);
}
inline bool all_zero() const {
return VecMaskCheck<T, N>::all_zero(mask_);
}
inline bool all_masked() const {
return VecMaskCheck<T, N>::all_masked(mask_);
}
inline bool is_masked(int i) const {
return VecMaskCheck<T, N>::is_masked(mask_, i);
}
inline operator VectorizedN<T, N>() const {
return mask_;
}
template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
inline operator Vectorized<T>() const {
return mask_[0];
}
inline Vectorized<T> operator[](int i) const {
return mask_[i];
}
template <
typename U,
int L,
std::enable_if_t<L >= 2 && VectorizedN<U, L>::size() >= size(), int> = 0>
VectorizedN<U, L> loadu(const U* ptr) const {
return VecMaskLoad<U, L, T, N>::apply(ptr, *this);
}
template <
typename U,
int L,
std::enable_if_t<L == 1 && Vectorized<U>::size() >= size(), int> = 0>
Vectorized<U> loadu(const U* ptr) const {
return VecMaskLoad<U, L, T, N>::apply(ptr, *this);
}
};
#define VEC_MASK_DEFINE_UNARY_OP_GLOBAL(op) \
template <typename T, int N> \
inline VecMask<T, N> op(const VecMask<T, N>& a) { \
return op(VectorizedN<T, N>(a)); \
}
#define VEC_MASK_DEFINE_BINARY_OP_GLOBAL(op) \
template < \
typename T, \
int N, \
typename V, \
int M, \
std::enable_if_t<VecMask<T, N>::size() == VecMask<V, M>::size(), int> = \
0> \
inline VecMask<T, N> op(const VecMask<T, N>& a, const VecMask<V, M>& b) { \
return op( \
VectorizedN<T, N>(a), VectorizedN<T, N>(b.template cast<T, N>())); \
}
#define VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(op, EXPR) \
template < \
typename T, \
int N, \
typename V, \
int M, \
std::enable_if_t<VecMask<T, N>::size() == VecMask<V, M>::size(), int> = \
0> \
inline VecMask<T, N> op(const VecMask<T, N>& a, const VecMask<V, M>& b) { \
return EXPR; \
}
VEC_MASK_DEFINE_UNARY_OP_GLOBAL(operator~)
VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator&)
VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator|)
VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator^)
VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>, a & ~b)
VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<, ~a& b)
VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator==, ~(a ^ b))
VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>=, (a == b) | (a > b))
VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<=, (a == b) | (a < b))
VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator!=, (a ^ b))
#undef VEC_MASK_DEFINE_UNARY_OP_GLOBAL
#undef VEC_MASK_DEFINE_BINARY_OP_GLOBAL
#undef VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -0,0 +1,361 @@
#pragma once
#include <ATen/cpu/vec/vec_base.h>
#include <array>
namespace at::vec {
inline namespace CPU_CAPABILITY {
/**
* @brief A class template representing a vectorized type with
* `N * Vectorized<T>::size()` elements, aiming to support vectors of
* arbitrary size. A specific use case of it is to represent vectors
* converted from data types with different sizes but with the same
* number of vector elements, e.g., `VectorizedN<float, 2>` can be
* a vector converted from two `Vectorized<bfloat16>`, `VectorizedN<int64_t, 2>`
* can be a vector converted from two `Vectorized<int32_t>` etc.
*
* It supports most of the operations of `Vectorized<T>`
* and the implementation delegates to `Vectorized<T>` with loops over `N`.
*
* @tparam T The underlying type of the vectorized elements.
* @tparam N The number of underlying `Vectorized<T>`.
*/
template <typename T, int N>
class VectorizedN {
public:
using value_type = T;
using size_type = int;
static constexpr size_type size_T = sizeof(T);
static constexpr size_type size() {
return Vectorized<T>::size() * N;
}
private:
std::array<Vectorized<T>, N> values;
public:
// methods not implemented yet:
// variadic constructor, operator T*, as_bytes, zero_mask
#define VECTORIZEDN_DEFINE_UNARY_OP(op) \
VectorizedN<T, N> op() const { \
return unary_op([](const Vectorized<T>& a) { return a.op(); }); \
}
#define VECTORIZEDN_DEFINE_BINARY_OP(op) \
VectorizedN<T, N> op(const VectorizedN<T, N>& other) const { \
return binary_op( \
other, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
return a.op(b); \
}); \
}
template <typename Op>
inline VectorizedN<T, N> unary_op(Op op) const {
VectorizedN<T, N> result;
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
result.values[i] = op(values[i]);
}
return result;
}
template <typename Op>
inline VectorizedN<T, N> binary_op(const VectorizedN<T, N>& other, Op op)
const {
VectorizedN<T, N> result;
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
result.values[i] = op(values[i], other.values[i]);
}
return result;
}
VectorizedN() = default;
explicit VectorizedN(T val) {
for (int i = 0; i < N; ++i) {
values[i] = Vectorized<T>(val);
}
}
template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
VectorizedN(const Vectorized<T>& val) : values({val}) {}
template <int L = N, typename std::enable_if_t<L == 2, int> = 0>
VectorizedN(const Vectorized<T>& val_0, const Vectorized<T>& val_1) : values({val_0, val_1}) {}
template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
inline operator Vectorized<T>() const {
return values[0];
}
inline const Vectorized<T>& operator[](int i) const {
return values[i];
}
inline Vectorized<T>& operator[](int i) {
return values[i];
}
template <int64_t mask>
static VectorizedN<T, N> blend(
const VectorizedN<T, N>& a,
const VectorizedN<T, N>& b) {
VectorizedN<T, N> result;
for (int i = 0; i < N; ++i) {
result.values[i] = Vectorized<T>::template blend<mask>(a.values[i], b.values[i]);
}
return result;
}
static VectorizedN<T, N> blendv(
const VectorizedN<T, N>& a,
const VectorizedN<T, N>& b,
const VectorizedN<T, N>& mask) {
VectorizedN<T, N> result;
for (int i = 0; i < N; ++i) {
result.values[i] =
Vectorized<T>::blendv(a.values[i], b.values[i], mask.values[i]);
}
return result;
}
template <typename step_t>
static VectorizedN<T, N> arange(
T base = static_cast<T>(0),
step_t step = static_cast<step_t>(1)) {
VectorizedN<T, N> result;
for (int i = 0; i < N; ++i) {
result.values[i] = Vectorized<T>::arange(base, step);
base += step * Vectorized<T>::size();
}
return result;
}
static VectorizedN<T, N> set(
const VectorizedN<T, N>& a,
const VectorizedN<T, N>& b,
int64_t count = size()) {
VectorizedN<T, N> result;
for (int i = 0; i < N; ++i) {
if (count > 0) {
result.values[i] = Vectorized<T>::set(
a.values[i],
b.values[i],
std::min(count, (int64_t)Vectorized<T>::size()));
count -= Vectorized<T>::size();
} else {
result.values[i] = a.values[i];
}
}
return result;
}
static VectorizedN<T, N> loadu(const void* ptr) {
VectorizedN<T, N> result;
for (int i = 0; i < N; ++i) {
result.values[i] = Vectorized<T>::loadu(ptr);
ptr = static_cast<const T*>(ptr) + Vectorized<T>::size();
}
return result;
}
static VectorizedN<T, N> loadu(const void* ptr, int64_t count) {
VectorizedN<T, N> result;
for (int i = 0; i < N; ++i) {
result.values[i] = Vectorized<T>::loadu(
ptr, std::min(count, (int64_t)Vectorized<T>::size()));
ptr = static_cast<const T*>(ptr) + Vectorized<T>::size();
count -= Vectorized<T>::size();
if (count <= 0) {
break;
}
}
return result;
}
void store(void* ptr) const {
for (int i = 0; i < N; ++i) {
values[i].store(ptr);
ptr = static_cast<T*>(ptr) + Vectorized<T>::size();
}
}
void store(void* ptr, int count) const {
for (int i = 0; i < N; ++i) {
values[i].store(ptr, std::min(count, (int)Vectorized<T>::size()));
ptr = static_cast<T*>(ptr) + Vectorized<T>::size();
count -= Vectorized<T>::size();
if (count <= 0) {
break;
}
}
}
bool has_inf_nan() const {
for (int i = 0; i < N; ++i) {
if (values[i].has_inf_nan()) {
return true;
}
}
return false;
}
VectorizedN<T, N> map(T (*const f)(T)) const {
VectorizedN<T, N> result;
for (int i = 0; i < N; ++i) {
result.values[i] = values[i].map(f);
}
return result;
}
VectorizedN<T, N> map(T (*const f)(const T&)) const {
VectorizedN<T, N> result;
for (int i = 0; i < N; ++i) {
result.values[i] = values[i].map(f);
}
return result;
}
VECTORIZEDN_DEFINE_UNARY_OP(isnan)
VECTORIZEDN_DEFINE_UNARY_OP(abs)
VECTORIZEDN_DEFINE_UNARY_OP(sgn)
VECTORIZEDN_DEFINE_UNARY_OP(angle)
VECTORIZEDN_DEFINE_UNARY_OP(real)
VECTORIZEDN_DEFINE_UNARY_OP(imag)
VECTORIZEDN_DEFINE_UNARY_OP(conj)
VECTORIZEDN_DEFINE_UNARY_OP(acos)
VECTORIZEDN_DEFINE_UNARY_OP(acosh)
VECTORIZEDN_DEFINE_UNARY_OP(asin)
VECTORIZEDN_DEFINE_UNARY_OP(atan)
VECTORIZEDN_DEFINE_UNARY_OP(atanh)
VECTORIZEDN_DEFINE_BINARY_OP(atan2)
VECTORIZEDN_DEFINE_BINARY_OP(copysign)
VECTORIZEDN_DEFINE_UNARY_OP(erf)
VECTORIZEDN_DEFINE_UNARY_OP(erfc)
VECTORIZEDN_DEFINE_UNARY_OP(erfinv)
VECTORIZEDN_DEFINE_UNARY_OP(exp)
VECTORIZEDN_DEFINE_UNARY_OP(exp2)
VECTORIZEDN_DEFINE_UNARY_OP(expm1)
VECTORIZEDN_DEFINE_UNARY_OP(exp_u20)
VECTORIZEDN_DEFINE_UNARY_OP(frac)
VECTORIZEDN_DEFINE_BINARY_OP(fmod)
VECTORIZEDN_DEFINE_UNARY_OP(log)
VECTORIZEDN_DEFINE_UNARY_OP(log10)
VECTORIZEDN_DEFINE_UNARY_OP(log1p)
VECTORIZEDN_DEFINE_UNARY_OP(log2)
VECTORIZEDN_DEFINE_UNARY_OP(ceil)
VECTORIZEDN_DEFINE_UNARY_OP(cos)
VECTORIZEDN_DEFINE_UNARY_OP(cosh)
VECTORIZEDN_DEFINE_UNARY_OP(floor)
VECTORIZEDN_DEFINE_BINARY_OP(hypot)
VECTORIZEDN_DEFINE_UNARY_OP(i0)
VECTORIZEDN_DEFINE_UNARY_OP(i0e)
VECTORIZEDN_DEFINE_UNARY_OP(digamma)
VECTORIZEDN_DEFINE_BINARY_OP(igamma)
VECTORIZEDN_DEFINE_BINARY_OP(igammac)
VECTORIZEDN_DEFINE_UNARY_OP(neg)
VECTORIZEDN_DEFINE_BINARY_OP(nextafter)
VECTORIZEDN_DEFINE_UNARY_OP(round)
VECTORIZEDN_DEFINE_UNARY_OP(sin)
VECTORIZEDN_DEFINE_UNARY_OP(sinh)
VECTORIZEDN_DEFINE_UNARY_OP(tan)
VECTORIZEDN_DEFINE_UNARY_OP(tanh)
VECTORIZEDN_DEFINE_UNARY_OP(trunc)
VECTORIZEDN_DEFINE_UNARY_OP(lgamma)
VECTORIZEDN_DEFINE_UNARY_OP(sqrt)
VECTORIZEDN_DEFINE_UNARY_OP(reciprocal)
VECTORIZEDN_DEFINE_UNARY_OP(rsqrt)
VECTORIZEDN_DEFINE_BINARY_OP(pow)
VECTORIZEDN_DEFINE_BINARY_OP(operator==)
VECTORIZEDN_DEFINE_BINARY_OP(operator!=)
VECTORIZEDN_DEFINE_BINARY_OP(operator>=)
VECTORIZEDN_DEFINE_BINARY_OP(operator<=)
VECTORIZEDN_DEFINE_BINARY_OP(operator>)
VECTORIZEDN_DEFINE_BINARY_OP(operator<)
VECTORIZEDN_DEFINE_BINARY_OP(eq)
VECTORIZEDN_DEFINE_BINARY_OP(ne)
VECTORIZEDN_DEFINE_BINARY_OP(gt)
VECTORIZEDN_DEFINE_BINARY_OP(ge)
VECTORIZEDN_DEFINE_BINARY_OP(lt)
VECTORIZEDN_DEFINE_BINARY_OP(le)
#undef VECTORIZEDN_DEFINE_UNARY_OP
#undef VECTORIZEDN_DEFINE_BINARY_OP
};
#define VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(op) \
template <typename T, int N> \
inline VectorizedN<T, N> op(const VectorizedN<T, N>& a) { \
return a.unary_op([](const Vectorized<T>& a) { return op(a); }); \
}
#define VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(op) \
template <typename T, int N> \
inline VectorizedN<T, N> op( \
const VectorizedN<T, N>& a, const VectorizedN<T, N>& b) { \
return a.binary_op(b, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
return op(a, b); \
}); \
}
#define VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(op) \
template <typename T, int N> \
inline VectorizedN<T, N>& op( \
VectorizedN<T, N>& a, const VectorizedN<T, N>& b) { \
a = a.binary_op(b, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
return op(a, b); \
}); \
return a; \
}
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator+)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator-)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator*)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator/)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator%)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator||)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator<<)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator>>)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(maximum)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(minimum)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(fmadd)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(fmsub)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_max)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_min)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator&)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator|)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator^)
VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(operator~)
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator+=)
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator-=)
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator*=)
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator/=)
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator%=)
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator<<=)
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator>>=)
#undef VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL
#undef VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL
#undef VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL
template <typename T, int N, typename OpVec>
inline T vec_reduce_all(const OpVec& vec_fun, VectorizedN<T, N> acc_vec) {
Vectorized<T> vec_result = acc_vec[0];
for (int i = 1; i < N; i++) {
vec_result = vec_fun(vec_result, acc_vec[i]);
}
return vec_reduce_all(vec_fun, vec_result);
}
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -0,0 +1,170 @@
#pragma once
#include <ATen/Config.h>
#include <ATen/Parallel.h>
#include <ATen/OpMathType.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <c10/util/complex.h>
// This header implements various unary operations using a MKL VML style
// interface.
// It implements various functions with a simple interface
// For example it enables the user to call vsin(float* out, const float* in,
// size) This functions takes a pointer to a continuous output array of floats and
// a constant input array. It will then apply sin to each value in the input
// array and write the result into the output array. out and in may point to the
// same memory, i.e. this fully supports in-place operations. These functions
// also implement their own parallelization, so take precautions when calling
// these from threaded functions.
// When MKL is available it will call into MKL's VML library similar to NumPy
// If MKL is not available it will use SLEEF.
// This file might be compiled under AVX or AVX2 when called from e.g.
// UnaryOpsKernel.cpp
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <type_traits>
#if AT_MKL_ENABLED() && !defined(__APPLE__)
#include <mkl.h>
#endif
namespace at::vml {
inline namespace CPU_CAPABILITY {
using namespace vec;
template <typename scalar_t>
inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) {
map(
[](const Vectorized<scalar_t>& x) {
return Vectorized<scalar_t>((scalar_t)(1)) / x.sqrt();
},
out + begin,
in + begin,
end - begin);
});
}
// NB: We ignore numerical errors by convention and leave them to the user
#define IMPLEMENT_VML(op) \
template <typename scalar_t> \
inline void v##op(scalar_t* out, const scalar_t* in, int64_t size) { \
using vec_t = Vectorized<vec_scalar_t<scalar_t>>; \
vec::map([](vec_t x) { return x.op(); }, out, in, size); \
} \
IMPLEMENT_VML(abs)
IMPLEMENT_VML(acos)
IMPLEMENT_VML(asin)
IMPLEMENT_VML(atan)
IMPLEMENT_VML(atanh)
IMPLEMENT_VML(ceil)
IMPLEMENT_VML(cos)
// IMPLEMENT_VML(cosh)
IMPLEMENT_VML(erf)
IMPLEMENT_VML(erfc)
IMPLEMENT_VML(erfinv)
IMPLEMENT_VML(exp)
IMPLEMENT_VML(expm1)
IMPLEMENT_VML(floor)
IMPLEMENT_VML(i0)
IMPLEMENT_VML(i0e)
IMPLEMENT_VML(digamma)
IMPLEMENT_VML(reciprocal)
IMPLEMENT_VML(log)
IMPLEMENT_VML(log10)
IMPLEMENT_VML(log1p)
IMPLEMENT_VML(log2)
IMPLEMENT_VML(neg)
IMPLEMENT_VML(sin)
// IMPLEMENT_VML(sinh)
IMPLEMENT_VML(sqrt)
IMPLEMENT_VML(round)
IMPLEMENT_VML(rsqrt)
IMPLEMENT_VML(tan)
IMPLEMENT_VML(tanh)
IMPLEMENT_VML(trunc)
IMPLEMENT_VML(lgamma)
#if AT_MKL_ENABLED() && !defined(__APPLE__)
// NB: LP64 MKL is the most commonly used and thus we assume it here. That means
// we need to expect MKL_INT to be of type int, which implies int32_t or int64_t in most
// cases.
static_assert(
std::is_same_v<MKL_INT, int32_t> || std::is_same_v<MKL_INT, int64_t>,
"MKL_INT is assumed to be int32_t or int64_t");
#define IMPLEMENT_VML_MKL_STUB(op, mklop, type, mkltype) \
template <> \
inline void v##op(type * out, const type * in, int64_t size) { \
int64_t max_mkl_ind = std::numeric_limits<MKL_INT>::max(); \
if (size <= static_cast<int64_t>(max_mkl_ind)) { \
vm##mkltype##mklop( \
size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
} else { \
MKL_INT ind = 0; \
int64_t chunks = size / max_mkl_ind; \
int64_t rest = size % max_mkl_ind; \
for (; ind < chunks; ind++) { \
vm##mkltype##mklop( \
max_mkl_ind, \
in + ind * max_mkl_ind, \
out + ind * max_mkl_ind, \
VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
} \
vm##mkltype##mklop( \
rest, \
in + ind * max_mkl_ind, \
out + ind * max_mkl_ind, \
VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
} \
}
#define IMPLEMENT_VML_MKL(op, mklop) \
IMPLEMENT_VML_MKL_STUB(op, mklop, float, s) \
IMPLEMENT_VML_MKL_STUB(op, mklop, double, d)
// NB: abs, cosh and sinh were temporarily disabled due to issues with Apple
// NB: expm1 is disabled because on some configs it produces expm1(nan)=-1
IMPLEMENT_VML_MKL(acos, Acos)
IMPLEMENT_VML_MKL(asin, Asin)
IMPLEMENT_VML_MKL(atan, Atan)
IMPLEMENT_VML_MKL(cos, Cos)
// IMPLEMENT_VML_MKL(cosh, Cosh)
IMPLEMENT_VML_MKL(erf, Erf)
IMPLEMENT_VML_MKL(erfc, Erfc)
IMPLEMENT_VML_MKL(erfinv, ErfInv)
IMPLEMENT_VML_MKL(exp, Exp)
// IMPLEMENT_VML_MKL(expm1, Expm1)
IMPLEMENT_VML_MKL(log, Ln)
IMPLEMENT_VML_MKL(log10, Log10)
IMPLEMENT_VML_MKL(sin, Sin)
// IMPLEMENT_VML_MKL(sinh, Sinh)
IMPLEMENT_VML_MKL(sqrt, Sqrt)
IMPLEMENT_VML_MKL(tan, Tan)
IMPLEMENT_VML_MKL(tanh, Tanh)
IMPLEMENT_VML_MKL(trunc, Trunc)
// Not vectorized in MKL version tested
// IMPLEMENT_VML_MKL(abs, Abs)
// IMPLEMENT_VML_MKL(log1p, Log1p)
#if INTEL_MKL_VERSION >= 20180406
IMPLEMENT_VML_MKL(log2, Log2)
#endif
#endif
} // namespace
} // namespace at::vml