I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,81 @@
#include <c10/macros/Macros.h>
#include <c10/util/Backtrace.h>
#include <c10/util/env.h>
#include <cstdlib>
#include <exception>
#include <iostream>
#include <mutex>
#include <optional>
namespace c10 {
class AbortHandlerHelper {
public:
static AbortHandlerHelper& getInstance() {
#ifdef _WIN32
thread_local
#endif // _WIN32
static AbortHandlerHelper instance;
return instance;
}
void set(std::terminate_handler handler) {
std::lock_guard<std::mutex> lk(mutex);
if (!inited) {
prev = std::set_terminate(handler);
curr = std::get_terminate();
inited = true;
}
}
std::terminate_handler getPrev() const {
return prev;
}
private:
std::terminate_handler prev = nullptr;
std::terminate_handler curr = nullptr;
bool inited = false;
std::mutex mutex;
AbortHandlerHelper() = default;
~AbortHandlerHelper() {
// Only restore the handler if we are the current one
if (inited && curr == std::get_terminate()) {
std::set_terminate(prev);
}
}
public:
AbortHandlerHelper(AbortHandlerHelper const&) = delete;
void operator=(AbortHandlerHelper const&) = delete;
};
namespace detail {
C10_ALWAYS_INLINE void terminate_handler() {
std::cout << "Unhandled exception caught in c10/util/AbortHandler.h" << '\n';
auto backtrace = get_backtrace();
std::cout << backtrace << '\n' << std::flush;
auto prev_handler = AbortHandlerHelper::getInstance().getPrev();
if (prev_handler) {
prev_handler();
} else {
std::abort();
}
}
} // namespace detail
C10_ALWAYS_INLINE void set_terminate_handler() {
bool use_custom_terminate = false;
// On Windows it is enabled by default based on
// https://github.com/pytorch/pytorch/pull/50320#issuecomment-763147062
#ifdef _WIN32
use_custom_terminate = true;
#endif // _WIN32
auto result = c10::utils::check_env("TORCH_CUSTOM_TERMINATE");
if (result != std::nullopt) {
use_custom_terminate = result.value();
}
if (use_custom_terminate) {
AbortHandlerHelper::getInstance().set(detail::terminate_handler);
}
}
} // namespace c10

View File

@ -0,0 +1,176 @@
//===--- AlignOf.h - Portable calculation of type alignment -----*- C++ -*-===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// This file defines the AlignedCharArray and AlignedCharArrayUnion classes.
//
//===----------------------------------------------------------------------===//
// ATen: modified from llvm::AlignOf
// replaced LLVM_ALIGNAS with alignas
#pragma once
#include <cstddef>
namespace c10 {
/// \struct AlignedCharArray
/// \brief Helper for building an aligned character array type.
///
/// This template is used to explicitly build up a collection of aligned
/// character array types. We have to build these up using a macro and explicit
/// specialization to cope with MSVC (at least till 2015) where only an
/// integer literal can be used to specify an alignment constraint. Once built
/// up here, we can then begin to indirect between these using normal C++
/// template parameters.
// MSVC requires special handling here.
#ifndef _MSC_VER
template <size_t Alignment, size_t Size>
struct AlignedCharArray {
// NOLINTNEXTLINE(*c-arrays)
alignas(Alignment) char buffer[Size];
};
#else // _MSC_VER
/// \brief Create a type with an aligned char buffer.
template <size_t Alignment, size_t Size>
struct AlignedCharArray;
// We provide special variations of this template for the most common
// alignments because __declspec(align(...)) doesn't actually work when it is
// a member of a by-value function argument in MSVC, even if the alignment
// request is something reasonably like 8-byte or 16-byte. Note that we can't
// even include the declspec with the union that forces the alignment because
// MSVC warns on the existence of the declspec despite the union member forcing
// proper alignment.
template <size_t Size>
struct AlignedCharArray<1, Size> {
union {
char aligned;
char buffer[Size];
};
};
template <size_t Size>
struct AlignedCharArray<2, Size> {
union {
short aligned;
char buffer[Size];
};
};
template <size_t Size>
struct AlignedCharArray<4, Size> {
union {
int aligned;
char buffer[Size];
};
};
template <size_t Size>
struct AlignedCharArray<8, Size> {
union {
double aligned;
char buffer[Size];
};
};
// The rest of these are provided with a __declspec(align(...)) and we simply
// can't pass them by-value as function arguments on MSVC.
#define AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(x) \
template <size_t Size> \
struct AlignedCharArray<x, Size> { \
__declspec(align(x)) char buffer[Size]; \
};
AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(16)
AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(32)
AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(64)
AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(128)
#undef AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT
#endif // _MSC_VER
namespace detail {
template <
typename T1,
typename T2 = char,
typename T3 = char,
typename T4 = char,
typename T5 = char,
typename T6 = char,
typename T7 = char,
typename T8 = char,
typename T9 = char,
typename T10 = char>
class AlignerImpl {
T1 t1;
T2 t2;
T3 t3;
T4 t4;
T5 t5;
T6 t6;
T7 t7;
T8 t8;
T9 t9;
T10 t10;
public:
AlignerImpl() = delete;
};
template <
typename T1,
typename T2 = char,
typename T3 = char,
typename T4 = char,
typename T5 = char,
typename T6 = char,
typename T7 = char,
typename T8 = char,
typename T9 = char,
typename T10 = char>
union SizerImpl {
// NOLINTNEXTLINE(*c-arrays)
char arr1[sizeof(T1)], arr2[sizeof(T2)], arr3[sizeof(T3)], arr4[sizeof(T4)],
arr5[sizeof(T5)], arr6[sizeof(T6)], arr7[sizeof(T7)], arr8[sizeof(T8)],
arr9[sizeof(T9)], arr10[sizeof(T10)];
};
} // end namespace detail
/// \brief This union template exposes a suitably aligned and sized character
/// array member which can hold elements of any of up to ten types.
///
/// These types may be arrays, structs, or any other types. The goal is to
/// expose a char array buffer member which can be used as suitable storage for
/// a placement new of any of these types. Support for more than ten types can
/// be added at the cost of more boilerplate.
template <
typename T1,
typename T2 = char,
typename T3 = char,
typename T4 = char,
typename T5 = char,
typename T6 = char,
typename T7 = char,
typename T8 = char,
typename T9 = char,
typename T10 = char>
struct AlignedCharArrayUnion
: AlignedCharArray<
alignof(detail::AlignerImpl<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10>),
sizeof(::c10::detail::
SizerImpl<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10>)> {};
} // end namespace c10

View File

@ -0,0 +1,115 @@
// Copyright 2023-present Facebook. All Rights Reserved.
#pragma once
#include <c10/macros/Export.h>
#include <array>
#include <chrono>
#include <cstddef>
#include <cstdint>
#include <ctime>
#include <functional>
#include <type_traits>
#if defined(C10_IOS) && defined(C10_MOBILE)
#include <sys/time.h> // for gettimeofday()
#endif
#if defined(__i386__) || defined(__x86_64__) || defined(__amd64__)
#define C10_RDTSC
#if defined(_MSC_VER)
#include <intrin.h>
#elif defined(__CUDACC__) || defined(__HIPCC__)
#undef C10_RDTSC
#elif defined(__clang__)
// `__rdtsc` is available by default.
// NB: This has to be first, because Clang will also define `__GNUC__`
#elif defined(__GNUC__)
#include <x86intrin.h>
#else
#undef C10_RDTSC
#endif
#endif
namespace c10 {
using time_t = int64_t;
using steady_clock_t = std::conditional_t<
std::chrono::high_resolution_clock::is_steady,
std::chrono::high_resolution_clock,
std::chrono::steady_clock>;
inline time_t getTimeSinceEpoch() {
auto now = std::chrono::system_clock::now().time_since_epoch();
return std::chrono::duration_cast<std::chrono::nanoseconds>(now).count();
}
inline time_t getTime(bool allow_monotonic = false) {
#if defined(C10_IOS) && defined(C10_MOBILE)
// clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS
// can't rely on CLOCK_REALTIME, as it is defined no matter if clock_gettime
// is implemented or not
struct timeval now;
gettimeofday(&now, NULL);
return static_cast<time_t>(now.tv_sec) * 1000000000 +
static_cast<time_t>(now.tv_usec) * 1000;
#elif defined(_WIN32) || defined(__MACH__)
return std::chrono::duration_cast<std::chrono::nanoseconds>(
steady_clock_t::now().time_since_epoch())
.count();
#else
// clock_gettime is *much* faster than std::chrono implementation on Linux
struct timespec t {};
auto mode = CLOCK_REALTIME;
if (allow_monotonic) {
mode = CLOCK_MONOTONIC;
}
clock_gettime(mode, &t);
return static_cast<time_t>(t.tv_sec) * 1000000000 +
static_cast<time_t>(t.tv_nsec);
#endif
}
// We often do not need to capture true wall times. If a fast mechanism such
// as TSC is available we can use that instead and convert back to epoch time
// during post processing. This greatly reduce the clock's contribution to
// profiling.
// http://btorpey.github.io/blog/2014/02/18/clock-sources-in-linux/
// https://quick-bench.com/q/r8opkkGZSJMu9wM_XTbDouq-0Io
// TODO: We should use
// `https://github.com/google/benchmark/blob/main/src/cycleclock.h`
inline auto getApproximateTime() {
#if defined(C10_RDTSC)
return static_cast<uint64_t>(__rdtsc());
#else
return getTime();
#endif
}
using approx_time_t = decltype(getApproximateTime());
static_assert(
std::is_same_v<approx_time_t, int64_t> ||
std::is_same_v<approx_time_t, uint64_t>,
"Expected either int64_t (`getTime`) or uint64_t (some TSC reads).");
// Convert `getCount` results to Nanoseconds since unix epoch.
class C10_API ApproximateClockToUnixTimeConverter final {
public:
ApproximateClockToUnixTimeConverter();
std::function<time_t(approx_time_t)> makeConverter();
struct UnixAndApproximateTimePair {
time_t t_;
approx_time_t approx_t_;
};
static UnixAndApproximateTimePair measurePair();
private:
static constexpr size_t replicates = 1001;
using time_pairs = std::array<UnixAndApproximateTimePair, replicates>;
time_pairs measurePairs();
time_pairs start_times_;
};
} // namespace c10

View File

@ -0,0 +1,18 @@
#pragma once
#include <array>
#include <utility>
namespace c10 {
// This helper function creates a constexpr std::array
// From a compile time list of values, without requiring you to explicitly
// write out the length.
//
// See also https://stackoverflow.com/a/26351760/23845
template <typename V, typename... T>
inline constexpr auto array_of(T&&... t) -> std::array<V, sizeof...(T)> {
return {{std::forward<T>(t)...}};
}
} // namespace c10

View File

@ -0,0 +1,380 @@
//===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
// ATen: modified from llvm::ArrayRef.
// removed llvm-specific functionality
// removed some implicit const -> non-const conversions that rely on
// complicated std::enable_if meta-programming
// removed a bunch of slice variants for simplicity...
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/Deprecated.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <array>
#include <cstddef>
#include <cstdint>
#include <initializer_list>
#include <iterator>
#include <ostream>
#include <type_traits>
#include <vector>
namespace c10 {
/// ArrayRef - Represent a constant reference to an array (0 or more elements
/// consecutively in memory), i.e. a start pointer and a length. It allows
/// various APIs to take consecutive elements easily and conveniently.
///
/// This class does not own the underlying data, it is expected to be used in
/// situations where the data resides in some other buffer, whose lifetime
/// extends past that of the ArrayRef. For this reason, it is not in general
/// safe to store an ArrayRef.
///
/// This is intended to be trivially copyable, so it should be passed by
/// value.
template <typename T>
class ArrayRef final {
public:
using iterator = const T*;
using const_iterator = const T*;
using size_type = size_t;
using value_type = T;
using reverse_iterator = std::reverse_iterator<iterator>;
private:
/// The start of the array, in an external buffer.
const T* Data;
/// The number of elements.
size_type Length;
void debugCheckNullptrInvariant() {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
Data != nullptr || Length == 0,
"created ArrayRef with nullptr and non-zero length! std::optional relies on this being illegal");
}
public:
/// @name Constructors
/// @{
/// Construct an empty ArrayRef.
/* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {}
/// Construct an ArrayRef from a single element.
// TODO Make this explicit
constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
/// Construct an ArrayRef from a pointer and length.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef(const T* data, size_t length)
: Data(data), Length(length) {
debugCheckNullptrInvariant();
}
/// Construct an ArrayRef from a range.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef(const T* begin, const T* end)
: Data(begin), Length(end - begin) {
debugCheckNullptrInvariant();
}
/// Construct an ArrayRef from a SmallVector. This is templated in order to
/// avoid instantiating SmallVectorTemplateCommon<T> whenever we
/// copy-construct an ArrayRef.
template <typename U>
/* implicit */ ArrayRef(const SmallVectorTemplateCommon<T, U>& Vec)
: Data(Vec.data()), Length(Vec.size()) {
debugCheckNullptrInvariant();
}
template <
typename Container,
typename = std::enable_if_t<std::is_same_v<
std::remove_const_t<decltype(std::declval<Container>().data())>,
T*>>>
/* implicit */ ArrayRef(const Container& container)
: Data(container.data()), Length(container.size()) {
debugCheckNullptrInvariant();
}
/// Construct an ArrayRef from a std::vector.
// The enable_if stuff here makes sure that this isn't used for
// std::vector<bool>, because ArrayRef can't work on a std::vector<bool>
// bitfield.
template <typename A>
/* implicit */ ArrayRef(const std::vector<T, A>& Vec)
: Data(Vec.data()), Length(Vec.size()) {
static_assert(
!std::is_same<T, bool>::value,
"ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
}
/// Construct an ArrayRef from a std::array
template <size_t N>
/* implicit */ constexpr ArrayRef(const std::array<T, N>& Arr)
: Data(Arr.data()), Length(N) {}
/// Construct an ArrayRef from a C array.
template <size_t N>
// NOLINTNEXTLINE(*c-arrays*)
/* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}
/// Construct an ArrayRef from a std::initializer_list.
/* implicit */ constexpr ArrayRef(const std::initializer_list<T>& Vec)
: Data(
std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
: std::begin(Vec)),
Length(Vec.size()) {}
/// @}
/// @name Simple Operations
/// @{
constexpr iterator begin() const {
return Data;
}
constexpr iterator end() const {
return Data + Length;
}
// These are actually the same as iterator, since ArrayRef only
// gives you const iterators.
constexpr const_iterator cbegin() const {
return Data;
}
constexpr const_iterator cend() const {
return Data + Length;
}
constexpr reverse_iterator rbegin() const {
return reverse_iterator(end());
}
constexpr reverse_iterator rend() const {
return reverse_iterator(begin());
}
/// empty - Check if the array is empty.
constexpr bool empty() const {
return Length == 0;
}
constexpr const T* data() const {
return Data;
}
/// size - Get the array size.
constexpr size_t size() const {
return Length;
}
/// front - Get the first element.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& front() const {
TORCH_CHECK(
!empty(), "ArrayRef: attempted to access front() of empty list");
return Data[0];
}
/// back - Get the last element.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& back() const {
TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list");
return Data[Length - 1];
}
/// equals - Check for element-wise equality.
constexpr bool equals(ArrayRef RHS) const {
return Length == RHS.Length && std::equal(begin(), end(), RHS.begin());
}
/// slice(n, m) - Take M elements of the array starting at element N
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef<T> slice(size_t N, size_t M)
const {
TORCH_CHECK(
N + M <= size(),
"ArrayRef: invalid slice, N = ",
N,
"; M = ",
M,
"; size = ",
size());
return ArrayRef<T>(data() + N, M);
}
/// slice(n) - Chop off the first N elements of the array.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef<T> slice(size_t N) const {
TORCH_CHECK(
N <= size(), "ArrayRef: invalid slice, N = ", N, "; size = ", size());
return slice(N, size() - N);
}
/// @}
/// @name Operator Overloads
/// @{
constexpr const T& operator[](size_t Index) const {
return Data[Index];
}
/// Vector compatibility
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& at(size_t Index) const {
TORCH_CHECK(
Index < Length,
"ArrayRef: invalid index Index = ",
Index,
"; Length = ",
Length);
return Data[Index];
}
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
std::enable_if_t<std::is_same_v<U, T>, ArrayRef<T>>& operator=(
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
U&& Temporary) = delete;
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
std::enable_if_t<std::is_same_v<U, T>, ArrayRef<T>>& operator=(
std::initializer_list<U>) = delete;
/// @}
/// @name Expensive Operations
/// @{
std::vector<T> vec() const {
return std::vector<T>(Data, Data + Length);
}
/// @}
};
template <typename T>
std::ostream& operator<<(std::ostream& out, ArrayRef<T> list) {
int i = 0;
out << "[";
for (const auto& e : list) {
if (i++ > 0)
out << ", ";
out << e;
}
out << "]";
return out;
}
/// @name ArrayRef Convenience constructors
/// @{
/// Construct an ArrayRef from a single element.
template <typename T>
ArrayRef<T> makeArrayRef(const T& OneElt) {
return OneElt;
}
/// Construct an ArrayRef from a pointer and length.
template <typename T>
ArrayRef<T> makeArrayRef(const T* data, size_t length) {
return ArrayRef<T>(data, length);
}
/// Construct an ArrayRef from a range.
template <typename T>
ArrayRef<T> makeArrayRef(const T* begin, const T* end) {
return ArrayRef<T>(begin, end);
}
/// Construct an ArrayRef from a SmallVector.
template <typename T>
ArrayRef<T> makeArrayRef(const SmallVectorImpl<T>& Vec) {
return Vec;
}
/// Construct an ArrayRef from a SmallVector.
template <typename T, unsigned N>
ArrayRef<T> makeArrayRef(const SmallVector<T, N>& Vec) {
return Vec;
}
/// Construct an ArrayRef from a std::vector.
template <typename T>
ArrayRef<T> makeArrayRef(const std::vector<T>& Vec) {
return Vec;
}
/// Construct an ArrayRef from a std::array.
template <typename T, std::size_t N>
ArrayRef<T> makeArrayRef(const std::array<T, N>& Arr) {
return Arr;
}
/// Construct an ArrayRef from an ArrayRef (no-op) (const)
template <typename T>
ArrayRef<T> makeArrayRef(const ArrayRef<T>& Vec) {
return Vec;
}
/// Construct an ArrayRef from an ArrayRef (no-op)
template <typename T>
ArrayRef<T>& makeArrayRef(ArrayRef<T>& Vec) {
return Vec;
}
/// Construct an ArrayRef from a C array.
template <typename T, size_t N>
// NOLINTNEXTLINE(*c-arrays*)
ArrayRef<T> makeArrayRef(const T (&Arr)[N]) {
return ArrayRef<T>(Arr);
}
// WARNING: Template instantiation will NOT be willing to do an implicit
// conversions to get you to an c10::ArrayRef, which is why we need so
// many overloads.
template <typename T>
bool operator==(c10::ArrayRef<T> a1, c10::ArrayRef<T> a2) {
return a1.equals(a2);
}
template <typename T>
bool operator!=(c10::ArrayRef<T> a1, c10::ArrayRef<T> a2) {
return !a1.equals(a2);
}
template <typename T>
bool operator==(const std::vector<T>& a1, c10::ArrayRef<T> a2) {
return c10::ArrayRef<T>(a1).equals(a2);
}
template <typename T>
bool operator!=(const std::vector<T>& a1, c10::ArrayRef<T> a2) {
return !c10::ArrayRef<T>(a1).equals(a2);
}
template <typename T>
bool operator==(c10::ArrayRef<T> a1, const std::vector<T>& a2) {
return a1.equals(c10::ArrayRef<T>(a2));
}
template <typename T>
bool operator!=(c10::ArrayRef<T> a1, const std::vector<T>& a2) {
return !a1.equals(c10::ArrayRef<T>(a2));
}
using IntArrayRef = ArrayRef<int64_t>;
// This alias is deprecated because it doesn't make ownership
// semantics obvious. Use IntArrayRef instead!
C10_DEFINE_DEPRECATED_USING(IntList, ArrayRef<int64_t>)
} // namespace c10

View File

@ -0,0 +1,361 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/bit_cast.h>
#include <limits>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
#if defined(CL_SYCL_LANGUAGE_VERSION)
#include <CL/sycl.hpp> // for SYCL 1.2.1
#else
#include <sycl/sycl.hpp> // for SYCL 2020
#endif
#include <ext/oneapi/bfloat16.hpp>
#endif
namespace c10 {
/// Constructors
inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
:
#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \
__CUDA_ARCH__ >= 800
x(__bfloat16_as_ushort(__float2bfloat16(value)))
#elif defined(__SYCL_DEVICE_ONLY__) && \
defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
x(c10::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value)))
#else
// RNE by default
x(detail::round_to_nearest_even(value))
#endif
{
}
/// Implicit conversions
inline C10_HOST_DEVICE BFloat16::operator float() const {
#if defined(__CUDACC__) && !defined(USE_ROCM)
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
#elif defined(__SYCL_DEVICE_ONLY__) && \
defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
return float(*reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x));
#else
return detail::f32_from_bits(x);
#endif
}
#if defined(__CUDACC__) && !defined(USE_ROCM)
inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) {
x = *reinterpret_cast<const unsigned short*>(&value);
}
inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const {
return *reinterpret_cast<const __nv_bfloat16*>(&x);
}
#endif
#if defined(__HIPCC__) && defined(USE_ROCM)
// 6.2.0 introduced __hip_bfloat16_raw
#if defined(__BF16_HOST_DEVICE__)
inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) {
x = __hip_bfloat16_raw(value).x;
}
inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const {
return __hip_bfloat16(__hip_bfloat16_raw{x});
}
#else // !defined(__BF16_HOST_DEVICE__)
inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) {
x = value.data;
}
inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const {
return __hip_bfloat16{x};
}
#endif // !defined(__BF16_HOST_DEVICE__)
#endif // defined(__HIPCC__) && defined(USE_ROCM)
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
inline C10_HOST_DEVICE BFloat16::BFloat16(
const sycl::ext::oneapi::bfloat16& value) {
x = *reinterpret_cast<const unsigned short*>(&value);
}
inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const {
return *reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x);
}
#endif
// CUDA intrinsics
#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) {
#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __ldg(reinterpret_cast<const __nv_bfloat16*>(ptr));
#else
return *ptr;
#endif
}
#endif
/// Arithmetic
inline C10_HOST_DEVICE BFloat16
operator+(const BFloat16& a, const BFloat16& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline C10_HOST_DEVICE BFloat16
operator-(const BFloat16& a, const BFloat16& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline C10_HOST_DEVICE BFloat16
operator*(const BFloat16& a, const BFloat16& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}
inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a) {
return -static_cast<float>(a);
}
inline C10_HOST_DEVICE BFloat16& operator+=(BFloat16& a, const BFloat16& b) {
a = a + b;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator-=(BFloat16& a, const BFloat16& b) {
a = a - b;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator*=(BFloat16& a, const BFloat16& b) {
a = a * b;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) {
a = a / b;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) {
a.x = a.x | b.x;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) {
a.x = a.x ^ b.x;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) {
a.x = a.x & b.x;
return a;
}
/// Arithmetic with floats
inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) {
return static_cast<float>(a) + b;
}
inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) {
return static_cast<float>(a) - b;
}
inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) {
return static_cast<float>(a) * b;
}
inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) {
return static_cast<float>(a) / b;
}
inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) {
return a + static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) {
return a - static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) {
return a * static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) {
return a / static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) {
return a += static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) {
return a -= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) {
return a *= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) {
return a /= static_cast<float>(b);
}
/// Arithmetic with doubles
inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) {
return static_cast<double>(a) + b;
}
inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) {
return static_cast<double>(a) - b;
}
inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) {
return static_cast<double>(a) * b;
}
inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) {
return static_cast<double>(a) / b;
}
inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) {
return a + static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) {
return a - static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) {
return a * static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) {
return a / static_cast<double>(b);
}
/// Arithmetic with ints
inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) {
return a + static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) {
return a - static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) {
return a * static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) {
return a / static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) {
return static_cast<BFloat16>(a) + b;
}
inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) {
return static_cast<BFloat16>(a) - b;
}
inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) {
return static_cast<BFloat16>(a) * b;
}
inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) {
return static_cast<BFloat16>(a) / b;
}
//// Arithmetic with int64_t
inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) {
return a + static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) {
return a - static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) {
return a * static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) {
return a / static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) {
return static_cast<BFloat16>(a) + b;
}
inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) {
return static_cast<BFloat16>(a) - b;
}
inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) {
return static_cast<BFloat16>(a) * b;
}
inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) {
return static_cast<BFloat16>(a) / b;
}
// Overloading < and > operators, because std::max and std::min use them.
inline C10_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) {
return float(lhs) > float(rhs);
}
inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) {
return float(lhs) < float(rhs);
}
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::BFloat16> {
public:
static constexpr bool is_signed = true;
static constexpr bool is_specialized = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
static constexpr auto has_denorm_loss =
numeric_limits<float>::has_denorm_loss;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 8;
static constexpr int digits10 = 2;
static constexpr int max_digits10 = 4;
static constexpr int radix = 2;
static constexpr int min_exponent = -125;
static constexpr int min_exponent10 = -37;
static constexpr int max_exponent = 128;
static constexpr int max_exponent10 = 38;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before =
numeric_limits<float>::tinyness_before;
static constexpr c10::BFloat16 min() {
return c10::BFloat16(0x0080, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 lowest() {
return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 max() {
return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 epsilon() {
return c10::BFloat16(0x3C00, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 round_error() {
return c10::BFloat16(0x3F00, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 infinity() {
return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 quiet_NaN() {
return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 signaling_NaN() {
return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 denorm_min() {
return c10::BFloat16(0x0001, c10::BFloat16::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -0,0 +1,292 @@
#pragma once
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
#endif
namespace std {
template <typename T>
struct is_reduced_floating_point
: std::integral_constant<
bool,
std::is_same_v<T, c10::Half> || std::is_same_v<T, c10::BFloat16>> {};
template <typename T>
constexpr bool is_reduced_floating_point_v =
is_reduced_floating_point<T>::value;
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T acos(T a) {
return std::acos(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T asin(T a) {
return std::asin(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T atan(T a) {
return std::atan(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T atanh(T a) {
return std::atanh(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T erf(T a) {
return std::erf(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T erfc(T a) {
return std::erfc(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T exp(T a) {
return std::exp(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T expm1(T a) {
return std::expm1(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline bool isfinite(T a) {
return std::isfinite(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T log(T a) {
return std::log(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T log10(T a) {
return std::log10(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T log1p(T a) {
return std::log1p(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T log2(T a) {
return std::log2(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T ceil(T a) {
return std::ceil(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T cos(T a) {
return std::cos(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T floor(T a) {
return std::floor(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T nearbyint(T a) {
return std::nearbyint(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T sin(T a) {
return std::sin(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T tan(T a) {
return std::tan(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T sinh(T a) {
return std::sinh(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T cosh(T a) {
return std::cosh(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T tanh(T a) {
return std::tanh(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T trunc(T a) {
return std::trunc(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T lgamma(T a) {
return std::lgamma(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T sqrt(T a) {
return std::sqrt(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T rsqrt(T a) {
return 1.0 / std::sqrt(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T abs(T a) {
return std::abs(float(a));
}
#if defined(_MSC_VER) && defined(__CUDACC__)
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T pow(T a, double b) {
return std::pow(float(a), float(b));
}
#else
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T pow(T a, double b) {
return std::pow(float(a), b);
}
#endif
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T pow(T a, T b) {
return std::pow(float(a), float(b));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T fmod(T a, T b) {
return std::fmod(float(a), float(b));
}
/*
The following function is inspired from the implementation in `musl`
Link to License: https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT
----------------------------------------------------------------------
Copyright © 2005-2020 Rich Felker, et al.
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
----------------------------------------------------------------------
*/
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
C10_HOST_DEVICE inline T nextafter(T from, T to) {
// Reference:
// https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c
using int_repr_t = uint16_t;
constexpr uint8_t bits = 16;
union {
T f;
int_repr_t i;
} ufrom = {from}, uto = {to};
// get a mask to get the sign bit i.e. MSB
int_repr_t sign_mask = int_repr_t{1} << (bits - 1);
// short-circuit: if either is NaN, return NaN
if (from != from || to != to) {
return from + to;
}
// short-circuit: if they are exactly the same.
if (ufrom.i == uto.i) {
return from;
}
// mask the sign-bit to zero i.e. positive
// equivalent to abs(x)
int_repr_t abs_from = ufrom.i & ~sign_mask;
int_repr_t abs_to = uto.i & ~sign_mask;
if (abs_from == 0) {
// if both are zero but with different sign,
// preserve the sign of `to`.
if (abs_to == 0) {
return to;
}
// smallest subnormal with sign of `to`.
ufrom.i = (uto.i & sign_mask) | int_repr_t{1};
return ufrom.f;
}
// if abs(from) > abs(to) or sign(from) != sign(to)
if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) {
ufrom.i--;
} else {
ufrom.i++;
}
return ufrom.f;
}
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -0,0 +1,133 @@
#pragma once
// Defines the bloat16 type (brain floating-point). This representation uses
// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.
#include <c10/macros/Macros.h>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <iosfwd>
#include <ostream>
#if defined(__CUDACC__) && !defined(USE_ROCM)
#include <cuda_bf16.h>
#endif
#if defined(__HIPCC__) && defined(USE_ROCM)
#include <hip/hip_bf16.h>
#endif
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
#if defined(CL_SYCL_LANGUAGE_VERSION)
#include <CL/sycl.hpp> // for SYCL 1.2.1
#else
#include <sycl/sycl.hpp> // for SYCL 2020
#endif
#include <ext/oneapi/bfloat16.hpp>
#endif
namespace c10 {
namespace detail {
inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
float res = 0;
uint32_t tmp = src;
tmp <<= 16;
#if defined(USE_ROCM)
float* tempRes;
// We should be using memcpy in order to respect the strict aliasing rule
// but it fails in the HIP environment.
tempRes = reinterpret_cast<float*>(&tmp);
res = *tempRes;
#else
std::memcpy(&res, &tmp, sizeof(tmp));
#endif
return res;
}
inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
uint32_t res = 0;
#if defined(USE_ROCM)
// We should be using memcpy in order to respect the strict aliasing rule
// but it fails in the HIP environment.
uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src);
res = *tempRes;
#else
std::memcpy(&res, &src, sizeof(res));
#endif
return res >> 16;
}
inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
#if defined(USE_ROCM)
if (src != src) {
#elif defined(_MSC_VER)
if (isnan(src)) {
#else
if (std::isnan(src)) {
#endif
return UINT16_C(0x7FC0);
} else {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
union {
uint32_t U32; // NOLINT(facebook-hte-BadMemberName)
float F32; // NOLINT(facebook-hte-BadMemberName)
};
F32 = src;
uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
}
}
} // namespace detail
struct alignas(2) BFloat16 {
uint16_t x;
// HIP wants __host__ __device__ tag, CUDA does not
#if defined(USE_ROCM)
C10_HOST_DEVICE BFloat16() = default;
#else
BFloat16() = default;
#endif
struct from_bits_t {};
static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
return from_bits_t();
}
constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t)
: x(bits) {}
/* implicit */ inline C10_HOST_DEVICE BFloat16(float value);
inline C10_HOST_DEVICE operator float() const;
#if defined(__CUDACC__) && !defined(USE_ROCM)
inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
#endif
#if defined(__HIPCC__) && defined(USE_ROCM)
inline C10_HOST_DEVICE BFloat16(const __hip_bfloat16& value);
explicit inline C10_HOST_DEVICE operator __hip_bfloat16() const;
#endif
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value);
explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const;
#endif
};
C10_API inline std::ostream& operator<<(
std::ostream& out,
const BFloat16& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/BFloat16-inl.h> // IWYU pragma: keep

View File

@ -0,0 +1,31 @@
#ifndef C10_UTIL_BACKTRACE_H_
#define C10_UTIL_BACKTRACE_H_
#include <cstddef>
#include <memory>
#include <string>
#include <typeinfo>
#include <c10/macros/Macros.h>
#include <c10/util/Lazy.h>
namespace c10 {
// Symbolizing the backtrace can be expensive; pass it around as a lazy string
// so it is symbolized only if actually needed.
using Backtrace = std::shared_ptr<const LazyValue<std::string>>;
// DEPRECATED: Prefer get_lazy_backtrace().
C10_API std::string get_backtrace(
size_t frames_to_skip = 0,
size_t maximum_number_of_frames = 64,
bool skip_python_frames = true);
C10_API Backtrace get_lazy_backtrace(
size_t frames_to_skip = 0,
size_t maximum_number_of_frames = 64,
bool skip_python_frames = true);
} // namespace c10
#endif // C10_UTIL_BACKTRACE_H_

View File

@ -0,0 +1,116 @@
#pragma once
#include <cstddef>
#if defined(_MSC_VER)
#include <intrin.h>
#endif
namespace c10::utils {
/**
* This is a simple bitset class with sizeof(long long int) bits.
* You can set bits, unset bits, query bits by index,
* and query for the first set bit.
* Before using this class, please also take a look at std::bitset,
* which has more functionality and is more generic. It is probably
* a better fit for your use case. The sole reason for c10::utils::bitset
* to exist is that std::bitset misses a find_first_set() method.
*/
struct bitset final {
private:
#if defined(_MSC_VER)
// MSVCs _BitScanForward64 expects int64_t
using bitset_type = int64_t;
#else
// POSIX ffsll expects long long int
using bitset_type = long long int;
#endif
public:
static constexpr size_t NUM_BITS() {
return 8 * sizeof(bitset_type);
}
constexpr bitset() noexcept = default;
constexpr bitset(const bitset&) noexcept = default;
constexpr bitset(bitset&&) noexcept = default;
// there is an issure for gcc 5.3.0 when define default function as constexpr
// see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=68754.
bitset& operator=(const bitset&) noexcept = default;
bitset& operator=(bitset&&) noexcept = default;
constexpr void set(size_t index) noexcept {
bitset_ |= (static_cast<long long int>(1) << index);
}
constexpr void unset(size_t index) noexcept {
bitset_ &= ~(static_cast<long long int>(1) << index);
}
constexpr bool get(size_t index) const noexcept {
return bitset_ & (static_cast<long long int>(1) << index);
}
constexpr bool is_entirely_unset() const noexcept {
return 0 == bitset_;
}
// Call the given functor with the index of each bit that is set
template <class Func>
void for_each_set_bit(Func&& func) const {
bitset cur = *this;
size_t index = cur.find_first_set();
while (0 != index) {
// -1 because find_first_set() is not one-indexed.
index -= 1;
func(index);
cur.unset(index);
index = cur.find_first_set();
}
}
private:
// Return the index of the first set bit. The returned index is one-indexed
// (i.e. if the very first bit is set, this function returns '1'), and a
// return of '0' means that there was no bit set.
size_t find_first_set() const {
#if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_ARM64))
unsigned long result;
bool has_bits_set = (0 != _BitScanForward64(&result, bitset_));
if (!has_bits_set) {
return 0;
}
return result + 1;
#elif defined(_MSC_VER) && defined(_M_IX86)
unsigned long result;
if (static_cast<uint32_t>(bitset_) != 0) {
bool has_bits_set =
(0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_)));
if (!has_bits_set) {
return 0;
}
return result + 1;
} else {
bool has_bits_set =
(0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_ >> 32)));
if (!has_bits_set) {
return 32;
}
return result + 33;
}
#else
return __builtin_ffsll(bitset_);
#endif
}
friend bool operator==(bitset lhs, bitset rhs) noexcept {
return lhs.bitset_ == rhs.bitset_;
}
bitset_type bitset_{0};
};
inline bool operator!=(bitset lhs, bitset rhs) noexcept {
return !(lhs == rhs);
}
} // namespace c10::utils

View File

@ -0,0 +1,142 @@
#pragma once
#ifndef C10_UTIL_CPP17_H_
#define C10_UTIL_CPP17_H_
#include <c10/macros/Macros.h>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#if !defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && \
__GNUC__ < 9
#error \
"You're trying to build PyTorch with a too old version of GCC. We need GCC 9 or later."
#endif
#if defined(__clang__) && __clang_major__ < 9
#error \
"You're trying to build PyTorch with a too old version of Clang. We need Clang 9 or later."
#endif
#if (defined(_MSC_VER) && (!defined(_MSVC_LANG) || _MSVC_LANG < 201703L)) || \
(!defined(_MSC_VER) && __cplusplus < 201703L)
#error You need C++17 to compile PyTorch
#endif
#if defined(_WIN32) && (defined(min) || defined(max))
#error Macro clash with min and max -- define NOMINMAX when compiling your program on Windows
#endif
/*
* This header adds some polyfills with C++17 functionality
*/
namespace c10 {
// std::is_pod is deprecated in C++20, std::is_standard_layout and
// std::is_trivial are introduced in C++11, std::conjunction has been introduced
// in C++17.
template <typename T>
using is_pod = std::conjunction<std::is_standard_layout<T>, std::is_trivial<T>>;
template <typename T>
constexpr bool is_pod_v = is_pod<T>::value;
namespace guts {
template <typename Base, typename Child, typename... Args>
std::enable_if_t<
!std::is_array_v<Base> && !std::is_array_v<Child> &&
std::is_base_of_v<Base, Child>,
std::unique_ptr<Base>>
make_unique_base(Args&&... args) {
return std::unique_ptr<Base>(new Child(std::forward<Args>(args)...));
}
#if defined(__cpp_lib_apply) && !defined(__CUDA_ARCH__) && !defined(__HIP__)
template <class F, class Tuple>
C10_HOST_DEVICE inline constexpr decltype(auto) apply(F&& f, Tuple&& t) {
return std::apply(std::forward<F>(f), std::forward<Tuple>(t));
}
#else
// Implementation from http://en.cppreference.com/w/cpp/utility/apply (but
// modified)
// TODO This is an incomplete implementation of std::apply, not working for
// member functions.
namespace detail {
template <class F, class Tuple, std::size_t... INDEX>
#if defined(_MSC_VER)
// MSVC has a problem with the decltype() return type, but it also doesn't need
// it
C10_HOST_DEVICE constexpr auto apply_impl(
F&& f,
Tuple&& t,
std::index_sequence<INDEX...>)
#else
// GCC/Clang need the decltype() return type
C10_HOST_DEVICE constexpr decltype(auto) apply_impl(
F&& f,
Tuple&& t,
std::index_sequence<INDEX...>)
#endif
{
return std::forward<F>(f)(std::get<INDEX>(std::forward<Tuple>(t))...);
}
} // namespace detail
template <class F, class Tuple>
C10_HOST_DEVICE constexpr decltype(auto) apply(F&& f, Tuple&& t) {
return detail::apply_impl(
std::forward<F>(f),
std::forward<Tuple>(t),
std::make_index_sequence<
std::tuple_size<std::remove_reference_t<Tuple>>::value>{});
}
#endif
template <typename Functor, typename... Args>
std::enable_if_t<
std::is_member_pointer_v<std::decay_t<Functor>>,
typename std::invoke_result_t<Functor, Args...>>
invoke(Functor&& f, Args&&... args) {
return std::mem_fn(std::forward<Functor>(f))(std::forward<Args>(args)...);
}
template <typename Functor, typename... Args>
std::enable_if_t<
!std::is_member_pointer_v<std::decay_t<Functor>>,
typename std::invoke_result_t<Functor, Args...>>
invoke(Functor&& f, Args&&... args) {
return std::forward<Functor>(f)(std::forward<Args>(args)...);
}
namespace detail {
struct _identity final {
template <class T>
using type_identity = T;
template <class T>
decltype(auto) operator()(T&& arg) {
return std::forward<T>(arg);
}
};
template <class Func, class Enable = void>
struct function_takes_identity_argument : std::false_type {};
template <class Func>
struct function_takes_identity_argument<
Func,
std::void_t<decltype(std::declval<Func>()(_identity()))>> : std::true_type {
};
} // namespace detail
} // namespace guts
} // namespace c10
#endif // C10_UTIL_CPP17_H_

View File

@ -0,0 +1,67 @@
#pragma once
#include <atomic>
#include <mutex>
#include <utility>
#include <c10/macros/Macros.h>
#include <c10/util/C++17.h>
namespace c10 {
// custom c10 call_once implementation to avoid the deadlock in std::call_once.
// The implementation here is a simplified version from folly and likely much
// much higher memory footprint.
template <typename Flag, typename F, typename... Args>
inline void call_once(Flag& flag, F&& f, Args&&... args) {
if (C10_LIKELY(flag.test_once())) {
return;
}
flag.call_once_slow(std::forward<F>(f), std::forward<Args>(args)...);
}
class once_flag {
public:
#ifndef _WIN32
// running into build error on MSVC. Can't seem to get a repro locally so I'm
// just avoiding constexpr
//
// C:/actions-runner/_work/pytorch/pytorch\c10/util/CallOnce.h(26): error:
// defaulted default constructor cannot be constexpr because the
// corresponding implicitly declared default constructor would not be
// constexpr 1 error detected in the compilation of
// "C:/actions-runner/_work/pytorch/pytorch/aten/src/ATen/cuda/cub.cu".
constexpr
#endif
once_flag() noexcept = default;
once_flag(const once_flag&) = delete;
once_flag& operator=(const once_flag&) = delete;
private:
template <typename Flag, typename F, typename... Args>
friend void call_once(Flag& flag, F&& f, Args&&... args);
template <typename F, typename... Args>
void call_once_slow(F&& f, Args&&... args) {
std::lock_guard<std::mutex> guard(mutex_);
if (init_.load(std::memory_order_relaxed)) {
return;
}
c10::guts::invoke(std::forward<F>(f), std::forward<Args>(args)...);
init_.store(true, std::memory_order_release);
}
bool test_once() {
return init_.load(std::memory_order_acquire);
}
void reset_once() {
init_.store(false, std::memory_order_release);
}
private:
std::mutex mutex_;
std::atomic<bool> init_{false};
};
} // namespace c10

View File

@ -0,0 +1,130 @@
#pragma once
#include <c10/util/IdWrapper.h>
#include <c10/util/string_view.h>
#include <cstddef>
#include <cstdint>
namespace c10::util {
namespace detail {
// NOLINTNEXTLINE(*c-arrays*)
constexpr uint64_t crc64_table[] = {
0x0000000000000000, 0x7ad870c830358979, 0xf5b0e190606b12f2,
0x8f689158505e9b8b, 0xc038e5739841b68f, 0xbae095bba8743ff6,
0x358804e3f82aa47d, 0x4f50742bc81f2d04, 0xab28ecb46814fe75,
0xd1f09c7c5821770c, 0x5e980d24087fec87, 0x24407dec384a65fe,
0x6b1009c7f05548fa, 0x11c8790fc060c183, 0x9ea0e857903e5a08,
0xe478989fa00bd371, 0x7d08ff3b88be6f81, 0x07d08ff3b88be6f8,
0x88b81eabe8d57d73, 0xf2606e63d8e0f40a, 0xbd301a4810ffd90e,
0xc7e86a8020ca5077, 0x4880fbd87094cbfc, 0x32588b1040a14285,
0xd620138fe0aa91f4, 0xacf86347d09f188d, 0x2390f21f80c18306,
0x594882d7b0f40a7f, 0x1618f6fc78eb277b, 0x6cc0863448deae02,
0xe3a8176c18803589, 0x997067a428b5bcf0, 0xfa11fe77117cdf02,
0x80c98ebf2149567b, 0x0fa11fe77117cdf0, 0x75796f2f41224489,
0x3a291b04893d698d, 0x40f16bccb908e0f4, 0xcf99fa94e9567b7f,
0xb5418a5cd963f206, 0x513912c379682177, 0x2be1620b495da80e,
0xa489f35319033385, 0xde51839b2936bafc, 0x9101f7b0e12997f8,
0xebd98778d11c1e81, 0x64b116208142850a, 0x1e6966e8b1770c73,
0x8719014c99c2b083, 0xfdc17184a9f739fa, 0x72a9e0dcf9a9a271,
0x08719014c99c2b08, 0x4721e43f0183060c, 0x3df994f731b68f75,
0xb29105af61e814fe, 0xc849756751dd9d87, 0x2c31edf8f1d64ef6,
0x56e99d30c1e3c78f, 0xd9810c6891bd5c04, 0xa3597ca0a188d57d,
0xec09088b6997f879, 0x96d1784359a27100, 0x19b9e91b09fcea8b,
0x636199d339c963f2, 0xdf7adabd7a6e2d6f, 0xa5a2aa754a5ba416,
0x2aca3b2d1a053f9d, 0x50124be52a30b6e4, 0x1f423fcee22f9be0,
0x659a4f06d21a1299, 0xeaf2de5e82448912, 0x902aae96b271006b,
0x74523609127ad31a, 0x0e8a46c1224f5a63, 0x81e2d7997211c1e8,
0xfb3aa75142244891, 0xb46ad37a8a3b6595, 0xceb2a3b2ba0eecec,
0x41da32eaea507767, 0x3b024222da65fe1e, 0xa2722586f2d042ee,
0xd8aa554ec2e5cb97, 0x57c2c41692bb501c, 0x2d1ab4dea28ed965,
0x624ac0f56a91f461, 0x1892b03d5aa47d18, 0x97fa21650afae693,
0xed2251ad3acf6fea, 0x095ac9329ac4bc9b, 0x7382b9faaaf135e2,
0xfcea28a2faafae69, 0x8632586aca9a2710, 0xc9622c4102850a14,
0xb3ba5c8932b0836d, 0x3cd2cdd162ee18e6, 0x460abd1952db919f,
0x256b24ca6b12f26d, 0x5fb354025b277b14, 0xd0dbc55a0b79e09f,
0xaa03b5923b4c69e6, 0xe553c1b9f35344e2, 0x9f8bb171c366cd9b,
0x10e3202993385610, 0x6a3b50e1a30ddf69, 0x8e43c87e03060c18,
0xf49bb8b633338561, 0x7bf329ee636d1eea, 0x012b592653589793,
0x4e7b2d0d9b47ba97, 0x34a35dc5ab7233ee, 0xbbcbcc9dfb2ca865,
0xc113bc55cb19211c, 0x5863dbf1e3ac9dec, 0x22bbab39d3991495,
0xadd33a6183c78f1e, 0xd70b4aa9b3f20667, 0x985b3e827bed2b63,
0xe2834e4a4bd8a21a, 0x6debdf121b863991, 0x1733afda2bb3b0e8,
0xf34b37458bb86399, 0x8993478dbb8deae0, 0x06fbd6d5ebd3716b,
0x7c23a61ddbe6f812, 0x3373d23613f9d516, 0x49aba2fe23cc5c6f,
0xc6c333a67392c7e4, 0xbc1b436e43a74e9d, 0x95ac9329ac4bc9b5,
0xef74e3e19c7e40cc, 0x601c72b9cc20db47, 0x1ac40271fc15523e,
0x5594765a340a7f3a, 0x2f4c0692043ff643, 0xa02497ca54616dc8,
0xdafce7026454e4b1, 0x3e847f9dc45f37c0, 0x445c0f55f46abeb9,
0xcb349e0da4342532, 0xb1eceec59401ac4b, 0xfebc9aee5c1e814f,
0x8464ea266c2b0836, 0x0b0c7b7e3c7593bd, 0x71d40bb60c401ac4,
0xe8a46c1224f5a634, 0x927c1cda14c02f4d, 0x1d148d82449eb4c6,
0x67ccfd4a74ab3dbf, 0x289c8961bcb410bb, 0x5244f9a98c8199c2,
0xdd2c68f1dcdf0249, 0xa7f41839ecea8b30, 0x438c80a64ce15841,
0x3954f06e7cd4d138, 0xb63c61362c8a4ab3, 0xcce411fe1cbfc3ca,
0x83b465d5d4a0eece, 0xf96c151de49567b7, 0x76048445b4cbfc3c,
0x0cdcf48d84fe7545, 0x6fbd6d5ebd3716b7, 0x15651d968d029fce,
0x9a0d8ccedd5c0445, 0xe0d5fc06ed698d3c, 0xaf85882d2576a038,
0xd55df8e515432941, 0x5a3569bd451db2ca, 0x20ed197575283bb3,
0xc49581ead523e8c2, 0xbe4df122e51661bb, 0x3125607ab548fa30,
0x4bfd10b2857d7349, 0x04ad64994d625e4d, 0x7e7514517d57d734,
0xf11d85092d094cbf, 0x8bc5f5c11d3cc5c6, 0x12b5926535897936,
0x686de2ad05bcf04f, 0xe70573f555e26bc4, 0x9ddd033d65d7e2bd,
0xd28d7716adc8cfb9, 0xa85507de9dfd46c0, 0x273d9686cda3dd4b,
0x5de5e64efd965432, 0xb99d7ed15d9d8743, 0xc3450e196da80e3a,
0x4c2d9f413df695b1, 0x36f5ef890dc31cc8, 0x79a59ba2c5dc31cc,
0x037deb6af5e9b8b5, 0x8c157a32a5b7233e, 0xf6cd0afa9582aa47,
0x4ad64994d625e4da, 0x300e395ce6106da3, 0xbf66a804b64ef628,
0xc5bed8cc867b7f51, 0x8aeeace74e645255, 0xf036dc2f7e51db2c,
0x7f5e4d772e0f40a7, 0x05863dbf1e3ac9de, 0xe1fea520be311aaf,
0x9b26d5e88e0493d6, 0x144e44b0de5a085d, 0x6e963478ee6f8124,
0x21c640532670ac20, 0x5b1e309b16452559, 0xd476a1c3461bbed2,
0xaeaed10b762e37ab, 0x37deb6af5e9b8b5b, 0x4d06c6676eae0222,
0xc26e573f3ef099a9, 0xb8b627f70ec510d0, 0xf7e653dcc6da3dd4,
0x8d3e2314f6efb4ad, 0x0256b24ca6b12f26, 0x788ec2849684a65f,
0x9cf65a1b368f752e, 0xe62e2ad306bafc57, 0x6946bb8b56e467dc,
0x139ecb4366d1eea5, 0x5ccebf68aecec3a1, 0x2616cfa09efb4ad8,
0xa97e5ef8cea5d153, 0xd3a62e30fe90582a, 0xb0c7b7e3c7593bd8,
0xca1fc72bf76cb2a1, 0x45775673a732292a, 0x3faf26bb9707a053,
0x70ff52905f188d57, 0x0a2722586f2d042e, 0x854fb3003f739fa5,
0xff97c3c80f4616dc, 0x1bef5b57af4dc5ad, 0x61372b9f9f784cd4,
0xee5fbac7cf26d75f, 0x9487ca0fff135e26, 0xdbd7be24370c7322,
0xa10fceec0739fa5b, 0x2e675fb4576761d0, 0x54bf2f7c6752e8a9,
0xcdcf48d84fe75459, 0xb71738107fd2dd20, 0x387fa9482f8c46ab,
0x42a7d9801fb9cfd2, 0x0df7adabd7a6e2d6, 0x772fdd63e7936baf,
0xf8474c3bb7cdf024, 0x829f3cf387f8795d, 0x66e7a46c27f3aa2c,
0x1c3fd4a417c62355, 0x935745fc4798b8de, 0xe98f353477ad31a7,
0xa6df411fbfb21ca3, 0xdc0731d78f8795da, 0x536fa08fdfd90e51,
0x29b7d047efec8728,
};
inline C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA uint64_t
crc64impl(uint64_t accumulator, const char* data, size_t size) {
for (size_t i = 0; i < size; ++i) {
accumulator =
crc64_table[(accumulator ^ data[i]) & 0xFF] ^ (accumulator >> 8);
}
return accumulator;
}
} // namespace detail
struct crc64_t final : IdWrapper<crc64_t, uint64_t> {
constexpr crc64_t(uint64_t checksum) : IdWrapper(checksum) {}
constexpr uint64_t checksum() const {
return this->underlyingId();
}
};
// CRC64 with Jones coefficients and an init value of 0.
inline C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA crc64_t
crc64(const char* str, size_t size) {
return crc64_t{detail::crc64impl(0, str, size)};
}
inline C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA crc64_t crc64(c10::string_view str) {
return crc64(str.data(), str.size());
}
} // namespace c10::util
// Allow usage of crc64_t in std::unordered_set
C10_DEFINE_HASH_FOR_IDWRAPPER(c10::util::crc64_t);

View File

@ -0,0 +1,48 @@
#pragma once
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
/// This file provides some simple utilities for detecting common deadlocks in
/// PyTorch. For now, we focus exclusively on detecting Python GIL deadlocks,
/// as the GIL is a wide ranging lock that is taken out in many situations.
/// The basic strategy is before performing an operation that may block, you
/// can use TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP() to assert that the GIL is
/// not held. This macro is to be used in contexts where no static dependency
/// on Python is available (we will handle indirecting a virtual call for you).
///
/// If the GIL is held by a torchdeploy interpreter, we always report false.
/// If you are in a context where Python bindings are available, it's better
/// to directly assert on PyGILState_Check (as it avoids a vcall and also
/// works correctly with torchdeploy.)
#define TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP() \
TORCH_INTERNAL_ASSERT( \
!c10::impl::check_python_gil(), \
"Holding GIL before a blocking operation! Please release the GIL before blocking, or see https://github.com/pytorch/pytorch/issues/56297 for how to release the GIL for destructors of objects")
namespace c10::impl {
C10_API bool check_python_gil();
struct C10_API PythonGILHooks {
virtual ~PythonGILHooks() = default;
// Returns true if we hold the GIL. If not linked against Python we
// always return false.
virtual bool check_python_gil() const = 0;
};
C10_API void SetPythonGILHooks(PythonGILHooks* factory);
// DO NOT call this registerer from a torch deploy instance! You will clobber
// other registrations
struct C10_API PythonGILHooksRegisterer {
explicit PythonGILHooksRegisterer(PythonGILHooks* factory) {
SetPythonGILHooks(factory);
}
~PythonGILHooksRegisterer() {
SetPythonGILHooks(nullptr);
}
};
} // namespace c10::impl

View File

@ -0,0 +1,102 @@
#pragma once
/**
* This file provides portable macros for marking declarations
* as deprecated. You should generally use C10_DEPRECATED,
* except when marking 'using' declarations as deprecated,
* in which case you should use C10_DEFINE_DEPRECATED_USING
* (due to portability concerns).
*/
// Sample usage:
//
// C10_DEPRECATED void bad_func();
// struct C10_DEPRECATED BadStruct {
// ...
// };
// NB: __cplusplus doesn't work for MSVC, so for now MSVC always uses
// the "__declspec(deprecated)" implementation and not the C++14
// "[[deprecated]]" attribute. We tried enabling "[[deprecated]]" for C++14 on
// MSVC, but ran into issues with some older MSVC versions.
#if (defined(__cplusplus) && __cplusplus >= 201402L)
#define C10_DEPRECATED [[deprecated]]
#define C10_DEPRECATED_MESSAGE(message) [[deprecated(message)]]
#elif defined(__GNUC__)
#define C10_DEPRECATED __attribute__((deprecated))
// TODO Is there some way to implement this?
#define C10_DEPRECATED_MESSAGE(message) __attribute__((deprecated))
#elif defined(_MSC_VER)
#define C10_DEPRECATED __declspec(deprecated)
#define C10_DEPRECATED_MESSAGE(message) __declspec(deprecated(message))
#else
#warning "You need to implement C10_DEPRECATED for this compiler"
#define C10_DEPRECATED
#endif
// Sample usage:
//
// C10_DEFINE_DEPRECATED_USING(BadType, int)
//
// which is the portable version of
//
// using BadType [[deprecated]] = int;
// technically [[deprecated]] syntax is from c++14 standard, but it works in
// many compilers.
#if defined(__has_cpp_attribute)
#if __has_cpp_attribute(deprecated) && !defined(__CUDACC__)
#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \
using TypeName [[deprecated]] = TypeThingy;
#endif
#endif
#if defined(_MSC_VER)
#if defined(__CUDACC__)
// neither [[deprecated]] nor __declspec(deprecated) work on nvcc on Windows;
// you get the error:
//
// error: attribute does not apply to any entity
//
// So we just turn the macro off in this case.
#if defined(C10_DEFINE_DEPRECATED_USING)
#undef C10_DEFINE_DEPRECATED_USING
#endif
#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \
using TypeName = TypeThingy;
#else
// [[deprecated]] does work in windows without nvcc, though msc doesn't support
// `__has_cpp_attribute` when c++14 is supported, otherwise
// __declspec(deprecated) is used as the alternative.
#ifndef C10_DEFINE_DEPRECATED_USING
#if defined(_MSVC_LANG) && _MSVC_LANG >= 201402L
#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \
using TypeName [[deprecated]] = TypeThingy;
#else
#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \
using TypeName = __declspec(deprecated) TypeThingy;
#endif
#endif
#endif
#endif
#if !defined(C10_DEFINE_DEPRECATED_USING) && defined(__GNUC__)
// nvcc has a bug where it doesn't understand __attribute__((deprecated))
// declarations even when the host compiler supports it. We'll only use this gcc
// attribute when not cuda, and when using a GCC compiler that doesn't support
// the c++14 syntax we checked for above (available in __GNUC__ >= 5)
#if !defined(__CUDACC__)
#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \
using TypeName __attribute__((deprecated)) = TypeThingy;
#else
// using cuda + gcc < 5, neither deprecated syntax is available so turning off.
#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \
using TypeName = TypeThingy;
#endif
#endif
#if !defined(C10_DEFINE_DEPRECATED_USING)
#warning "You need to implement C10_DEFINE_DEPRECATED_USING for this compiler"
#define C10_DEFINE_DEPRECATED_USING
#endif

View File

@ -0,0 +1,17 @@
#pragma once
#include <c10/core/SymInt.h>
#include <c10/core/impl/SizesAndStrides.h>
#include <c10/util/SmallVector.h>
#include <cstddef>
#include <cstdint>
namespace c10 {
constexpr size_t kDimVectorStaticSize = C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE;
/// A container for sizes or strides
using DimVector = SmallVector<int64_t, kDimVectorStaticSize>;
using SymDimVector = SmallVector<c10::SymInt, kDimVectorStaticSize>;
} // namespace c10

View File

@ -0,0 +1,49 @@
#pragma once
#include <functional>
#include <memory>
#include <string_view>
#include <c10/macros/Macros.h>
namespace c10::monitor {
class C10_API DynamicCounter {
public:
using Callback = std::function<int64_t()>;
// Creates a dynamic counter that can be queried at any point in time by
// multiple backends. Only one counter with a given key can exist at any point
// in time.
//
// The callback is invoked every time the counter is queried.
// The callback must be thread-safe.
// The callback must not throw.
// The callback must not block.
DynamicCounter(std::string_view key, Callback getCounterCallback);
// Unregisters the callback.
// Waits for all ongoing callback invocations to finish.
~DynamicCounter();
private:
struct Guard;
std::unique_ptr<Guard> guard_;
};
namespace detail {
class DynamicCounterBackendIf {
public:
virtual ~DynamicCounterBackendIf() = default;
virtual void registerCounter(
std::string_view key,
DynamicCounter::Callback getCounterCallback) = 0;
// MUST wait for all ongoing callback invocations to finish
virtual void unregisterCounter(std::string_view key) = 0;
};
void C10_API
registerDynamicCounterBackend(std::unique_ptr<DynamicCounterBackendIf>);
} // namespace detail
} // namespace c10::monitor

View File

@ -0,0 +1,714 @@
#ifndef C10_UTIL_EXCEPTION_H_
#define C10_UTIL_EXCEPTION_H_
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/Backtrace.h>
#include <c10/util/Lazy.h>
#include <c10/util/StringUtil.h>
#include <cstdint>
#include <exception>
#include <memory>
#include <string>
#include <variant>
#include <vector>
#if defined(_MSC_VER) && _MSC_VER <= 1900
#define __func__ __FUNCTION__
#endif
namespace c10 {
/// The primary ATen error class.
/// Provides a complete error message with source location information via
/// `what()`, and a more concise message via `what_without_backtrace()`.
/// Don't throw this directly; use TORCH_CHECK/TORCH_INTERNAL_ASSERT instead.
///
/// NB: c10::Error is handled specially by the default torch to suppress the
/// backtrace, see torch/csrc/Exceptions.h
class C10_API Error : public std::exception {
private:
// The actual error message.
std::string msg_;
// Context for the message (in order of decreasing specificity). Context will
// be automatically formatted appropriately, so it is not necessary to add
// extra leading/trailing newlines to strings inside this vector
std::vector<std::string> context_;
// The C++ backtrace at the point when this exception was raised. This
// may be empty if there is no valid backtrace. (We don't use optional
// here to reduce the dependencies this file has.)
Backtrace backtrace_;
// These two are derived fields from msg_stack_ and backtrace_, but we need
// fields for the strings so that we can return a const char* (as the
// signature of std::exception requires). Currently, the invariant
// is that these fields are ALWAYS populated consistently with respect
// to msg_stack_ and backtrace_.
mutable OptimisticLazy<std::string> what_;
std::string what_without_backtrace_;
// This is a little debugging trick: you can stash a relevant pointer
// in caller, and then when you catch the exception, you can compare
// against pointers you have on hand to get more information about
// where the exception came from. In Caffe2, this is used to figure
// out which operator raised an exception.
const void* caller_;
public:
// PyTorch-style Error constructor. NB: the implementation of this
// is actually in Logging.cpp
Error(SourceLocation source_location, std::string msg);
// Caffe2-style error message
Error(
const char* file,
const uint32_t line,
const char* condition,
const std::string& msg,
Backtrace backtrace,
const void* caller = nullptr);
// Base constructor
Error(
std::string msg,
Backtrace backtrace = nullptr,
const void* caller = nullptr);
// Add some new context to the message stack. The last added context
// will be formatted at the end of the context list upon printing.
// WARNING: This method is O(n) in the size of the stack, so don't go
// wild adding a ridiculous amount of context to error messages.
void add_context(std::string msg);
const std::string& msg() const {
return msg_;
}
const std::vector<std::string>& context() const {
return context_;
}
const Backtrace& backtrace() const;
/// Returns the complete error message, including the source location.
/// The returned pointer is invalidated if you call add_context() on
/// this object.
const char* what() const noexcept override;
const void* caller() const noexcept {
return caller_;
}
/// Returns only the error message string, without source location.
/// The returned pointer is invalidated if you call add_context() on
/// this object.
virtual const char* what_without_backtrace() const noexcept {
return what_without_backtrace_.c_str();
}
private:
void refresh_what();
std::string compute_what(bool include_backtrace) const;
};
class C10_API Warning {
public:
class C10_API UserWarning {};
class C10_API DeprecationWarning {};
using warning_variant_t = std::variant<UserWarning, DeprecationWarning>;
Warning(
warning_variant_t type,
const SourceLocation& source_location,
std::string msg,
bool verbatim);
Warning(
warning_variant_t type,
SourceLocation source_location,
const char* msg,
bool verbatim);
Warning(
warning_variant_t type,
SourceLocation source_location,
::c10::detail::CompileTimeEmptyString msg,
bool verbatim);
// Getters for members
warning_variant_t type() const;
const SourceLocation& source_location() const;
const std::string& msg() const;
bool verbatim() const;
private:
// The type of warning
warning_variant_t type_;
// Where the warning happened.
SourceLocation source_location_;
// The actual warning message.
std::string msg_;
// See note: [Verbatim Warnings]
bool verbatim_;
};
using UserWarning = Warning::UserWarning;
using DeprecationWarning = Warning::DeprecationWarning;
// Issue a warning with a given message. Dispatched to the current
// warning handler.
void C10_API warn(const Warning& warning);
class C10_API WarningHandler {
public:
virtual ~WarningHandler() = default;
/// The default warning handler. Prints the message to stderr.
virtual void process(const Warning& warning);
};
namespace WarningUtils {
// Note: [Verbatim Warnings]
// Warnings originating in C++ code can appear out-of-place to Python users:
// a user runs a line in Python, but the warning references a line in C++.
// Some parts of PyTorch, like the JIT, are cognizant of this mismatch
// and take care to map warnings back to the user's program, but most
// of PyTorch simply throws a context-free warning. To allow warning
// handlers to add context where appropriate, warn takes the
// "verbatim" flag. When this is false a warning handler might append
// the C++ warning to a Python warning message that relates the warning
// back to the user's program. Callers who have already accounted for
// context in their warnings should set verbatim to true so their warnings
// appear without modification.
/// Sets the global warning handler. This is not thread-safe, so it should
/// generally be called once during initialization or while holding the GIL
/// for programs that use python.
/// User is responsible for keeping the WarningHandler alive until
/// it is not needed.
C10_API void set_warning_handler(WarningHandler* handler) noexcept(true);
/// Gets the global warning handler.
C10_API WarningHandler* get_warning_handler() noexcept(true);
class C10_API WarningHandlerGuard {
WarningHandler* prev_handler_;
public:
WarningHandlerGuard(WarningHandler* new_handler)
: prev_handler_(c10::WarningUtils::get_warning_handler()) {
c10::WarningUtils::set_warning_handler(new_handler);
}
~WarningHandlerGuard() {
c10::WarningUtils::set_warning_handler(prev_handler_);
}
};
/// The TORCH_WARN_ONCE macro is difficult to test for. Use
/// setWarnAlways(true) to turn it into TORCH_WARN, which can be
/// tested for more easily.
C10_API void set_warnAlways(bool) noexcept(true);
C10_API bool get_warnAlways() noexcept(true);
// A RAII guard that sets warn_always (not thread-local) on
// construction, and sets it back to the original value upon destruction.
struct C10_API WarnAlways {
public:
explicit WarnAlways(bool setting = true);
~WarnAlways();
private:
bool prev_setting;
};
} // namespace WarningUtils
// Like Error, but we always report the C++ backtrace, instead of only
// reporting when TORCH_SHOW_CPP_STACKTRACES
class C10_API ErrorAlwaysShowCppStacktrace : public Error {
using Error::Error;
const char* what_without_backtrace() const noexcept override {
return what();
}
};
// Used in ATen for out-of-bound indices that can reasonably only be detected
// lazily inside a kernel (See: advanced indexing). These turn into
// IndexError when they cross to Python.
class C10_API IndexError : public Error {
using Error::Error;
};
// Used in ATen for invalid values. These turn into
// ValueError when they cross to Python.
class C10_API ValueError : public Error {
using Error::Error;
};
// Used in ATen for invalid types. These turn into
// TypeError when they cross to Python.
class C10_API TypeError : public Error {
using Error::Error;
};
// Used in ATen for functionality that is not implemented. These turn into
// NotImplementedError when they cross to Python.
class C10_API NotImplementedError : public Error {
using Error::Error;
};
// Used in ATen for non finite indices. These turn into
// ExitException when they cross to Python.
class C10_API EnforceFiniteError : public Error {
using Error::Error;
};
// Used in Onnxifi backend lowering. These turn into
// ExitException when they cross to Python.
class C10_API OnnxfiBackendSystemError : public Error {
using Error::Error;
};
// Used for numerical errors from the linalg module. These
// turn into LinAlgError when they cross into Python.
class C10_API LinAlgError : public Error {
using Error::Error;
};
class C10_API OutOfMemoryError : public Error {
using Error::Error;
};
// Base error type for all distributed errors.
// These turn into DistError when they cross into Python.
class C10_API DistError : public Error {
using Error::Error;
};
// Used for collective communication library errors from the distributed module.
// These turn into DistBackendError when they cross into Python.
class C10_API DistBackendError : public DistError {
using DistError::DistError;
};
// Used for errors originating from the store.
// These turn into DistStoreError when they cross into Python.
class C10_API DistStoreError : public DistError {
using DistError::DistError;
};
// Used for errors originating from the TCP/IP stack and not from collective
// libraries. These turn into DistNetworkError when they cross into Python.
class C10_API DistNetworkError : public DistError {
using DistError::DistError;
};
// A utility function to return an exception std::string by prepending its
// exception type before its what() content
C10_API std::string GetExceptionString(const std::exception& e);
} // namespace c10
// Private helper macro for implementing TORCH_INTERNAL_ASSERT and TORCH_CHECK
//
// Note: In the debug build With MSVC, __LINE__ might be of long type (a.k.a
// int32_t), which is different from the definition of `SourceLocation` that
// requires unsigned int (a.k.a uint32_t) and may cause a compile error with the
// message: error C2397: conversion from 'long' to 'uint32_t' requires a
// narrowing conversion Here the static cast is used to pass the build. if this
// is used inside a lambda the __func__ macro expands to operator(), which isn't
// very useful, but hard to fix in a macro so suppressing the warning.
#define C10_THROW_ERROR(err_type, msg) \
throw ::c10::err_type( \
{__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, msg)
#define C10_BUILD_ERROR(err_type, msg) \
::c10::err_type({__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, msg)
// Private helper macro for workaround MSVC misexpansion of nested macro
// invocations involving __VA_ARGS__. See
// https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly
#define C10_EXPAND_MSVC_WORKAROUND(x) x
// On nvcc, C10_UNLIKELY thwarts missing return statement analysis. In cases
// where the unlikely expression may be a constant, use this macro to ensure
// return statement analysis keeps working (at the cost of not getting the
// likely/unlikely annotation on nvcc).
// https://github.com/pytorch/pytorch/issues/21418
//
// Currently, this is only used in the error reporting macros below. If you
// want to use it more generally, move me to Macros.h
//
// TODO: Brian Vaughan observed that we might be able to get this to work on
// nvcc by writing some sort of C++ overload that distinguishes constexpr inputs
// from non-constexpr. Since there isn't any evidence that losing C10_UNLIKELY
// in nvcc is causing us perf problems, this is not yet implemented, but this
// might be an interesting piece of C++ code for an intrepid bootcamper to
// write.
#if defined(__CUDACC__)
#define C10_UNLIKELY_OR_CONST(e) e
#else
#define C10_UNLIKELY_OR_CONST(e) C10_UNLIKELY(e)
#endif
// ----------------------------------------------------------------------------
// Error reporting macros
// ----------------------------------------------------------------------------
#ifdef STRIP_ERROR_MESSAGES
#define TORCH_RETHROW(e, ...) throw
#else
#define TORCH_RETHROW(e, ...) \
do { \
e.add_context(::c10::str(__VA_ARGS__)); \
throw; \
} while (false)
#endif
// A utility macro to provide assert()-like functionality; that is, enforcement
// of internal invariants in code. It supports an arbitrary number of extra
// arguments (evaluated only on failure), which will be printed in the assert
// failure message using operator<< (this is useful to print some variables
// which may be useful for debugging.)
//
// Usage:
// TORCH_INTERNAL_ASSERT(should_be_true);
// TORCH_INTERNAL_ASSERT(x == 0, "x = ", x);
//
// Assuming no bugs in PyTorch, the conditions tested by this macro should
// always be true; e.g., it should be possible to disable all of these
// conditions without changing observable user behavior. If you would like to
// do error reporting for user input, please use TORCH_CHECK instead.
//
// NOTE: It is SAFE to use this macro in production code; on failure, this
// simply raises an exception, it does NOT unceremoniously quit the process
// (unlike assert()).
//
#ifdef STRIP_ERROR_MESSAGES
#define TORCH_INTERNAL_ASSERT(cond, ...) \
if (C10_UNLIKELY_OR_CONST(!(cond))) { \
::c10::detail::torchCheckFail( \
__func__, \
__FILE__, \
static_cast<uint32_t>(__LINE__), \
#cond " INTERNAL ASSERT FAILED at " C10_STRINGIZE(__FILE__)); \
}
#else
// It would be nice if we could build a combined string literal out of
// the TORCH_INTERNAL_ASSERT prefix and a user-provided string literal
// as the first argument, but there doesn't seem to be any good way to
// do that while still supporting having a first argument that isn't a
// string literal.
#define TORCH_INTERNAL_ASSERT(cond, ...) \
if (C10_UNLIKELY_OR_CONST(!(cond))) { \
::c10::detail::torchInternalAssertFail( \
__func__, \
__FILE__, \
static_cast<uint32_t>(__LINE__), \
#cond \
" INTERNAL ASSERT FAILED at " C10_STRINGIZE(__FILE__) ":" C10_STRINGIZE( \
__LINE__) ", please report a bug to PyTorch. ", \
c10::str(__VA_ARGS__)); \
}
#endif
// A utility macro to make it easier to test for error conditions from user
// input. Like TORCH_INTERNAL_ASSERT, it supports an arbitrary number of extra
// arguments (evaluated only on failure), which will be printed in the error
// message using operator<< (e.g., you can pass any object which has
// operator<< defined. Most objects in PyTorch have these definitions!)
//
// Usage:
// TORCH_CHECK(should_be_true); // A default error message will be provided
// // in this case; but we recommend writing an
// // explicit error message, as it is more
// // user friendly.
// TORCH_CHECK(x == 0, "Expected x to be 0, but got ", x);
//
// On failure, this macro will raise an exception. If this exception propagates
// to Python, it will convert into a Python RuntimeError.
//
// NOTE: It is SAFE to use this macro in production code; on failure, this
// simply raises an exception, it does NOT unceremoniously quit the process
// (unlike CHECK() from glog.)
//
#define TORCH_CHECK_WITH(error_t, cond, ...) \
TORCH_CHECK_WITH_MSG(error_t, cond, "", __VA_ARGS__)
#ifdef STRIP_ERROR_MESSAGES
#define TORCH_CHECK_MSG(cond, type, ...) \
(#cond #type " CHECK FAILED at " C10_STRINGIZE(__FILE__))
#define TORCH_CHECK_WITH_MSG(error_t, cond, type, ...) \
if (C10_UNLIKELY_OR_CONST(!(cond))) { \
C10_THROW_ERROR(Error, TORCH_CHECK_MSG(cond, type, __VA_ARGS__)); \
}
#else
namespace c10::detail {
template <typename... Args>
decltype(auto) torchCheckMsgImpl(const char* /*msg*/, const Args&... args) {
return ::c10::str(args...);
}
inline C10_API const char* torchCheckMsgImpl(const char* msg) {
return msg;
}
// If there is just 1 user-provided C-string argument, use it.
inline C10_API const char* torchCheckMsgImpl(
const char* /*msg*/,
const char* args) {
return args;
}
} // namespace c10::detail
#define TORCH_CHECK_MSG(cond, type, ...) \
(::c10::detail::torchCheckMsgImpl( \
"Expected " #cond \
" to be true, but got false. " \
"(Could this error message be improved? If so, " \
"please report an enhancement request to PyTorch.)", \
##__VA_ARGS__))
#define TORCH_CHECK_WITH_MSG(error_t, cond, type, ...) \
if (C10_UNLIKELY_OR_CONST(!(cond))) { \
C10_THROW_ERROR(error_t, TORCH_CHECK_MSG(cond, type, __VA_ARGS__)); \
}
#endif
namespace c10::detail {
[[noreturn]] C10_API void torchCheckFail(
const char* func,
const char* file,
uint32_t line,
const std::string& msg);
[[noreturn]] C10_API void torchCheckFail(
const char* func,
const char* file,
uint32_t line,
const char* msg);
// The c10::str() call that creates userMsg can have 1 of 3 return
// types depending on the number and types of arguments passed to
// TORCH_INTERNAL_ASSERT. 0 arguments will get a
// CompileTimeEmptyString, 1 const char * will be passed straight
// through, and anything else will get converted to std::string.
[[noreturn]] C10_API void torchInternalAssertFail(
const char* func,
const char* file,
uint32_t line,
const char* condMsg,
const char* userMsg);
[[noreturn]] inline C10_API void torchInternalAssertFail(
const char* func,
const char* file,
uint32_t line,
const char* condMsg,
::c10::detail::CompileTimeEmptyString /*userMsg*/) {
torchCheckFail(func, file, line, condMsg);
}
[[noreturn]] C10_API void torchInternalAssertFail(
const char* func,
const char* file,
uint32_t line,
const char* condMsg,
const std::string& userMsg);
} // namespace c10::detail
#ifdef STRIP_ERROR_MESSAGES
#define TORCH_CHECK(cond, ...) \
if (C10_UNLIKELY_OR_CONST(!(cond))) { \
::c10::detail::torchCheckFail( \
__func__, \
__FILE__, \
static_cast<uint32_t>(__LINE__), \
TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \
}
#else
#define TORCH_CHECK(cond, ...) \
if (C10_UNLIKELY_OR_CONST(!(cond))) { \
::c10::detail::torchCheckFail( \
__func__, \
__FILE__, \
static_cast<uint32_t>(__LINE__), \
TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \
}
#endif
// An utility macro that does what `TORCH_CHECK` does if compiled in the host
// code, otherwise does nothing. Supposed to be used in the code shared between
// host and device code as an alternative for `TORCH_CHECK`.
#if defined(__CUDACC__) || defined(__HIPCC__)
#define TORCH_CHECK_IF_NOT_ON_CUDA(cond, ...)
#else
#define TORCH_CHECK_IF_NOT_ON_CUDA(cond, ...) TORCH_CHECK(cond, ##__VA_ARGS__)
#endif
// Debug only version of TORCH_INTERNAL_ASSERT. This macro only checks in debug
// build, and does nothing in release build. It is appropriate to use
// in situations where you want to add an assert to a hotpath, but it is
// too expensive to run this assert on production builds.
#ifdef NDEBUG
// Optimized version - generates no code.
#define TORCH_INTERNAL_ASSERT_DEBUG_ONLY(...) \
while (false) \
C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(__VA_ARGS__))
#else
#define TORCH_INTERNAL_ASSERT_DEBUG_ONLY(...) \
C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(__VA_ARGS__))
#endif
// TODO: We're going to get a lot of similar looking string literals
// this way; check if this actually affects binary size.
// Like TORCH_CHECK, but raises LinAlgError instead of Error.
#define TORCH_CHECK_LINALG(cond, ...) \
TORCH_CHECK_WITH_MSG(LinAlgError, cond, "LINALG", __VA_ARGS__)
// Like TORCH_CHECK, but raises IndexErrors instead of Errors.
#define TORCH_CHECK_INDEX(cond, ...) \
TORCH_CHECK_WITH_MSG(IndexError, cond, "INDEX", __VA_ARGS__)
// Like TORCH_CHECK, but raises ValueErrors instead of Errors.
#define TORCH_CHECK_VALUE(cond, ...) \
TORCH_CHECK_WITH_MSG(ValueError, cond, "VALUE", __VA_ARGS__)
// Like TORCH_CHECK, but raises TypeErrors instead of Errors.
#define TORCH_CHECK_TYPE(cond, ...) \
TORCH_CHECK_WITH_MSG(TypeError, cond, "TYPE", __VA_ARGS__)
// Like TORCH_CHECK, but raises NotImplementedErrors instead of Errors.
#define TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \
TORCH_CHECK_WITH_MSG(NotImplementedError, cond, "TYPE", __VA_ARGS__)
#define TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(cond, ...) \
TORCH_CHECK_WITH_MSG( \
ErrorAlwaysShowCppStacktrace, cond, "TYPE", ##__VA_ARGS__)
#ifdef STRIP_ERROR_MESSAGES
#define WARNING_MESSAGE_STRING(...) \
::c10::detail::CompileTimeEmptyString {}
#else
#define WARNING_MESSAGE_STRING(...) ::c10::str(__VA_ARGS__)
#endif
// Report a warning to the user. Accepts an arbitrary number of extra
// arguments which are concatenated into the warning message using operator<<
//
#ifdef DISABLE_WARN
#define _TORCH_WARN_WITH(...) ((void)0);
#else
#define _TORCH_WARN_WITH(warning_t, ...) \
::c10::warn(::c10::Warning( \
warning_t(), \
{__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, \
WARNING_MESSAGE_STRING(__VA_ARGS__), \
false));
#endif
#define TORCH_WARN(...) _TORCH_WARN_WITH(::c10::UserWarning, __VA_ARGS__);
#define TORCH_WARN_DEPRECATION(...) \
_TORCH_WARN_WITH(::c10::DeprecationWarning, __VA_ARGS__);
// Report a warning to the user only once. Accepts an arbitrary number of extra
// arguments which are concatenated into the warning message using operator<<
//
#define _TORCH_WARN_ONCE(...) \
C10_UNUSED static const auto C10_ANONYMOUS_VARIABLE(torch_warn_once_) = \
[&] { \
TORCH_WARN(__VA_ARGS__); \
return true; \
}()
#ifdef DISABLE_WARN
#define TORCH_WARN_ONCE(...) ((void)0);
#else
#define TORCH_WARN_ONCE(...) \
if (::c10::WarningUtils::get_warnAlways()) { \
TORCH_WARN(__VA_ARGS__); \
} else { \
_TORCH_WARN_ONCE(__VA_ARGS__); \
}
#endif
// Report an error with a specific argument
// NOTE: using the argument name in TORCH_CHECK's message is preferred
#define TORCH_CHECK_ARG(cond, argN, ...) \
TORCH_CHECK(cond, "invalid argument ", argN, ": ", __VA_ARGS__)
// ----------------------------------------------------------------------------
// Deprecated macros
// ----------------------------------------------------------------------------
namespace c10::detail {
/*
// Deprecation disabled until we fix sites in our codebase
C10_DEPRECATED_MESSAGE("AT_ERROR(msg) is deprecated, use TORCH_CHECK(false, msg)
instead.")
*/
inline void deprecated_AT_ERROR() {}
/*
// Deprecation disabled until we fix sites in our codebase
C10_DEPRECATED_MESSAGE("AT_ASSERT is deprecated, if you mean to indicate an
internal invariant failure, use " \
"TORCH_INTERNAL_ASSERT instead; if you mean to do user
error checking, use " \ "TORCH_CHECK. See
https://github.com/pytorch/pytorch/issues/20287 for more details.")
*/
inline void deprecated_AT_ASSERT() {}
/*
// Deprecation disabled until we fix sites in our codebase
C10_DEPRECATED_MESSAGE("AT_ASSERTM is deprecated, if you mean to indicate an
internal invariant failure, use " \
"TORCH_INTERNAL_ASSERT instead; if you mean to do user
error checking, use " \ "TORCH_CHECK. See
https://github.com/pytorch/pytorch/issues/20287 for more details.")
*/
inline void deprecated_AT_ASSERTM() {}
} // namespace c10::detail
// Deprecated alias; this alias was deprecated because people kept mistakenly
// using it for user error checking. Use TORCH_INTERNAL_ASSERT or TORCH_CHECK
// instead. See https://github.com/pytorch/pytorch/issues/20287 for more
// details.
#define AT_ASSERT(...) \
do { \
::c10::detail::deprecated_AT_ASSERT(); \
C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(__VA_ARGS__)); \
} while (false)
// Deprecated alias, like AT_ASSERT. The new TORCH_INTERNAL_ASSERT macro
// supports both 0-ary and variadic calls, so having a separate
// message-accepting macro is not necessary.
//
// NB: we MUST include cond explicitly here, as MSVC will miscompile the macro
// expansion, shunting all of __VA_ARGS__ to cond. An alternate workaround
// can be seen at
// https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly
#define AT_ASSERTM(cond, ...) \
do { \
::c10::detail::deprecated_AT_ASSERTM(); \
C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(cond, __VA_ARGS__)); \
} while (false)
// Deprecated alias; this alias was deprecated because it represents extra API
// surface that makes it hard for people to understand what macro to use.
// Use TORCH_CHECK(false, ...) or TORCH_INTERNAL_ASSERT(false, ...) to
// unconditionally fail at a line of code.
#define AT_ERROR(...) \
do { \
::c10::detail::deprecated_AT_ERROR(); \
C10_EXPAND_MSVC_WORKAROUND(TORCH_CHECK(false, ::c10::str(__VA_ARGS__))); \
} while (false)
#endif // C10_UTIL_EXCEPTION_H_

View File

@ -0,0 +1,140 @@
#pragma once
#include <utility>
namespace c10 {
// See example implementation in TensorBase.h and TensorBody.h.
// Synopsis:
//
// repr_type -- type to use to store an owned T in ExclusivelyOwned.
//
// pointer_type -- pointer-esque type to return from
// ExclusivelyOwned's get() and operator*() methods.
//
// const_pointer_type -- similar to pointer_type, used for the const methods.
//
// static repr_type nullRepr() -- return a null instance of repr_type.
//
// template <class... Args>
// static repr_type createInPlace(Args&&... args) -- used by the in-place
// ExclusivelyOwned constructor.
//
// static repr_type moveToRepr(T&& x) -- move the given x into an
// instance of repr_type. used by the ExclusivelyOwned(T&&)
// constructor.
//
// static void destroyOwned(repr_type x) -- free memory for a
// known-exclusively-owned instance of x. Replaces calling repr_type's
// destructor. Being able to implement this more efficiently than
// repr_type's destructor is the main reason to use ExclusivelyOwned
// for a type.
//
// static T take(repr_type&) -- move out of the given repr_type into an owned T.
//
// static pointer_type getImpl(const repr_type&) -- return a pointer
// to the given repr_type. May take repr_type by value if that is more
// efficient.
template <typename T>
struct ExclusivelyOwnedTraits;
/// ExclusivelyOwned is a smart-pointer-like wrapper around an
/// exclusively-owned instance of some type T that normally has
/// mandatory reference counting (currently just Tensor). If you have
/// an isolated piece of code that knows that it has sole ownership of
/// an object of one of these types (i.e., because you created it
/// directly or using a factory function) and that object will not
/// escape from that isolated piece of code, then moving the object
/// into an ExclusivelyOwned will avoid an atomic reference count
/// decrement at destruction time.
///
/// If you directly create the Tensor in the first
/// place, you can use the in_place constructor of ExclusivelyOwned to
/// additionally avoid doing any stores to initialize the refcount &
/// weakcount.
template <typename T>
class ExclusivelyOwned {
using EOT = ExclusivelyOwnedTraits<T>;
typename ExclusivelyOwnedTraits<T>::repr_type repr_;
public:
ExclusivelyOwned() : repr_(EOT::nullRepr()) {}
explicit ExclusivelyOwned(T&& t) : repr_(EOT::moveToRepr(std::move(t))) {}
template <class... Args>
explicit ExclusivelyOwned(std::in_place_t, Args&&... args)
: repr_(EOT::createInPlace(std::forward<Args>(args)...)) {}
ExclusivelyOwned(const ExclusivelyOwned&) = delete;
ExclusivelyOwned(ExclusivelyOwned&& rhs) noexcept
: repr_(std::move(rhs.repr_)) {
rhs.repr_ = EOT::nullRepr();
}
ExclusivelyOwned& operator=(const ExclusivelyOwned&) = delete;
ExclusivelyOwned& operator=(ExclusivelyOwned&& rhs) noexcept {
EOT::destroyOwned(repr_);
repr_ = std::move(rhs.repr_);
rhs.repr_ = EOT::nullRepr();
return *this;
}
ExclusivelyOwned& operator=(T&& rhs) noexcept {
EOT::destroyOwned(repr_);
repr_ = EOT::moveToRepr(std::move(rhs));
return *this;
}
~ExclusivelyOwned() {
EOT::destroyOwned(repr_);
// Don't bother to call the destructor of repr_, since we already
// did specialized destruction for the exclusively-owned case in
// destroyOwned!
}
// We don't provide this because it would require us to be able to
// differentiate an owned-but-empty T from a lack of T. This is
// particularly problematic for Tensor, which wants to use an
// undefined Tensor as its null state.
explicit operator bool() const noexcept = delete;
operator T() && {
return take();
}
// NOTE: the equivalent operation on MaybeOwned is a moving
// operator*. For ExclusivelyOwned, take() and operator*() may well
// have different return types, so they are different functions.
T take() && {
return EOT::take(repr_);
}
typename EOT::const_pointer_type operator->() const {
return get();
}
typename EOT::const_pointer_type get() const {
return EOT::getImpl(repr_);
}
typename EOT::pointer_type operator->() {
return get();
}
typename EOT::pointer_type get() {
return EOT::getImpl(repr_);
}
std::remove_pointer_t<typename EOT::const_pointer_type>& operator*() const {
return *get();
}
std::remove_pointer_t<typename EOT::pointer_type>& operator*() {
return *get();
}
};
} // namespace c10

View File

@ -0,0 +1,75 @@
#pragma once
#include <c10/core/TensorImpl.h>
#include <c10/core/UndefinedTensorImpl.h>
#include <utility>
namespace c10 {
// Shared ExclusivelyOwnedTraits implementation between caffe2::Tensor and
// at::TensorBase.
template <typename TensorType>
struct ExclusivelyOwnedTensorTraits {
using repr_type = TensorType;
using pointer_type = TensorType*;
using const_pointer_type = const TensorType*;
static repr_type nullRepr() {
return TensorType();
}
template <class... Args>
static repr_type createInPlace(Args&&... args) {
return TensorType(std::forward<Args>(args)...);
}
static repr_type moveToRepr(TensorType&& x) {
return std::move(x);
}
static void destroyOwned(TensorType& x) {
TensorImpl* const toDestroy = x.unsafeReleaseTensorImpl();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
toDestroy != nullptr, "Tensor somehow got null TensorImpl?");
// May be 0 because UndefinedTensorImpl doesn't get its refcount
// incremented.
const bool isUndefined = toDestroy == UndefinedTensorImpl::singleton();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
toDestroy->refcount_ == 1 || (toDestroy->refcount_ == 0 && isUndefined),
"ExclusivelyOwned<Tensor> destroyed with isUndefined ",
isUndefined,
" and refcount ",
toDestroy->refcount_,
", expected 1 or, if isUndefined, 0!");
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
toDestroy->weakcount_ == 1 ||
(toDestroy->weakcount_ == 0 &&
toDestroy == UndefinedTensorImpl::singleton()),
"ExclusivelyOwned<Tensor> destroyed with isUndefined ",
isUndefined,
" and weakcount ",
toDestroy->weakcount_,
", expected 1 or, if isUndefined, 0!");
if (!isUndefined) {
#ifndef NDEBUG
// Needed to pass the debug assertions in ~intrusive_ptr_target.
toDestroy->refcount_ = 0;
toDestroy->weakcount_ = 0;
#endif
delete toDestroy;
}
}
static TensorType take(TensorType& x) {
return std::move(x);
}
static pointer_type getImpl(repr_type& x) {
return &x;
}
static const_pointer_type getImpl(const repr_type& x) {
return &x;
}
};
} // namespace c10

View File

@ -0,0 +1,29 @@
#ifndef C10_UTIL_FBCODEMAPS_H_
#define C10_UTIL_FBCODEMAPS_H_
// Map typedefs so that we can use folly's F14 maps in fbcode without
// taking a folly dependency.
#ifdef FBCODE_CAFFE2
#include <folly/container/F14Map.h>
#include <folly/container/F14Set.h>
#else
#include <unordered_map>
#include <unordered_set>
#endif
namespace c10 {
#ifdef FBCODE_CAFFE2
template <typename Key, typename Value>
using FastMap = folly::F14FastMap<Key, Value>;
template <typename Key>
using FastSet = folly::F14FastSet<Key>;
#else
template <typename Key, typename Value>
using FastMap = std::unordered_map<Key, Value>;
template <typename Key>
using FastSet = std::unordered_set<Key>;
#endif
} // namespace c10
#endif // C10_UTIL_FBCODEMAPS_H_

View File

@ -0,0 +1,226 @@
#ifndef C10_UTIL_FLAGS_H_
#define C10_UTIL_FLAGS_H_
/* Commandline flags support for C10.
*
* This is a portable commandline flags tool for c10, so we can optionally
* choose to use gflags or a lightweight custom implementation if gflags is
* not possible on a certain platform. If you have gflags installed, set the
* macro C10_USE_GFLAGS will seamlessly route everything to gflags.
*
* To define a flag foo of type bool default to true, do the following in the
* *global* namespace:
* C10_DEFINE_bool(foo, true, "An example.");
*
* To use it in another .cc file, you can use C10_DECLARE_* as follows:
* C10_DECLARE_bool(foo);
*
* In both cases, you can then access the flag via FLAGS_foo.
*
* It is recommended that you build with gflags. To learn more about the flags
* usage, refer to the gflags page here:
*
* https://gflags.github.io/gflags/
*
* Note about Python users / devs: gflags is initiated from a C++ function
* ParseCommandLineFlags, and is usually done in native binaries in the main
* function. As Python does not have a modifiable main function, it is usually
* difficult to change the flags after Python starts. Hence, it is recommended
* that one sets the default value of the flags to one that's acceptable in
* general - that will allow Python to run without wrong flags.
*/
#include <c10/macros/Export.h>
#include <string>
#include <c10/util/Registry.h>
namespace c10 {
/**
* Sets the usage message when a commandline tool is called with "--help".
*/
C10_API void SetUsageMessage(const std::string& str);
/**
* Returns the usage message for the commandline tool set by SetUsageMessage.
*/
C10_API const char* UsageMessage();
/**
* Parses the commandline flags.
*
* This command parses all the commandline arguments passed in via pargc
* and argv. Once it is finished, partc and argv will contain the remaining
* commandline args that c10 does not deal with. Note that following
* convention, argv[0] contains the binary name and is not parsed.
*/
C10_API bool ParseCommandLineFlags(int* pargc, char*** pargv);
/**
* Checks if the commandline flags has already been passed.
*/
C10_API bool CommandLineFlagsHasBeenParsed();
} // namespace c10
////////////////////////////////////////////////////////////////////////////////
// Below are gflags and non-gflags specific implementations.
// In general, they define the following macros for one to declare (use
// C10_DECLARE) or define (use C10_DEFINE) flags:
// C10_{DECLARE,DEFINE}_{int,int64,double,bool,string}
////////////////////////////////////////////////////////////////////////////////
#ifdef C10_USE_GFLAGS
////////////////////////////////////////////////////////////////////////////////
// Begin gflags section: most functions are basically rerouted to gflags.
////////////////////////////////////////////////////////////////////////////////
#include <gflags/gflags.h>
// C10 uses hidden visibility by default. However, in gflags, it only uses
// export on Windows platform (with dllexport) but not on linux/mac (with
// default visibility). As a result, to ensure that we are always exporting
// global variables, we will redefine the GFLAGS_DLL_DEFINE_FLAG macro if we
// are building C10 as a shared library.
// This has to be done after the inclusion of gflags, because some early
// versions of gflags.h (e.g. 2.0 on ubuntu 14.04) directly defines the
// macros, so we need to do definition after gflags is done.
#ifdef GFLAGS_DLL_DEFINE_FLAG
#undef GFLAGS_DLL_DEFINE_FLAG
#endif // GFLAGS_DLL_DEFINE_FLAG
#ifdef GFLAGS_DLL_DECLARE_FLAG
#undef GFLAGS_DLL_DECLARE_FLAG
#endif // GFLAGS_DLL_DECLARE_FLAG
#define GFLAGS_DLL_DEFINE_FLAG C10_EXPORT
#define GFLAGS_DLL_DECLARE_FLAG C10_IMPORT
// gflags before 2.0 uses namespace google and after 2.1 uses namespace gflags.
// Using GFLAGS_GFLAGS_H_ to capture this change.
#ifndef GFLAGS_GFLAGS_H_
namespace gflags = google;
#endif // GFLAGS_GFLAGS_H_
// Motivation about the gflags wrapper:
// (1) We would need to make sure that the gflags version and the non-gflags
// version of C10 are going to expose the same flags abstraction. One should
// explicitly use FLAGS_flag_name to access the flags.
// (2) For flag names, it is recommended to start with c10_ to distinguish it
// from regular gflags flags. For example, do
// C10_DEFINE_BOOL(c10_my_flag, true, "An example");
// to allow one to use FLAGS_c10_my_flag.
// (3) Gflags has a design issue that does not properly expose the global flags,
// if one builds the library with -fvisibility=hidden. The current gflags (as of
// Aug 2018) only deals with the Windows case using dllexport, and not the Linux
// counterparts. As a result, we will explicitly use C10_EXPORT to export the
// flags defined in C10. This is done via a global reference, so the flag
// itself is not duplicated - under the hood it is the same global gflags flag.
#define C10_GFLAGS_DEF_WRAPPER(type, real_type, name, default_value, help_str) \
DEFINE_##type(name, default_value, help_str);
#define C10_DEFINE_int(name, default_value, help_str) \
C10_GFLAGS_DEF_WRAPPER(int32, gflags::int32, name, default_value, help_str)
#define C10_DEFINE_int32(name, default_value, help_str) \
C10_DEFINE_int(name, default_value, help_str)
#define C10_DEFINE_int64(name, default_value, help_str) \
C10_GFLAGS_DEF_WRAPPER(int64, gflags::int64, name, default_value, help_str)
#define C10_DEFINE_double(name, default_value, help_str) \
C10_GFLAGS_DEF_WRAPPER(double, double, name, default_value, help_str)
#define C10_DEFINE_bool(name, default_value, help_str) \
C10_GFLAGS_DEF_WRAPPER(bool, bool, name, default_value, help_str)
#define C10_DEFINE_string(name, default_value, help_str) \
C10_GFLAGS_DEF_WRAPPER(string, ::fLS::clstring, name, default_value, help_str)
// DECLARE_typed_var should be used in header files and in the global namespace.
#define C10_GFLAGS_DECLARE_WRAPPER(type, real_type, name) DECLARE_##type(name);
#define C10_DECLARE_int(name) \
C10_GFLAGS_DECLARE_WRAPPER(int32, gflags::int32, name)
#define C10_DECLARE_int32(name) C10_DECLARE_int(name)
#define C10_DECLARE_int64(name) \
C10_GFLAGS_DECLARE_WRAPPER(int64, gflags::int64, name)
#define C10_DECLARE_double(name) \
C10_GFLAGS_DECLARE_WRAPPER(double, double, name)
#define C10_DECLARE_bool(name) C10_GFLAGS_DECLARE_WRAPPER(bool, bool, name)
#define C10_DECLARE_string(name) \
C10_GFLAGS_DECLARE_WRAPPER(string, ::fLS::clstring, name)
////////////////////////////////////////////////////////////////////////////////
// End gflags section.
////////////////////////////////////////////////////////////////////////////////
#else // C10_USE_GFLAGS
////////////////////////////////////////////////////////////////////////////////
// Begin non-gflags section: providing equivalent functionality.
////////////////////////////////////////////////////////////////////////////////
namespace c10 {
class C10_API C10FlagParser {
public:
bool success() {
return success_;
}
protected:
template <typename T>
bool Parse(const std::string& content, T* value);
bool success_{false};
};
C10_DECLARE_REGISTRY(C10FlagsRegistry, C10FlagParser, const std::string&);
} // namespace c10
// The macros are defined outside the c10 namespace. In your code, you should
// write the C10_DEFINE_* and C10_DECLARE_* macros outside any namespace
// as well.
#define C10_DEFINE_typed_var(type, name, default_value, help_str) \
C10_EXPORT type FLAGS_##name = default_value; \
namespace c10 { \
namespace { \
class C10FlagParser_##name : public C10FlagParser { \
public: \
explicit C10FlagParser_##name(const std::string& content) { \
success_ = C10FlagParser::Parse<type>(content, &FLAGS_##name); \
} \
}; \
} \
RegistererC10FlagsRegistry g_C10FlagsRegistry_##name( \
#name, \
C10FlagsRegistry(), \
RegistererC10FlagsRegistry::DefaultCreator<C10FlagParser_##name>, \
"(" #type ", default " #default_value ") " help_str); \
}
#define C10_DEFINE_int(name, default_value, help_str) \
C10_DEFINE_typed_var(int, name, default_value, help_str)
#define C10_DEFINE_int32(name, default_value, help_str) \
C10_DEFINE_int(name, default_value, help_str)
#define C10_DEFINE_int64(name, default_value, help_str) \
C10_DEFINE_typed_var(int64_t, name, default_value, help_str)
#define C10_DEFINE_double(name, default_value, help_str) \
C10_DEFINE_typed_var(double, name, default_value, help_str)
#define C10_DEFINE_bool(name, default_value, help_str) \
C10_DEFINE_typed_var(bool, name, default_value, help_str)
#define C10_DEFINE_string(name, default_value, help_str) \
C10_DEFINE_typed_var(std::string, name, default_value, help_str)
// DECLARE_typed_var should be used in header files and in the global namespace.
#define C10_DECLARE_typed_var(type, name) C10_API extern type FLAGS_##name
#define C10_DECLARE_int(name) C10_DECLARE_typed_var(int, name)
#define C10_DECLARE_int32(name) C10_DECLARE_int(name)
#define C10_DECLARE_int64(name) C10_DECLARE_typed_var(int64_t, name)
#define C10_DECLARE_double(name) C10_DECLARE_typed_var(double, name)
#define C10_DECLARE_bool(name) C10_DECLARE_typed_var(bool, name)
#define C10_DECLARE_string(name) C10_DECLARE_typed_var(std::string, name)
////////////////////////////////////////////////////////////////////////////////
// End non-gflags section.
////////////////////////////////////////////////////////////////////////////////
#endif // C10_USE_GFLAGS
#endif // C10_UTIL_FLAGS_H_

View File

@ -0,0 +1,274 @@
#pragma once
#include <c10/macros/Macros.h>
#include <cstdint>
#include <limits>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
namespace c10 {
/// Constructors
inline C10_HOST_DEVICE Float8_e4m3fn::Float8_e4m3fn(float value)
: x(detail::fp8e4m3fn_from_fp32_value(value)) {}
/// Implicit conversions
inline C10_HOST_DEVICE Float8_e4m3fn::operator float() const {
return detail::fp8e4m3fn_to_fp32_value(x);
}
/// Special values helper
inline C10_HOST_DEVICE bool Float8_e4m3fn::isnan() const {
return (x & 0b01111111) == 0b01111111;
}
/// Arithmetic
inline C10_HOST_DEVICE Float8_e4m3fn
operator+(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn
operator-(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn
operator*(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator/(
const Float8_e4m3fn& a,
const Float8_e4m3fn& b) __ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator-(const Float8_e4m3fn& a) {
return -static_cast<float>(a);
}
inline C10_HOST_DEVICE Float8_e4m3fn& operator+=(
Float8_e4m3fn& a,
const Float8_e4m3fn& b) {
a = a + b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fn& operator-=(
Float8_e4m3fn& a,
const Float8_e4m3fn& b) {
a = a - b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fn& operator*=(
Float8_e4m3fn& a,
const Float8_e4m3fn& b) {
a = a * b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fn& operator/=(
Float8_e4m3fn& a,
const Float8_e4m3fn& b) {
a = a / b;
return a;
}
/// Arithmetic with floats
inline C10_HOST_DEVICE float operator+(Float8_e4m3fn a, float b) {
return static_cast<float>(a) + b;
}
inline C10_HOST_DEVICE float operator-(Float8_e4m3fn a, float b) {
return static_cast<float>(a) - b;
}
inline C10_HOST_DEVICE float operator*(Float8_e4m3fn a, float b) {
return static_cast<float>(a) * b;
}
inline C10_HOST_DEVICE float operator/(Float8_e4m3fn a, float b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / b;
}
inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fn b) {
return a + static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fn b) {
return a - static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fn b) {
return a * static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fn b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fn& b) {
return a += static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fn& b) {
return a -= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fn& b) {
return a *= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fn& b) {
return a /= static_cast<float>(b);
}
/// Arithmetic with doubles
inline C10_HOST_DEVICE double operator+(Float8_e4m3fn a, double b) {
return static_cast<double>(a) + b;
}
inline C10_HOST_DEVICE double operator-(Float8_e4m3fn a, double b) {
return static_cast<double>(a) - b;
}
inline C10_HOST_DEVICE double operator*(Float8_e4m3fn a, double b) {
return static_cast<double>(a) * b;
}
inline C10_HOST_DEVICE double operator/(Float8_e4m3fn a, double b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<double>(a) / b;
}
inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fn b) {
return a + static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fn b) {
return a - static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fn b) {
return a * static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fn b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<double>(b);
}
/// Arithmetic with ints
inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int b) {
return a + static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int b) {
return a - static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int b) {
return a * static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int b) {
return a / static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator+(int a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) + b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator-(int a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) - b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator*(int a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) * b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator/(int a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) / b;
}
//// Arithmetic with int64_t
inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int64_t b) {
return a + static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int64_t b) {
return a - static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int64_t b) {
return a * static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int64_t b) {
return a / static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator+(int64_t a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) + b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator-(int64_t a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) - b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator*(int64_t a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) * b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator/(int64_t a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) / b;
}
/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Float8_e4m3fn to float.
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::Float8_e4m3fn> {
public:
static constexpr bool is_specialized = true;
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = false;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = false;
static constexpr auto has_denorm = true;
static constexpr auto has_denorm_loss = true;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 4;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 3;
static constexpr int radix = 2;
static constexpr int min_exponent = -5;
static constexpr int min_exponent10 = -1;
static constexpr int max_exponent = 8;
static constexpr int max_exponent10 = 2;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before = false;
static constexpr c10::Float8_e4m3fn min() {
return c10::Float8_e4m3fn(0x08, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn lowest() {
return c10::Float8_e4m3fn(0xFE, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn max() {
return c10::Float8_e4m3fn(0x7E, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn epsilon() {
return c10::Float8_e4m3fn(0x20, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn round_error() {
return c10::Float8_e4m3fn(0x30, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn quiet_NaN() {
return c10::Float8_e4m3fn(0x7F, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn denorm_min() {
return c10::Float8_e4m3fn(0x01, c10::Float8_e4m3fn::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -0,0 +1,240 @@
#pragma once
/// Defines the Float8_e4m3fn type (8-bit floating-point) including conversions
/// to standard C types and basic arithmetic operations. Note that arithmetic
/// operations are implemented by converting to floating point and
/// performing the operation in float32.
/// Binary configuration:
/// s eeee mmm
/// 1 sign bit
/// 4 exponent bits
/// 3 mantissa bits
/// bias = 7
///
/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf
/// and inspired by Half implementation from pytorch/c10/util/Half.h
#include <c10/macros/Macros.h>
#include <c10/util/floating_point_utils.h>
#if defined(__cplusplus)
#include <cmath>
#include <cstdint>
#elif !defined(__OPENCL_VERSION__)
#include <math.h>
#include <stdint.h>
#endif
#ifdef _MSC_VER
#include <intrin.h>
#endif
#include <climits>
#include <iostream>
namespace c10 {
namespace detail {
/*
* Convert a 8-bit floating-point number in fp8 E4M3FN format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format, in bit representation.
*
* @note The implementation doesn't use any floating-point operations.
*/
inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) {
/*
* Extend the fp8 E4M3FN number to 32 bits and shift to the
* upper part of the 32-bit word:
* +---+----+---+-----------------------------+
* | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
* +---+----+---+-----------------------------+
* Bits 31 27-30 24-26 0-23
*
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
* - zero bits.
*/
const uint32_t w = (uint32_t)input << 24;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = w & UINT32_C(0x80000000);
/*
* Extract mantissa and biased exponent of the input number into the bits 0-30
* of the 32-bit word:
*
* +---+----+---+-----------------------------+
* | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
* +---+----+---+-----------------------------+
* Bits 31 27-30 24-26 0-23
*/
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
/*
* Renorm shift is the number of bits to shift mantissa left to make the
* half-precision number normalized. If the initial number is normalized, some
* of its high 5 bits (sign == 0 and 4-bit exponent) equals one. In this case
* renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note
* that if we shift denormalized nonsign by renorm_shift, the unit bit of
* mantissa will shift into exponent, turning the biased exponent into 1, and
* making mantissa normalized (i.e. without leading 1).
*/
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
uint32_t renorm_shift = __clz(nonsign);
#elif defined(__SYCL_DEVICE_ONLY__)
// Note: zero is not a supported input into `__builtin_clz`
uint32_t renorm_shift =
nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT;
#elif defined(_MSC_VER)
unsigned long nonsign_bsr;
_BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
#else
// Note: zero is not a supported input into `__builtin_clz`
uint32_t renorm_shift =
nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT;
#endif
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
/*
* Iff fp8e4m3fn number has all exponent and mantissa bits set to 1,
* the addition overflows it into bit 31, and the subsequent shift turns the
* high 9 bits into 1. Thus inf_nan_mask == 0x7F800000 if the fp8e4m3fn number
* is Nan, 0x00000000 otherwise
*/
const int32_t inf_nan_mask =
((int32_t)(nonsign + 0x01000000) >> 8) & INT32_C(0x7F800000);
/*
* Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31
* into 1. Otherwise, bit 31 remains 0. The signed shift right by 31
* broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==
* 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)
* 0x00000000 otherwise
*/
const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
/*
* 1. Shift nonsign left by renorm_shift to normalize it (if the input
* was denormal)
* 2. Shift nonsign right by 4 so the exponent (4 bits originally)
* becomes an 8-bit field and 3-bit mantissa shifts into the 3 high
* bits of the 23-bit mantissa of IEEE single-precision number.
* 3. Add 0x78 to the exponent (starting at bit 23) to compensate the
* different in exponent bias (0x7F for single-precision number less 0x07
* for fp8e4m3fn number).
* 4. Subtract renorm_shift from the exponent (starting at bit 23) to
* account for renormalization. As renorm_shift is less than 0x78, this
* can be combined with step 3.
* 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the
* input was NaN or infinity.
* 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent
* into zero if the input was zero.
* 7. Combine with the sign of the input number.
*/
uint32_t result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return fp32_from_bits(result);
}
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E4M3FN format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) {
/*
* Binary representation of 480.0f, which is the first value
* not representable in fp8e4m3fn range:
* 0 1111 111 - fp8e4m3fn
* 0 10000111 11100000000000000000000 - fp32
*/
constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
/*
* A mask for converting fp32 numbers lower than fp8e4m3fn normal range
* into denorm representation
* magic number: ((127 - 7) + (23 - 3) + 1)
*/
constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
uint32_t f_bits = fp32_to_bits(f);
uint8_t result = 0u;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = f_bits & UINT32_C(0x80000000);
/*
* Set sign bit to 0
*/
f_bits ^= sign;
if (f_bits >= fp8_max) {
// NaN - all exponent and mantissa bits set to 1
result = 0x7f;
} else {
if (f_bits < (UINT32_C(121) << 23)) {
// Input number is smaller than 2^(-6), which is the smallest
// fp8e4m3fn normal number
f_bits =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
// resulting mantissa is odd
uint8_t mant_odd = (f_bits >> 20) & 1;
// update exponent, rounding bias part 1
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
// rounding bias part 2
f_bits += mant_odd;
// take the bits!
result = static_cast<uint8_t>(f_bits >> 20);
}
}
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
} // namespace detail
struct alignas(1) Float8_e4m3fn {
uint8_t x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
Float8_e4m3fn() = default;
constexpr C10_HOST_DEVICE Float8_e4m3fn(uint8_t bits, from_bits_t)
: x(bits) {}
inline C10_HOST_DEVICE Float8_e4m3fn(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
};
C10_API inline std::ostream& operator<<(
std::ostream& out,
const Float8_e4m3fn& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/Float8_e4m3fn-inl.h> // IWYU pragma: keep

View File

@ -0,0 +1,279 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/Float8_fnuz_cvt.h>
#include <cstring>
#include <limits>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
namespace c10 {
/// Constructors
inline C10_HOST_DEVICE Float8_e4m3fnuz::Float8_e4m3fnuz(float value)
: x(detail::fp8e4m3fnuz_from_fp32_value(value)) {}
/// Implicit conversions
inline C10_HOST_DEVICE Float8_e4m3fnuz::operator float() const {
return detail::fp8_fnuz_to_fp32_value<4, 3>(x);
}
/// Special values helper
inline C10_HOST_DEVICE bool Float8_e4m3fnuz::isnan() const {
return x == 0b10000000;
}
/// Arithmetic
inline C10_HOST_DEVICE Float8_e4m3fnuz
operator+(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz
operator-(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz
operator*(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(
const Float8_e4m3fnuz& a,
const Float8_e4m3fnuz& b) __ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(const Float8_e4m3fnuz& a) {
return -static_cast<float>(a);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator+=(
Float8_e4m3fnuz& a,
const Float8_e4m3fnuz& b) {
a = a + b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator-=(
Float8_e4m3fnuz& a,
const Float8_e4m3fnuz& b) {
a = a - b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator*=(
Float8_e4m3fnuz& a,
const Float8_e4m3fnuz& b) {
a = a * b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator/=(
Float8_e4m3fnuz& a,
const Float8_e4m3fnuz& b) {
a = a / b;
return a;
}
/// Arithmetic with floats
inline C10_HOST_DEVICE float operator+(Float8_e4m3fnuz a, float b) {
return static_cast<float>(a) + b;
}
inline C10_HOST_DEVICE float operator-(Float8_e4m3fnuz a, float b) {
return static_cast<float>(a) - b;
}
inline C10_HOST_DEVICE float operator*(Float8_e4m3fnuz a, float b) {
return static_cast<float>(a) * b;
}
inline C10_HOST_DEVICE float operator/(Float8_e4m3fnuz a, float b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / b;
}
inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fnuz b) {
return a + static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fnuz b) {
return a - static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fnuz b) {
return a * static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fnuz b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fnuz& b) {
return a += static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fnuz& b) {
return a -= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fnuz& b) {
return a *= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fnuz& b) {
return a /= static_cast<float>(b);
}
/// Arithmetic with doubles
inline C10_HOST_DEVICE double operator+(Float8_e4m3fnuz a, double b) {
return static_cast<double>(a) + b;
}
inline C10_HOST_DEVICE double operator-(Float8_e4m3fnuz a, double b) {
return static_cast<double>(a) - b;
}
inline C10_HOST_DEVICE double operator*(Float8_e4m3fnuz a, double b) {
return static_cast<double>(a) * b;
}
inline C10_HOST_DEVICE double operator/(Float8_e4m3fnuz a, double b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<double>(a) / b;
}
inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fnuz b) {
return a + static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fnuz b) {
return a - static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fnuz b) {
return a * static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fnuz b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<double>(b);
}
/// Arithmetic with ints
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int b) {
return a + static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int b) {
return a - static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int b) {
return a * static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int b) {
return a / static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) + b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) - b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) * b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) / b;
}
//// Arithmetic with int64_t
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int64_t b) {
return a + static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int64_t b) {
return a - static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int64_t b) {
return a * static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int64_t b) {
return a / static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int64_t a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) + b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int64_t a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) - b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int64_t a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) * b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int64_t a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) / b;
}
/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Float8_e4m3fnuz to float.
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::Float8_e4m3fnuz> {
public:
static constexpr bool is_specialized = true;
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = false;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = false;
static constexpr auto has_denorm = true;
static constexpr auto has_denorm_loss = true;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 4;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 3;
static constexpr int radix = 2;
static constexpr int min_exponent = -6;
static constexpr int min_exponent10 = -1;
static constexpr int max_exponent = 8;
static constexpr int max_exponent10 = 2;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before = false;
static constexpr c10::Float8_e4m3fnuz min() {
return c10::Float8_e4m3fnuz(0x08, c10::Float8_e4m3fnuz::from_bits());
}
static constexpr c10::Float8_e4m3fnuz lowest() {
return c10::Float8_e4m3fnuz(0xFF, c10::Float8_e4m3fnuz::from_bits());
}
static constexpr c10::Float8_e4m3fnuz max() {
return c10::Float8_e4m3fnuz(0x7F, c10::Float8_e4m3fnuz::from_bits());
}
static constexpr c10::Float8_e4m3fnuz epsilon() {
return c10::Float8_e4m3fnuz(0x28, c10::Float8_e4m3fnuz::from_bits());
}
static constexpr c10::Float8_e4m3fnuz round_error() {
return c10::Float8_e4m3fnuz(0x38, c10::Float8_e4m3fnuz::from_bits());
}
static constexpr c10::Float8_e4m3fnuz infinity() {
// NaN (no infinities)
return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits());
}
static constexpr c10::Float8_e4m3fnuz quiet_NaN() {
return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits());
}
static constexpr c10::Float8_e4m3fnuz denorm_min() {
return c10::Float8_e4m3fnuz(0x01, c10::Float8_e4m3fnuz::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -0,0 +1,139 @@
#pragma once
/// Defines the Float8_e4m3fnuz type (8-bit floating-point) including
/// conversions to standard C types and basic arithmetic operations. Note that
/// arithmetic operations are implemented by converting to floating point and
/// performing the operation in float32.
/// Binary configuration remains the same as Float8_e4m3fn:
/// s eeee mmm
/// 1 sign bit
/// 4 exponent bits
/// 3 mantissa bits
/// The key differences versus Float8_e4m3fn are:
/// bias = 8
/// no infinities or negative zero
/// NaN only when sign bit is 1, rest all 0s
///
/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and
/// the existing Float8_e4m3fn implementation.
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/floating_point_utils.h>
#include <type_traits>
#if defined(__cplusplus)
#include <cstdint>
#elif !defined(__OPENCL_VERSION__)
#include <math.h>
#include <stdint.h>
#endif
#include <iosfwd>
#include <ostream>
namespace c10 {
namespace detail {
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E4M3FNUZ format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f) {
/*
* Binary representation of 256.0f, which is the first value not representable
* (i.e. the first value which would overflow in to the sign bit, resulting in
* a NaN) in fp8e4m3fnuz range:
* 1 0000 000 - fp8e4m3fnuz
* 0 10000111 00000000000000000000000 - fp32
*/
constexpr uint32_t fnuz_max = UINT32_C(0x87) << 23;
/*
* A mask for converting fp32 numbers lower than fp8e4m3fnuz normal range
* into denorm representation
* magic number: ((127 - 8) + (23 - 3) + 1)
*/
constexpr uint32_t denorm_mask = UINT32_C(0x8C) << 23;
uint32_t f_bits = fp32_to_bits(f);
uint32_t result = 0u;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = f_bits & UINT32_C(0x80000000);
/*
* Set sign bit to 0
*/
f_bits ^= sign;
if (f_bits >= fnuz_max) {
// NaN -- sign bit set to 1, rest 0s.
return 0x80;
}
if (f_bits < (UINT32_C(0x78) << 23) /* 2^-7 in float32 */) {
// Input exponent is less than -7, the smallest e4m3fnuz exponent, so the
// number will become subnormal.
f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
if (result == 0) {
// fnuz types don't have negative zero.
return 0;
}
} else {
// resulting mantissa is odd
uint8_t mant_odd = (f_bits >> 20) & 1;
// update exponent, rounding bias part 1
f_bits += ((uint32_t)(8 - 127) << 23) + 0x7FFFF;
// rounding bias part 2
f_bits += mant_odd;
// take the bits!
result = static_cast<uint8_t>(f_bits >> 20);
}
result |= sign >> 24;
return result;
}
} // namespace detail
struct alignas(1) Float8_e4m3fnuz {
uint8_t x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
Float8_e4m3fnuz() = default;
constexpr C10_HOST_DEVICE Float8_e4m3fnuz(uint8_t bits, from_bits_t)
: x(bits) {}
inline C10_HOST_DEVICE Float8_e4m3fnuz(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
};
C10_API inline std::ostream& operator<<(
std::ostream& out,
const Float8_e4m3fnuz& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/Float8_e4m3fnuz-inl.h> // IWYU pragma: keep

View File

@ -0,0 +1,286 @@
#pragma once
#include <c10/macros/Macros.h>
#include <cstring>
#include <limits>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
#define EXP_WIDTH_FP8 5
#define MAN_WIDTH_FP8 2
#define EXP_BIAS_FP8 15
namespace c10 {
/// Constructors
inline C10_HOST_DEVICE Float8_e5m2::Float8_e5m2(float value)
: x(detail::fp8e5m2_from_fp32_value(value)) {}
/// Implicit conversions
inline C10_HOST_DEVICE Float8_e5m2::operator float() const {
return detail::fp8e5m2_to_fp32_value(x);
}
/// Special values helpers
inline C10_HOST_DEVICE bool Float8_e5m2::isnan() const {
return (x & 0b01111111) > 0b01111100;
}
inline C10_HOST_DEVICE bool Float8_e5m2::isinf() const {
return (x & 0b01111111) == 0b01111100;
}
/// Arithmetic
inline C10_HOST_DEVICE Float8_e5m2
operator+(const Float8_e5m2& a, const Float8_e5m2& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2
operator-(const Float8_e5m2& a, const Float8_e5m2& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2
operator*(const Float8_e5m2& a, const Float8_e5m2& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator/(
const Float8_e5m2& a,
const Float8_e5m2& b) __ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator-(const Float8_e5m2& a) {
return -static_cast<float>(a);
}
inline C10_HOST_DEVICE Float8_e5m2& operator+=(
Float8_e5m2& a,
const Float8_e5m2& b) {
a = a + b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2& operator-=(
Float8_e5m2& a,
const Float8_e5m2& b) {
a = a - b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2& operator*=(
Float8_e5m2& a,
const Float8_e5m2& b) {
a = a * b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2& operator/=(
Float8_e5m2& a,
const Float8_e5m2& b) {
a = a / b;
return a;
}
/// Arithmetic with floats
inline C10_HOST_DEVICE float operator+(Float8_e5m2 a, float b) {
return static_cast<float>(a) + b;
}
inline C10_HOST_DEVICE float operator-(Float8_e5m2 a, float b) {
return static_cast<float>(a) - b;
}
inline C10_HOST_DEVICE float operator*(Float8_e5m2 a, float b) {
return static_cast<float>(a) * b;
}
inline C10_HOST_DEVICE float operator/(Float8_e5m2 a, float b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / b;
}
inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2 b) {
return a + static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2 b) {
return a - static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2 b) {
return a * static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2 b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2& b) {
return a += static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2& b) {
return a -= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2& b) {
return a *= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2& b) {
return a /= static_cast<float>(b);
}
/// Arithmetic with doubles
inline C10_HOST_DEVICE double operator+(Float8_e5m2 a, double b) {
return static_cast<double>(a) + b;
}
inline C10_HOST_DEVICE double operator-(Float8_e5m2 a, double b) {
return static_cast<double>(a) - b;
}
inline C10_HOST_DEVICE double operator*(Float8_e5m2 a, double b) {
return static_cast<double>(a) * b;
}
inline C10_HOST_DEVICE double operator/(Float8_e5m2 a, double b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<double>(a) / b;
}
inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2 b) {
return a + static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2 b) {
return a - static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2 b) {
return a * static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2 b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<double>(b);
}
/// Arithmetic with ints
inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int b) {
return a + static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int b) {
return a - static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int b) {
return a * static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int b) {
return a / static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator+(int a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) + b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator-(int a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) - b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator*(int a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) * b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator/(int a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) / b;
}
//// Arithmetic with int64_t
inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int64_t b) {
return a + static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int64_t b) {
return a - static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int64_t b) {
return a * static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int64_t b) {
return a / static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator+(int64_t a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) + b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator-(int64_t a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) - b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator*(int64_t a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) * b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator/(int64_t a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) / b;
}
/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Float8_e5m2 to float.
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::Float8_e5m2> {
public:
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_specialized = true;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = false;
static constexpr auto has_denorm = true;
static constexpr auto has_denorm_loss = true;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 3;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 2;
static constexpr int radix = 2;
static constexpr int min_exponent = -13;
static constexpr int min_exponent10 = -4;
static constexpr int max_exponent = 16;
static constexpr int max_exponent10 = 4;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before =
numeric_limits<float>::tinyness_before;
static constexpr c10::Float8_e5m2 min() {
return c10::Float8_e5m2(0x4, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 max() {
return c10::Float8_e5m2(0x7B, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 lowest() {
return c10::Float8_e5m2(0xFB, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 epsilon() {
return c10::Float8_e5m2(0x34, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 round_error() {
return c10::Float8_e5m2(0x38, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 infinity() {
return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 quiet_NaN() {
return c10::Float8_e5m2(0x7F, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 denorm_min() {
return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -0,0 +1,148 @@
#pragma once
/// Defines the Float8_e5m2 type (8-bit floating-point) including conversions
/// to standard C types and basic arithmetic operations. Note that arithmetic
/// operations are implemented by converting to floating point and
/// performing the operation in float32.
/// Binary configuration:
/// s eeeee mm
/// 1 sign bit
/// 5 exponent bits
/// 2 mantissa bits
/// bias = 15
///
/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf
/// and inspired by Half implementation from pytorch/c10/util/Half.h
#include <c10/util/Half.h>
namespace c10 {
namespace detail {
/*
* Convert a 8-bit floating-point number in fp8 E5M2 format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format, in bit representation.
*
* @note The implementation doesn't use any floating-point operations.
*/
inline C10_HOST_DEVICE float fp8e5m2_to_fp32_value(uint8_t input) {
/*
* Extend the fp8 E5M2 number to 32 bits and shift to the
* upper part of the 32-bit word:
* +---+----+---+-----------------------------+
* | S |EEEEE|MM|0000 0000 0000 0000 0000 0000|
* +---+----+---+-----------------------------+
* Bits 31 26-30 24-25 0-23
*
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
* - zero bits.
*/
uint16_t half_representation = input;
half_representation <<= 8;
return fp16_ieee_to_fp32_value(half_representation);
}
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E5M2 format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e5m2_from_fp32_value(float f) {
/*
* Binary representation of fp32 infinity
* 0 11111111 00000000000000000000000
*/
constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
/*
* Binary representation of 65536.0f, which is the first value
* not representable in fp8e5m2 range:
* 0 11111 00 - fp8e5m2
* 0 10001111 00000000000000000000000 - fp32
*/
constexpr uint32_t fp8_max = UINT32_C(143) << 23;
/*
* A mask for converting fp32 numbers lower than fp8e5m2 normal range
* into denorm representation
* magic number: ((127 - 15) + (23 - 2) + 1)
*/
constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
uint32_t f_bits = fp32_to_bits(f);
uint8_t result = 0u;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = f_bits & UINT32_C(0x80000000);
/*
* Set sign bit to 0
*/
f_bits ^= sign;
if (f_bits >= fp8_max) {
// NaN - all exponent and mantissa bits set to 1
result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
} else {
if (f_bits < (UINT32_C(113) << 23)) {
// Input number is smaller than 2^(-14), which is the smallest
// fp8e5m2 normal number
f_bits =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
// resulting mantissa is odd
uint32_t mant_odd = (f_bits >> 21) & 1;
// update exponent, rounding bias part 1
f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
// rounding bias part 2
f_bits += mant_odd;
// take the bits!
result = static_cast<uint8_t>(f_bits >> 21);
}
}
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
} // namespace detail
struct alignas(1) Float8_e5m2 {
uint8_t x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
Float8_e5m2() = default;
constexpr C10_HOST_DEVICE Float8_e5m2(uint8_t bits, from_bits_t) : x(bits) {}
inline C10_HOST_DEVICE Float8_e5m2(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
inline C10_HOST_DEVICE bool isinf() const;
};
C10_API inline std::ostream& operator<<(
std::ostream& out,
const Float8_e5m2& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/Float8_e5m2-inl.h> // IWYU pragma: keep

View File

@ -0,0 +1,285 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/Float8_fnuz_cvt.h>
#include <cstring>
#include <limits>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
namespace c10 {
/// Constructors
inline C10_HOST_DEVICE Float8_e5m2fnuz::Float8_e5m2fnuz(float value)
: x(detail::fp8e5m2fnuz_from_fp32_value(value)) {}
/// Implicit conversions
inline C10_HOST_DEVICE Float8_e5m2fnuz::operator float() const {
return detail::fp8_fnuz_to_fp32_value<5, 2>(x);
}
/// Special values helpers
inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isnan() const {
return x == 0b10000000;
}
inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isinf() const {
return false;
}
/// Arithmetic
inline C10_HOST_DEVICE Float8_e5m2fnuz
operator+(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz
operator-(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz
operator*(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(
const Float8_e5m2fnuz& a,
const Float8_e5m2fnuz& b) __ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(const Float8_e5m2fnuz& a) {
return -static_cast<float>(a);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator+=(
Float8_e5m2fnuz& a,
const Float8_e5m2fnuz& b) {
a = a + b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator-=(
Float8_e5m2fnuz& a,
const Float8_e5m2fnuz& b) {
a = a - b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator*=(
Float8_e5m2fnuz& a,
const Float8_e5m2fnuz& b) {
a = a * b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator/=(
Float8_e5m2fnuz& a,
const Float8_e5m2fnuz& b) {
a = a / b;
return a;
}
/// Arithmetic with floats
inline C10_HOST_DEVICE float operator+(Float8_e5m2fnuz a, float b) {
return static_cast<float>(a) + b;
}
inline C10_HOST_DEVICE float operator-(Float8_e5m2fnuz a, float b) {
return static_cast<float>(a) - b;
}
inline C10_HOST_DEVICE float operator*(Float8_e5m2fnuz a, float b) {
return static_cast<float>(a) * b;
}
inline C10_HOST_DEVICE float operator/(Float8_e5m2fnuz a, float b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / b;
}
inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2fnuz b) {
return a + static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2fnuz b) {
return a - static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2fnuz b) {
return a * static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2fnuz b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2fnuz& b) {
return a += static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2fnuz& b) {
return a -= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2fnuz& b) {
return a *= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2fnuz& b) {
return a /= static_cast<float>(b);
}
/// Arithmetic with doubles
inline C10_HOST_DEVICE double operator+(Float8_e5m2fnuz a, double b) {
return static_cast<double>(a) + b;
}
inline C10_HOST_DEVICE double operator-(Float8_e5m2fnuz a, double b) {
return static_cast<double>(a) - b;
}
inline C10_HOST_DEVICE double operator*(Float8_e5m2fnuz a, double b) {
return static_cast<double>(a) * b;
}
inline C10_HOST_DEVICE double operator/(Float8_e5m2fnuz a, double b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<double>(a) / b;
}
inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2fnuz b) {
return a + static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2fnuz b) {
return a - static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2fnuz b) {
return a * static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2fnuz b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<double>(b);
}
/// Arithmetic with ints
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int b) {
return a + static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int b) {
return a - static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int b) {
return a * static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int b) {
return a / static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) + b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) - b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) * b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) / b;
}
//// Arithmetic with int64_t
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int64_t b) {
return a + static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int64_t b) {
return a - static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int64_t b) {
return a * static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int64_t b) {
return a / static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int64_t a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) + b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int64_t a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) - b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int64_t a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) * b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int64_t a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) / b;
}
/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Float8_e5m2fnuz to float.
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::Float8_e5m2fnuz> {
public:
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_specialized = true;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = false;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = false;
static constexpr auto has_denorm = true;
static constexpr auto has_denorm_loss = true;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 3;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 2;
static constexpr int radix = 2;
static constexpr int min_exponent = -14;
static constexpr int min_exponent10 = -4;
static constexpr int max_exponent = 16;
static constexpr int max_exponent10 = 4;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before =
numeric_limits<float>::tinyness_before;
static constexpr c10::Float8_e5m2fnuz min() {
return c10::Float8_e5m2fnuz(0x04, c10::Float8_e5m2fnuz::from_bits());
}
static constexpr c10::Float8_e5m2fnuz max() {
return c10::Float8_e5m2fnuz(0x7F, c10::Float8_e5m2fnuz::from_bits());
}
static constexpr c10::Float8_e5m2fnuz lowest() {
return c10::Float8_e5m2fnuz(0xFF, c10::Float8_e5m2fnuz::from_bits());
}
static constexpr c10::Float8_e5m2fnuz epsilon() {
return c10::Float8_e5m2fnuz(0x34, c10::Float8_e5m2fnuz::from_bits());
}
static constexpr c10::Float8_e5m2fnuz round_error() {
return c10::Float8_e5m2fnuz(0x38, c10::Float8_e5m2fnuz::from_bits());
}
static constexpr c10::Float8_e5m2fnuz infinity() {
return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
}
// TODO(future): we are mapping neg_zero to both inf and NaN, this is
// surprising and we should figure out what to do about it.
static constexpr c10::Float8_e5m2fnuz quiet_NaN() {
return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
}
static constexpr c10::Float8_e5m2fnuz denorm_min() {
return c10::Float8_e5m2fnuz(0x01, c10::Float8_e5m2fnuz::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -0,0 +1,138 @@
#pragma once
/// Defines the Float8_e5m2fnuz type (8-bit floating-point) including
/// conversions to standard C types and basic arithmetic operations. Note that
/// arithmetic operations are implemented by converting to floating point and
/// performing the operation in float32.
/// Binary configuration remains the same as e5m2:
/// s eeeee mm
/// 1 sign bit
/// 5 exponent bits
/// 2 mantissa bits
/// The key differences that e5m2fnuz brings are:
/// bias = 16
/// no infinities or negative zero
/// NaN only when sign bit is 1, rest all 0s
///
/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and
/// the existing Float8_e4m3fn implementation.
#include <c10/macros/Macros.h>
#include <c10/util/TypeSafeSignMath.h>
#include <c10/util/floating_point_utils.h>
#if defined(__cplusplus)
#include <cstdint>
#elif !defined(__OPENCL_VERSION__)
#include <math.h>
#include <stdint.h>
#endif
#include <iosfwd>
#include <ostream>
namespace c10 {
namespace detail {
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E5M2 format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e5m2fnuz_from_fp32_value(float f) {
/*
* Binary representation of 65536.0f, which is the first value not
* representable (i.e. the first value which would overflow in to the sign
* bit, resulting in a NaN) in fp8e4m3fnuz range:
* 1 00000 00 - fp8e5m2fnuz
* 0 10001111 00000000000000000000000 - fp32
*/
constexpr uint32_t fnuz_max = UINT32_C(0x8F) << 23;
/*
* A mask for converting fp32 numbers lower than fp8e5m2fnuz normal range
* into denormalized representation.
* magic number: ((127 - 16) + (23 - 2) + 1)
*/
constexpr uint32_t denorm_mask = UINT32_C(0x85) << 23;
uint32_t f_bits = fp32_to_bits(f);
uint32_t result = 0u;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = f_bits & UINT32_C(0x80000000);
/*
* Set sign bit to 0
*/
f_bits ^= sign;
if (f_bits >= fnuz_max) {
// NaN -- sign bit set to 1, rest 0s
return 0x80;
}
if (f_bits < (UINT32_C(0x70) << 23) /* 2^-15 in float32 */) {
// Input exponent is less than -15, the smallest e5m2fnuz exponent, so the
// number will become subnormal.
f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
if (result == 0) {
// fnuz types don't have negative zero.
return 0;
}
} else {
// resulting mantissa is odd
uint8_t mant_odd = (f_bits >> 21) & 1;
// update exponent, rounding bias part 1
f_bits += ((uint32_t)(16 - 127) << 23) + 0xFFFFF;
// rounding bias part 2
f_bits += mant_odd;
// take the bits!
result = static_cast<uint8_t>(f_bits >> 21);
}
result |= sign >> 24;
return result;
}
} // namespace detail
struct alignas(1) Float8_e5m2fnuz {
uint8_t x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
Float8_e5m2fnuz() = default;
constexpr C10_HOST_DEVICE Float8_e5m2fnuz(uint8_t bits, from_bits_t)
: x(bits) {}
inline C10_HOST_DEVICE Float8_e5m2fnuz(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
inline C10_HOST_DEVICE bool isinf() const;
};
C10_API inline std::ostream& operator<<(
std::ostream& out,
const Float8_e5m2fnuz& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/Float8_e5m2fnuz-inl.h> // IWYU pragma: keep

View File

@ -0,0 +1,64 @@
#pragma once
#include <c10/util/floating_point_utils.h>
#include <cstdint>
#if defined(SYCL_LANGUAGE_VERSION)
#include <sycl/sycl.hpp>
#endif
namespace c10::detail {
/*
* Convert a 8-bit floating-point number in either f8 E4M3FNUZ or bf8 E5M2FNUZ
* format, in bit representation, to a 32-bit floating-point number.
*/
template <uint32_t we, uint32_t wm>
inline C10_HOST_DEVICE float fp8_fnuz_to_fp32_value(uint8_t x) {
static_assert((we == 4 && wm == 3) || (we == 5 && wm == 2));
constexpr uint32_t weo = 8;
constexpr uint32_t wmo = 23;
if (x == 0) {
return 0;
}
if (x == 0x80) {
constexpr uint32_t ifNaN = 0x7F800001;
return fp32_from_bits(ifNaN);
}
uint32_t mantissa = x & ((1 << wm) - 1);
uint32_t exponent = (x & 0x7F) >> wm;
// subnormal input
if (exponent == 0) {
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
uint32_t renorm_shift = __clz(mantissa);
#elif defined(__SYCL_DEVICE_ONLY__)
uint32_t renorm_shift = sycl::clz(mantissa);
#elif defined(_MSC_VER)
unsigned long nonsign_bsr;
_BitScanReverse(&nonsign_bsr, (unsigned long)mantissa);
uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
#else
uint32_t renorm_shift = __builtin_clz(mantissa);
#endif
uint32_t sh = 1 + renorm_shift - (32 - wm);
mantissa <<= sh;
exponent += 1 - sh;
mantissa &= ((1 << wm) - 1);
}
const uint32_t exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1));
exponent += exp_low_cutoff - 1;
mantissa <<= wmo - wm;
uint32_t sign = x >> 7;
uint32_t retval = (sign << 31) | (exponent << 23) | mantissa;
return fp32_from_bits(retval);
}
} // namespace c10::detail

View File

@ -0,0 +1,73 @@
//===- llvm/ADT/STLExtras.h - Useful STL related functions ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains some templates that are useful if you are working with the
// STL at all.
//
// No library is required when using these functions.
//
//===----------------------------------------------------------------------===//
// c10: modified from llvm::function_ref
// c10: added more SFINAE to enable use in overloaded functions
#pragma once
#include <cstdint>
#include <type_traits>
#include <utility>
namespace c10 {
/// An efficient, type-erasing, non-owning reference to a callable. This is
/// intended for use as the type of a function parameter that is not used
/// after the function in question returns.
///
/// This class does not own the callable, so it is not in general safe to store
/// a function_ref.
template <typename Fn>
class function_ref;
template <typename Ret, typename... Params>
class function_ref<Ret(Params...)> {
Ret (*callback)(intptr_t callable, Params... params) = nullptr;
intptr_t callable{};
template <typename Callable>
static Ret callback_fn(intptr_t callable, Params... params) {
return (*reinterpret_cast<Callable*>(callable))(
std::forward<Params>(params)...);
}
public:
function_ref() = default;
function_ref(std::nullptr_t) {}
template <typename Callable>
function_ref(
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
Callable&& callable,
std::enable_if_t<
!std::is_same_v<std::remove_reference_t<Callable>, function_ref>>* =
nullptr,
std::enable_if_t<std::is_convertible_v<
typename std::invoke_result_t<Callable, Params...>,
Ret>>* = nullptr)
: callback(callback_fn<std::remove_reference_t<Callable>>),
callable(reinterpret_cast<intptr_t>(&callable)) {}
Ret operator()(Params... params) const {
return callback(callable, std::forward<Params>(params)...);
}
operator bool() const {
return callback;
}
};
} // namespace c10

View File

@ -0,0 +1,48 @@
#pragma once
#include <memory>
#include <string_view>
#include <c10/macros/Macros.h>
#include <c10/util/SmallVector.h>
namespace c10::monitor {
namespace detail {
class GaugeImpl;
class GaugeBackendIf {
public:
virtual ~GaugeBackendIf() = default;
virtual void record(int64_t value) noexcept = 0;
};
class GaugeBackendFactoryIf {
public:
virtual ~GaugeBackendFactoryIf() = default;
// May return nullptr if the gauge will be ignored by the given backend.
virtual std::unique_ptr<GaugeBackendIf> create(
std::string_view key) noexcept = 0;
};
void C10_API registerGaugeBackend(std::unique_ptr<GaugeBackendFactoryIf>);
} // namespace detail
// A handle to a Gauge.
class C10_API GaugeHandle {
public:
explicit GaugeHandle(std::string_view key);
void record(int64_t value);
private:
detail::GaugeImpl& impl_;
};
} // namespace c10::monitor
#define STATIC_GAUGE(_key) \
[]() -> ::c10::monitor::GaugeHandle& { \
static ::c10::monitor::GaugeHandle handle(#_key); \
return handle; \
}()

View File

@ -0,0 +1,350 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/bit_cast.h>
#include <cstring>
#include <limits>
#ifdef __CUDACC__
#include <cuda_fp16.h>
#endif
#ifdef __HIPCC__
#include <hip/hip_fp16.h>
#endif
#if defined(CL_SYCL_LANGUAGE_VERSION)
#include <CL/sycl.hpp> // for SYCL 1.2.1
#elif defined(SYCL_LANGUAGE_VERSION)
#include <sycl/sycl.hpp> // for SYCL 2020
#endif
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
!defined(__APPLE__)
#include <ATen/cpu/vec/vec_half.h>
#endif
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
namespace c10 {
#if defined(__aarch64__) && !defined(__CUDACC__)
/// Constructors
inline Half::Half(float16_t value) : x(detail::fp16_to_bits(value)) {}
inline Half::operator float16_t() const {
return detail::fp16_from_bits(x);
}
#else
inline C10_HOST_DEVICE Half::Half(float value)
:
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
x(__half_as_short(__float2half(value)))
#elif defined(__SYCL_DEVICE_ONLY__)
x(c10::bit_cast<uint16_t>(sycl::half(value)))
#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
!defined(__APPLE__)
x(at::vec::float2half_scalar(value))
#else
x(detail::fp16_ieee_from_fp32_value(value))
#endif
{
}
/// Implicit conversions
inline C10_HOST_DEVICE Half::operator float() const {
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return __half2float(*reinterpret_cast<const __half*>(&x));
#elif defined(__SYCL_DEVICE_ONLY__)
return float(c10::bit_cast<sycl::half>(x));
#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
!defined(__APPLE__)
return at::vec::half2float_scalar(x);
#elif defined(__aarch64__) && !defined(__CUDACC__)
return detail::native_fp16_to_fp32_value(x);
#else
return detail::fp16_ieee_to_fp32_value(x);
#endif
}
#endif /* !defined(__aarch64__) || defined(__CUDACC__) \
*/
#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_HOST_DEVICE Half::Half(const __half& value) {
x = *reinterpret_cast<const unsigned short*>(&value);
}
inline C10_HOST_DEVICE Half::operator __half() const {
return *reinterpret_cast<const __half*>(&x);
}
#endif
#ifdef SYCL_LANGUAGE_VERSION
inline C10_HOST_DEVICE Half::Half(const sycl::half& value) {
x = *reinterpret_cast<const unsigned short*>(&value);
}
inline C10_HOST_DEVICE Half::operator sycl::half() const {
return *reinterpret_cast<const sycl::half*>(&x);
}
#endif
// CUDA intrinsics
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)) || \
(defined(__clang__) && defined(__CUDA__))
inline __device__ Half __ldg(const Half* ptr) {
return __ldg(reinterpret_cast<const __half*>(ptr));
}
#endif
/// Arithmetic
inline C10_HOST_DEVICE Half operator+(const Half& a, const Half& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline C10_HOST_DEVICE Half operator-(const Half& a, const Half& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline C10_HOST_DEVICE Half operator*(const Half& a, const Half& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}
inline C10_HOST_DEVICE Half operator-(const Half& a) {
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
defined(__HIP_DEVICE_COMPILE__)
return __hneg(a);
#elif defined(__SYCL_DEVICE_ONLY__)
return -c10::bit_cast<sycl::half>(a);
#else
return -static_cast<float>(a);
#endif
}
inline C10_HOST_DEVICE Half& operator+=(Half& a, const Half& b) {
a = a + b;
return a;
}
inline C10_HOST_DEVICE Half& operator-=(Half& a, const Half& b) {
a = a - b;
return a;
}
inline C10_HOST_DEVICE Half& operator*=(Half& a, const Half& b) {
a = a * b;
return a;
}
inline C10_HOST_DEVICE Half& operator/=(Half& a, const Half& b) {
a = a / b;
return a;
}
/// Arithmetic with floats
inline C10_HOST_DEVICE float operator+(Half a, float b) {
return static_cast<float>(a) + b;
}
inline C10_HOST_DEVICE float operator-(Half a, float b) {
return static_cast<float>(a) - b;
}
inline C10_HOST_DEVICE float operator*(Half a, float b) {
return static_cast<float>(a) * b;
}
inline C10_HOST_DEVICE float operator/(Half a, float b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / b;
}
inline C10_HOST_DEVICE float operator+(float a, Half b) {
return a + static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator-(float a, Half b) {
return a - static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator*(float a, Half b) {
return a * static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator/(float a, Half b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator+=(float& a, const Half& b) {
return a += static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator-=(float& a, const Half& b) {
return a -= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator*=(float& a, const Half& b) {
return a *= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator/=(float& a, const Half& b) {
return a /= static_cast<float>(b);
}
/// Arithmetic with doubles
inline C10_HOST_DEVICE double operator+(Half a, double b) {
return static_cast<double>(a) + b;
}
inline C10_HOST_DEVICE double operator-(Half a, double b) {
return static_cast<double>(a) - b;
}
inline C10_HOST_DEVICE double operator*(Half a, double b) {
return static_cast<double>(a) * b;
}
inline C10_HOST_DEVICE double operator/(Half a, double b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<double>(a) / b;
}
inline C10_HOST_DEVICE double operator+(double a, Half b) {
return a + static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator-(double a, Half b) {
return a - static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator*(double a, Half b) {
return a * static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator/(double a, Half b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<double>(b);
}
/// Arithmetic with ints
inline C10_HOST_DEVICE Half operator+(Half a, int b) {
return a + static_cast<Half>(b);
}
inline C10_HOST_DEVICE Half operator-(Half a, int b) {
return a - static_cast<Half>(b);
}
inline C10_HOST_DEVICE Half operator*(Half a, int b) {
return a * static_cast<Half>(b);
}
inline C10_HOST_DEVICE Half operator/(Half a, int b) {
return a / static_cast<Half>(b);
}
inline C10_HOST_DEVICE Half operator+(int a, Half b) {
return static_cast<Half>(a) + b;
}
inline C10_HOST_DEVICE Half operator-(int a, Half b) {
return static_cast<Half>(a) - b;
}
inline C10_HOST_DEVICE Half operator*(int a, Half b) {
return static_cast<Half>(a) * b;
}
inline C10_HOST_DEVICE Half operator/(int a, Half b) {
return static_cast<Half>(a) / b;
}
//// Arithmetic with int64_t
inline C10_HOST_DEVICE Half operator+(Half a, int64_t b) {
return a + static_cast<Half>(b);
}
inline C10_HOST_DEVICE Half operator-(Half a, int64_t b) {
return a - static_cast<Half>(b);
}
inline C10_HOST_DEVICE Half operator*(Half a, int64_t b) {
return a * static_cast<Half>(b);
}
inline C10_HOST_DEVICE Half operator/(Half a, int64_t b) {
return a / static_cast<Half>(b);
}
inline C10_HOST_DEVICE Half operator+(int64_t a, Half b) {
return static_cast<Half>(a) + b;
}
inline C10_HOST_DEVICE Half operator-(int64_t a, Half b) {
return static_cast<Half>(a) - b;
}
inline C10_HOST_DEVICE Half operator*(int64_t a, Half b) {
return static_cast<Half>(a) * b;
}
inline C10_HOST_DEVICE Half operator/(int64_t a, Half b) {
return static_cast<Half>(a) / b;
}
/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Half to float.
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::Half> {
public:
static constexpr bool is_specialized = true;
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
static constexpr auto has_denorm_loss =
numeric_limits<float>::has_denorm_loss;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = true;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 11;
static constexpr int digits10 = 3;
static constexpr int max_digits10 = 5;
static constexpr int radix = 2;
static constexpr int min_exponent = -13;
static constexpr int min_exponent10 = -4;
static constexpr int max_exponent = 16;
static constexpr int max_exponent10 = 4;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before =
numeric_limits<float>::tinyness_before;
static constexpr c10::Half min() {
return c10::Half(0x0400, c10::Half::from_bits());
}
static constexpr c10::Half lowest() {
return c10::Half(0xFBFF, c10::Half::from_bits());
}
static constexpr c10::Half max() {
return c10::Half(0x7BFF, c10::Half::from_bits());
}
static constexpr c10::Half epsilon() {
return c10::Half(0x1400, c10::Half::from_bits());
}
static constexpr c10::Half round_error() {
return c10::Half(0x3800, c10::Half::from_bits());
}
static constexpr c10::Half infinity() {
return c10::Half(0x7C00, c10::Half::from_bits());
}
static constexpr c10::Half quiet_NaN() {
return c10::Half(0x7E00, c10::Half::from_bits());
}
static constexpr c10::Half signaling_NaN() {
return c10::Half(0x7D00, c10::Half::from_bits());
}
static constexpr c10::Half denorm_min() {
return c10::Half(0x0001, c10::Half::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -0,0 +1,535 @@
#pragma once
/// Defines the Half type (half-precision floating-point) including conversions
/// to standard C types and basic arithmetic operations. Note that arithmetic
/// operations are implemented by converting to floating point and
/// performing the operation in float32, instead of using CUDA half intrinsics.
/// Most uses of this type within ATen are memory bound, including the
/// element-wise kernels, and the half intrinsics aren't efficient on all GPUs.
/// If you are writing a compute bound kernel, you can use the CUDA half
/// intrinsics directly on the Half type from device code.
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/TypeSafeSignMath.h>
#include <c10/util/bit_cast.h>
#include <c10/util/complex.h>
#include <c10/util/floating_point_utils.h>
#include <type_traits>
#if defined(__cplusplus)
#include <cmath>
#elif !defined(__OPENCL_VERSION__)
#include <math.h>
#endif
#ifdef _MSC_VER
#include <intrin.h>
#endif
#include <cstdint>
#include <cstring>
#include <iosfwd>
#include <limits>
#include <ostream>
#ifdef __CUDACC__
#include <cuda_fp16.h>
#endif
#ifdef __HIPCC__
#include <hip/hip_fp16.h>
#endif
#if defined(CL_SYCL_LANGUAGE_VERSION)
#include <CL/sycl.hpp> // for SYCL 1.2.1
#elif defined(SYCL_LANGUAGE_VERSION)
#include <sycl/sycl.hpp> // for SYCL 2020
#endif
#if defined(__aarch64__) && !defined(__CUDACC__)
#include <arm_neon.h>
#endif
namespace c10 {
namespace detail {
/*
* Convert a 16-bit floating-point number in IEEE half-precision format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format, in bit representation.
*
* @note The implementation doesn't use any floating-point operations.
*/
inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) {
/*
* Extend the half-precision floating-point number to 32 bits and shift to the
* upper part of the 32-bit word:
* +---+-----+------------+-------------------+
* | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
* +---+-----+------------+-------------------+
* Bits 31 26-30 16-25 0-15
*
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
* - zero bits.
*/
const uint32_t w = (uint32_t)h << 16;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = w & UINT32_C(0x80000000);
/*
* Extract mantissa and biased exponent of the input number into the bits 0-30
* of the 32-bit word:
*
* +---+-----+------------+-------------------+
* | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
* +---+-----+------------+-------------------+
* Bits 30 27-31 17-26 0-16
*/
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
/*
* Renorm shift is the number of bits to shift mantissa left to make the
* half-precision number normalized. If the initial number is normalized, some
* of its high 6 bits (sign == 0 and 5-bit exponent) equals one. In this case
* renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note
* that if we shift denormalized nonsign by renorm_shift, the unit bit of
* mantissa will shift into exponent, turning the biased exponent into 1, and
* making mantissa normalized (i.e. without leading 1).
*/
#ifdef _MSC_VER
unsigned long nonsign_bsr;
_BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
#else
uint32_t renorm_shift = __builtin_clz(nonsign);
#endif
renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0;
/*
* Iff half-precision number has exponent of 15, the addition overflows
* it into bit 31, and the subsequent shift turns the high 9 bits
* into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number
* had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise
*/
const int32_t inf_nan_mask =
((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000);
/*
* Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31
* into 1. Otherwise, bit 31 remains 0. The signed shift right by 31
* broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==
* 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)
* 0x00000000 otherwise
*/
const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
/*
* 1. Shift nonsign left by renorm_shift to normalize it (if the input
* was denormal)
* 2. Shift nonsign right by 3 so the exponent (5 bits originally)
* becomes an 8-bit field and 10-bit mantissa shifts into the 10 high
* bits of the 23-bit mantissa of IEEE single-precision number.
* 3. Add 0x70 to the exponent (starting at bit 23) to compensate the
* different in exponent bias (0x7F for single-precision number less 0xF
* for half-precision number).
* 4. Subtract renorm_shift from the exponent (starting at bit 23) to
* account for renormalization. As renorm_shift is less than 0x70, this
* can be combined with step 3.
* 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the
* input was NaN or infinity.
* 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent
* into zero if the input was zero.
* 7. Combine with the sign of the input number.
*/
return sign |
((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
}
/*
* Convert a 16-bit floating-point number in IEEE half-precision format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format.
*
* @note The implementation relies on IEEE-like (no assumption about rounding
* mode and no operations on denormals) floating-point operations and bitcasts
* between integer and floating-point variables.
*/
C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
/*
* Extend the half-precision floating-point number to 32 bits and shift to the
* upper part of the 32-bit word:
* +---+-----+------------+-------------------+
* | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
* +---+-----+------------+-------------------+
* Bits 31 26-30 16-25 0-15
*
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
* - zero bits.
*/
const uint32_t w = (uint32_t)h << 16;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = w & UINT32_C(0x80000000);
/*
* Extract mantissa and biased exponent of the input number into the high bits
* of the 32-bit word:
*
* +-----+------------+---------------------+
* |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000|
* +-----+------------+---------------------+
* Bits 27-31 17-26 0-16
*/
const uint32_t two_w = w + w;
/*
* Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become
* mantissa and exponent of a single-precision floating-point number:
*
* S|Exponent | Mantissa
* +-+---+-----+------------+----------------+
* |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000|
* +-+---+-----+------------+----------------+
* Bits | 23-31 | 0-22
*
* Next, there are some adjustments to the exponent:
* - The exponent needs to be corrected by the difference in exponent bias
* between single-precision and half-precision formats (0x7F - 0xF = 0x70)
* - Inf and NaN values in the inputs should become Inf and NaN values after
* conversion to the single-precision number. Therefore, if the biased
* exponent of the half-precision input was 0x1F (max possible value), the
* biased exponent of the single-precision output must be 0xFF (max possible
* value). We do this correction in two steps:
* - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset
* below) rather than by 0x70 suggested by the difference in the exponent bias
* (see above).
* - Then we multiply the single-precision result of exponent adjustment by
* 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the
* necessary exponent adjustment by 0x70 due to difference in exponent bias.
* The floating-point multiplication hardware would ensure than Inf and
* NaN would retain their value on at least partially IEEE754-compliant
* implementations.
*
* Note that the above operations do not handle denormal inputs (where biased
* exponent == 0). However, they also do not operate on denormal inputs, and
* do not produce denormal results.
*/
constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23;
// const float exp_scale = 0x1.0p-112f;
constexpr uint32_t scale_bits = (uint32_t)15 << 23;
float exp_scale_val = 0;
std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
const float exp_scale = exp_scale_val;
const float normalized_value =
fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
/*
* Convert denormalized half-precision inputs into single-precision results
* (always normalized). Zero inputs are also handled here.
*
* In a denormalized number the biased exponent is zero, and mantissa has
* on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word.
*
* zeros | mantissa
* +---------------------------+------------+
* |0000 0000 0000 0000 0000 00|MM MMMM MMMM|
* +---------------------------+------------+
* Bits 10-31 0-9
*
* Now, remember that denormalized half-precision numbers are represented as:
* FP16 = mantissa * 2**(-24).
* The trick is to construct a normalized single-precision number with the
* same mantissa and thehalf-precision input and with an exponent which would
* scale the corresponding mantissa bits to 2**(-24). A normalized
* single-precision floating-point number is represented as: FP32 = (1 +
* mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased
* exponent is 126, a unit change in the mantissa of the input denormalized
* half-precision number causes a change of the constructed single-precision
* number by 2**(-24), i.e. the same amount.
*
* The last step is to adjust the bias of the constructed single-precision
* number. When the input half-precision number is zero, the constructed
* single-precision number has the value of FP32 = 1 * 2**(126 - 127) =
* 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed
* single-precision number to get the numerical equivalent of the input
* half-precision number.
*/
constexpr uint32_t magic_mask = UINT32_C(126) << 23;
constexpr float magic_bias = 0.5f;
const float denormalized_value =
fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
/*
* - Choose either results of conversion of input as a normalized number, or
* as a denormalized number, depending on the input exponent. The variable
* two_w contains input exponent in bits 27-31, therefore if its smaller than
* 2**27, the input is either a denormal number, or zero.
* - Combine the result of conversion of exponent and mantissa with the sign
* of the input number.
*/
constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27;
const uint32_t result = sign |
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value)
: fp32_to_bits(normalized_value));
return fp32_from_bits(result);
}
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 16-bit floating-point number in IEEE half-precision format, in bit
* representation.
*
* @note The implementation relies on IEEE-like (no assumption about rounding
* mode and no operations on denormals) floating-point operations and bitcasts
* between integer and floating-point variables.
*/
inline uint16_t fp16_ieee_from_fp32_value(float f) {
// const float scale_to_inf = 0x1.0p+112f;
// const float scale_to_zero = 0x1.0p-110f;
constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23;
constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23;
float scale_to_inf_val = 0, scale_to_zero_val = 0;
std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val));
std::memcpy(
&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val));
const float scale_to_inf = scale_to_inf_val;
const float scale_to_zero = scale_to_zero_val;
#if defined(_MSC_VER) && _MSC_VER == 1916
float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero;
#else
float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
#endif
const uint32_t w = fp32_to_bits(f);
const uint32_t shl1_w = w + w;
const uint32_t sign = w & UINT32_C(0x80000000);
uint32_t bias = shl1_w & UINT32_C(0xFF000000);
if (bias < UINT32_C(0x71000000)) {
bias = UINT32_C(0x71000000);
}
base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
const uint32_t bits = fp32_to_bits(base);
const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
const uint32_t nonsign = exp_bits + mantissa_bits;
return static_cast<uint16_t>(
(sign >> 16) |
(shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign));
}
#if defined(__aarch64__) && !defined(__CUDACC__)
inline float16_t fp16_from_bits(uint16_t h) {
return c10::bit_cast<float16_t>(h);
}
inline uint16_t fp16_to_bits(float16_t f) {
return c10::bit_cast<uint16_t>(f);
}
// According to https://godbolt.org/z/frExdbsWG it would translate to single
// fcvt s0, h0
inline float native_fp16_to_fp32_value(uint16_t h) {
return static_cast<float>(fp16_from_bits(h));
}
inline uint16_t native_fp16_from_fp32_value(float f) {
return fp16_to_bits(static_cast<float16_t>(f));
}
#endif
} // namespace detail
struct alignas(2) Half {
unsigned short x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
// HIP wants __host__ __device__ tag, CUDA does not
#if defined(USE_ROCM)
C10_HOST_DEVICE Half() = default;
#else
Half() = default;
#endif
constexpr C10_HOST_DEVICE Half(unsigned short bits, from_bits_t) : x(bits) {}
#if defined(__aarch64__) && !defined(__CUDACC__)
inline Half(float16_t value);
inline operator float16_t() const;
#else
inline C10_HOST_DEVICE Half(float value);
inline C10_HOST_DEVICE operator float() const;
#endif
#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_HOST_DEVICE Half(const __half& value);
inline C10_HOST_DEVICE operator __half() const;
#endif
#ifdef SYCL_LANGUAGE_VERSION
inline C10_HOST_DEVICE Half(const sycl::half& value);
inline C10_HOST_DEVICE operator sycl::half() const;
#endif
};
// TODO : move to complex.h
template <>
struct alignas(4) complex<Half> {
Half real_;
Half imag_;
// Constructors
complex() = default;
// Half constructor is not constexpr so the following constructor can't
// be constexpr
C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag)
: real_(real), imag_(imag) {}
C10_HOST_DEVICE inline complex(const c10::complex<float>& value)
: real_(value.real()), imag_(value.imag()) {}
// Conversion operator
inline C10_HOST_DEVICE operator c10::complex<float>() const {
return {real_, imag_};
}
constexpr C10_HOST_DEVICE Half real() const {
return real_;
}
constexpr C10_HOST_DEVICE Half imag() const {
return imag_;
}
C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) {
real_ = static_cast<float>(real_) + static_cast<float>(other.real_);
imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);
return *this;
}
C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) {
real_ = static_cast<float>(real_) - static_cast<float>(other.real_);
imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);
return *this;
}
C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) {
auto a = static_cast<float>(real_);
auto b = static_cast<float>(imag_);
auto c = static_cast<float>(other.real());
auto d = static_cast<float>(other.imag());
real_ = a * c - b * d;
imag_ = a * d + b * c;
return *this;
}
};
// In some versions of MSVC, there will be a compiler error when building.
// C4146: unary minus operator applied to unsigned type, result still unsigned
// C4804: unsafe use of type 'bool' in operation
// It can be addressed by disabling the following warning.
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4146)
#pragma warning(disable : 4804)
#pragma warning(disable : 4018)
#endif
// The overflow checks may involve float to int conversion which may
// trigger precision loss warning. Re-enable the warning once the code
// is fixed. See T58053069.
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
#endif
// bool can be converted to any type.
// Without specializing on bool, in pytorch_linux_trusty_py2_7_9_build:
// `error: comparison of constant '255' with boolean expression is always false`
// for `f > limit::max()` below
template <typename To, typename From>
std::enable_if_t<std::is_same_v<From, bool>, bool> overflows(
From /*f*/,
bool strict_unsigned [[maybe_unused]] = false) {
return false;
}
// skip isnan and isinf check for integral types
template <typename To, typename From>
std::enable_if_t<std::is_integral_v<From> && !std::is_same_v<From, bool>, bool>
overflows(From f, bool strict_unsigned = false) {
using limit = std::numeric_limits<typename scalar_value_type<To>::type>;
if constexpr (!limit::is_signed && std::numeric_limits<From>::is_signed) {
// allow for negative numbers to wrap using two's complement arithmetic.
// For example, with uint8, this allows for `a - b` to be treated as
// `a + 255 * b`.
if (!strict_unsigned) {
return greater_than_max<To>(f) ||
(c10::is_negative(f) &&
-static_cast<uint64_t>(f) > static_cast<uint64_t>(limit::max()));
}
}
return c10::less_than_lowest<To>(f) || greater_than_max<To>(f);
}
template <typename To, typename From>
std::enable_if_t<std::is_floating_point_v<From>, bool> overflows(
From f,
bool strict_unsigned [[maybe_unused]] = false) {
using limit = std::numeric_limits<typename scalar_value_type<To>::type>;
if (limit::has_infinity && std::isinf(static_cast<double>(f))) {
return false;
}
if (!limit::has_quiet_NaN && (f != f)) {
return true;
}
return f < limit::lowest() || f > limit::max();
}
C10_CLANG_DIAGNOSTIC_POP()
#ifdef _MSC_VER
#pragma warning(pop)
#endif
template <typename To, typename From>
std::enable_if_t<is_complex<From>::value, bool> overflows(
From f,
bool strict_unsigned = false) {
// casts from complex to real are considered to overflow if the
// imaginary component is non-zero
if (!is_complex<To>::value && f.imag() != 0) {
return true;
}
// Check for overflow componentwise
// (Technically, the imag overflow check is guaranteed to be false
// when !is_complex<To>, but any optimizer worth its salt will be
// able to figure it out.)
return overflows<
typename scalar_value_type<To>::type,
typename From::value_type>(f.real(), strict_unsigned) ||
overflows<
typename scalar_value_type<To>::type,
typename From::value_type>(f.imag(), strict_unsigned);
}
C10_API inline std::ostream& operator<<(std::ostream& out, const Half& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/Half-inl.h> // IWYU pragma: keep

View File

@ -0,0 +1,77 @@
#pragma once
#include <cstddef>
#include <functional>
#include <utility>
namespace c10 {
/**
* This template simplifies generation of simple classes that wrap an id
* in a typesafe way. Namely, you can use it to create a very lightweight
* type that only offers equality comparators and hashing. Example:
*
* struct MyIdType final : IdWrapper<MyIdType, uint32_t> {
* constexpr explicit MyIdType(uint32_t id): IdWrapper(id) {}
* };
*
* Then in the global top level namespace:
*
* C10_DEFINE_HASH_FOR_IDWRAPPER(MyIdType);
*
* That's it - equality operators and hash functions are automatically defined
* for you, given the underlying type supports it.
*/
template <class ConcreteType, class UnderlyingType>
class IdWrapper {
public:
using underlying_type = UnderlyingType;
using concrete_type = ConcreteType;
protected:
constexpr explicit IdWrapper(underlying_type id) noexcept(
noexcept(underlying_type(std::declval<underlying_type>())))
: id_(id) {}
constexpr underlying_type underlyingId() const
noexcept(noexcept(underlying_type(std::declval<underlying_type>()))) {
return id_;
}
private:
friend size_t hash_value(const concrete_type& v) {
return std::hash<underlying_type>()(v.id_);
}
// TODO Making operator== noexcept if underlying type is noexcept equality
// comparable doesn't work with GCC 4.8.
// Fix this once we don't need GCC 4.8 anymore.
friend constexpr bool operator==(
const concrete_type& lhs,
const concrete_type& rhs) noexcept {
return lhs.id_ == rhs.id_;
}
// TODO Making operator!= noexcept if operator== is noexcept doesn't work with
// GCC 4.8.
// Fix this once we don't need GCC 4.8 anymore.
friend constexpr bool operator!=(
const concrete_type& lhs,
const concrete_type& rhs) noexcept {
return !(lhs == rhs);
}
underlying_type id_;
};
} // namespace c10
#define C10_DEFINE_HASH_FOR_IDWRAPPER(ClassName) \
namespace std { \
template <> \
struct hash<ClassName> { \
size_t operator()(ClassName x) const { \
return hash_value(x); \
} \
}; \
}

View File

@ -0,0 +1,120 @@
#pragma once
#include <atomic>
#include <utility>
namespace c10 {
/**
* Thread-safe lazy value with opportunistic concurrency: on concurrent first
* access, the factory may be called by multiple threads, but only one result is
* stored and its reference returned to all the callers.
*
* Value is heap-allocated; this optimizes for the case in which the value is
* never actually computed.
*/
template <class T>
class OptimisticLazy {
public:
OptimisticLazy() = default;
OptimisticLazy(const OptimisticLazy& other) {
if (T* value = other.value_.load(std::memory_order_acquire)) {
value_ = new T(*value);
}
}
OptimisticLazy(OptimisticLazy&& other) noexcept
: value_(other.value_.exchange(nullptr, std::memory_order_acq_rel)) {}
~OptimisticLazy() {
reset();
}
template <class Factory>
T& ensure(Factory&& factory) {
if (T* value = value_.load(std::memory_order_acquire)) {
return *value;
}
T* value = new T(factory());
T* old = nullptr;
if (!value_.compare_exchange_strong(
old, value, std::memory_order_release, std::memory_order_acquire)) {
delete value;
value = old;
}
return *value;
}
// The following methods are not thread-safe: they should not be called
// concurrently with any other method.
OptimisticLazy& operator=(const OptimisticLazy& other) {
*this = OptimisticLazy{other};
return *this;
}
OptimisticLazy& operator=(OptimisticLazy&& other) noexcept {
if (this != &other) {
reset();
value_.store(
other.value_.exchange(nullptr, std::memory_order_acquire),
std::memory_order_release);
}
return *this;
}
void reset() {
if (T* old = value_.load(std::memory_order_relaxed)) {
value_.store(nullptr, std::memory_order_relaxed);
delete old;
}
}
private:
std::atomic<T*> value_{nullptr};
};
/**
* Interface for a value that is computed on first access.
*/
template <class T>
class LazyValue {
public:
virtual ~LazyValue() = default;
virtual const T& get() const = 0;
};
/**
* Convenience thread-safe LazyValue implementation with opportunistic
* concurrency.
*/
template <class T>
class OptimisticLazyValue : public LazyValue<T> {
public:
const T& get() const override {
return value_.ensure([this] { return compute(); });
}
private:
virtual T compute() const = 0;
mutable OptimisticLazy<T> value_;
};
/**
* Convenience immutable (thus thread-safe) LazyValue implementation for cases
* in which the value is not actually lazy.
*/
template <class T>
class PrecomputedLazyValue : public LazyValue<T> {
public:
PrecomputedLazyValue(T value) : value_(std::move(value)) {}
const T& get() const override {
return value_;
}
private:
T value_;
};
} // namespace c10

View File

@ -0,0 +1,223 @@
#include <c10/macros/Macros.h>
#include <c10/util/Synchronized.h>
#include <array>
#include <atomic>
#include <mutex>
#include <thread>
namespace c10 {
namespace detail {
struct IncrementRAII final {
public:
explicit IncrementRAII(std::atomic<int32_t>* counter) : _counter(counter) {
_counter->fetch_add(1);
}
~IncrementRAII() {
_counter->fetch_sub(1);
}
private:
std::atomic<int32_t>* _counter;
C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII);
};
} // namespace detail
// LeftRight wait-free readers synchronization primitive
// https://hal.archives-ouvertes.fr/hal-01207881/document
//
// LeftRight is quite easy to use (it can make an arbitrary
// data structure permit wait-free reads), but it has some
// particular performance characteristics you should be aware
// of if you're deciding to use it:
//
// - Reads still incur an atomic write (this is how LeftRight
// keeps track of how long it needs to keep around the old
// data structure)
//
// - Writes get executed twice, to keep both the left and right
// versions up to date. So if your write is expensive or
// nondeterministic, this is also an inappropriate structure
//
// LeftRight is used fairly rarely in PyTorch's codebase. If you
// are still not sure if you need it or not, consult your local
// C++ expert.
//
template <class T>
class LeftRight final {
public:
template <class... Args>
explicit LeftRight(const Args&... args)
: _counters{{{0}, {0}}},
_foregroundCounterIndex(0),
_foregroundDataIndex(0),
_data{{T{args...}, T{args...}}},
_writeMutex() {}
// Copying and moving would not be threadsafe.
// Needs more thought and careful design to make that work.
LeftRight(const LeftRight&) = delete;
LeftRight(LeftRight&&) noexcept = delete;
LeftRight& operator=(const LeftRight&) = delete;
LeftRight& operator=(LeftRight&&) noexcept = delete;
~LeftRight() {
// wait until any potentially running writers are finished
{ std::unique_lock<std::mutex> lock(_writeMutex); }
// wait until any potentially running readers are finished
while (_counters[0].load() != 0 || _counters[1].load() != 0) {
std::this_thread::yield();
}
}
template <typename F>
auto read(F&& readFunc) const {
detail::IncrementRAII _increment_counter(
&_counters[_foregroundCounterIndex.load()]);
return std::forward<F>(readFunc)(_data[_foregroundDataIndex.load()]);
}
// Throwing an exception in writeFunc is ok but causes the state to be either
// the old or the new state, depending on if the first or the second call to
// writeFunc threw.
template <typename F>
auto write(F&& writeFunc) {
std::unique_lock<std::mutex> lock(_writeMutex);
return _write(std::forward<F>(writeFunc));
}
private:
template <class F>
auto _write(const F& writeFunc) {
/*
* Assume, A is in background and B in foreground. In simplified terms, we
* want to do the following:
* 1. Write to A (old background)
* 2. Switch A/B
* 3. Write to B (new background)
*
* More detailed algorithm (explanations on why this is important are below
* in code):
* 1. Write to A
* 2. Switch A/B data pointers
* 3. Wait until A counter is zero
* 4. Switch A/B counters
* 5. Wait until B counter is zero
* 6. Write to B
*/
auto localDataIndex = _foregroundDataIndex.load();
// 1. Write to A
_callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex);
// 2. Switch A/B data pointers
localDataIndex = localDataIndex ^ 1;
_foregroundDataIndex = localDataIndex;
/*
* 3. Wait until A counter is zero
*
* In the previous write run, A was foreground and B was background.
* There was a time after switching _foregroundDataIndex (B to foreground)
* and before switching _foregroundCounterIndex, in which new readers could
* have read B but incremented A's counter.
*
* In this current run, we just switched _foregroundDataIndex (A back to
* foreground), but before writing to the new background B, we have to make
* sure A's counter was zero briefly, so all these old readers are gone.
*/
auto localCounterIndex = _foregroundCounterIndex.load();
_waitForBackgroundCounterToBeZero(localCounterIndex);
/*
* 4. Switch A/B counters
*
* Now that we know all readers on B are really gone, we can switch the
* counters and have new readers increment A's counter again, which is the
* correct counter since they're reading A.
*/
localCounterIndex = localCounterIndex ^ 1;
_foregroundCounterIndex = localCounterIndex;
/*
* 5. Wait until B counter is zero
*
* This waits for all the readers on B that came in while both data and
* counter for B was in foreground, i.e. normal readers that happened
* outside of that brief gap between switching data and counter.
*/
_waitForBackgroundCounterToBeZero(localCounterIndex);
// 6. Write to B
return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex);
}
template <class F>
auto _callWriteFuncOnBackgroundInstance(
const F& writeFunc,
uint8_t localDataIndex) {
try {
return writeFunc(_data[localDataIndex ^ 1]);
} catch (...) {
// recover invariant by copying from the foreground instance
_data[localDataIndex ^ 1] = _data[localDataIndex];
// rethrow
throw;
}
}
void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) {
while (_counters[counterIndex ^ 1].load() != 0) {
std::this_thread::yield();
}
}
mutable std::array<std::atomic<int32_t>, 2> _counters;
std::atomic<uint8_t> _foregroundCounterIndex;
std::atomic<uint8_t> _foregroundDataIndex;
std::array<T, 2> _data;
std::mutex _writeMutex;
};
// RWSafeLeftRightWrapper is API compatible with LeftRight and uses a
// read-write lock to protect T (data).
template <class T>
class RWSafeLeftRightWrapper final {
public:
template <class... Args>
explicit RWSafeLeftRightWrapper(const Args&... args) : data_{args...} {}
// RWSafeLeftRightWrapper is not copyable or moveable since LeftRight
// is not copyable or moveable.
RWSafeLeftRightWrapper(const RWSafeLeftRightWrapper&) = delete;
RWSafeLeftRightWrapper(RWSafeLeftRightWrapper&&) noexcept = delete;
RWSafeLeftRightWrapper& operator=(const RWSafeLeftRightWrapper&) = delete;
RWSafeLeftRightWrapper& operator=(RWSafeLeftRightWrapper&&) noexcept = delete;
template <typename F>
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
auto read(F&& readFunc) const {
return data_.withLock(
[&readFunc](T const& data) { return std::forward<F>(readFunc)(data); });
}
template <typename F>
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
auto write(F&& writeFunc) {
return data_.withLock(
[&writeFunc](T& data) { return std::forward<F>(writeFunc)(data); });
}
private:
c10::Synchronized<T> data_;
};
} // namespace c10

View File

@ -0,0 +1,38 @@
#pragma once
#include <c10/macros/Macros.h>
#include <cstring>
namespace c10 {
namespace detail {
template <typename T>
struct LoadImpl {
C10_HOST_DEVICE static T apply(const void* src) {
return *reinterpret_cast<const T*>(src);
}
};
template <>
struct LoadImpl<bool> {
C10_HOST_DEVICE static bool apply(const void* src) {
static_assert(sizeof(bool) == sizeof(char));
// NOTE: [Loading boolean values]
// Protect against invalid boolean values by loading as a byte
// first, then converting to bool (see gh-54789).
return *reinterpret_cast<const unsigned char*>(src);
}
};
} // namespace detail
template <typename T>
C10_HOST_DEVICE T load(const void* src) {
return c10::detail::LoadImpl<T>::apply(src);
}
template <typename scalar_t>
C10_HOST_DEVICE scalar_t load(const scalar_t* src) {
return c10::detail::LoadImpl<scalar_t>::apply(src);
}
} // namespace c10

View File

@ -0,0 +1,370 @@
#ifndef C10_UTIL_LOGGING_H_
#define C10_UTIL_LOGGING_H_
#include <climits>
#include <exception>
#include <functional>
#include <limits>
#include <sstream>
#include <c10/macros/Macros.h>
#include <c10/util/Backtrace.h>
#include <c10/util/Exception.h>
#include <c10/util/Flags.h>
#include <c10/util/StringUtil.h>
// CAFFE2_LOG_THRESHOLD is a compile time flag that would allow us to turn off
// logging at compile time so no logging message below that level is produced
// at all. The value should be between INT_MIN and CAFFE_FATAL.
#ifndef CAFFE2_LOG_THRESHOLD
// If we have not defined the compile time log threshold, we keep all the
// log cases.
#define CAFFE2_LOG_THRESHOLD INT_MIN
#endif // CAFFE2_LOG_THRESHOLD
// Below are different implementations for glog and non-glog cases.
#ifdef C10_USE_GLOG
#include <c10/util/logging_is_google_glog.h>
#else // !C10_USE_GLOG
#include <c10/util/logging_is_not_google_glog.h>
#endif // C10_USE_GLOG
C10_DECLARE_int(caffe2_log_level);
C10_DECLARE_bool(caffe2_use_fatal_for_enforce);
// Some versions of GLOG support less-spammy version of LOG_EVERY_MS. If it's
// not available - just short-circuit to the always working one one.
// We define the C10_ name to avoid confusing other files
#ifdef LOG_EVERY_MS
#define C10_LOG_EVERY_MS(severity, ms) LOG_EVERY_MS(severity, ms)
#else
#define C10_LOG_EVERY_MS(severity, ms) LOG(severity)
#endif
// Same for LOG_FIRST_N
#ifdef LOG_FIRST_N
#define C10_LOG_FIRST_N(severity, n) LOG_FIRST_N(severity, n)
#else
#define C10_LOG_FIRST_N(severity, n) LOG(severity)
#endif
// Same for LOG_EVERY_N
#ifdef LOG_EVERY_N
#define C10_LOG_EVERY_N(severity, n) LOG_EVERY_N(severity, n)
#else
#define C10_LOG_EVERY_N(severity, n) LOG(severity)
#endif
namespace c10 {
using std::string;
// Functions that we use for initialization.
C10_API bool InitCaffeLogging(int* argc, char** argv);
C10_API void UpdateLoggingLevelsFromFlags();
[[noreturn]] C10_API void ThrowEnforceNotMet(
const char* file,
const int line,
const char* condition,
const std::string& msg,
const void* caller = nullptr);
[[noreturn]] C10_API void ThrowEnforceNotMet(
const char* file,
const int line,
const char* condition,
const char* msg,
const void* caller = nullptr);
[[noreturn]] C10_API inline void ThrowEnforceNotMet(
const char* file,
const int line,
const char* condition,
detail::CompileTimeEmptyString /*msg*/,
const void* caller = nullptr) {
ThrowEnforceNotMet(file, line, condition, "", caller);
}
[[noreturn]] C10_API void ThrowEnforceFiniteNotMet(
const char* file,
const int line,
const char* condition,
const std::string& msg,
const void* caller = nullptr);
[[noreturn]] C10_API void ThrowEnforceFiniteNotMet(
const char* file,
const int line,
const char* condition,
const char* msg,
const void* caller = nullptr);
[[noreturn]] C10_API inline void ThrowEnforceFiniteNotMet(
const char* file,
const int line,
const char* condition,
detail::CompileTimeEmptyString /*msg*/,
const void* caller = nullptr) {
ThrowEnforceFiniteNotMet(file, line, condition, "", caller);
}
constexpr bool IsUsingGoogleLogging() {
#ifdef C10_USE_GLOG
return true;
#else
return false;
#endif
}
/**
* A utility to allow one to show log info to stderr after the program starts.
*
* This is similar to calling GLOG's --logtostderr, or setting caffe2_log_level
* to smaller than INFO. You are recommended to only use this in a few sparse
* cases, such as when you want to write a tutorial or something. Normally, use
* the commandline flags to set the log level.
*/
C10_API void ShowLogInfoToStderr();
C10_API void SetStackTraceFetcher(std::function<::c10::Backtrace()> fetcher);
/**
* Convenience function for non-lazy stack trace fetchers. The Backtrace
* overload should be preferred when stringifying the backtrace is expensive.
*/
C10_API void SetStackTraceFetcher(std::function<std::string()> fetcher);
using EnforceNotMet = ::c10::Error;
#define CAFFE_ENFORCE(condition, ...) \
do { \
if (C10_UNLIKELY(!(condition))) { \
::c10::ThrowEnforceNotMet( \
__FILE__, __LINE__, #condition, ::c10::str(__VA_ARGS__)); \
} \
} while (false)
#define CAFFE_ENFORCE_FINITE(condition, ...) \
do { \
if (C10_UNLIKELY(!(condition))) { \
::c10::ThrowEnforceFiniteNotMet( \
__FILE__, __LINE__, #condition, ::c10::str(__VA_ARGS__)); \
} \
} while (false)
#define CAFFE_ENFORCE_WITH_CALLER(condition, ...) \
do { \
if (C10_UNLIKELY(!(condition))) { \
::c10::ThrowEnforceNotMet( \
__FILE__, __LINE__, #condition, ::c10::str(__VA_ARGS__), this); \
} \
} while (false)
#define CAFFE_THROW(...) \
::c10::ThrowEnforceNotMet(__FILE__, __LINE__, "", ::c10::str(__VA_ARGS__))
/**
* Rich logging messages
*
* CAFFE_ENFORCE_THAT can be used with one of the "checker functions" that
* capture input argument values and add it to the exception message. E.g.
* `CAFFE_ENFORCE_THAT(Equals(foo(x), bar(y)), "Optional additional message")`
* would evaluate both foo and bar only once and if the results are not equal -
* include them in the exception message.
*
* Some of the basic checker functions like Equals or Greater are already
* defined below. Other header might define customized checkers by adding
* functions to caffe2::enforce_detail namespace. For example:
*
* namespace caffe2 { namespace enforce_detail {
* inline EnforceFailMessage IsVector(const vector<int64_t>& shape) {
* if (shape.size() == 1) { return EnforceOK(); }
* return c10::str("Shape ", shape, " is not a vector");
* }
* }}
*
* With further usages like `CAFFE_ENFORCE_THAT(IsVector(Input(0).dims()))`
*
* Convenient wrappers for binary operations like CAFFE_ENFORCE_EQ are provided
* too. Please use them instead of TORCH_CHECK_EQ and friends for failures in
* user-provided input.
*/
namespace enforce_detail {
template <typename T1, typename T2>
std::string enforceFailMsgImpl(const T1& x, const T2& y) {
return c10::str(x, " vs ", y);
}
template <typename T1, typename T2, typename... Args>
std::string enforceFailMsgImpl(const T1& x, const T2& y, const Args&... args) {
return c10::str(x, " vs ", y, ". ", args...);
}
template <typename Pred, typename T1, typename T2, typename GetFailMsgFunc>
void enforceThatImpl(
Pred p,
const T1& lhs,
const T2& rhs,
const char* file,
int line,
const char* expr,
const void* caller,
GetFailMsgFunc getFailMsg) {
if (C10_UNLIKELY(!(p(lhs, rhs)))) {
::c10::ThrowEnforceNotMet(file, line, expr, getFailMsg(lhs, rhs), caller);
}
}
#define CAFFE_ENFORCE_THAT_IMPL(op, lhs, rhs, expr, ...) \
::c10::enforce_detail::enforceThatImpl( \
op, \
(lhs), \
(rhs), \
__FILE__, \
__LINE__, \
expr, \
nullptr, \
[&](const auto& arg1, const auto& arg2) { \
return ::c10::enforce_detail::enforceFailMsgImpl( \
arg1, arg2, ##__VA_ARGS__); \
})
#define CAFFE_ENFORCE_THAT_IMPL_WITH_CALLER(op, lhs, rhs, expr, ...) \
::c10::enforce_detail::enforceThatImpl( \
op, \
(lhs), \
(rhs), \
__FILE__, \
__LINE__, \
expr, \
this, \
[&](const auto& arg1, const auto& arg2) { \
return ::c10::enforce_detail::enforceFailMsgImpl( \
arg1, arg2, ##__VA_ARGS__); \
})
} // namespace enforce_detail
#define CAFFE_ENFORCE_THAT(cmp, op, lhs, rhs, ...) \
CAFFE_ENFORCE_THAT_IMPL(cmp, lhs, rhs, #lhs " " #op " " #rhs, ##__VA_ARGS__)
#define CAFFE_ENFORCE_BINARY_OP(cmp, op, x, y, ...) \
CAFFE_ENFORCE_THAT_IMPL(cmp, x, y, #x " " #op " " #y, ##__VA_ARGS__)
#define CAFFE_ENFORCE_EQ(x, y, ...) \
CAFFE_ENFORCE_BINARY_OP(std::equal_to<void>(), ==, x, y, ##__VA_ARGS__)
#define CAFFE_ENFORCE_NE(x, y, ...) \
CAFFE_ENFORCE_BINARY_OP(std::not_equal_to<void>(), !=, x, y, ##__VA_ARGS__)
#define CAFFE_ENFORCE_LE(x, y, ...) \
CAFFE_ENFORCE_BINARY_OP(std::less_equal<void>(), <=, x, y, ##__VA_ARGS__)
#define CAFFE_ENFORCE_LT(x, y, ...) \
CAFFE_ENFORCE_BINARY_OP(std::less<void>(), <, x, y, ##__VA_ARGS__)
#define CAFFE_ENFORCE_GE(x, y, ...) \
CAFFE_ENFORCE_BINARY_OP(std::greater_equal<void>(), >=, x, y, ##__VA_ARGS__)
#define CAFFE_ENFORCE_GT(x, y, ...) \
CAFFE_ENFORCE_BINARY_OP(std::greater<void>(), >, x, y, ##__VA_ARGS__)
#define CAFFE_ENFORCE_BINARY_OP_WITH_CALLER(cmp, op, x, y, ...) \
CAFFE_ENFORCE_THAT_IMPL_WITH_CALLER( \
cmp, x, y, #x " " #op " " #y, ##__VA_ARGS__)
#define CAFFE_ENFORCE_EQ_WITH_CALLER(x, y, ...) \
CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \
std::equal_to<void>(), ==, x, y, ##__VA_ARGS__)
#define CAFFE_ENFORCE_NE_WITH_CALLER(x, y, ...) \
CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \
std::not_equal_to<void>(), !=, x, y, ##__VA_ARGS__)
#define CAFFE_ENFORCE_LE_WITH_CALLER(x, y, ...) \
CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \
std::less_equal<void>(), <=, x, y, ##__VA_ARGS__)
#define CAFFE_ENFORCE_LT_WITH_CALLER(x, y, ...) \
CAFFE_ENFORCE_BINARY_OP_WITH_CALLER(std::less<void>(), <, x, y, ##__VA_ARGS__)
#define CAFFE_ENFORCE_GE_WITH_CALLER(x, y, ...) \
CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \
std::greater_equal<void>(), >=, x, y, ##__VA_ARGS__)
#define CAFFE_ENFORCE_GT_WITH_CALLER(x, y, ...) \
CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \
std::greater<void>(), >, x, y, ##__VA_ARGS__)
struct IValue;
class C10_API EventSampledHandler {
public:
virtual void log(
std::string_view model_id,
const std::vector<c10::IValue>& args) = 0;
virtual ~EventSampledHandler() = default;
};
#define C10_LOG_EVENT_SAMPLED(event, ...) \
static const std::unique_ptr<::c10::EventSampledHandler>& \
_##event##EventSampledHandler = ::c10::GetEventSampledHandler(#event); \
if (_##event##EventSampledHandler) { \
_##event##EventSampledHandler->log(__VA_ARGS__); \
}
// Must be called in the main thread before any other threads are spawned.
C10_API void InitEventSampledHandlers(
std::vector<
std::pair<std::string_view, std::unique_ptr<EventSampledHandler>>>);
C10_API const std::unique_ptr<EventSampledHandler>& GetEventSampledHandler(
std::string_view);
/**
* Very lightweight logging for the first time API usage. It's beneficial for
* tracking of individual functionality usage in larger applications.
*
* In order to ensure light-weightedness of logging, we utilize static variable
* trick - LogAPIUsage will be invoked only once and further invocations will
* just do an atomic check.
*
* Example:
* // Logs caller info with an arbitrary text event, if there is a usage.
* C10_LOG_API_USAGE_ONCE("my_api");
*/
#define C10_LOG_API_USAGE_ONCE(...) \
C10_UNUSED static bool C10_ANONYMOUS_VARIABLE(logFlag) = \
::c10::detail::LogAPIUsageFakeReturn(__VA_ARGS__);
// API usage logging capabilities
C10_API void SetAPIUsageLogger(std::function<void(const std::string&)> logger);
C10_API void LogAPIUsage(const std::string& context);
C10_API void SetAPIUsageMetadataLogger(
std::function<void(
const std::string&,
const std::map<std::string, std::string>& metadata_map)> logger);
C10_API void LogAPIUsageMetadata(
const std::string& context,
const std::map<std::string, std::string>& metadata_map);
// PyTorch ddp usage logging capabilities
// DDPLoggingData holds data that can be logged in applications
// for analysis and debugging. Data structure is defined in
// c10 directory so that it can be easily imported by both c10
// and torch files.
struct DDPLoggingData {
// logging fields that are string types.
std::map<std::string, std::string> strs_map;
// logging fields that are int64_t types.
std::map<std::string, int64_t> ints_map;
};
C10_API void SetPyTorchDDPUsageLogger(
std::function<void(const DDPLoggingData&)> logger);
C10_API void LogPyTorchDDPUsage(const DDPLoggingData& ddpData);
namespace detail {
// Return value is needed to do the static variable initialization trick
C10_API bool LogAPIUsageFakeReturn(const std::string& context);
} // namespace detail
// Initializes the c10 logger.
C10_API void initLogging();
// Sets the rank, which will be included in log messages
C10_API void SetGlobalRank(int64_t rank);
} // namespace c10
#endif // C10_UTIL_LOGGING_H_

View File

@ -0,0 +1,142 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
#endif
namespace c10 {
// TODO: Replace me with inline constexpr variable when C++17 becomes available
namespace detail {
template <typename T>
C10_HOST_DEVICE inline constexpr T e() {
return static_cast<T>(2.718281828459045235360287471352662);
}
template <typename T>
C10_HOST_DEVICE inline constexpr T euler() {
return static_cast<T>(0.577215664901532860606512090082402);
}
template <typename T>
C10_HOST_DEVICE inline constexpr T frac_1_pi() {
return static_cast<T>(0.318309886183790671537767526745028);
}
template <typename T>
C10_HOST_DEVICE inline constexpr T frac_1_sqrt_pi() {
return static_cast<T>(0.564189583547756286948079451560772);
}
template <typename T>
C10_HOST_DEVICE inline constexpr T frac_sqrt_2() {
return static_cast<T>(0.707106781186547524400844362104849);
}
template <typename T>
C10_HOST_DEVICE inline constexpr T frac_sqrt_3() {
return static_cast<T>(0.577350269189625764509148780501957);
}
template <typename T>
C10_HOST_DEVICE inline constexpr T golden_ratio() {
return static_cast<T>(1.618033988749894848204586834365638);
}
template <typename T>
C10_HOST_DEVICE inline constexpr T ln_10() {
return static_cast<T>(2.302585092994045684017991454684364);
}
template <typename T>
C10_HOST_DEVICE inline constexpr T ln_2() {
return static_cast<T>(0.693147180559945309417232121458176);
}
template <typename T>
C10_HOST_DEVICE inline constexpr T log_10_e() {
return static_cast<T>(0.434294481903251827651128918916605);
}
template <typename T>
C10_HOST_DEVICE inline constexpr T log_2_e() {
return static_cast<T>(1.442695040888963407359924681001892);
}
template <typename T>
C10_HOST_DEVICE inline constexpr T pi() {
return static_cast<T>(3.141592653589793238462643383279502);
}
template <typename T>
C10_HOST_DEVICE inline constexpr T sqrt_2() {
return static_cast<T>(1.414213562373095048801688724209698);
}
template <typename T>
C10_HOST_DEVICE inline constexpr T sqrt_3() {
return static_cast<T>(1.732050807568877293527446341505872);
}
template <>
C10_HOST_DEVICE inline constexpr BFloat16 pi<BFloat16>() {
// According to
// https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Special_values
// pi is encoded as 4049
return BFloat16(0x4049, BFloat16::from_bits());
}
template <>
C10_HOST_DEVICE inline constexpr Half pi<Half>() {
return Half(0x4248, Half::from_bits());
}
} // namespace detail
template <typename T>
constexpr T e = c10::detail::e<T>();
template <typename T>
constexpr T euler = c10::detail::euler<T>();
template <typename T>
constexpr T frac_1_pi = c10::detail::frac_1_pi<T>();
template <typename T>
constexpr T frac_1_sqrt_pi = c10::detail::frac_1_sqrt_pi<T>();
template <typename T>
constexpr T frac_sqrt_2 = c10::detail::frac_sqrt_2<T>();
template <typename T>
constexpr T frac_sqrt_3 = c10::detail::frac_sqrt_3<T>();
template <typename T>
constexpr T golden_ratio = c10::detail::golden_ratio<T>();
template <typename T>
constexpr T ln_10 = c10::detail::ln_10<T>();
template <typename T>
constexpr T ln_2 = c10::detail::ln_2<T>();
template <typename T>
constexpr T log_10_e = c10::detail::log_10_e<T>();
template <typename T>
constexpr T log_2_e = c10::detail::log_2_e<T>();
template <typename T>
constexpr T pi = c10::detail::pi<T>();
template <typename T>
constexpr T sqrt_2 = c10::detail::sqrt_2<T>();
template <typename T>
constexpr T sqrt_3 = c10::detail::sqrt_3<T>();
} // namespace c10
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -0,0 +1,237 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <memory>
#include <type_traits>
#include <utility>
namespace c10 {
/// MaybeOwnedTraits<T> describes how to borrow from T. Here is how we
/// can implement borrowing from an arbitrary type T using a raw
/// pointer to const:
template <typename T>
struct MaybeOwnedTraitsGenericImpl {
using owned_type = T;
using borrow_type = const T*;
static borrow_type createBorrow(const owned_type& from) {
return &from;
}
static void assignBorrow(borrow_type& lhs, borrow_type rhs) {
lhs = rhs;
}
static void destroyBorrow(borrow_type& /*toDestroy*/) {}
static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
return *borrow;
}
static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
return borrow;
}
static bool debugBorrowIsValid(const borrow_type& borrow) {
return borrow != nullptr;
}
};
/// It is possible to eliminate the extra layer of indirection for
/// borrows for some types that we control. For examples, see
/// intrusive_ptr.h and TensorBody.h.
template <typename T>
struct MaybeOwnedTraits;
// Explicitly enable MaybeOwned<shared_ptr<T>>, rather than allowing
// MaybeOwned to be used for any type right away.
template <typename T>
struct MaybeOwnedTraits<std::shared_ptr<T>>
: public MaybeOwnedTraitsGenericImpl<std::shared_ptr<T>> {};
/// A smart pointer around either a borrowed or owned T. When
/// constructed with borrowed(), the caller MUST ensure that the
/// borrowed-from argument outlives this MaybeOwned<T>. Compare to
/// Rust's std::borrow::Cow
/// (https://doc.rust-lang.org/std/borrow/enum.Cow.html), but note
/// that it is probably not suitable for general use because C++ has
/// no borrow checking. Included here to support
/// Tensor::expect_contiguous.
template <typename T>
class MaybeOwned final {
using borrow_type = typename MaybeOwnedTraits<T>::borrow_type;
using owned_type = typename MaybeOwnedTraits<T>::owned_type;
bool isBorrowed_;
union {
borrow_type borrow_;
owned_type own_;
};
/// Don't use this; use borrowed() instead.
explicit MaybeOwned(const owned_type& t)
: isBorrowed_(true), borrow_(MaybeOwnedTraits<T>::createBorrow(t)) {}
/// Don't use this; use owned() instead.
explicit MaybeOwned(T&& t) noexcept(std::is_nothrow_move_constructible_v<T>)
: isBorrowed_(false), own_(std::move(t)) {}
/// Don't use this; use owned() instead.
template <class... Args>
explicit MaybeOwned(std::in_place_t, Args&&... args)
: isBorrowed_(false), own_(std::forward<Args>(args)...) {}
public:
explicit MaybeOwned() : isBorrowed_(true), borrow_() {}
// Copying a borrow yields another borrow of the original, as with a
// T*. Copying an owned T yields another owned T for safety: no
// chains of borrowing by default! (Note you could get that behavior
// with MaybeOwned<T>::borrowed(*rhs) if you wanted it.)
MaybeOwned(const MaybeOwned& rhs) : isBorrowed_(rhs.isBorrowed_) {
if (C10_LIKELY(rhs.isBorrowed_)) {
MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
} else {
new (&own_) T(rhs.own_);
}
}
MaybeOwned& operator=(const MaybeOwned& rhs) {
if (this == &rhs) {
return *this;
}
if (C10_UNLIKELY(!isBorrowed_)) {
if (rhs.isBorrowed_) {
own_.~T();
MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
isBorrowed_ = true;
} else {
own_ = rhs.own_;
}
} else {
if (C10_LIKELY(rhs.isBorrowed_)) {
MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
} else {
MaybeOwnedTraits<T>::destroyBorrow(borrow_);
new (&own_) T(rhs.own_);
isBorrowed_ = false;
}
}
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isBorrowed_ == rhs.isBorrowed_);
return *this;
}
MaybeOwned(MaybeOwned&& rhs) noexcept(
// NOLINTNEXTLINE(*-noexcept-move-*)
std::is_nothrow_move_constructible_v<T> &&
std::is_nothrow_move_assignable_v<borrow_type>)
: isBorrowed_(rhs.isBorrowed_) {
if (C10_LIKELY(rhs.isBorrowed_)) {
MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
} else {
new (&own_) T(std::move(rhs.own_));
}
}
MaybeOwned& operator=(MaybeOwned&& rhs) noexcept(
std::is_nothrow_move_assignable_v<T> &&
std::is_nothrow_move_assignable_v<borrow_type> &&
std::is_nothrow_move_constructible_v<T> &&
// NOLINTNEXTLINE(*-noexcept-move-*)
std::is_nothrow_destructible_v<T> &&
std::is_nothrow_destructible_v<borrow_type>) {
if (this == &rhs) {
return *this;
}
if (C10_UNLIKELY(!isBorrowed_)) {
if (rhs.isBorrowed_) {
own_.~T();
MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
isBorrowed_ = true;
} else {
own_ = std::move(rhs.own_);
}
} else {
if (C10_LIKELY(rhs.isBorrowed_)) {
MaybeOwnedTraits<T>::assignBorrow(borrow_, rhs.borrow_);
} else {
MaybeOwnedTraits<T>::destroyBorrow(borrow_);
new (&own_) T(std::move(rhs.own_));
isBorrowed_ = false;
}
}
return *this;
}
static MaybeOwned borrowed(const T& t) {
return MaybeOwned(t);
}
static MaybeOwned owned(T&& t) noexcept(
std::is_nothrow_move_constructible_v<T>) {
return MaybeOwned(std::move(t));
}
template <class... Args>
static MaybeOwned owned(std::in_place_t, Args&&... args) {
return MaybeOwned(std::in_place, std::forward<Args>(args)...);
}
~MaybeOwned() noexcept(
// NOLINTNEXTLINE(*-noexcept-destructor)
std::is_nothrow_destructible_v<T> &&
std::is_nothrow_destructible_v<borrow_type>) {
if (C10_UNLIKELY(!isBorrowed_)) {
own_.~T();
} else {
MaybeOwnedTraits<T>::destroyBorrow(borrow_);
}
}
// This is an implementation detail! You should know what you're doing
// if you are testing this. If you just want to guarantee ownership move
// this into a T
bool unsafeIsBorrowed() const {
return isBorrowed_;
}
const T& operator*() const& {
if (isBorrowed_) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
MaybeOwnedTraits<T>::debugBorrowIsValid(borrow_));
}
return C10_LIKELY(isBorrowed_)
? MaybeOwnedTraits<T>::referenceFromBorrow(borrow_)
: own_;
}
const T* operator->() const {
if (isBorrowed_) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
MaybeOwnedTraits<T>::debugBorrowIsValid(borrow_));
}
return C10_LIKELY(isBorrowed_)
? MaybeOwnedTraits<T>::pointerFromBorrow(borrow_)
: &own_;
}
// If borrowed, copy the underlying T. If owned, move from
// it. borrowed/owned state remains the same, and either we
// reference the same borrow as before or we are an owned moved-from
// T.
T operator*() && {
if (isBorrowed_) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
MaybeOwnedTraits<T>::debugBorrowIsValid(borrow_));
return MaybeOwnedTraits<T>::referenceFromBorrow(borrow_);
} else {
return std::move(own_);
}
}
};
} // namespace c10

View File

@ -0,0 +1,224 @@
#pragma once
#include <c10/util/TypeList.h>
#include <type_traits>
namespace c10::guts {
/**
* Access information about result type or arguments from a function type.
* Example:
* using A = function_traits<int (float, double)>::return_type // A == int
* using A = function_traits<int (float, double)>::parameter_types::tuple_type
* // A == tuple<float, double>
*/
template <class Func>
struct function_traits {
static_assert(
!std::is_same_v<Func, Func>,
"In function_traits<Func>, Func must be a plain function type.");
};
template <class Result, class... Args>
struct function_traits<Result(Args...)> {
using func_type = Result(Args...);
using return_type = Result;
using parameter_types = typelist::typelist<Args...>;
static constexpr auto number_of_parameters = sizeof...(Args);
};
/**
* infer_function_traits: creates a `function_traits` type for a simple
* function (pointer) or functor (lambda/struct). Currently does not support
* class methods.
*/
template <typename Functor>
struct infer_function_traits {
using type = function_traits<
c10::guts::detail::strip_class_t<decltype(&Functor::operator())>>;
};
template <typename Result, typename... Args>
struct infer_function_traits<Result (*)(Args...)> {
using type = function_traits<Result(Args...)>;
};
template <typename Result, typename... Args>
struct infer_function_traits<Result(Args...)> {
using type = function_traits<Result(Args...)>;
};
template <typename T>
using infer_function_traits_t = typename infer_function_traits<T>::type;
/**
* make_function_traits: creates a `function_traits` type given a Return type
* and a typelist of Argument types
*
* Example:
* bool f(int, int);
*
* infer_function_traits_t<f> == make_function_traits_t<bool,
* typelist::typelist<int, int>>
*/
template <typename Result, typename ArgList>
struct make_function_traits {
static_assert(
false_t<ArgList>::value,
"In guts::make_function_traits<Result, TypeList>, the ArgList argument must be typelist<...>.");
};
template <typename Result, typename... Args>
struct make_function_traits<Result, typelist::typelist<Args...>> {
using type = function_traits<Result(Args...)>;
};
template <typename Result, typename ArgList>
using make_function_traits_t =
typename make_function_traits<Result, ArgList>::type;
/**
* make_offset_index_sequence<Start, N>
* Like make_index_sequence<N>, but starting from Start instead of 0.
*
* Example:
* make_offset_index_sequence<10, 3> == std::index_sequence<10, 11, 12>
*/
template <size_t Start, size_t N, size_t... Is>
struct make_offset_index_sequence_impl
: make_offset_index_sequence_impl<Start, N - 1, Start + N - 1, Is...> {
static_assert(
static_cast<int>(Start) >= 0,
"make_offset_index_sequence: Start < 0");
static_assert(static_cast<int>(N) >= 0, "make_offset_index_sequence: N < 0");
};
template <size_t Start, size_t... Is>
struct make_offset_index_sequence_impl<Start, 0, Is...> {
typedef std::index_sequence<Is...> type;
};
template <size_t Start, size_t N>
using make_offset_index_sequence =
typename make_offset_index_sequence_impl<Start, N>::type;
/**
* Use tuple_elements to extract a position-indexed subset of elements
* from the argument tuple into a result tuple.
*
* Example:
* std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
* std::tuple<int, double> result = tuple_elements(t, std::index_sequence<0,
* 2>());
*/
template <class Tuple, size_t... Is>
constexpr auto tuple_elements(Tuple t, std::index_sequence<Is...>) {
return std::tuple<std::tuple_element_t<Is, Tuple>...>(std::get<Is>(t)...);
}
/**
* Use tuple_take to extract the first or last n elements from the argument
* tuple into a result tuple.
*
* Example:
* std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
* std::tuple<int, const char*> first_two = tuple_take<decltype(t), 2>(t);
* std::tuple<const char*, double> last_two = tuple_take<decltype(t), -2>(t);
*/
template <class Tuple, int N, class Enable = void>
struct TupleTake {};
template <class Tuple, int N>
struct TupleTake<Tuple, N, std::enable_if_t<N >= 0, void>> {
static auto call(Tuple t) {
constexpr size_t size = std::tuple_size<Tuple>();
static_assert(N <= size, "tuple_take: N > size");
return tuple_elements(t, std::make_index_sequence<N>{});
}
};
template <class Tuple, int N>
struct TupleTake < Tuple,
N, std::enable_if_t<N<0, void>> {
static auto call(Tuple t) {
constexpr size_t size = std::tuple_size<Tuple>();
static_assert(-N <= size, "tuple_take: -N > size");
return tuple_elements(t, make_offset_index_sequence<size + N, -N>{});
}
};
template <class Tuple, int N>
auto tuple_take(Tuple t) {
return TupleTake<Tuple, N>::call(t);
}
/**
* Use tuple_slice to extract a contiguous subtuple from the argument.
*
* Example:
* std::tuple<int, const char*, double, bool> t = std::make_tuple(0,
* "HEY", 2.0, false); std::tuple<int, const char*> middle_two =
* tuple_slice<decltype(t), 1, 2>(t);
*/
template <class Tuple, size_t Start, size_t N>
constexpr auto tuple_slice(Tuple t) {
constexpr size_t size = std::tuple_size<Tuple>();
static_assert(Start + N <= size, "tuple_slice: Start + N > size");
return tuple_elements(t, make_offset_index_sequence<Start, N>{});
}
/**
* Use tuple_map to run a mapping function over a tuple to get a new tuple.
*
* Example 1:
* auto result = tuple_map(std::tuple<int32_t, int32_t, int32_t>(3, 4, 5), []
* (int32_t a) -> int16_t {return a+1;});
* // result == std::tuple<int16_t, int16_t, int16_t>(4, 5, 6)
*
* Example 2:
* struct Mapper {
* std::string operator()(int32_t a) const {
* return std::to_string(a);
* }
* int64_t operator()(const std::string& a) const {
* return atoi(a.c_str());
* }
* };
* auto result = tuple_map(std::tuple<int32_t, std::string>(3, "4"),
* Mapper());
* // result == std::tuple<std::string, int64_t>("3", 4)
*
* Example 3:
* struct A final {
* int32_t func() {
* return 5;
* }
* };
* struct B final {
* std::string func() {
* return "5";
* }
* };
* auto result = tuple_map(std::make_tuple(A(), B()), [] (auto a) { return
* a.func(); });
* // result == std::tuple<int32_t, std::string>(5, "5");
*/
namespace detail {
template <class Mapper, class... Args, size_t... Indices>
auto tuple_map(
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
std::tuple<Args...>&& tuple,
const Mapper& mapper,
std::index_sequence<Indices...>) {
return std::tuple<decltype(mapper(std::forward<Args>(std::get<Indices>(
tuple))))...>(mapper(std::forward<Args>(std::get<Indices>(tuple)))...);
}
} // namespace detail
template <class Mapper, class... Args>
auto tuple_map(std::tuple<Args...>&& tuple, const Mapper& mapper) {
return detail::tuple_map(
std::move(tuple), mapper, std::index_sequence_for<Args...>());
}
} // namespace c10::guts

View File

@ -0,0 +1,54 @@
#pragma once
#include <c10/macros/Macros.h>
#include <string>
#include <vector>
/**
* This file provides a network flow implementation.
* https://en.wikipedia.org/wiki/Flow_network
*
* It aims to mirror some of the behavior of networkx, which is/was used by
* functorch partitioners for splitting the graph into a forward and backward
* graph.
*/
namespace c10 {
enum class C10_API_ENUM MinCutStatus {
SUCCESS = 0,
UNBOUNDED = 1,
OVERFLOW_INF = 2,
INVALID = 3,
};
struct MinCutResult {
MinCutStatus status;
int64_t max_flow;
std::vector<std::string> reachable;
std::vector<std::string> unreachable;
};
// Modeled after networkx implementation
class C10_API NetworkFlowGraph {
public:
// selected such that INF + INF is < INT64_MAX
constexpr static int64_t INF = (1LL << 62) - 1;
struct Edge {
std::string source, dest;
int64_t capacity;
};
MinCutStatus add_edge(
const std::string& source,
const std::string& dest,
int64_t capacity = 1);
MinCutResult minimum_cut(const std::string& s, const std::string& t) const;
std::vector<Edge> edges;
};
} // namespace c10

View File

@ -0,0 +1,48 @@
#ifndef C10_UTIL_OPTIONAL_H_
#define C10_UTIL_OPTIONAL_H_
#include <optional>
#include <type_traits>
// Macros.h is not needed, but it does namespace shenanigans that lots
// of downstream code seems to rely on. Feel free to remove it and fix
// up builds.
namespace c10 {
// NOLINTNEXTLINE(misc-unused-using-decls)
using std::bad_optional_access;
// NOLINTNEXTLINE(misc-unused-using-decls)
using std::make_optional;
// NOLINTNEXTLINE(misc-unused-using-decls)
using std::nullopt;
// NOLINTNEXTLINE(misc-unused-using-decls)
using std::nullopt_t;
// NOLINTNEXTLINE(misc-unused-using-decls)
using std::optional;
namespace detail_ {
// the call to convert<A>(b) has return type A and converts b to type A iff b
// decltype(b) is implicitly convertible to A
template <class U>
constexpr U convert(U v) {
return v;
}
} // namespace detail_
template <class T, class F>
constexpr T value_or_else(const std::optional<T>& v, F&& func) {
static_assert(
std::is_convertible_v<typename std::invoke_result_t<F>, T>,
"func parameters must be a callable that returns a type convertible to the value stored in the optional");
return v.has_value() ? *v : detail_::convert<T>(std::forward<F>(func)());
}
template <class T, class F>
constexpr T value_or_else(std::optional<T>&& v, F&& func) {
static_assert(
std::is_convertible_v<typename std::invoke_result_t<F>, T>,
"func parameters must be a callable that returns a type convertible to the value stored in the optional");
return v.has_value() ? constexpr_move(std::move(v).contained_val())
: detail_::convert<T>(std::forward<F>(func)());
}
} // namespace c10
#endif // C10_UTIL_OPTIONAL_H_

View File

@ -0,0 +1,236 @@
// This file defines OptionalArrayRef<T>, a class that has almost the same
// exact functionality as std::optional<ArrayRef<T>>, except that its
// converting constructor fixes a dangling pointer issue.
//
// The implicit converting constructor of both std::optional<ArrayRef<T>> and
// std::optional<ArrayRef<T>> can cause the underlying ArrayRef<T> to store
// a dangling pointer. OptionalArrayRef<T> prevents this by wrapping
// a std::optional<ArrayRef<T>> and fixing the constructor implementation.
//
// See https://github.com/pytorch/pytorch/issues/63645 for more on this.
#pragma once
#include <c10/util/ArrayRef.h>
#include <cstdint>
#include <initializer_list>
#include <optional>
#include <type_traits>
#include <utility>
namespace c10 {
template <typename T>
class OptionalArrayRef final {
public:
// Constructors
constexpr OptionalArrayRef() noexcept = default;
constexpr OptionalArrayRef(std::nullopt_t) noexcept {}
OptionalArrayRef(const OptionalArrayRef& other) = default;
OptionalArrayRef(OptionalArrayRef&& other) noexcept = default;
constexpr OptionalArrayRef(const std::optional<ArrayRef<T>>& other) noexcept
: wrapped_opt_array_ref(other) {}
constexpr OptionalArrayRef(std::optional<ArrayRef<T>>&& other) noexcept
: wrapped_opt_array_ref(std::move(other)) {}
constexpr OptionalArrayRef(const T& value) noexcept
: wrapped_opt_array_ref(value) {}
template <
typename U = ArrayRef<T>,
std::enable_if_t<
!std::is_same_v<std::decay_t<U>, OptionalArrayRef> &&
!std::is_same_v<std::decay_t<U>, std::in_place_t> &&
std::is_constructible_v<ArrayRef<T>, U&&> &&
std::is_convertible_v<U&&, ArrayRef<T>> &&
!std::is_convertible_v<U&&, T>,
bool> = false>
constexpr OptionalArrayRef(U&& value) noexcept(
std::is_nothrow_constructible_v<ArrayRef<T>, U&&>)
: wrapped_opt_array_ref(std::forward<U>(value)) {}
template <
typename U = ArrayRef<T>,
std::enable_if_t<
!std::is_same_v<std::decay_t<U>, OptionalArrayRef> &&
!std::is_same_v<std::decay_t<U>, std::in_place_t> &&
std::is_constructible_v<ArrayRef<T>, U&&> &&
!std::is_convertible_v<U&&, ArrayRef<T>>,
bool> = false>
constexpr explicit OptionalArrayRef(U&& value) noexcept(
std::is_nothrow_constructible_v<ArrayRef<T>, U&&>)
: wrapped_opt_array_ref(std::forward<U>(value)) {}
template <typename... Args>
constexpr explicit OptionalArrayRef(
std::in_place_t ip,
Args&&... args) noexcept
: wrapped_opt_array_ref(ip, std::forward<Args>(args)...) {}
template <typename U, typename... Args>
constexpr explicit OptionalArrayRef(
std::in_place_t ip,
std::initializer_list<U> il,
Args&&... args)
: wrapped_opt_array_ref(ip, il, std::forward<Args>(args)...) {}
constexpr OptionalArrayRef(const std::initializer_list<T>& Vec)
: wrapped_opt_array_ref(ArrayRef<T>(Vec)) {}
// Destructor
~OptionalArrayRef() = default;
// Assignment
constexpr OptionalArrayRef& operator=(std::nullopt_t) noexcept {
wrapped_opt_array_ref = std::nullopt;
return *this;
}
OptionalArrayRef& operator=(const OptionalArrayRef& other) = default;
OptionalArrayRef& operator=(OptionalArrayRef&& other) noexcept = default;
constexpr OptionalArrayRef& operator=(
const std::optional<ArrayRef<T>>& other) noexcept {
wrapped_opt_array_ref = other;
return *this;
}
constexpr OptionalArrayRef& operator=(
std::optional<ArrayRef<T>>&& other) noexcept {
wrapped_opt_array_ref = std::move(other);
return *this;
}
template <
typename U = ArrayRef<T>,
typename = std::enable_if_t<
!std::is_same_v<std::decay_t<U>, OptionalArrayRef> &&
std::is_constructible_v<ArrayRef<T>, U&&> &&
std::is_assignable_v<ArrayRef<T>&, U&&>>>
constexpr OptionalArrayRef& operator=(U&& value) noexcept(
std::is_nothrow_constructible_v<ArrayRef<T>, U&&> &&
std::is_nothrow_assignable_v<ArrayRef<T>&, U&&>) {
wrapped_opt_array_ref = std::forward<U>(value);
return *this;
}
// Observers
constexpr ArrayRef<T>* operator->() noexcept {
return &wrapped_opt_array_ref.value();
}
constexpr const ArrayRef<T>* operator->() const noexcept {
return &wrapped_opt_array_ref.value();
}
constexpr ArrayRef<T>& operator*() & noexcept {
return wrapped_opt_array_ref.value();
}
constexpr const ArrayRef<T>& operator*() const& noexcept {
return wrapped_opt_array_ref.value();
}
constexpr ArrayRef<T>&& operator*() && noexcept {
return std::move(wrapped_opt_array_ref.value());
}
constexpr const ArrayRef<T>&& operator*() const&& noexcept {
return std::move(wrapped_opt_array_ref.value());
}
constexpr explicit operator bool() const noexcept {
return wrapped_opt_array_ref.has_value();
}
constexpr bool has_value() const noexcept {
return wrapped_opt_array_ref.has_value();
}
constexpr ArrayRef<T>& value() & {
return wrapped_opt_array_ref.value();
}
constexpr const ArrayRef<T>& value() const& {
return wrapped_opt_array_ref.value();
}
constexpr ArrayRef<T>&& value() && {
return std::move(wrapped_opt_array_ref.value());
}
constexpr const ArrayRef<T>&& value() const&& {
return std::move(wrapped_opt_array_ref.value());
}
template <typename U>
constexpr std::
enable_if_t<std::is_convertible_v<U&&, ArrayRef<T>>, ArrayRef<T>>
value_or(U&& default_value) const& {
return wrapped_opt_array_ref.value_or(std::forward<U>(default_value));
}
template <typename U>
constexpr std::
enable_if_t<std::is_convertible_v<U&&, ArrayRef<T>>, ArrayRef<T>>
value_or(U&& default_value) && {
return wrapped_opt_array_ref.value_or(std::forward<U>(default_value));
}
// Modifiers
constexpr void swap(OptionalArrayRef& other) noexcept {
std::swap(wrapped_opt_array_ref, other.wrapped_opt_array_ref);
}
constexpr void reset() noexcept {
wrapped_opt_array_ref.reset();
}
template <typename... Args>
constexpr std::
enable_if_t<std::is_constructible_v<ArrayRef<T>, Args&&...>, ArrayRef<T>&>
emplace(Args&&... args) noexcept(
std::is_nothrow_constructible_v<ArrayRef<T>, Args&&...>) {
return wrapped_opt_array_ref.emplace(std::forward<Args>(args)...);
}
template <typename U, typename... Args>
constexpr ArrayRef<T>& emplace(
std::initializer_list<U> il,
Args&&... args) noexcept {
return wrapped_opt_array_ref.emplace(il, std::forward<Args>(args)...);
}
private:
std::optional<ArrayRef<T>> wrapped_opt_array_ref;
};
using OptionalIntArrayRef = OptionalArrayRef<int64_t>;
inline bool operator==(
const OptionalIntArrayRef& a1,
const IntArrayRef& other) {
if (!a1.has_value()) {
return false;
}
return a1.value() == other;
}
inline bool operator==(
const c10::IntArrayRef& a1,
const c10::OptionalIntArrayRef& a2) {
return a2 == a1;
}
} // namespace c10

View File

@ -0,0 +1,20 @@
#pragma once
#include <c10/macros/Macros.h>
namespace c10 {
// RAII thread local guard that tracks whether code is being executed in
// `at::parallel_for` or `at::parallel_reduce` loop function.
class C10_API ParallelGuard {
public:
static bool is_enabled();
ParallelGuard(bool state);
~ParallelGuard();
private:
bool previous_state_;
};
} // namespace c10

View File

@ -0,0 +1,326 @@
#ifndef C10_UTIL_REGISTRY_H_
#define C10_UTIL_REGISTRY_H_
/**
* Simple registry implementation that uses static variables to
* register object creators during program initialization time.
*/
// NB: This Registry works poorly when you have other namespaces.
// Make all macro invocations from inside the at namespace.
#include <cstdio>
#include <cstdlib>
#include <functional>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/Type.h>
namespace c10 {
template <typename KeyType>
inline std::string KeyStrRepr(const KeyType& /*key*/) {
return "[key type printing not supported]";
}
template <>
inline std::string KeyStrRepr(const std::string& key) {
return key;
}
enum RegistryPriority {
REGISTRY_FALLBACK = 1,
REGISTRY_DEFAULT = 2,
REGISTRY_PREFERRED = 3,
};
/**
* @brief A template class that allows one to register classes by keys.
*
* The keys are usually a std::string specifying the name, but can be anything
* that can be used in a std::map.
*
* You should most likely not use the Registry class explicitly, but use the
* helper macros below to declare specific registries as well as registering
* objects.
*/
template <class SrcType, class ObjectPtrType, class... Args>
class Registry {
public:
typedef std::function<ObjectPtrType(Args...)> Creator;
Registry(bool warning = true) : registry_(), priority_(), warning_(warning) {}
void Register(
const SrcType& key,
Creator creator,
const RegistryPriority priority = REGISTRY_DEFAULT) {
std::lock_guard<std::mutex> lock(register_mutex_);
// The if statement below is essentially the same as the following line:
// TORCH_CHECK_EQ(registry_.count(key), 0) << "Key " << key
// << " registered twice.";
// However, TORCH_CHECK_EQ depends on google logging, and since registration
// is carried out at static initialization time, we do not want to have an
// explicit dependency on glog's initialization function.
if (registry_.count(key) != 0) {
auto cur_priority = priority_[key];
if (priority > cur_priority) {
#ifdef DEBUG
std::string warn_msg =
"Overwriting already registered item for key " + KeyStrRepr(key);
fprintf(stderr, "%s\n", warn_msg.c_str());
#endif
registry_[key] = creator;
priority_[key] = priority;
} else if (priority == cur_priority) {
std::string err_msg =
"Key already registered with the same priority: " + KeyStrRepr(key);
fprintf(stderr, "%s\n", err_msg.c_str());
if (terminate_) {
std::exit(1);
} else {
throw std::runtime_error(err_msg);
}
} else if (warning_) {
std::string warn_msg =
"Higher priority item already registered, skipping registration of " +
KeyStrRepr(key);
fprintf(stderr, "%s\n", warn_msg.c_str());
}
} else {
registry_[key] = creator;
priority_[key] = priority;
}
}
void Register(
const SrcType& key,
Creator creator,
const std::string& help_msg,
const RegistryPriority priority = REGISTRY_DEFAULT) {
Register(key, creator, priority);
help_message_[key] = help_msg;
}
inline bool Has(const SrcType& key) {
return (registry_.count(key) != 0);
}
ObjectPtrType Create(const SrcType& key, Args... args) {
auto it = registry_.find(key);
if (it == registry_.end()) {
// Returns nullptr if the key is not registered.
return nullptr;
}
return it->second(args...);
}
/**
* Returns the keys currently registered as a std::vector.
*/
std::vector<SrcType> Keys() const {
std::vector<SrcType> keys;
keys.reserve(registry_.size());
for (const auto& it : registry_) {
keys.push_back(it.first);
}
return keys;
}
inline const std::unordered_map<SrcType, std::string>& HelpMessage() const {
return help_message_;
}
const char* HelpMessage(const SrcType& key) const {
auto it = help_message_.find(key);
if (it == help_message_.end()) {
return nullptr;
}
return it->second.c_str();
}
// Used for testing, if terminate is unset, Registry throws instead of
// calling std::exit
void SetTerminate(bool terminate) {
terminate_ = terminate;
}
private:
std::unordered_map<SrcType, Creator> registry_;
std::unordered_map<SrcType, RegistryPriority> priority_;
bool terminate_{true};
const bool warning_;
std::unordered_map<SrcType, std::string> help_message_;
std::mutex register_mutex_;
C10_DISABLE_COPY_AND_ASSIGN(Registry);
};
template <class SrcType, class ObjectPtrType, class... Args>
class Registerer {
public:
explicit Registerer(
const SrcType& key,
Registry<SrcType, ObjectPtrType, Args...>* registry,
typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
const std::string& help_msg = "") {
registry->Register(key, creator, help_msg);
}
explicit Registerer(
const SrcType& key,
const RegistryPriority priority,
Registry<SrcType, ObjectPtrType, Args...>* registry,
typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
const std::string& help_msg = "") {
registry->Register(key, creator, help_msg, priority);
}
template <class DerivedType>
static ObjectPtrType DefaultCreator(Args... args) {
return ObjectPtrType(new DerivedType(args...));
}
};
/**
* C10_DECLARE_TYPED_REGISTRY is a macro that expands to a function
* declaration, as well as creating a convenient typename for its corresponding
* registerer.
*/
// Note on C10_IMPORT and C10_EXPORT below: we need to explicitly mark DECLARE
// as import and DEFINE as export, because these registry macros will be used
// in downstream shared libraries as well, and one cannot use *_API - the API
// macro will be defined on a per-shared-library basis. Semantically, when one
// declares a typed registry it is always going to be IMPORT, and when one
// defines a registry (which should happen ONLY ONCE and ONLY IN SOURCE FILE),
// the instantiation unit is always going to be exported.
//
// The only unique condition is when in the same file one does DECLARE and
// DEFINE - in Windows compilers, this generates a warning that dllimport and
// dllexport are mixed, but the warning is fine and linker will be properly
// exporting the symbol. Same thing happens in the gflags flag declaration and
// definition caes.
#define C10_DECLARE_TYPED_REGISTRY( \
RegistryName, SrcType, ObjectType, PtrType, ...) \
C10_API ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
RegistryName(); \
typedef ::c10::Registerer<SrcType, PtrType<ObjectType>, ##__VA_ARGS__> \
Registerer##RegistryName
#define TORCH_DECLARE_TYPED_REGISTRY( \
RegistryName, SrcType, ObjectType, PtrType, ...) \
TORCH_API ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
RegistryName(); \
typedef ::c10::Registerer<SrcType, PtrType<ObjectType>, ##__VA_ARGS__> \
Registerer##RegistryName
#define C10_DEFINE_TYPED_REGISTRY( \
RegistryName, SrcType, ObjectType, PtrType, ...) \
C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
RegistryName() { \
static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
registry = new ::c10:: \
Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>(); \
return registry; \
}
#define C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
RegistryName, SrcType, ObjectType, PtrType, ...) \
C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
RegistryName() { \
static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
registry = \
new ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>( \
false); \
return registry; \
}
// Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated
// creator with comma in its templated arguments.
#define C10_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \
static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
key, RegistryName(), ##__VA_ARGS__);
#define C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \
RegistryName, key, priority, ...) \
static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
key, priority, RegistryName(), ##__VA_ARGS__);
#define C10_REGISTER_TYPED_CLASS(RegistryName, key, ...) \
static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
key, \
RegistryName(), \
Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \
::c10::demangle_type<__VA_ARGS__>());
#define C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \
RegistryName, key, priority, ...) \
static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
key, \
priority, \
RegistryName(), \
Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \
::c10::demangle_type<__VA_ARGS__>());
// C10_DECLARE_REGISTRY and C10_DEFINE_REGISTRY are hard-wired to use
// std::string as the key type, because that is the most commonly used cases.
#define C10_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
C10_DECLARE_TYPED_REGISTRY( \
RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
#define TORCH_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
TORCH_DECLARE_TYPED_REGISTRY( \
RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
#define C10_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \
C10_DEFINE_TYPED_REGISTRY( \
RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
#define C10_DEFINE_REGISTRY_WITHOUT_WARNING(RegistryName, ObjectType, ...) \
C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
#define C10_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
C10_DECLARE_TYPED_REGISTRY( \
RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
#define TORCH_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
TORCH_DECLARE_TYPED_REGISTRY( \
RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
#define C10_DEFINE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
C10_DEFINE_TYPED_REGISTRY( \
RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
#define C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING( \
RegistryName, ObjectType, ...) \
C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
// C10_REGISTER_CREATOR and C10_REGISTER_CLASS are hard-wired to use std::string
// as the key
// type, because that is the most commonly used cases.
#define C10_REGISTER_CREATOR(RegistryName, key, ...) \
C10_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__)
#define C10_REGISTER_CREATOR_WITH_PRIORITY(RegistryName, key, priority, ...) \
C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \
RegistryName, #key, priority, __VA_ARGS__)
#define C10_REGISTER_CLASS(RegistryName, key, ...) \
C10_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__)
#define C10_REGISTER_CLASS_WITH_PRIORITY(RegistryName, key, priority, ...) \
C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \
RegistryName, #key, priority, __VA_ARGS__)
} // namespace c10
#endif // C10_UTIL_REGISTRY_H_

View File

@ -0,0 +1,50 @@
#pragma once
#include <type_traits>
#include <utility>
namespace c10 {
/**
* Mostly copied from https://llvm.org/doxygen/ScopeExit_8h_source.html
*/
template <typename Callable>
class scope_exit {
Callable ExitFunction;
bool Engaged = true; // False once moved-from or release()d.
public:
template <typename Fp>
// NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
explicit scope_exit(Fp&& F) : ExitFunction(std::forward<Fp>(F)) {}
scope_exit(scope_exit&& Rhs) noexcept
: ExitFunction(std::move(Rhs.ExitFunction)), Engaged(Rhs.Engaged) {
Rhs.release();
}
scope_exit(const scope_exit&) = delete;
scope_exit& operator=(scope_exit&&) = delete;
scope_exit& operator=(const scope_exit&) = delete;
void release() {
Engaged = false;
}
~scope_exit() {
if (Engaged) {
ExitFunction();
}
}
};
// Keeps the callable object that is passed in, and execute it at the
// destruction of the returned object (usually at the scope exit where the
// returned object is kept).
//
// Interface is specified by p0052r2.
template <typename Callable>
scope_exit<std::decay_t<Callable>> make_scope_exit(Callable&& F) {
return scope_exit<std::decay_t<Callable>>(std::forward<Callable>(F));
}
} // namespace c10

View File

@ -0,0 +1,87 @@
#pragma once
#include <array>
#include <cstddef>
#include <cstdint>
#include <type_traits>
/** Helper class for allocating temporary fixed size arrays with SBO.
*
* This is intentionally much simpler than SmallVector, to improve performance
* at the expense of many features:
* - No zero-initialization for numeric types
* - No resizing after construction
* - No copy/move
* - No non-trivial types
*/
namespace c10 {
template <typename T, size_t N>
class SmallBuffer {
static_assert(std::is_trivial_v<T>, "SmallBuffer is intended for POD types");
std::array<T, N> storage_;
size_t size_{};
T* data_{};
public:
SmallBuffer(size_t size) : size_(size) {
if (size > N) {
data_ = new T[size];
} else {
data_ = &storage_[0];
}
}
SmallBuffer(const SmallBuffer&) = delete;
SmallBuffer& operator=(const SmallBuffer&) = delete;
// move constructor is needed in function return
SmallBuffer(SmallBuffer&& rhs) noexcept : size_{rhs.size_} {
rhs.size_ = 0;
if (size_ > N) {
data_ = rhs.data_;
rhs.data_ = nullptr;
} else {
storage_ = std::move(rhs.storage_);
data_ = &storage_[0];
}
}
SmallBuffer& operator=(SmallBuffer&&) = delete;
~SmallBuffer() {
if (size_ > N) {
delete[] data_;
}
}
T& operator[](size_t idx) {
return data()[idx];
}
const T& operator[](size_t idx) const {
return data()[idx];
}
T* data() {
return data_;
}
const T* data() const {
return data_;
}
size_t size() const {
return size_;
}
T* begin() {
return data_;
}
const T* begin() const {
return data_;
}
T* end() {
return data_ + size_;
}
const T* end() const {
return data_ + size_;
}
};
} // namespace c10

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,211 @@
#ifndef C10_UTIL_STRINGUTIL_H_
#define C10_UTIL_STRINGUTIL_H_
#include <c10/macros/Macros.h>
#include <c10/util/string_utils.h>
#include <c10/util/string_view.h>
#include <cstddef>
#include <ostream>
#include <sstream>
#include <string>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32")
#endif
namespace c10 {
namespace detail {
// Obtains the base name from a full path.
C10_API std::string StripBasename(const std::string& full_path);
C10_API std::string ExcludeFileExtension(const std::string& full_path);
struct CompileTimeEmptyString {
operator const std::string&() const {
static const std::string empty_string_literal;
return empty_string_literal;
}
operator const char*() const {
return "";
}
};
template <typename T>
struct CanonicalizeStrTypes {
using type = const T&;
};
template <size_t N>
// NOLINTNEXTLINE(*c-arrays*)
struct CanonicalizeStrTypes<char[N]> {
using type = const char*;
};
inline std::ostream& _str(std::ostream& ss) {
return ss;
}
template <typename T>
inline std::ostream& _str(std::ostream& ss, const T& t) {
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
ss << t;
return ss;
}
// Overloads of _str for wide types; forces narrowing.
C10_API std::ostream& _str(std::ostream& ss, const wchar_t* wCStr);
C10_API std::ostream& _str(std::ostream& ss, const wchar_t& wChar);
C10_API std::ostream& _str(std::ostream& ss, const std::wstring& wString);
template <>
inline std::ostream& _str<CompileTimeEmptyString>(
std::ostream& ss,
const CompileTimeEmptyString&) {
return ss;
}
template <typename T, typename... Args>
inline std::ostream& _str(std::ostream& ss, const T& t, const Args&... args) {
return _str(_str(ss, t), args...);
}
template <typename... Args>
struct _str_wrapper final {
static std::string call(const Args&... args) {
std::ostringstream ss;
_str(ss, args...);
return ss.str();
}
};
// Specializations for already-a-string types.
template <>
struct _str_wrapper<std::string> final {
// return by reference to avoid the binary size of a string copy
static const std::string& call(const std::string& str) {
return str;
}
};
template <>
struct _str_wrapper<const char*> final {
static const char* call(const char* str) {
return str;
}
};
// For c10::str() with an empty argument list (which is common in our assert
// macros), we don't want to pay the binary size for constructing and
// destructing a stringstream or even constructing a string.
template <>
struct _str_wrapper<> final {
static CompileTimeEmptyString call() {
return CompileTimeEmptyString();
}
};
} // namespace detail
// Convert a list of string-like arguments into a single string.
template <typename... Args>
inline decltype(auto) str(const Args&... args) {
return detail::_str_wrapper<
typename detail::CanonicalizeStrTypes<Args>::type...>::call(args...);
}
template <class Container>
inline std::string Join(const std::string& delimiter, const Container& v) {
std::stringstream s;
int cnt = static_cast<int64_t>(v.size()) - 1;
for (auto i = v.begin(); i != v.end(); ++i, --cnt) {
s << (*i) << (cnt ? delimiter : "");
}
return s.str();
}
// Replace all occurrences of "from" substring to "to" string.
// Returns number of replacements
size_t C10_API
ReplaceAll(std::string& s, c10::string_view from, c10::string_view to);
/// Represents a location in source code (for debugging).
struct C10_API SourceLocation {
const char* function;
const char* file;
uint32_t line;
};
std::ostream& operator<<(std::ostream& out, const SourceLocation& loc);
// unix isprint but insensitive to locale
inline bool isPrint(char s) {
return s > 0x1f && s < 0x7f;
}
inline void printQuotedString(std::ostream& stmt, const string_view str) {
stmt << "\"";
for (auto s : str) {
switch (s) {
case '\\':
stmt << "\\\\";
break;
case '\'':
stmt << "\\'";
break;
case '\"':
stmt << "\\\"";
break;
case '\a':
stmt << "\\a";
break;
case '\b':
stmt << "\\b";
break;
case '\f':
stmt << "\\f";
break;
case '\n':
stmt << "\\n";
break;
case '\r':
stmt << "\\r";
break;
case '\t':
stmt << "\\t";
break;
case '\v':
stmt << "\\v";
break;
default:
if (isPrint(s)) {
stmt << s;
} else {
// C++ io has stateful formatting settings. Messing with
// them is probably worse than doing this manually.
// NOLINTNEXTLINE(*c-arrays*)
char buf[4] = "000";
// NOLINTNEXTLINE(*narrowing-conversions)
buf[2] += s % 8;
s /= 8;
// NOLINTNEXTLINE(*narrowing-conversions)
buf[1] += s % 8;
s /= 8;
// NOLINTNEXTLINE(*narrowing-conversions)
buf[0] += s;
stmt << "\\" << buf;
}
break;
}
}
stmt << "\"";
}
} // namespace c10
C10_CLANG_DIAGNOSTIC_POP()
#endif // C10_UTIL_STRINGUTIL_H_

View File

@ -0,0 +1,61 @@
#pragma once
#include <mutex>
namespace c10 {
/**
* A very simple Synchronization class for error-free use of data
* in a multi-threaded context. See folly/docs/Synchronized.md for
* the inspiration of this class.
*
* Full URL:
* https://github.com/facebook/folly/blob/main/folly/docs/Synchronized.md
*
* This class implements a small subset of the generic functionality
* implemented by folly:Synchronized<T>. Specifically, only withLock<T>
* is implemented here since it's the smallest possible API that is
* able to cover a large surface area of functionality offered by
* folly::Synchronized<T>.
*/
template <typename T>
class Synchronized final {
mutable std::mutex mutex_;
T data_;
public:
Synchronized() = default;
Synchronized(T const& data) : data_(data) {}
Synchronized(T&& data) : data_(std::move(data)) {}
// Don't permit copy construction, move, assignment, or
// move assignment, since the underlying std::mutex
// isn't necessarily copyable/moveable.
Synchronized(Synchronized const&) = delete;
Synchronized(Synchronized&&) = delete;
Synchronized operator=(Synchronized const&) = delete;
Synchronized operator=(Synchronized&&) = delete;
/**
* To use, call withLock<T> with a callback that accepts T either
* by copy or by reference. Use the protected variable in the
* provided callback safely.
*/
template <typename CB>
auto withLock(CB&& cb) {
std::lock_guard<std::mutex> guard(this->mutex_);
return std::forward<CB>(cb)(this->data_);
}
/**
* To use, call withLock<T> with a callback that accepts T either
* by copy or by const reference. Use the protected variable in
* the provided callback safely.
*/
template <typename CB>
auto withLock(CB&& cb) const {
std::lock_guard<std::mutex> guard(this->mutex_);
return std::forward<CB>(cb)(this->data_);
}
};
} // end namespace c10

View File

@ -0,0 +1,153 @@
#pragma once
#include <c10/macros/Macros.h>
/**
* Android versions with libgnustl incorrectly handle thread_local C++
* qualifier with composite types. NDK up to r17 version is affected.
*
* (A fix landed on Jun 4 2018:
* https://android-review.googlesource.com/c/toolchain/gcc/+/683601)
*
* In such cases, use c10::ThreadLocal<T> wrapper
* which is `pthread_*` based with smart pointer semantics.
*
* In addition, convenient macro C10_DEFINE_TLS_static is available.
* To define static TLS variable of type std::string, do the following
* ```
* C10_DEFINE_TLS_static(std::string, str_tls_);
* ///////
* {
* *str_tls_ = "abc";
* assert(str_tls_->length(), 3);
* }
* ```
*
* (see c10/test/util/ThreadLocal_test.cpp for more examples)
*/
#if !defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
#if defined(C10_ANDROID) && defined(__GLIBCXX__) && __GLIBCXX__ < 20180604
#define C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE
#endif // defined(C10_ANDROID) && defined(__GLIBCXX__) && __GLIBCXX__ < 20180604
#endif // !defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
#if defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
#include <c10/util/Exception.h>
#include <errno.h>
#include <pthread.h>
#include <memory>
namespace c10 {
/**
* @brief Temporary thread_local C++ qualifier replacement for Android
* based on `pthread_*`.
* To be used with composite types that provide default ctor.
*/
template <typename Type>
class ThreadLocal {
public:
ThreadLocal() {
pthread_key_create(
&key_, [](void* buf) { delete static_cast<Type*>(buf); });
}
~ThreadLocal() {
if (void* current = pthread_getspecific(key_)) {
delete static_cast<Type*>(current);
}
pthread_key_delete(key_);
}
ThreadLocal(const ThreadLocal&) = delete;
ThreadLocal& operator=(const ThreadLocal&) = delete;
Type& get() {
if (void* current = pthread_getspecific(key_)) {
return *static_cast<Type*>(current);
}
std::unique_ptr<Type> ptr = std::make_unique<Type>();
if (0 == pthread_setspecific(key_, ptr.get())) {
return *ptr.release();
}
int err = errno;
TORCH_INTERNAL_ASSERT(false, "pthread_setspecific() failed, errno = ", err);
}
Type& operator*() {
return get();
}
Type* operator->() {
return &get();
}
private:
pthread_key_t key_;
};
} // namespace c10
#define C10_DEFINE_TLS_static(Type, Name) static ::c10::ThreadLocal<Type> Name
#define C10_DECLARE_TLS_class_static(Class, Type, Name) \
static ::c10::ThreadLocal<Type> Name
#define C10_DEFINE_TLS_class_static(Class, Type, Name) \
::c10::ThreadLocal<Type> Class::Name
#else // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
namespace c10 {
/**
* @brief Default thread_local implementation for non-Android cases.
* To be used with composite types that provide default ctor.
*/
template <typename Type>
class ThreadLocal {
public:
using Accessor = Type* (*)();
explicit ThreadLocal(Accessor accessor) : accessor_(accessor) {}
ThreadLocal(const ThreadLocal&) = delete;
ThreadLocal& operator=(const ThreadLocal&) = delete;
Type& get() {
return *accessor_();
}
Type& operator*() {
return get();
}
Type* operator->() {
return &get();
}
private:
Accessor accessor_;
};
} // namespace c10
#define C10_DEFINE_TLS_static(Type, Name) \
static ::c10::ThreadLocal<Type> Name([]() { \
static thread_local Type var; \
return &var; \
})
#define C10_DECLARE_TLS_class_static(Class, Type, Name) \
static ::c10::ThreadLocal<Type> Name
#define C10_DEFINE_TLS_class_static(Class, Type, Name) \
::c10::ThreadLocal<Type> Class::Name([]() { \
static thread_local Type var; \
return &var; \
})
#endif // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)

View File

@ -0,0 +1,83 @@
#pragma once
#include <c10/macros/Export.h>
#include <cstdint>
#include <memory>
namespace c10 {
enum class C10_API_ENUM DebugInfoKind : uint8_t {
PRODUCER_INFO = 0,
MOBILE_RUNTIME_INFO,
PROFILER_STATE,
INFERENCE_CONTEXT, // for inference usage
PARAM_COMMS_INFO,
TEST_INFO, // used only in tests
TEST_INFO_2, // used only in tests
};
class C10_API DebugInfoBase {
public:
DebugInfoBase() = default;
virtual ~DebugInfoBase() = default;
};
// Thread local debug information is propagated across the forward
// (including async fork tasks) and backward passes and is supposed
// to be utilized by the user's code to pass extra information from
// the higher layers (e.g. model id) down to the lower levels
// (e.g. to the operator observers used for debugging, logging,
// profiling, etc)
class C10_API ThreadLocalDebugInfo {
public:
static DebugInfoBase* get(DebugInfoKind kind);
// Get current ThreadLocalDebugInfo
static std::shared_ptr<ThreadLocalDebugInfo> current();
// Internal, use DebugInfoGuard/ThreadLocalStateGuard
static void _forceCurrentDebugInfo(
std::shared_ptr<ThreadLocalDebugInfo> info);
// Push debug info struct of a given kind
static void _push(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
// Pop debug info, throws in case the last pushed
// debug info is not of a given kind
static std::shared_ptr<DebugInfoBase> _pop(DebugInfoKind kind);
// Peek debug info, throws in case the last pushed debug info is not of the
// given kind
static std::shared_ptr<DebugInfoBase> _peek(DebugInfoKind kind);
private:
std::shared_ptr<DebugInfoBase> info_;
DebugInfoKind kind_;
std::shared_ptr<ThreadLocalDebugInfo> parent_info_;
friend class DebugInfoGuard;
};
// DebugInfoGuard is used to set debug information,
// ThreadLocalDebugInfo is semantically immutable, the values are set
// through the scope-based guard object.
// Nested DebugInfoGuard adds/overrides existing values in the scope,
// restoring the original values after exiting the scope.
// Users can access the values through the ThreadLocalDebugInfo::get() call;
class C10_API DebugInfoGuard {
public:
DebugInfoGuard(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
explicit DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info);
~DebugInfoGuard();
DebugInfoGuard(const DebugInfoGuard&) = delete;
DebugInfoGuard(DebugInfoGuard&&) = delete;
private:
bool active_ = false;
std::shared_ptr<ThreadLocalDebugInfo> prev_info_ = nullptr;
};
} // namespace c10

View File

@ -0,0 +1,30 @@
#ifndef C10_UTIL_TYPE_H_
#define C10_UTIL_TYPE_H_
#include <cstddef>
#include <string>
#ifdef __GXX_RTTI
#include <typeinfo>
#endif // __GXX_RTTI
#include <c10/macros/Macros.h>
namespace c10 {
/// Utility to demangle a C++ symbol name.
C10_API std::string demangle(const char* name);
/// Returns the printable name of the type.
template <typename T>
inline const char* demangle_type() {
#ifdef __GXX_RTTI
static const auto& name = *(new std::string(demangle(typeid(T).name())));
return name.c_str();
#else // __GXX_RTTI
return "(RTTI disabled, cannot show name)";
#endif // __GXX_RTTI
}
} // namespace c10
#endif // C10_UTIL_TYPE_H_

View File

@ -0,0 +1,195 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Half.h>
#include <c10/util/complex.h>
#include <type_traits>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
#endif
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
namespace c10 {
template <typename dest_t, typename src_t>
struct needs_real {
constexpr static bool value =
(is_complex<src_t>::value && !is_complex<dest_t>::value);
};
template <bool, typename src_t>
struct maybe_real {
C10_HOST_DEVICE static inline src_t apply(src_t src) {
return src;
}
};
template <typename src_t>
struct maybe_real<true, src_t> {
C10_HOST_DEVICE static inline decltype(auto) apply(src_t src) {
return src.real();
}
};
template <bool, typename src_t>
struct maybe_bool {
C10_HOST_DEVICE static inline src_t apply(src_t src) {
return src;
}
};
template <typename src_t>
struct maybe_bool<true, src_t> {
C10_HOST_DEVICE static inline decltype(auto) apply(src_t src) {
// Don't use bool operator so as to to also compile for ComplexHalf.
return src.real() || src.imag();
}
};
// Note: deliberately ignores undefined behavior, consistent with NumPy.
// PyTorch's type conversions can cause a variety of undefined behavior,
// including float to integral overflow and signed to unsigned integer overflow.
// Some of this undefined behavior is addressed below.
template <typename dest_t, typename src_t>
struct static_cast_with_inter_type {
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline dest_t apply(
src_t src) {
constexpr bool real = needs_real<dest_t, src_t>::value;
auto r = maybe_real<real, src_t>::apply(src);
return static_cast<dest_t>(r);
}
};
// Partial template specialization for casting to bool.
// Need to handle complex types separately, as we don't
// simply want to cast the real part to bool.
template <typename src_t>
struct static_cast_with_inter_type<bool, src_t> {
C10_HOST_DEVICE static inline bool apply(src_t src) {
constexpr bool complex = needs_real<bool, src_t>::value;
return static_cast<bool>(maybe_bool<complex, src_t>::apply(src));
}
};
// Partial template instantiation for casting to uint8.
// Note: Converting from negative float values to unsigned integer types is
// undefined behavior in C++, and current CPU and GPU compilers exhibit
// divergent behavior. Casting from negative float values to signed
// integer types and then to unsigned integer types is not undefined,
// however, so this cast improves the consistency of type conversions
// to uint8 across compilers.
// Further note: Type conversions across compilers still have other undefined
// and divergent behavior.
template <typename src_t>
struct static_cast_with_inter_type<uint8_t, src_t> {
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline uint8_t apply(
src_t src) {
constexpr bool real = needs_real<uint8_t, src_t>::value;
return static_cast<uint8_t>(
static_cast<int64_t>(maybe_real<real, src_t>::apply(src)));
}
};
template <>
struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::BFloat16> {
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
c10::Half>
apply(c10::BFloat16 src) {
return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
}
};
template <>
struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::Float8_e5m2> {
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
c10::Half>
apply(c10::Float8_e5m2 src) {
return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
}
};
template <>
struct static_cast_with_inter_type<
c10::complex<c10::Half>,
c10::Float8_e5m2fnuz> {
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
c10::Half>
apply(c10::Float8_e5m2fnuz src) {
return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
}
};
template <>
struct static_cast_with_inter_type<
c10::complex<c10::Half>,
c10::Float8_e4m3fn> {
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
c10::Half>
apply(c10::Float8_e4m3fn src) {
return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
}
};
template <>
struct static_cast_with_inter_type<
c10::complex<c10::Half>,
c10::Float8_e4m3fnuz> {
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
c10::Half>
apply(c10::Float8_e4m3fnuz src) {
return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
}
};
template <>
struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::Half> {
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
c10::Half>
apply(c10::Half src) {
return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
}
};
template <>
struct static_cast_with_inter_type<
c10::complex<c10::Half>,
c10::complex<double>> {
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
c10::Half>
apply(c10::complex<double> src) {
return static_cast<c10::complex<c10::Half>>(
static_cast<c10::complex<float>>(src));
}
};
template <typename To, typename From>
C10_HOST_DEVICE To convert(From f) {
return static_cast_with_inter_type<To, From>::apply(f);
}
// Define separately to avoid being inlined and prevent code-size bloat
[[noreturn]] C10_API void report_overflow(const char* name);
template <typename To, typename From>
To checked_convert(From f, const char* name) {
// Converting to bool can't overflow so we exclude this case from checking.
if (!std::is_same_v<To, bool> && overflows<To, From>(f)) {
report_overflow(name);
}
return convert<To, From>(f);
}
} // namespace c10
C10_CLANG_DIAGNOSTIC_POP()
// Trigger tests for D25440771. TODO: Remove this line any time you want.

View File

@ -0,0 +1,196 @@
#pragma once
#include <c10/util/ConstexprCrc.h>
#include <c10/util/IdWrapper.h>
#include <c10/util/string_view.h>
#include <cstdint>
#include <ostream>
#include <stdexcept>
#include <string>
#include <type_traits>
namespace c10::util {
// TODO Make it work for more compilers
// Intel compiler works
#if defined(__INTEL_COMPILER)
#define C10_TYPENAME_SUPPORTS_CONSTEXPR 0
#define C10_TYPENAME_CONSTEXPR
// Clang works
#elif defined(__clang__)
// except for NVCC
#if defined(__CUDACC__)
#define C10_TYPENAME_SUPPORTS_CONSTEXPR 0
#define C10_TYPENAME_CONSTEXPR
#else
#define C10_TYPENAME_SUPPORTS_CONSTEXPR 1
#define C10_TYPENAME_CONSTEXPR constexpr
#endif
// Windows works
#elif defined(_MSC_VER)
// except for NVCC
#if defined(__CUDACC__)
#define C10_TYPENAME_SUPPORTS_CONSTEXPR 0
#define C10_TYPENAME_CONSTEXPR
#else
#define C10_TYPENAME_SUPPORTS_CONSTEXPR 1
#define C10_TYPENAME_CONSTEXPR constexpr
#endif
// GCC works
#elif defined(__GNUC__)
// except when gcc < 9
#if (__GNUC__ < 9) || defined(__CUDACC__)
#define C10_TYPENAME_SUPPORTS_CONSTEXPR 0
#define C10_TYPENAME_CONSTEXPR
#else
#define C10_TYPENAME_SUPPORTS_CONSTEXPR 1
#define C10_TYPENAME_CONSTEXPR constexpr
#endif
// some other compiler we don't know about
#else
#define C10_TYPENAME_SUPPORTS_CONSTEXPR 1
#define C10_TYPENAME_CONSTEXPR constexpr
#endif
struct type_index final : IdWrapper<type_index, uint64_t> {
constexpr explicit type_index(uint64_t checksum) : IdWrapper(checksum) {}
// Allow usage in std::map / std::set
// TODO Disallow this and rather use std::unordered_map/set everywhere
friend constexpr bool operator<(type_index lhs, type_index rhs) noexcept {
return lhs.underlyingId() < rhs.underlyingId();
}
friend std::ostream& operator<<(std::ostream& stream, type_index typeId) {
return stream << typeId.underlyingId();
}
};
namespace detail {
#if !defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && \
__GNUC__ < 5
// Getting __PRETTY_FUNCTION__ at compile time only works with GCC >= 5
#error "You're running a too old version of GCC. We need GCC 5 or later."
#endif
#if defined(__clang__) && __clang_major__ < 4
// Getting __PRETTY_FUNCTION__ at compile time only works with Clang >= 4
#error "You're running a too old version of Clang. We need Clang 4 or later."
#endif
inline constexpr string_view extract(
string_view prefix,
string_view suffix,
string_view str) {
#if !defined(__CUDA_ARCH__) // CUDA doesn't like std::logic_error in device code
return (!str.starts_with(prefix) || !str.ends_with(suffix))
? (throw std::logic_error("Invalid pattern"), string_view())
: str.substr(prefix.size(), str.size() - prefix.size() - suffix.size());
#else
return str.substr(prefix.size(), str.size() - prefix.size() - suffix.size());
#endif
}
template <typename T>
inline C10_TYPENAME_CONSTEXPR c10::string_view fully_qualified_type_name_impl() {
#if defined(_MSC_VER) && !defined(__clang__)
#if defined(__NVCC__)
return extract(
"c10::basic_string_view<char> c10::util::detail::fully_qualified_type_name_impl<",
">()",
__FUNCSIG__);
#else
return extract(
"class c10::basic_string_view<char> __cdecl c10::util::detail::fully_qualified_type_name_impl<",
">(void)",
__FUNCSIG__);
#endif
#elif defined(__clang__)
return extract(
"c10::string_view c10::util::detail::fully_qualified_type_name_impl() [T = ",
"]",
__PRETTY_FUNCTION__);
#elif defined(__GNUC__)
return extract(
#if C10_TYPENAME_SUPPORTS_CONSTEXPR
"constexpr c10::string_view c10::util::detail::fully_qualified_type_name_impl() [with T = ",
#else
"c10::string_view c10::util::detail::fully_qualified_type_name_impl() [with T = ",
#endif
"; c10::string_view = c10::basic_string_view<char>]",
__PRETTY_FUNCTION__);
#endif
}
#if !defined(__CUDA_ARCH__)
template <typename T>
inline constexpr uint64_t type_index_impl() {
// Idea: __PRETTY_FUNCTION__ (or __FUNCSIG__ on msvc) contains a qualified name
// of this function, including its template parameter, i.e. including the
// type we want an id for. We use this name and run crc64 on it to get a type
// id.
#if defined(_MSC_VER) && !defined(__clang__)
return crc64(__FUNCSIG__, sizeof(__FUNCSIG__)).checksum();
#elif defined(__clang__)
return crc64(__PRETTY_FUNCTION__, sizeof(__PRETTY_FUNCTION__)).checksum();
#elif defined(__GNUC__)
return crc64(__PRETTY_FUNCTION__, sizeof(__PRETTY_FUNCTION__)).checksum();
#endif
}
#endif
} // namespace detail
template <typename T>
inline constexpr type_index get_type_index() {
#if !defined(__CUDA_ARCH__)
// To enforce that this is really computed at compile time, we pass the
// type index through std::integral_constant.
return type_index{std::integral_constant<
uint64_t,
detail::type_index_impl<std::decay_t<T>>()>::value};
#else
// There's nothing in theory preventing us from running this on device code
// except for nvcc throwing a compiler error if we enable it.
return (abort(), type_index(0));
#endif
}
#if !defined(TORCH_PEDANTIC)
// Use precomputed hashsum for std::string
// Needed to workaround ambiguity in class name resolution
// into __PRETTY_FUNCTION__ when abovementioned class is defined in inlined
// namespace. In multi-ABI C++ library, `std::string` is an alias to
// `std::__cxx11::basic_string<char>` which depending on compiler flags can be
// resolved to `basic_string<char>` either in `std` namespace or in
// `std::__cxx11` one (`__cxx11` is an inline namespace)
template <>
inline constexpr type_index get_type_index<std::string>() {
// hashsum for std::basic_string<char>
return type_index{4193213214807308375ULL};
}
#endif
template <typename T>
inline C10_TYPENAME_CONSTEXPR string_view
get_fully_qualified_type_name() noexcept {
#if C10_TYPENAME_SUPPORTS_CONSTEXPR
constexpr
#else
static
#endif
string_view name = detail::fully_qualified_type_name_impl<T>();
return name;
}
} // namespace c10::util
C10_DEFINE_HASH_FOR_IDWRAPPER(c10::util::type_index);

View File

@ -0,0 +1,515 @@
#pragma once
#include <c10/util/TypeTraits.h>
#include <algorithm>
#include <cstddef>
#include <tuple>
#include <type_traits>
#include <utility>
namespace c10::guts {
template <class... T>
struct false_t : std::false_type {};
template <template <class> class... T>
struct false_higher_t : std::false_type {};
namespace typelist {
/**
* Type holding a list of types for compile time type computations
*/
template <class... Items>
struct typelist final {
public:
typelist() = delete; // not for instantiation
};
/**
* Returns the number of types in a typelist
* Example:
* 3 == size<typelist<int, int, double>>::value
*/
template <class TypeList>
struct size final {
static_assert(
false_t<TypeList>::value,
"In typelist::size<T>, T must be typelist<...>.");
};
template <class... Types>
struct size<typelist<Types...>> final {
static constexpr size_t value = sizeof...(Types);
};
/**
* Transforms a list of types into a tuple holding these types.
* Example:
* std::tuple<int, string> == to_tuple_t<typelist<int, string>>
*/
template <class TypeList>
struct to_tuple final {
static_assert(
false_t<TypeList>::value,
"In typelist::to_tuple<T>, T must be typelist<...>.");
};
template <class... Types>
struct to_tuple<typelist<Types...>> final {
using type = std::tuple<Types...>;
};
template <class TypeList>
using to_tuple_t = typename to_tuple<TypeList>::type;
/**
* Creates a typelist containing the types of a given tuple.
* Example:
* typelist<int, string> == from_tuple_t<std::tuple<int, string>>
*/
template <class Tuple>
struct from_tuple final {
static_assert(
false_t<Tuple>::value,
"In typelist::from_tuple<T>, T must be std::tuple<...>.");
};
template <class... Types>
struct from_tuple<std::tuple<Types...>> final {
using type = typelist<Types...>;
};
template <class Tuple>
using from_tuple_t = typename from_tuple<Tuple>::type;
/**
* Concatenates multiple type lists.
* Example:
* typelist<int, string, int> == concat_t<typelist<int, string>,
* typelist<int>>
*/
template <class... TypeLists>
struct concat final {
static_assert(
false_t<TypeLists...>::value,
"In typelist::concat<T1, ...>, the T arguments each must be typelist<...>.");
};
template <class... Head1Types, class... Head2Types, class... TailLists>
struct concat<typelist<Head1Types...>, typelist<Head2Types...>, TailLists...>
final {
using type =
typename concat<typelist<Head1Types..., Head2Types...>, TailLists...>::
type;
};
template <class... HeadTypes>
struct concat<typelist<HeadTypes...>> final {
using type = typelist<HeadTypes...>;
};
template <>
struct concat<> final {
using type = typelist<>;
};
template <class... TypeLists>
using concat_t = typename concat<TypeLists...>::type;
/**
* Filters the types in a type list by a type trait.
* Examples:
* typelist<int&, const string&&> == filter_t<std::is_reference,
* typelist<void, string, int&, bool, const string&&, int>>
*/
template <template <class> class Condition, class TypeList>
struct filter final {
static_assert(
false_t<TypeList>::value,
"In typelist::filter<Condition, TypeList>, the TypeList argument must be typelist<...>.");
};
template <template <class> class Condition, class Head, class... Tail>
struct filter<Condition, typelist<Head, Tail...>> final {
static_assert(
is_type_condition<Condition>::value,
"In typelist::filter<Condition, TypeList>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
using type = std::conditional_t<
Condition<Head>::value,
concat_t<
typelist<Head>,
typename filter<Condition, typelist<Tail...>>::type>,
typename filter<Condition, typelist<Tail...>>::type>;
};
template <template <class> class Condition>
struct filter<Condition, typelist<>> final {
static_assert(
is_type_condition<Condition>::value,
"In typelist::filter<Condition, TypeList>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
using type = typelist<>;
};
template <template <class> class Condition, class TypeList>
using filter_t = typename filter<Condition, TypeList>::type;
/**
* Counts how many types in the list fulfill a type trait
* Examples:
* 2 == count_if<std::is_reference, typelist<void, string, int&, bool, const
* string&&, int>>
*/
template <template <class> class Condition, class TypeList>
struct count_if final {
static_assert(
is_type_condition<Condition>::value,
"In typelist::count_if<Condition, TypeList>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
static_assert(
is_instantiation_of<typelist, TypeList>::value,
"In typelist::count_if<Condition, TypeList>, the TypeList argument must be typelist<...>.");
// TODO Direct implementation might be faster
static constexpr size_t value = size<filter_t<Condition, TypeList>>::value;
};
/**
* Checks if a typelist contains a certain type.
* Examples:
* contains<typelist<int, string>, string> == true_type
* contains<typelist<int, string>, double> == false_type
*/
namespace detail {
template <class TypeList, class Type, class Enable = void>
struct contains {};
template <class Type>
struct contains<typelist<>, Type, void> : std::false_type {};
template <class Type, class Head, class... Tail>
struct contains<
typelist<Head, Tail...>,
Type,
std::enable_if_t<std::is_same_v<Head, Type>>> : std::true_type {};
template <class Type, class Head, class... Tail>
struct contains<
typelist<Head, Tail...>,
Type,
std::enable_if_t<!std::is_same_v<Head, Type>>>
: contains<typelist<Tail...>, Type> {};
} // namespace detail
template <class TypeList, class Type>
using contains = typename detail::contains<TypeList, Type>::type;
/**
* Returns true iff the type trait is true for all types in the type list
* Examples:
* true == all<std::is_reference, typelist<int&, const float&&, const
* MyClass&>>::value false == all<std::is_reference, typelist<int&, const
* float&&, MyClass>>::value
*/
template <template <class> class Condition, class TypeList>
struct all {
static_assert(
false_t<TypeList>::value,
"In typelist::all<Condition, TypeList>, the TypeList argument must be typelist<...>.");
};
template <template <class> class Condition, class... Types>
struct all<Condition, typelist<Types...>>
: std::conjunction<Condition<Types>...> {
static_assert(
is_type_condition<Condition>::value,
"In typelist::all<Condition, TypeList>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
};
/**
* Returns true iff the type trait is true for any type in the type list
* Examples:
* true == true_for_any_type<std::is_reference, typelist<int, const
* float&&, const MyClass>>::value false ==
* true_for_any_type<std::is_reference, typelist<int, const float,
* MyClass>>::value
*/
template <template <class> class Condition, class TypeList>
struct true_for_any_type final {
static_assert(
false_t<TypeList>::value,
"In typelist::true_for_any_type<Condition, TypeList>, the TypeList argument must be typelist<...>.");
};
template <template <class> class Condition, class... Types>
struct true_for_any_type<Condition, typelist<Types...>> final
: std::disjunction<Condition<Types>...> {
static_assert(
is_type_condition<Condition>::value,
"In typelist::true_for_any_type<Condition, TypeList>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
};
/**
* Maps types of a type list using a type trait
* Example:
* typelist<int&, double&, string&> == map_t<std::add_lvalue_reference_t,
* typelist<int, double, string>>
*/
template <template <class> class Mapper, class TypeList>
struct map final {
static_assert(
false_t<TypeList>::value,
"In typelist::map<Mapper, TypeList>, the TypeList argument must be typelist<...>.");
};
template <template <class> class Mapper, class... Types>
struct map<Mapper, typelist<Types...>> final {
using type = typelist<Mapper<Types>...>;
};
template <template <class> class Mapper, class TypeList>
using map_t = typename map<Mapper, TypeList>::type;
/**
* Returns the first element of a type list.
* Example:
* int == head_t<typelist<int, string>>
*/
template <class TypeList>
struct head final {
static_assert(
false_t<TypeList>::value,
"In typelist::head<T>, the T argument must be typelist<...>.");
};
template <class Head, class... Tail>
struct head<typelist<Head, Tail...>> final {
using type = Head;
};
template <class TypeList>
using head_t = typename head<TypeList>::type;
/**
* Returns the first element of a type list, or the specified default if the
* type list is empty. Example: int == head_t<bool, typelist<int, string>>
* bool == head_t<bool, typelist<>>
*/
template <class Default, class TypeList>
struct head_with_default final {
using type = Default;
};
template <class Default, class Head, class... Tail>
struct head_with_default<Default, typelist<Head, Tail...>> final {
using type = Head;
};
template <class Default, class TypeList>
using head_with_default_t = typename head_with_default<Default, TypeList>::type;
/**
* Returns the N-th element of a type list.
* Example:
* int == element_t<1, typelist<float, int, char>>
*/
/// Base template.
template <size_t Index, class TypeList>
struct element final {
static_assert(
false_t<TypeList>::value,
"In typelist::element<T>, the T argument must be typelist<...>.");
};
/// Successful case, we have reached the zero index and can "return" the head
/// type.
template <class Head, class... Tail>
struct element<0, typelist<Head, Tail...>> {
using type = Head;
};
/// Error case, we have an index but ran out of types! It will only be selected
/// if `Ts...` is actually empty!
template <size_t Index, class... Ts>
struct element<Index, typelist<Ts...>> {
static_assert(
Index < sizeof...(Ts),
"Index is out of bounds in typelist::element");
};
/// Shave off types until we hit the <0, Head, Tail...> or <Index> case.
template <size_t Index, class Head, class... Tail>
struct element<Index, typelist<Head, Tail...>>
: element<Index - 1, typelist<Tail...>> {};
/// Convenience alias.
template <size_t Index, class TypeList>
using element_t = typename element<Index, TypeList>::type;
/**
* Returns the last element of a type list.
* Example:
* int == last_t<typelist<int, string>>
*/
template <class TypeList>
struct last final {
static_assert(
false_t<TypeList>::value,
"In typelist::last<T>, the T argument must be typelist<...>.");
};
template <class Head, class... Tail>
struct last<typelist<Head, Tail...>> final {
using type = typename last<typelist<Tail...>>::type;
};
template <class Head>
struct last<typelist<Head>> final {
using type = Head;
};
template <class TypeList>
using last_t = typename last<TypeList>::type;
static_assert(std::is_same_v<int, last_t<typelist<double, float, int>>>);
/**
* Take/drop a number of arguments from a typelist.
* Example:
* typelist<int, string> == take_t<typelist<int, string, bool>, 2>
* typelist<bool> == drop_t<typelist<int, string, bool>, 2>
*/
namespace detail {
template <class TypeList, size_t offset, class IndexSequence>
struct take_elements final {};
template <class TypeList, size_t offset, size_t... Indices>
struct take_elements<TypeList, offset, std::index_sequence<Indices...>> final {
using type = typelist<typename element<offset + Indices, TypeList>::type...>;
};
} // namespace detail
template <class TypeList, size_t num>
struct take final {
static_assert(
is_instantiation_of<typelist, TypeList>::value,
"In typelist::take<T, num>, the T argument must be typelist<...>.");
static_assert(
num <= size<TypeList>::value,
"Tried to typelist::take more elements than there are in the list");
using type = typename detail::
take_elements<TypeList, 0, std::make_index_sequence<num>>::type;
};
template <class TypeList, size_t num>
using take_t = typename take<TypeList, num>::type;
template <class TypeList, size_t num>
struct drop final {
static_assert(
is_instantiation_of<typelist, TypeList>::value,
"In typelist::drop<T, num>, the T argument must be typelist<...>.");
static_assert(
num <= size<TypeList>::value,
"Tried to typelist::drop more elements than there are in the list");
using type = typename detail::take_elements<
TypeList,
num,
std::make_index_sequence<size<TypeList>::value - num>>::type;
};
template <class TypeList, size_t num>
using drop_t = typename drop<TypeList, num>::type;
/**
* Like drop, but returns an empty list rather than an assertion error if `num`
* is larger than the size of the TypeList.
* Example:
* typelist<> == drop_if_nonempty_t<typelist<string, bool>, 2>
* typelist<> == drop_if_nonempty_t<typelist<int, string, bool>, 3>
*/
template <class TypeList, size_t num>
struct drop_if_nonempty final {
static_assert(
is_instantiation_of<typelist, TypeList>::value,
"In typelist::drop<T, num>, the T argument must be typelist<...>.");
using type = typename detail::take_elements<
TypeList,
std::min(num, size<TypeList>::value),
std::make_index_sequence<
size<TypeList>::value - std::min(num, size<TypeList>::value)>>::type;
};
template <class TypeList, size_t num>
using drop_if_nonempty_t = typename drop_if_nonempty<TypeList, num>::type;
/**
* Reverses a typelist.
* Example:
* typelist<int, string> == reverse_t<typelist<string, int>>
*/
template <class TypeList>
struct reverse final {
static_assert(
false_t<TypeList>::value,
"In typelist::reverse<T>, the T argument must be typelist<...>.");
};
template <class Head, class... Tail>
struct reverse<typelist<Head, Tail...>> final {
using type =
concat_t<typename reverse<typelist<Tail...>>::type, typelist<Head>>;
};
template <>
struct reverse<typelist<>> final {
using type = typelist<>;
};
template <class TypeList>
using reverse_t = typename reverse<TypeList>::type;
/**
* Find the index of the first type in a typelist fulfilling a type trait
* condition. Example:
*
* 2 == find_if<typelist<char, int, char&, int&>, std::is_reference>::value
*/
template <class TypeList, template <class> class Condition, class Enable = void>
struct find_if final {
static_assert(
false_t<TypeList>::value,
"In typelist::find_if<TypeList, Condition>, the TypeList argument must be typelist<...>.");
};
template <template <class> class Condition>
struct find_if<typelist<>, Condition, void> final {
static_assert(
false_higher_t<Condition>::value,
"In typelist::find_if<Type/List, Condition>, didn't find any type fulfilling the Condition.");
};
template <class Head, class... Tail, template <class> class Condition>
struct find_if<
typelist<Head, Tail...>,
Condition,
std::enable_if_t<Condition<Head>::value>>
final {
static constexpr size_t value = 0;
};
template <class Head, class... Tail, template <class> class Condition>
struct find_if<
typelist<Head, Tail...>,
Condition,
std::enable_if_t<!Condition<Head>::value>>
final {
static constexpr size_t value =
1 + find_if<typelist<Tail...>, Condition>::value;
};
/**
* Maps a list of types into a list of values.
* Examples:
* // Example 1
* auto sizes =
* map_types_to_values<typelist<int64_t, bool, uint32_t>>(
* [] (auto t) { return sizeof(decltype(t)::type); }
* );
* // sizes == std::tuple<size_t, size_t, size_t>{8, 1, 4}
*
* // Example 2
* auto shared_ptrs =
* map_types_to_values<typelist<int, double>>(
* [] (auto t) { return make_shared<typename decltype(t)::type>(); }
* );
* // shared_ptrs == std::tuple<shared_ptr<int>, shared_ptr<double>>()
*/
namespace detail {
template <class T>
struct type_ final {
using type = T;
};
template <class TypeList>
struct map_types_to_values final {
static_assert(
false_t<TypeList>::value,
"In typelist::map_types_to_values<T>, the T argument must be typelist<...>.");
};
template <class... Types>
struct map_types_to_values<typelist<Types...>> final {
template <class Func>
static auto call(Func&& func) {
return std::tuple{std::forward<Func>(func)(type_<Types>())...};
}
};
} // namespace detail
template <class TypeList, class Func>
decltype(auto) map_types_to_values(Func&& func) {
return detail::map_types_to_values<TypeList>::call(std::forward<Func>(func));
}
} // namespace typelist
} // namespace c10::guts

View File

@ -0,0 +1,140 @@
#pragma once
#include <c10/macros/Macros.h>
#include <limits>
#include <type_traits>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wstring-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wstring-conversion")
#endif
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
namespace c10 {
/// Returns false since we cannot have x < 0 if x is unsigned.
template <typename T>
inline constexpr bool is_negative(
const T& /*x*/,
std::true_type /*is_unsigned*/) {
return false;
}
/// Returns true if a signed variable x < 0
template <typename T>
inline constexpr bool is_negative(const T& x, std::false_type /*is_unsigned*/) {
return x < T(0);
}
/// Returns true if x < 0
/// NOTE: Will fail on an unsigned custom type
/// For the most part it's possible to fix this if
/// the custom type has a constexpr constructor.
/// However, notably, c10::Half does not :-(
template <typename T>
inline constexpr bool is_negative(const T& x) {
return is_negative(x, std::is_unsigned<T>());
}
/// Returns the sign of an unsigned variable x as 0, 1
template <typename T>
inline constexpr int signum(const T& x, std::true_type /*is_unsigned*/) {
return T(0) < x;
}
/// Returns the sign of a signed variable x as -1, 0, 1
template <typename T>
inline constexpr int signum(const T& x, std::false_type /*is_unsigned*/) {
return (T(0) < x) - (x < T(0));
}
/// Returns the sign of x as -1, 0, 1
/// NOTE: Will fail on an unsigned custom type
/// For the most part it's possible to fix this if
/// the custom type has a constexpr constructor.
/// However, notably, c10::Half does not :-(
template <typename T>
inline constexpr int signum(const T& x) {
return signum(x, std::is_unsigned<T>());
}
/// Returns true if a and b are not both negative
template <typename T, typename U>
inline constexpr bool signs_differ(const T& a, const U& b) {
return is_negative(a) != is_negative(b);
}
// Suppress sign compare warning when compiling with GCC
// as later does not account for short-circuit rule before
// raising the warning, see https://godbolt.org/z/Tr3Msnz99
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wsign-compare"
#endif
/// Returns true if x is greater than the greatest value of the type Limit
template <typename Limit, typename T>
inline constexpr bool greater_than_max(const T& x) {
constexpr bool can_overflow =
std::numeric_limits<T>::digits > std::numeric_limits<Limit>::digits;
return can_overflow && x > std::numeric_limits<Limit>::max();
}
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif
/// Returns true if x < lowest(Limit). Standard comparison
template <typename Limit, typename T>
inline constexpr bool less_than_lowest(
const T& x,
std::false_type /*limit_is_unsigned*/,
std::false_type /*x_is_unsigned*/) {
return x < std::numeric_limits<Limit>::lowest();
}
/// Returns false since all the limit is signed and therefore includes
/// negative values but x cannot be negative because it is unsigned
template <typename Limit, typename T>
inline constexpr bool less_than_lowest(
const T& /*x*/,
std::false_type /*limit_is_unsigned*/,
std::true_type /*x_is_unsigned*/) {
return false;
}
/// Returns true if x < 0, where 0 is constructed from T.
/// Limit is not signed, so its lower value is zero
template <typename Limit, typename T>
inline constexpr bool less_than_lowest(
const T& x,
std::true_type /*limit_is_unsigned*/,
std::false_type /*x_is_unsigned*/) {
return x < T(0);
}
/// Returns false sign both types are unsigned
template <typename Limit, typename T>
inline constexpr bool less_than_lowest(
const T& /*x*/,
std::true_type /*limit_is_unsigned*/,
std::true_type /*x_is_unsigned*/) {
return false;
}
/// Returns true if x is less than the lowest value of type T
/// NOTE: Will fail on an unsigned custom type
/// For the most part it's possible to fix this if
/// the custom type has a constexpr constructor.
/// However, notably, c10::Half does not :
template <typename Limit, typename T>
inline constexpr bool less_than_lowest(const T& x) {
return less_than_lowest<Limit>(
x, std::is_unsigned<Limit>(), std::is_unsigned<T>());
}
} // namespace c10
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -0,0 +1,151 @@
#pragma once
#include <functional>
#include <type_traits>
namespace c10::guts {
/**
* is_equality_comparable<T> is true_type iff the equality operator is defined
* for T.
*/
template <class T, class Enable = void>
struct is_equality_comparable : std::false_type {};
template <class T>
struct is_equality_comparable<
T,
std::void_t<decltype(std::declval<T&>() == std::declval<T&>())>>
: std::true_type {};
template <class T>
using is_equality_comparable_t = typename is_equality_comparable<T>::type;
/**
* is_hashable<T> is true_type iff std::hash is defined for T
*/
template <class T, class Enable = void>
struct is_hashable : std::false_type {};
template <class T>
struct is_hashable<T, std::void_t<decltype(std::hash<T>()(std::declval<T&>()))>>
: std::true_type {};
template <class T>
using is_hashable_t = typename is_hashable<T>::type;
/**
* is_function_type<T> is true_type iff T is a plain function type (i.e.
* "Result(Args...)")
*/
template <class T>
struct is_function_type : std::false_type {};
template <class Result, class... Args>
struct is_function_type<Result(Args...)> : std::true_type {};
template <class T>
using is_function_type_t = typename is_function_type<T>::type;
/**
* is_instantiation_of<T, I> is true_type iff I is a template instantiation of T
* (e.g. vector<int> is an instantiation of vector) Example:
* is_instantiation_of_t<vector, vector<int>> // true
* is_instantiation_of_t<pair, pair<int, string>> // true
* is_instantiation_of_t<vector, pair<int, string>> // false
*/
template <template <class...> class Template, class T>
struct is_instantiation_of : std::false_type {};
template <template <class...> class Template, class... Args>
struct is_instantiation_of<Template, Template<Args...>> : std::true_type {};
template <template <class...> class Template, class T>
using is_instantiation_of_t = typename is_instantiation_of<Template, T>::type;
namespace detail {
/**
* strip_class: helper to remove the class type from pointers to `operator()`.
*/
template <typename T>
struct strip_class {};
template <typename Class, typename Result, typename... Args>
struct strip_class<Result (Class::*)(Args...)> {
using type = Result(Args...);
};
template <typename Class, typename Result, typename... Args>
struct strip_class<Result (Class::*)(Args...) const> {
using type = Result(Args...);
};
template <typename T>
using strip_class_t = typename strip_class<T>::type;
} // namespace detail
/**
* Evaluates to true_type, iff the given class is a Functor
* (i.e. has a call operator with some set of arguments)
*/
template <class Functor, class Enable = void>
struct is_functor : std::false_type {};
template <class Functor>
struct is_functor<
Functor,
std::enable_if_t<is_function_type<
detail::strip_class_t<decltype(&Functor::operator())>>::value>>
: std::true_type {};
/**
* lambda_is_stateless<T> is true iff the lambda type T is stateless
* (i.e. does not have a closure).
* Example:
* auto stateless_lambda = [] (int a) {return a;};
* lambda_is_stateless<decltype(stateless_lambda)> // true
* auto stateful_lambda = [&] (int a) {return a;};
* lambda_is_stateless<decltype(stateful_lambda)> // false
*/
namespace detail {
template <class LambdaType, class FuncType>
struct is_stateless_lambda__ final {
static_assert(
!std::is_same_v<LambdaType, LambdaType>,
"Base case shouldn't be hit");
};
// implementation idea: According to the C++ standard, stateless lambdas are
// convertible to function pointers
template <class LambdaType, class C, class Result, class... Args>
struct is_stateless_lambda__<LambdaType, Result (C::*)(Args...) const>
: std::is_convertible<LambdaType, Result (*)(Args...)> {};
template <class LambdaType, class C, class Result, class... Args>
struct is_stateless_lambda__<LambdaType, Result (C::*)(Args...)>
: std::is_convertible<LambdaType, Result (*)(Args...)> {};
// case where LambdaType is not even a functor
template <class LambdaType, class Enable = void>
struct is_stateless_lambda_ final : std::false_type {};
// case where LambdaType is a functor
template <class LambdaType>
struct is_stateless_lambda_<
LambdaType,
std::enable_if_t<is_functor<LambdaType>::value>>
: is_stateless_lambda__<LambdaType, decltype(&LambdaType::operator())> {};
} // namespace detail
template <class T>
using is_stateless_lambda = detail::is_stateless_lambda_<std::decay_t<T>>;
/**
* is_type_condition<C> is true_type iff C<...> is a type trait representing a
* condition (i.e. has a constexpr static bool ::value member) Example:
* is_type_condition<std::is_reference> // true
*/
template <template <class> class C, class Enable = void>
struct is_type_condition : std::false_type {};
template <template <class> class C>
struct is_type_condition<
C,
std::enable_if_t<
std::is_same_v<bool, std::remove_cv_t<decltype(C<int>::value)>>>>
: std::true_type {};
/**
* is_fundamental<T> is true_type iff the lambda type T is a fundamental type
* (that is, arithmetic type, void, or nullptr_t). Example: is_fundamental<int>
* // true We define it here to resolve a MSVC bug. See
* https://github.com/pytorch/pytorch/issues/30932 for details.
*/
template <class T>
struct is_fundamental : std::is_fundamental<T> {};
} // namespace c10::guts

View File

@ -0,0 +1,14 @@
#pragma once
#if defined(_WIN32)
#include <c10/util/Exception.h>
#include <c10/util/win32-headers.h>
#include <string>
#endif
namespace c10 {
#if defined(_WIN32)
C10_API std::wstring u8u16(const std::string& str);
C10_API std::string u16u8(const std::wstring& wstr);
#endif
} // namespace c10

View File

@ -0,0 +1,127 @@
#pragma once
#include <cstddef>
#include <memory>
#include <utility>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
namespace c10 {
using DeleterFnPtr = void (*)(void*);
namespace detail {
// Does not delete anything
C10_API void deleteNothing(void*);
// A detail::UniqueVoidPtr is an owning smart pointer like unique_ptr, but
// with three major differences:
//
// 1) It is specialized to void
//
// 2) It is specialized for a function pointer deleter
// void(void* ctx); i.e., the deleter doesn't take a
// reference to the data, just to a context pointer
// (erased as void*). In fact, internally, this pointer
// is implemented as having an owning reference to
// context, and a non-owning reference to data; this is why
// you release_context(), not release() (the conventional
// API for release() wouldn't give you enough information
// to properly dispose of the object later.)
//
// 3) The deleter is guaranteed to be called when the unique
// pointer is destructed and the context is non-null; this is different
// from std::unique_ptr where the deleter is not called if the
// data pointer is null.
//
// Some of the methods have slightly different types than std::unique_ptr
// to reflect this.
//
class UniqueVoidPtr {
private:
// Lifetime tied to ctx_
void* data_;
std::unique_ptr<void, DeleterFnPtr> ctx_;
public:
UniqueVoidPtr() : data_(nullptr), ctx_(nullptr, &deleteNothing) {}
explicit UniqueVoidPtr(void* data)
: data_(data), ctx_(nullptr, &deleteNothing) {}
UniqueVoidPtr(void* data, void* ctx, DeleterFnPtr ctx_deleter)
: data_(data), ctx_(ctx, ctx_deleter ? ctx_deleter : &deleteNothing) {}
void* operator->() const {
return data_;
}
void clear() {
ctx_ = nullptr;
data_ = nullptr;
}
void* get() const {
return data_;
}
void* get_context() const {
return ctx_.get();
}
void* release_context() {
return ctx_.release();
}
std::unique_ptr<void, DeleterFnPtr>&& move_context() {
return std::move(ctx_);
}
C10_NODISCARD bool compare_exchange_deleter(
DeleterFnPtr expected_deleter,
DeleterFnPtr new_deleter) {
if (get_deleter() != expected_deleter)
return false;
ctx_ = std::unique_ptr<void, DeleterFnPtr>(ctx_.release(), new_deleter);
return true;
}
template <typename T>
T* cast_context(DeleterFnPtr expected_deleter) const {
if (get_deleter() != expected_deleter)
return nullptr;
return static_cast<T*>(get_context());
}
operator bool() const {
return data_ || ctx_;
}
DeleterFnPtr get_deleter() const {
return ctx_.get_deleter();
}
};
// Note [How UniqueVoidPtr is implemented]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// UniqueVoidPtr solves a common problem for allocators of tensor data, which
// is that the data pointer (e.g., float*) which you are interested in, is not
// the same as the context pointer (e.g., DLManagedTensor) which you need
// to actually deallocate the data. Under a conventional deleter design, you
// have to store extra context in the deleter itself so that you can actually
// delete the right thing. Implementing this with standard C++ is somewhat
// error-prone: if you use a std::unique_ptr to manage tensors, the deleter will
// not be called if the data pointer is nullptr, which can cause a leak if the
// context pointer is non-null (and the deleter is responsible for freeing both
// the data pointer and the context pointer).
//
// So, in our reimplementation of unique_ptr, which just store the context
// directly in the unique pointer, and attach the deleter to the context
// pointer itself. In simple cases, the context pointer is just the pointer
// itself.
inline bool operator==(const UniqueVoidPtr& sp, std::nullptr_t) noexcept {
return !sp;
}
inline bool operator==(std::nullptr_t, const UniqueVoidPtr& sp) noexcept {
return !sp;
}
inline bool operator!=(const UniqueVoidPtr& sp, std::nullptr_t) noexcept {
return sp;
}
inline bool operator!=(std::nullptr_t, const UniqueVoidPtr& sp) noexcept {
return sp;
}
} // namespace detail
} // namespace c10

View File

@ -0,0 +1,30 @@
#pragma once
#include <c10/macros/Macros.h>
#include <type_traits>
// Utility to guarantee complete unrolling of a loop where the bounds are known
// at compile time. Various pragmas achieve similar effects, but are not as
// portable across compilers.
// Example: c10::ForcedUnroll<4>{}(f); is equivalent to f(0); f(1); f(2); f(3);
namespace c10 {
template <int n>
struct ForcedUnroll {
template <typename Func, typename... Args>
C10_ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
ForcedUnroll<n - 1>{}(f, args...);
f(std::integral_constant<int, n - 1>{}, args...);
}
};
template <>
struct ForcedUnroll<1> {
template <typename Func, typename... Args>
C10_ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
f(std::integral_constant<int, 0>{}, args...);
}
};
} // namespace c10

View File

@ -0,0 +1,94 @@
#pragma once
#include <chrono>
#include <memory>
#include <string>
#include <string_view>
#include <c10/macros/Macros.h>
#include <c10/util/ScopeExit.h>
#include <c10/util/SmallVector.h>
namespace c10::monitor {
namespace detail {
class WaitCounterImpl;
class WaitCounterBackendIf {
public:
virtual ~WaitCounterBackendIf() = default;
virtual intptr_t start(
std::chrono::steady_clock::time_point now) noexcept = 0;
virtual void stop(
std::chrono::steady_clock::time_point now,
intptr_t ctx) noexcept = 0;
};
class WaitCounterBackendFactoryIf {
public:
virtual ~WaitCounterBackendFactoryIf() = default;
// May return nullptr.
// In this case the counter will be ignored by the given backend.
virtual std::unique_ptr<WaitCounterBackendIf> create(
std::string_view key) noexcept = 0;
};
C10_API void registerWaitCounterBackend(
std::unique_ptr<WaitCounterBackendFactoryIf>);
} // namespace detail
// A handle to a wait counter.
class C10_API WaitCounterHandle {
public:
explicit WaitCounterHandle(std::string_view key);
class WaitGuard {
public:
WaitGuard(WaitGuard&& other) noexcept
: handle_{std::exchange(other.handle_, {})},
ctxs_{std::move(other.ctxs_)} {}
WaitGuard(const WaitGuard&) = delete;
WaitGuard& operator=(const WaitGuard&) = delete;
WaitGuard& operator=(WaitGuard&&) = delete;
~WaitGuard() {
stop();
}
void stop() {
if (auto handle = std::exchange(handle_, nullptr)) {
handle->stop(std::move(ctxs_));
}
}
private:
WaitGuard(WaitCounterHandle& handle, SmallVector<intptr_t>&& ctxs)
: handle_{&handle}, ctxs_{std::move(ctxs)} {}
friend class WaitCounterHandle;
WaitCounterHandle* handle_;
SmallVector<intptr_t> ctxs_;
};
// Starts a waiter
WaitGuard start();
private:
// Stops the waiter. Each start() call should be matched by exactly one stop()
// call.
void stop(SmallVector<intptr_t>&& ctxs);
detail::WaitCounterImpl& impl_;
};
} // namespace c10::monitor
#define STATIC_WAIT_COUNTER(_key) \
[]() -> ::c10::monitor::WaitCounterHandle& { \
static ::c10::monitor::WaitCounterHandle handle(#_key); \
return handle; \
}()
#define STATIC_SCOPED_WAIT_COUNTER(_name) \
auto C10_ANONYMOUS_VARIABLE(SCOPE_GUARD) = STATIC_WAIT_COUNTER(_name).start();

View File

@ -0,0 +1,124 @@
// Copyright 2004-present Facebook. All Rights Reserved.
#pragma once
#include <c10/util/Exception.h>
#include <cstdint>
#include <functional>
#include <iterator>
#include <numeric>
#include <type_traits>
#include <utility>
namespace c10 {
/// Sum of a list of integers; accumulates into the int64_t datatype
template <
typename C,
std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
inline int64_t sum_integers(const C& container) {
// std::accumulate infers return type from `init` type, so if the `init` type
// is not large enough to hold the result, computation can overflow. We use
// `int64_t` here to avoid this.
return std::accumulate(
container.begin(), container.end(), static_cast<int64_t>(0));
}
/// Sum of integer elements referred to by iterators; accumulates into the
/// int64_t datatype
template <
typename Iter,
std::enable_if_t<
std::is_integral_v<typename std::iterator_traits<Iter>::value_type>,
int> = 0>
inline int64_t sum_integers(Iter begin, Iter end) {
// std::accumulate infers return type from `init` type, so if the `init` type
// is not large enough to hold the result, computation can overflow. We use
// `int64_t` here to avoid this.
return std::accumulate(begin, end, static_cast<int64_t>(0));
}
/// Product of a list of integers; accumulates into the int64_t datatype
template <
typename C,
std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
inline int64_t multiply_integers(const C& container) {
// std::accumulate infers return type from `init` type, so if the `init` type
// is not large enough to hold the result, computation can overflow. We use
// `int64_t` here to avoid this.
return std::accumulate(
container.begin(),
container.end(),
static_cast<int64_t>(1),
std::multiplies<>());
}
/// Product of integer elements referred to by iterators; accumulates into the
/// int64_t datatype
template <
typename Iter,
std::enable_if_t<
std::is_integral_v<typename std::iterator_traits<Iter>::value_type>,
int> = 0>
inline int64_t multiply_integers(Iter begin, Iter end) {
// std::accumulate infers return type from `init` type, so if the `init` type
// is not large enough to hold the result, computation can overflow. We use
// `int64_t` here to avoid this.
return std::accumulate(
begin, end, static_cast<int64_t>(1), std::multiplies<>());
}
/// Return product of all dimensions starting from k
/// Returns 1 if k>=dims.size()
template <
typename C,
std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
inline int64_t numelements_from_dim(const int k, const C& dims) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(k >= 0);
if (k > static_cast<int>(dims.size())) {
return 1;
} else {
auto cbegin = dims.cbegin();
std::advance(cbegin, k);
return multiply_integers(cbegin, dims.cend());
}
}
/// Product of all dims up to k (not including dims[k])
/// Throws an error if k>dims.size()
template <
typename C,
std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
inline int64_t numelements_to_dim(const int k, const C& dims) {
TORCH_INTERNAL_ASSERT(0 <= k);
TORCH_INTERNAL_ASSERT((unsigned)k <= dims.size());
auto cend = dims.cbegin();
std::advance(cend, k);
return multiply_integers(dims.cbegin(), cend);
}
/// Product of all dims between k and l (including dims[k] and excluding
/// dims[l]) k and l may be supplied in either order
template <
typename C,
std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
inline int64_t numelements_between_dim(int k, int l, const C& dims) {
TORCH_INTERNAL_ASSERT(0 <= k);
TORCH_INTERNAL_ASSERT(0 <= l);
if (k > l) {
std::swap(k, l);
}
TORCH_INTERNAL_ASSERT((unsigned)l < dims.size());
auto cbegin = dims.cbegin();
auto cend = dims.cbegin();
std::advance(cbegin, k);
std::advance(cend, l);
return multiply_integers(cbegin, cend);
}
} // namespace c10

View File

@ -0,0 +1,31 @@
#pragma once
#include <cstring>
#include <type_traits>
namespace c10 {
// Implementations of std::bit_cast() from C++ 20.
//
// This is a less sketchy version of reinterpret_cast.
//
// See https://en.cppreference.com/w/cpp/numeric/bit_cast for more
// information as well as the source of our implementations.
template <class To, class From>
std::enable_if_t<
sizeof(To) == sizeof(From) && std::is_trivially_copyable_v<From> &&
std::is_trivially_copyable_v<To>,
To>
// constexpr support needs compiler magic
bit_cast(const From& src) noexcept {
static_assert(
std::is_trivially_constructible_v<To>,
"This implementation additionally requires "
"destination type to be trivially constructible");
To dst;
std::memcpy(&dst, &src, sizeof(To));
return dst;
}
} // namespace c10

View File

@ -0,0 +1,61 @@
#pragma once
#include <cstdint>
#include <c10/macros/Macros.h>
namespace c10 {
/**
* bits1x8 is an uninterpreted dtype of a tensor with 1 bit (packed to byte
* boundary), without any semantics defined.
*/
struct alignas(1) bits1x8 {
using underlying = uint8_t;
uint8_t val_;
bits1x8() = default;
C10_HOST_DEVICE explicit bits1x8(uint8_t val) : val_(val) {}
};
/**
* bits2x4 is an uninterpreted dtype of a tensor with 2 bits (packed to byte
* boundary), without any semantics defined.
*/
struct alignas(1) bits2x4 {
using underlying = uint8_t;
uint8_t val_;
bits2x4() = default;
C10_HOST_DEVICE explicit bits2x4(uint8_t val) : val_(val) {}
};
/**
* bits4x2 is an uninterpreted dtype of a tensor with 4 bits (packed to byte
* boundary), without any semantics defined.
*/
struct alignas(1) bits4x2 {
using underlying = uint8_t;
uint8_t val_;
bits4x2() = default;
C10_HOST_DEVICE explicit bits4x2(uint8_t val) : val_(val) {}
};
/**
* bits8 is an uninterpreted dtype of a tensor with 8 bits, without any
* semantics defined.
*/
struct alignas(1) bits8 {
uint8_t val_;
bits8() = default;
C10_HOST_DEVICE explicit bits8(uint8_t val) : val_(val) {}
};
/**
* bits16 is an uninterpreted dtype of a tensor with 16 bits, without any
* semantics defined.
*/
struct alignas(2) bits16 {
uint16_t val_;
bits16() = default;
C10_HOST_DEVICE explicit bits16(uint16_t val) : val_(val) {}
};
} // namespace c10

View File

@ -0,0 +1,618 @@
#pragma once
#include <complex>
#include <c10/macros/Macros.h>
#if defined(__CUDACC__) || defined(__HIPCC__)
#include <thrust/complex.h>
#endif
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
#endif
#if C10_CLANG_HAS_WARNING("-Wfloat-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion")
#endif
namespace c10 {
// c10::complex is an implementation of complex numbers that aims
// to work on all devices supported by PyTorch
//
// Most of the APIs duplicates std::complex
// Reference: https://en.cppreference.com/w/cpp/numeric/complex
//
// [NOTE: Complex Operator Unification]
// Operators currently use a mix of std::complex, thrust::complex, and
// c10::complex internally. The end state is that all operators will use
// c10::complex internally. Until then, there may be some hacks to support all
// variants.
//
//
// [Note on Constructors]
//
// The APIs of constructors are mostly copied from C++ standard:
// https://en.cppreference.com/w/cpp/numeric/complex/complex
//
// Since C++14, all constructors are constexpr in std::complex
//
// There are three types of constructors:
// - initializing from real and imag:
// `constexpr complex( const T& re = T(), const T& im = T() );`
// - implicitly-declared copy constructor
// - converting constructors
//
// Converting constructors:
// - std::complex defines converting constructor between float/double/long
// double,
// while we define converting constructor between float/double.
// - For these converting constructors, upcasting is implicit, downcasting is
// explicit.
// - We also define explicit casting from std::complex/thrust::complex
// - Note that the conversion from thrust is not constexpr, because
// thrust does not define them as constexpr ????
//
//
// [Operator =]
//
// The APIs of operator = are mostly copied from C++ standard:
// https://en.cppreference.com/w/cpp/numeric/complex/operator%3D
//
// Since C++20, all operator= are constexpr. Although we are not building with
// C++20, we also obey this behavior.
//
// There are three types of assign operator:
// - Assign a real value from the same scalar type
// - In std, this is templated as complex& operator=(const T& x)
// with specialization `complex& operator=(T x)` for float/double/long
// double Since we only support float and double, on will use `complex&
// operator=(T x)`
// - Copy assignment operator and converting assignment operator
// - There is no specialization of converting assignment operators, which type
// is
// convertible is solely dependent on whether the scalar type is convertible
//
// In addition to the standard assignment, we also provide assignment operators
// with std and thrust
//
//
// [Casting operators]
//
// std::complex does not have casting operators. We define casting operators
// casting to std::complex and thrust::complex
//
//
// [Operator ""]
//
// std::complex has custom literals `i`, `if` and `il` defined in namespace
// `std::literals::complex_literals`. We define our own custom literals in the
// namespace `c10::complex_literals`. Our custom literals does not follow the
// same behavior as in std::complex, instead, we define _if, _id to construct
// float/double complex literals.
//
//
// [real() and imag()]
//
// In C++20, there are two overload of these functions, one it to return the
// real/imag, another is to set real/imag, they are both constexpr. We follow
// this design.
//
//
// [Operator +=,-=,*=,/=]
//
// Since C++20, these operators become constexpr. In our implementation, they
// are also constexpr.
//
// There are two types of such operators: operating with a real number, or
// operating with another complex number. For the operating with a real number,
// the generic template form has argument type `const T &`, while the overload
// for float/double/long double has `T`. We will follow the same type as
// float/double/long double in std.
//
// [Unary operator +-]
//
// Since C++20, they are constexpr. We also make them expr
//
// [Binary operators +-*/]
//
// Each operator has three versions (taking + as example):
// - complex + complex
// - complex + real
// - real + complex
//
// [Operator ==, !=]
//
// Each operator has three versions (taking == as example):
// - complex == complex
// - complex == real
// - real == complex
//
// Some of them are removed on C++20, but we decide to keep them
//
// [Operator <<, >>]
//
// These are implemented by casting to std::complex
//
//
//
// TODO(@zasdfgbnm): c10::complex<c10::Half> is not currently supported,
// because:
// - lots of members and functions of c10::Half are not constexpr
// - thrust::complex only support float and double
template <typename T>
struct alignas(sizeof(T) * 2) complex {
using value_type = T;
T real_ = T(0);
T imag_ = T(0);
constexpr complex() = default;
C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T())
: real_(re), imag_(im) {}
template <typename U>
explicit constexpr complex(const std::complex<U>& other)
: complex(other.real(), other.imag()) {}
#if defined(__CUDACC__) || defined(__HIPCC__)
template <typename U>
explicit C10_HOST_DEVICE complex(const thrust::complex<U>& other)
: real_(other.real()), imag_(other.imag()) {}
// NOTE can not be implemented as follow due to ROCm bug:
// explicit C10_HOST_DEVICE complex(const thrust::complex<U> &other):
// complex(other.real(), other.imag()) {}
#endif
// Use SFINAE to specialize casting constructor for c10::complex<float> and
// c10::complex<double>
template <typename U = T>
C10_HOST_DEVICE explicit constexpr complex(
const std::enable_if_t<std::is_same_v<U, float>, complex<double>>& other)
: real_(other.real_), imag_(other.imag_) {}
template <typename U = T>
C10_HOST_DEVICE constexpr complex(
const std::enable_if_t<std::is_same_v<U, double>, complex<float>>& other)
: real_(other.real_), imag_(other.imag_) {}
constexpr complex<T>& operator=(T re) {
real_ = re;
imag_ = 0;
return *this;
}
constexpr complex<T>& operator+=(T re) {
real_ += re;
return *this;
}
constexpr complex<T>& operator-=(T re) {
real_ -= re;
return *this;
}
constexpr complex<T>& operator*=(T re) {
real_ *= re;
imag_ *= re;
return *this;
}
constexpr complex<T>& operator/=(T re) {
real_ /= re;
imag_ /= re;
return *this;
}
template <typename U>
constexpr complex<T>& operator=(const complex<U>& rhs) {
real_ = rhs.real();
imag_ = rhs.imag();
return *this;
}
template <typename U>
constexpr complex<T>& operator+=(const complex<U>& rhs) {
real_ += rhs.real();
imag_ += rhs.imag();
return *this;
}
template <typename U>
constexpr complex<T>& operator-=(const complex<U>& rhs) {
real_ -= rhs.real();
imag_ -= rhs.imag();
return *this;
}
template <typename U>
constexpr complex<T>& operator*=(const complex<U>& rhs) {
// (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i
T a = real_;
T b = imag_;
U c = rhs.real();
U d = rhs.imag();
real_ = a * c - b * d;
imag_ = a * d + b * c;
return *this;
}
#ifdef __APPLE__
#define FORCE_INLINE_APPLE __attribute__((always_inline))
#else
#define FORCE_INLINE_APPLE
#endif
template <typename U>
constexpr FORCE_INLINE_APPLE complex<T>& operator/=(const complex<U>& rhs)
__ubsan_ignore_float_divide_by_zero__ {
// (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
// the calculation below follows numpy's complex division
T a = real_;
T b = imag_;
U c = rhs.real();
U d = rhs.imag();
#if defined(__GNUC__) && !defined(__clang__)
// std::abs is already constexpr by gcc
auto abs_c = std::abs(c);
auto abs_d = std::abs(d);
#else
auto abs_c = c < 0 ? -c : c;
auto abs_d = d < 0 ? -d : d;
#endif
if (abs_c >= abs_d) {
if (abs_c == U(0) && abs_d == U(0)) {
/* divide by zeros should yield a complex inf or nan */
real_ = a / abs_c;
imag_ = b / abs_d;
} else {
auto rat = d / c;
auto scl = U(1.0) / (c + d * rat);
real_ = (a + b * rat) * scl;
imag_ = (b - a * rat) * scl;
}
} else {
auto rat = c / d;
auto scl = U(1.0) / (d + c * rat);
real_ = (a * rat + b) * scl;
imag_ = (b * rat - a) * scl;
}
return *this;
}
#undef FORCE_INLINE_APPLE
template <typename U>
constexpr complex<T>& operator=(const std::complex<U>& rhs) {
real_ = rhs.real();
imag_ = rhs.imag();
return *this;
}
#if defined(__CUDACC__) || defined(__HIPCC__)
template <typename U>
C10_HOST_DEVICE complex<T>& operator=(const thrust::complex<U>& rhs) {
real_ = rhs.real();
imag_ = rhs.imag();
return *this;
}
#endif
template <typename U>
explicit constexpr operator std::complex<U>() const {
return std::complex<U>(std::complex<T>(real(), imag()));
}
#if defined(__CUDACC__) || defined(__HIPCC__)
template <typename U>
C10_HOST_DEVICE explicit operator thrust::complex<U>() const {
return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
}
#endif
// consistent with NumPy behavior
explicit constexpr operator bool() const {
return real() || imag();
}
C10_HOST_DEVICE constexpr T real() const {
return real_;
}
constexpr void real(T value) {
real_ = value;
}
C10_HOST_DEVICE constexpr T imag() const {
return imag_;
}
constexpr void imag(T value) {
imag_ = value;
}
};
namespace complex_literals {
constexpr complex<float> operator""_if(long double imag) {
return complex<float>(0.0f, static_cast<float>(imag));
}
constexpr complex<double> operator""_id(long double imag) {
return complex<double>(0.0, static_cast<double>(imag));
}
constexpr complex<float> operator""_if(unsigned long long imag) {
return complex<float>(0.0f, static_cast<float>(imag));
}
constexpr complex<double> operator""_id(unsigned long long imag) {
return complex<double>(0.0, static_cast<double>(imag));
}
} // namespace complex_literals
template <typename T>
constexpr complex<T> operator+(const complex<T>& val) {
return val;
}
template <typename T>
constexpr complex<T> operator-(const complex<T>& val) {
return complex<T>(-val.real(), -val.imag());
}
template <typename T>
constexpr complex<T> operator+(const complex<T>& lhs, const complex<T>& rhs) {
complex<T> result = lhs;
return result += rhs;
}
template <typename T>
constexpr complex<T> operator+(const complex<T>& lhs, const T& rhs) {
complex<T> result = lhs;
return result += rhs;
}
template <typename T>
constexpr complex<T> operator+(const T& lhs, const complex<T>& rhs) {
return complex<T>(lhs + rhs.real(), rhs.imag());
}
template <typename T>
constexpr complex<T> operator-(const complex<T>& lhs, const complex<T>& rhs) {
complex<T> result = lhs;
return result -= rhs;
}
template <typename T>
constexpr complex<T> operator-(const complex<T>& lhs, const T& rhs) {
complex<T> result = lhs;
return result -= rhs;
}
template <typename T>
constexpr complex<T> operator-(const T& lhs, const complex<T>& rhs) {
complex<T> result = -rhs;
return result += lhs;
}
template <typename T>
constexpr complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs) {
complex<T> result = lhs;
return result *= rhs;
}
template <typename T>
constexpr complex<T> operator*(const complex<T>& lhs, const T& rhs) {
complex<T> result = lhs;
return result *= rhs;
}
template <typename T>
constexpr complex<T> operator*(const T& lhs, const complex<T>& rhs) {
complex<T> result = rhs;
return result *= lhs;
}
template <typename T>
constexpr complex<T> operator/(const complex<T>& lhs, const complex<T>& rhs) {
complex<T> result = lhs;
return result /= rhs;
}
template <typename T>
constexpr complex<T> operator/(const complex<T>& lhs, const T& rhs) {
complex<T> result = lhs;
return result /= rhs;
}
template <typename T>
constexpr complex<T> operator/(const T& lhs, const complex<T>& rhs) {
complex<T> result(lhs, T());
return result /= rhs;
}
// Define operators between integral scalars and c10::complex. std::complex does
// not support this when T is a floating-point number. This is useful because it
// saves a lot of "static_cast" when operate a complex and an integer. This
// makes the code both less verbose and potentially more efficient.
#define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \
typename std::enable_if_t< \
std::is_floating_point_v<fT> && std::is_integral_v<iT>, \
int> = 0
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
constexpr c10::complex<fT> operator+(const c10::complex<fT>& a, const iT& b) {
return a + static_cast<fT>(b);
}
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
constexpr c10::complex<fT> operator+(const iT& a, const c10::complex<fT>& b) {
return static_cast<fT>(a) + b;
}
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
constexpr c10::complex<fT> operator-(const c10::complex<fT>& a, const iT& b) {
return a - static_cast<fT>(b);
}
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
constexpr c10::complex<fT> operator-(const iT& a, const c10::complex<fT>& b) {
return static_cast<fT>(a) - b;
}
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
constexpr c10::complex<fT> operator*(const c10::complex<fT>& a, const iT& b) {
return a * static_cast<fT>(b);
}
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
constexpr c10::complex<fT> operator*(const iT& a, const c10::complex<fT>& b) {
return static_cast<fT>(a) * b;
}
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
constexpr c10::complex<fT> operator/(const c10::complex<fT>& a, const iT& b) {
return a / static_cast<fT>(b);
}
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
constexpr c10::complex<fT> operator/(const iT& a, const c10::complex<fT>& b) {
return static_cast<fT>(a) / b;
}
#undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION
template <typename T>
constexpr bool operator==(const complex<T>& lhs, const complex<T>& rhs) {
return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag());
}
template <typename T>
constexpr bool operator==(const complex<T>& lhs, const T& rhs) {
return (lhs.real() == rhs) && (lhs.imag() == T());
}
template <typename T>
constexpr bool operator==(const T& lhs, const complex<T>& rhs) {
return (lhs == rhs.real()) && (T() == rhs.imag());
}
template <typename T>
constexpr bool operator!=(const complex<T>& lhs, const complex<T>& rhs) {
return !(lhs == rhs);
}
template <typename T>
constexpr bool operator!=(const complex<T>& lhs, const T& rhs) {
return !(lhs == rhs);
}
template <typename T>
constexpr bool operator!=(const T& lhs, const complex<T>& rhs) {
return !(lhs == rhs);
}
template <typename T, typename CharT, typename Traits>
std::basic_ostream<CharT, Traits>& operator<<(
std::basic_ostream<CharT, Traits>& os,
const complex<T>& x) {
return (os << static_cast<std::complex<T>>(x));
}
template <typename T, typename CharT, typename Traits>
std::basic_istream<CharT, Traits>& operator>>(
std::basic_istream<CharT, Traits>& is,
complex<T>& x) {
std::complex<T> tmp;
is >> tmp;
x = tmp;
return is;
}
} // namespace c10
// std functions
//
// The implementation of these functions also follow the design of C++20
namespace std {
template <typename T>
constexpr T real(const c10::complex<T>& z) {
return z.real();
}
template <typename T>
constexpr T imag(const c10::complex<T>& z) {
return z.imag();
}
template <typename T>
C10_HOST_DEVICE T abs(const c10::complex<T>& z) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return thrust::abs(static_cast<thrust::complex<T>>(z));
#else
return std::abs(static_cast<std::complex<T>>(z));
#endif
}
#if defined(USE_ROCM)
#define ROCm_Bug(x)
#else
#define ROCm_Bug(x) x
#endif
template <typename T>
C10_HOST_DEVICE T arg(const c10::complex<T>& z) {
return ROCm_Bug(std)::atan2(std::imag(z), std::real(z));
}
#undef ROCm_Bug
template <typename T>
constexpr T norm(const c10::complex<T>& z) {
return z.real() * z.real() + z.imag() * z.imag();
}
// For std::conj, there are other versions of it:
// constexpr std::complex<float> conj( float z );
// template< class DoubleOrInteger >
// constexpr std::complex<double> conj( DoubleOrInteger z );
// constexpr std::complex<long double> conj( long double z );
// These are not implemented
// TODO(@zasdfgbnm): implement them as c10::conj
template <typename T>
constexpr c10::complex<T> conj(const c10::complex<T>& z) {
return c10::complex<T>(z.real(), -z.imag());
}
// Thrust does not have complex --> complex version of thrust::proj,
// so this function is not implemented at c10 right now.
// TODO(@zasdfgbnm): implement it by ourselves
// There is no c10 version of std::polar, because std::polar always
// returns std::complex. Use c10::polar instead;
} // namespace std
namespace c10 {
template <typename T>
C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<complex<T>>(thrust::polar(r, theta));
#else
// std::polar() requires r >= 0, so spell out the explicit implementation to
// avoid a branch.
return complex<T>(r * std::cos(theta), r * std::sin(theta));
#endif
}
} // namespace c10
C10_CLANG_DIAGNOSTIC_POP()
#define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
// math functions are included in a separate file
#include <c10/util/complex_math.h> // IWYU pragma: keep
// utilities for complex types
#include <c10/util/complex_utils.h> // IWYU pragma: keep
#undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H

View File

@ -0,0 +1,406 @@
#if !defined(C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H)
#error \
"c10/util/complex_math.h is not meant to be individually included. Include c10/util/complex.h instead."
#endif
namespace c10_complex_math {
// Exponential functions
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> exp(const c10::complex<T>& x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::exp(static_cast<thrust::complex<T>>(x)));
#else
return static_cast<c10::complex<T>>(
std::exp(static_cast<std::complex<T>>(x)));
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> log(const c10::complex<T>& x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::log(static_cast<thrust::complex<T>>(x)));
#else
return static_cast<c10::complex<T>>(
std::log(static_cast<std::complex<T>>(x)));
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> log10(const c10::complex<T>& x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::log10(static_cast<thrust::complex<T>>(x)));
#else
return static_cast<c10::complex<T>>(
std::log10(static_cast<std::complex<T>>(x)));
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> log2(const c10::complex<T>& x) {
const c10::complex<T> log2 = c10::complex<T>(::log(2.0), 0.0);
return c10_complex_math::log(x) / log2;
}
// Power functions
//
#if defined(_LIBCPP_VERSION) || \
(defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX))
namespace _detail {
C10_API c10::complex<float> sqrt(const c10::complex<float>& in);
C10_API c10::complex<double> sqrt(const c10::complex<double>& in);
C10_API c10::complex<float> acos(const c10::complex<float>& in);
C10_API c10::complex<double> acos(const c10::complex<double>& in);
} // namespace _detail
#endif
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> sqrt(const c10::complex<T>& x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::sqrt(static_cast<thrust::complex<T>>(x)));
#elif !( \
defined(_LIBCPP_VERSION) || \
(defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX)))
return static_cast<c10::complex<T>>(
std::sqrt(static_cast<std::complex<T>>(x)));
#else
return _detail::sqrt(x);
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> pow(
const c10::complex<T>& x,
const c10::complex<T>& y) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(thrust::pow(
static_cast<thrust::complex<T>>(x), static_cast<thrust::complex<T>>(y)));
#else
return static_cast<c10::complex<T>>(std::pow(
static_cast<std::complex<T>>(x), static_cast<std::complex<T>>(y)));
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> pow(
const c10::complex<T>& x,
const T& y) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::pow(static_cast<thrust::complex<T>>(x), y));
#else
return static_cast<c10::complex<T>>(
std::pow(static_cast<std::complex<T>>(x), y));
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> pow(
const T& x,
const c10::complex<T>& y) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::pow(x, static_cast<thrust::complex<T>>(y)));
#else
return static_cast<c10::complex<T>>(
std::pow(x, static_cast<std::complex<T>>(y)));
#endif
}
template <typename T, typename U>
C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
const c10::complex<T>& x,
const c10::complex<U>& y) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(thrust::pow(
static_cast<thrust::complex<T>>(x), static_cast<thrust::complex<T>>(y)));
#else
return static_cast<c10::complex<T>>(std::pow(
static_cast<std::complex<T>>(x), static_cast<std::complex<T>>(y)));
#endif
}
template <typename T, typename U>
C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
const c10::complex<T>& x,
const U& y) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::pow(static_cast<thrust::complex<T>>(x), y));
#else
return static_cast<c10::complex<T>>(
std::pow(static_cast<std::complex<T>>(x), y));
#endif
}
template <typename T, typename U>
C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
const T& x,
const c10::complex<U>& y) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::pow(x, static_cast<thrust::complex<T>>(y)));
#else
return static_cast<c10::complex<T>>(
std::pow(x, static_cast<std::complex<T>>(y)));
#endif
}
// Trigonometric functions
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> sin(const c10::complex<T>& x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::sin(static_cast<thrust::complex<T>>(x)));
#else
return static_cast<c10::complex<T>>(
std::sin(static_cast<std::complex<T>>(x)));
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> cos(const c10::complex<T>& x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::cos(static_cast<thrust::complex<T>>(x)));
#else
return static_cast<c10::complex<T>>(
std::cos(static_cast<std::complex<T>>(x)));
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> tan(const c10::complex<T>& x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::tan(static_cast<thrust::complex<T>>(x)));
#else
return static_cast<c10::complex<T>>(
std::tan(static_cast<std::complex<T>>(x)));
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> asin(const c10::complex<T>& x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::asin(static_cast<thrust::complex<T>>(x)));
#else
return static_cast<c10::complex<T>>(
std::asin(static_cast<std::complex<T>>(x)));
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> acos(const c10::complex<T>& x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::acos(static_cast<thrust::complex<T>>(x)));
#elif !defined(_LIBCPP_VERSION)
return static_cast<c10::complex<T>>(
std::acos(static_cast<std::complex<T>>(x)));
#else
return _detail::acos(x);
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> atan(const c10::complex<T>& x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::atan(static_cast<thrust::complex<T>>(x)));
#else
return static_cast<c10::complex<T>>(
std::atan(static_cast<std::complex<T>>(x)));
#endif
}
// Hyperbolic functions
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> sinh(const c10::complex<T>& x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::sinh(static_cast<thrust::complex<T>>(x)));
#else
return static_cast<c10::complex<T>>(
std::sinh(static_cast<std::complex<T>>(x)));
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> cosh(const c10::complex<T>& x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::cosh(static_cast<thrust::complex<T>>(x)));
#else
return static_cast<c10::complex<T>>(
std::cosh(static_cast<std::complex<T>>(x)));
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> tanh(const c10::complex<T>& x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::tanh(static_cast<thrust::complex<T>>(x)));
#else
return static_cast<c10::complex<T>>(
std::tanh(static_cast<std::complex<T>>(x)));
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> asinh(const c10::complex<T>& x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::asinh(static_cast<thrust::complex<T>>(x)));
#else
return static_cast<c10::complex<T>>(
std::asinh(static_cast<std::complex<T>>(x)));
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> acosh(const c10::complex<T>& x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::acosh(static_cast<thrust::complex<T>>(x)));
#else
return static_cast<c10::complex<T>>(
std::acosh(static_cast<std::complex<T>>(x)));
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> atanh(const c10::complex<T>& x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(
thrust::atanh(static_cast<thrust::complex<T>>(x)));
#else
return static_cast<c10::complex<T>>(
std::atanh(static_cast<std::complex<T>>(x)));
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> log1p(const c10::complex<T>& z) {
#if defined(__APPLE__) || defined(__MACOSX) || defined(__CUDACC__) || \
defined(__HIPCC__)
// For Mac, the new implementation yielded a high relative error. Falling back
// to the old version for now.
// See https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354
// For CUDA we also use this one, as thrust::log(thrust::complex) takes
// *forever* to compile
// log1p(z) = log(1 + z)
// Let's define 1 + z = r * e ^ (i * a), then we have
// log(r * e ^ (i * a)) = log(r) + i * a
// With z = x + iy, the term r can be written as
// r = ((1 + x) ^ 2 + y ^ 2) ^ 0.5
// = (1 + x ^ 2 + 2 * x + y ^ 2) ^ 0.5
// So, log(r) is
// log(r) = 0.5 * log(1 + x ^ 2 + 2 * x + y ^ 2)
// = 0.5 * log1p(x * (x + 2) + y ^ 2)
// we need to use the expression only on certain condition to avoid overflow
// and underflow from `(x * (x + 2) + y ^ 2)`
T x = z.real();
T y = z.imag();
T zabs = std::abs(z);
T theta = std::atan2(y, x + T(1));
if (zabs < 0.5) {
T r = x * (T(2) + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {T(0.5) * std::log1p(r), theta};
} else {
T z0 = std::hypot(x + 1, y);
return {std::log(z0), theta};
}
#else
// CPU path
// Based on https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354
c10::complex<T> u = z + T(1);
if (u == T(1)) {
return z;
} else {
auto log_u = log(u);
if (u - T(1) == z) {
return log_u;
}
return log_u * (z / (u - T(1)));
}
#endif
}
template <typename T>
C10_HOST_DEVICE inline c10::complex<T> expm1(const c10::complex<T>& z) {
// expm1(z) = exp(z) - 1
// Define z = x + i * y
// f = e ^ (x + i * y) - 1
// = e ^ x * e ^ (i * y) - 1
// = (e ^ x * cos(y) - 1) + i * (e ^ x * sin(y))
// = (e ^ x - 1) * cos(y) - (1 - cos(y)) + i * e ^ x * sin(y)
// = expm1(x) * cos(y) - 2 * sin(y / 2) ^ 2 + i * e ^ x * sin(y)
T x = z.real();
T y = z.imag();
T a = std::sin(y / 2);
T er = std::expm1(x) * std::cos(y) - T(2) * a * a;
T ei = std::exp(x) * std::sin(y);
return {er, ei};
}
} // namespace c10_complex_math
using c10_complex_math::acos;
using c10_complex_math::acosh;
using c10_complex_math::asin;
using c10_complex_math::asinh;
using c10_complex_math::atan;
using c10_complex_math::atanh;
using c10_complex_math::cos;
using c10_complex_math::cosh;
using c10_complex_math::exp;
using c10_complex_math::expm1;
using c10_complex_math::log;
using c10_complex_math::log10;
using c10_complex_math::log1p;
using c10_complex_math::log2;
using c10_complex_math::pow;
using c10_complex_math::sin;
using c10_complex_math::sinh;
using c10_complex_math::sqrt;
using c10_complex_math::tan;
using c10_complex_math::tanh;
namespace std {
using c10_complex_math::acos;
using c10_complex_math::acosh;
using c10_complex_math::asin;
using c10_complex_math::asinh;
using c10_complex_math::atan;
using c10_complex_math::atanh;
using c10_complex_math::cos;
using c10_complex_math::cosh;
using c10_complex_math::exp;
using c10_complex_math::expm1;
using c10_complex_math::log;
using c10_complex_math::log10;
using c10_complex_math::log1p;
using c10_complex_math::log2;
using c10_complex_math::pow;
using c10_complex_math::sin;
using c10_complex_math::sinh;
using c10_complex_math::sqrt;
using c10_complex_math::tan;
using c10_complex_math::tanh;
} // namespace std

View File

@ -0,0 +1,46 @@
#if !defined(C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H)
#error \
"c10/util/complex_utils.h is not meant to be individually included. Include c10/util/complex.h instead."
#endif
#include <limits>
namespace c10 {
template <typename T>
struct is_complex : public std::false_type {};
template <typename T>
struct is_complex<std::complex<T>> : public std::true_type {};
template <typename T>
struct is_complex<c10::complex<T>> : public std::true_type {};
// Extract double from std::complex<double>; is identity otherwise
// TODO: Write in more idiomatic C++17
template <typename T>
struct scalar_value_type {
using type = T;
};
template <typename T>
struct scalar_value_type<std::complex<T>> {
using type = T;
};
template <typename T>
struct scalar_value_type<c10::complex<T>> {
using type = T;
};
} // namespace c10
namespace std {
template <typename T>
class numeric_limits<c10::complex<T>> : public numeric_limits<T> {};
template <typename T>
bool isnan(const c10::complex<T>& v) {
return std::isnan(v.real()) || std::isnan(v.imag());
}
} // namespace std

View File

@ -0,0 +1,27 @@
#pragma once
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
namespace c10 {
// Note: Explicit implementation of copysign for Half and BFloat16
// is needed to workaround g++-7/8 crash on aarch64, but also makes
// copysign faster for the half-precision types
template <typename T, typename U>
inline auto copysign(const T& a, const U& b) {
return std::copysign(a, b);
}
// Implement copysign for half precision floats using bit ops
// Sign is the most significant bit for both half and bfloat16 types
inline c10::Half copysign(c10::Half a, c10::Half b) {
return c10::Half((a.x & 0x7fff) | (b.x & 0x8000), c10::Half::from_bits());
}
inline c10::BFloat16 copysign(c10::BFloat16 a, c10::BFloat16 b) {
return c10::BFloat16(
(a.x & 0x7fff) | (b.x & 0x8000), c10::BFloat16::from_bits());
}
} // namespace c10

View File

@ -0,0 +1,41 @@
#pragma once
#include <c10/util/Exception.h>
#include <cstdlib>
#include <cstring>
#include <optional>
namespace c10::utils {
// Reads an environment variable and returns
// - std::optional<true>, if set equal to "1"
// - std::optional<false>, if set equal to "0"
// - nullopt, otherwise
//
// NB:
// Issues a warning if the value of the environment variable is not 0 or 1.
inline std::optional<bool> check_env(const char* name) {
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4996)
#endif
auto envar = std::getenv(name);
#ifdef _MSC_VER
#pragma warning(pop)
#endif
if (envar) {
if (strcmp(envar, "0") == 0) {
return false;
}
if (strcmp(envar, "1") == 0) {
return true;
}
TORCH_WARN(
"Ignoring invalid value for boolean flag ",
name,
": ",
envar,
"valid values are 0 or 1.");
}
return std::nullopt;
}
} // namespace c10::utils

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,40 @@
#pragma once
#include <c10/macros/Macros.h>
#include <cstdint>
namespace c10::detail {
C10_HOST_DEVICE inline float fp32_from_bits(uint32_t w) {
#if defined(__OPENCL_VERSION__)
return as_float(w);
#elif defined(__CUDA_ARCH__)
return __uint_as_float((unsigned int)w);
#elif defined(__INTEL_COMPILER)
return _castu32_f32(w);
#else
union {
uint32_t as_bits;
float as_value;
} fp32 = {w};
return fp32.as_value;
#endif
}
C10_HOST_DEVICE inline uint32_t fp32_to_bits(float f) {
#if defined(__OPENCL_VERSION__)
return as_uint(f);
#elif defined(__CUDA_ARCH__)
return (uint32_t)__float_as_uint(f);
#elif defined(__INTEL_COMPILER)
return _castf32_u32(f);
#else
union {
float as_value;
uint32_t as_bits;
} fp32 = {f};
return fp32.as_bits;
#endif
}
} // namespace c10::detail

View File

@ -0,0 +1,72 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/TypeSafeSignMath.h>
#include <cmath>
#if defined(__CUDA_ARCH__)
#include <c10/cuda/CUDAMathCompat.h>
#define C10_COMPAT_COPYSIGN c10::cuda::compat::copysign
#elif defined(__HIPCC__)
#include <c10/hip/HIPMathCompat.h>
#define C10_COMPAT_COPYSIGN c10::hip::compat::copysign
#else
#include <c10/util/copysign.h>
#define C10_COMPAT_COPYSIGN c10::copysign
#endif
// The functions in this file should be header-only as it is used under
// ABI-compatibility mode.
namespace c10 {
// NOTE: [Floor Division in Python]
// Python's __floordiv__ operator is more complicated than just floor(a / b).
// It aims to maintain the property: a == (a // b) * b + remainder(a, b)
// which can otherwise fail due to rounding errors in the remainder.
// So, instead it is calculated as: a // b = (a - remainder(a, b)) / b
// With some additional fix-ups added to the result.
//
// For reference, see CPython's implementation:
// https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
template <typename scalar_t>
inline C10_HOST_DEVICE scalar_t div_floor_floating(scalar_t a, scalar_t b)
__ubsan_ignore_float_divide_by_zero__ {
if (C10_UNLIKELY(b == 0)) {
// Divide by zero: return standard IEEE result
return a / b;
}
auto mod = std::fmod(a, b);
auto div = (a - mod) / b;
if ((mod != 0) && (b < 0) != (mod < 0)) {
div -= scalar_t(1);
}
scalar_t floordiv;
if (div != 0) {
floordiv = std::floor(div);
if (div - floordiv > scalar_t(0.5)) {
floordiv += scalar_t(1.0);
}
} else {
floordiv = C10_COMPAT_COPYSIGN(scalar_t(0), a / b);
}
return floordiv;
}
template <typename scalar_t>
inline C10_HOST_DEVICE scalar_t div_floor_integer(scalar_t a, scalar_t b) {
if (c10::signs_differ(a, b)) {
// Subtracts one from the results of truncation division if the
// divisor and dividend have different sign(bit)s and the remainder of
// the division is nonzero
const auto quot = a / b;
const auto rem = a % b;
return rem ? quot - 1 : quot;
}
return a / b;
}
} // namespace c10

View File

@ -0,0 +1,379 @@
#pragma once
#include <c10/util/Exception.h>
#include <cstddef>
#include <functional>
#include <iomanip>
#include <ios>
#include <sstream>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
#include <c10/util/ArrayRef.h>
#include <c10/util/complex.h>
namespace c10 {
// NOTE: hash_combine and SHA1 hashing is based on implementation from Boost
//
// Boost Software License - Version 1.0 - August 17th, 2003
//
// Permission is hereby granted, free of charge, to any person or organization
// obtaining a copy of the software and accompanying documentation covered by
// this license (the "Software") to use, reproduce, display, distribute,
// execute, and transmit the Software, and to prepare derivative works of the
// Software, and to permit third-parties to whom the Software is furnished to
// do so, all subject to the following:
//
// The copyright notices in the Software and this entire statement, including
// the above license grant, this restriction and the following disclaimer,
// must be included in all copies of the Software, in whole or in part, and
// all derivative works of the Software, unless such copies or derivative
// works are solely in the form of machine-executable object code generated by
// a source language processor.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
// SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
// FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
inline size_t hash_combine(size_t seed, size_t value) {
return seed ^ (value + 0x9e3779b9 + (seed << 6u) + (seed >> 2u));
}
// Creates the SHA1 hash of a string. A 160-bit hash.
// Based on the implementation in Boost (see notice above).
// Note that SHA1 hashes are no longer considered cryptographically
// secure, but are the standard hash for generating unique ids.
// Usage:
// // Let 'code' be a std::string
// c10::sha1 sha1_hash{code};
// const auto hash_code = sha1_hash.str();
// TODO: Compare vs OpenSSL and/or CryptoPP implementations
struct sha1 {
typedef unsigned int(digest_type)[5];
sha1(const std::string& s = "") {
if (!s.empty()) {
reset();
process_bytes(s.c_str(), s.size());
}
}
void reset() {
h_[0] = 0x67452301;
h_[1] = 0xEFCDAB89;
h_[2] = 0x98BADCFE;
h_[3] = 0x10325476;
h_[4] = 0xC3D2E1F0;
block_byte_index_ = 0;
bit_count_low = 0;
bit_count_high = 0;
}
std::string str() {
unsigned int digest[5];
get_digest(digest);
std::ostringstream buf;
for (unsigned int i : digest) {
buf << std::hex << std::setfill('0') << std::setw(8) << i;
}
return buf.str();
}
private:
unsigned int left_rotate(unsigned int x, std::size_t n) {
return (x << n) ^ (x >> (32 - n));
}
void process_block_impl() {
unsigned int w[80];
for (std::size_t i = 0; i < 16; ++i) {
w[i] = (block_[i * 4 + 0] << 24);
w[i] |= (block_[i * 4 + 1] << 16);
w[i] |= (block_[i * 4 + 2] << 8);
w[i] |= (block_[i * 4 + 3]);
}
for (std::size_t i = 16; i < 80; ++i) {
w[i] = left_rotate((w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16]), 1);
}
unsigned int a = h_[0];
unsigned int b = h_[1];
unsigned int c = h_[2];
unsigned int d = h_[3];
unsigned int e = h_[4];
for (std::size_t i = 0; i < 80; ++i) {
unsigned int f = 0;
unsigned int k = 0;
if (i < 20) {
f = (b & c) | (~b & d);
k = 0x5A827999;
} else if (i < 40) {
f = b ^ c ^ d;
k = 0x6ED9EBA1;
} else if (i < 60) {
f = (b & c) | (b & d) | (c & d);
k = 0x8F1BBCDC;
} else {
f = b ^ c ^ d;
k = 0xCA62C1D6;
}
unsigned temp = left_rotate(a, 5) + f + e + k + w[i];
e = d;
d = c;
c = left_rotate(b, 30);
b = a;
a = temp;
}
h_[0] += a;
h_[1] += b;
h_[2] += c;
h_[3] += d;
h_[4] += e;
}
void process_byte_impl(unsigned char byte) {
block_[block_byte_index_++] = byte;
if (block_byte_index_ == 64) {
block_byte_index_ = 0;
process_block_impl();
}
}
void process_byte(unsigned char byte) {
process_byte_impl(byte);
// size_t max value = 0xFFFFFFFF
// if (bit_count_low + 8 >= 0x100000000) { // would overflow
// if (bit_count_low >= 0x100000000-8) {
if (bit_count_low < 0xFFFFFFF8) {
bit_count_low += 8;
} else {
bit_count_low = 0;
if (bit_count_high <= 0xFFFFFFFE) {
++bit_count_high;
} else {
TORCH_CHECK(false, "sha1 too many bytes");
}
}
}
void process_block(void const* bytes_begin, void const* bytes_end) {
unsigned char const* begin = static_cast<unsigned char const*>(bytes_begin);
unsigned char const* end = static_cast<unsigned char const*>(bytes_end);
for (; begin != end; ++begin) {
process_byte(*begin);
}
}
void process_bytes(void const* buffer, std::size_t byte_count) {
unsigned char const* b = static_cast<unsigned char const*>(buffer);
process_block(b, b + byte_count);
}
void get_digest(digest_type& digest) {
// append the bit '1' to the message
process_byte_impl(0x80);
// append k bits '0', where k is the minimum number >= 0
// such that the resulting message length is congruent to 56 (mod 64)
// check if there is enough space for padding and bit_count
if (block_byte_index_ > 56) {
// finish this block
while (block_byte_index_ != 0) {
process_byte_impl(0);
}
// one more block
while (block_byte_index_ < 56) {
process_byte_impl(0);
}
} else {
while (block_byte_index_ < 56) {
process_byte_impl(0);
}
}
// append length of message (before pre-processing)
// as a 64-bit big-endian integer
process_byte_impl(
static_cast<unsigned char>((bit_count_high >> 24) & 0xFF));
process_byte_impl(
static_cast<unsigned char>((bit_count_high >> 16) & 0xFF));
process_byte_impl(static_cast<unsigned char>((bit_count_high >> 8) & 0xFF));
process_byte_impl(static_cast<unsigned char>((bit_count_high) & 0xFF));
process_byte_impl(static_cast<unsigned char>((bit_count_low >> 24) & 0xFF));
process_byte_impl(static_cast<unsigned char>((bit_count_low >> 16) & 0xFF));
process_byte_impl(static_cast<unsigned char>((bit_count_low >> 8) & 0xFF));
process_byte_impl(static_cast<unsigned char>((bit_count_low) & 0xFF));
// get final digest
digest[0] = h_[0];
digest[1] = h_[1];
digest[2] = h_[2];
digest[3] = h_[3];
digest[4] = h_[4];
}
unsigned int h_[5]{};
unsigned char block_[64]{};
std::size_t block_byte_index_{};
std::size_t bit_count_low{};
std::size_t bit_count_high{};
};
constexpr uint64_t twang_mix64(uint64_t key) noexcept {
key = (~key) + (key << 21); // key *= (1 << 21) - 1; key -= 1;
key = key ^ (key >> 24);
key = key + (key << 3) + (key << 8); // key *= 1 + (1 << 3) + (1 << 8)
key = key ^ (key >> 14);
key = key + (key << 2) + (key << 4); // key *= 1 + (1 << 2) + (1 << 4)
key = key ^ (key >> 28);
key = key + (key << 31); // key *= 1 + (1 << 31)
return key;
}
////////////////////////////////////////////////////////////////////////////////
// c10::hash implementation
////////////////////////////////////////////////////////////////////////////////
namespace _hash_detail {
// Use template argument deduction to shorten calls to c10::hash
template <typename T>
size_t simple_get_hash(const T& o);
template <typename T, typename V>
using type_if_not_enum = std::enable_if_t<!std::is_enum_v<T>, V>;
// Use SFINAE to dispatch to std::hash if possible, cast enum types to int
// automatically, and fall back to T::hash otherwise. NOTE: C++14 added support
// for hashing enum types to the standard, and some compilers implement it even
// when C++14 flags aren't specified. This is why we have to disable this
// overload if T is an enum type (and use the one below in this case).
template <typename T>
auto dispatch_hash(const T& o)
-> decltype(std::hash<T>()(o), type_if_not_enum<T, size_t>()) {
return std::hash<T>()(o);
}
template <typename T>
std::enable_if_t<std::is_enum_v<T>, size_t> dispatch_hash(const T& o) {
using R = std::underlying_type_t<T>;
return std::hash<R>()(static_cast<R>(o));
}
template <typename T>
auto dispatch_hash(const T& o) -> decltype(T::hash(o), size_t()) {
return T::hash(o);
}
} // namespace _hash_detail
// Hasher struct
template <typename T>
struct hash {
size_t operator()(const T& o) const {
return _hash_detail::dispatch_hash(o);
};
};
// Specialization for std::tuple
template <typename... Types>
struct hash<std::tuple<Types...>> {
template <size_t idx, typename... Ts>
struct tuple_hash {
size_t operator()(const std::tuple<Ts...>& t) const {
return hash_combine(
_hash_detail::simple_get_hash(std::get<idx>(t)),
tuple_hash<idx - 1, Ts...>()(t));
}
};
template <typename... Ts>
struct tuple_hash<0, Ts...> {
size_t operator()(const std::tuple<Ts...>& t) const {
return _hash_detail::simple_get_hash(std::get<0>(t));
}
};
size_t operator()(const std::tuple<Types...>& t) const {
return tuple_hash<sizeof...(Types) - 1, Types...>()(t);
}
};
template <typename T1, typename T2>
struct hash<std::pair<T1, T2>> {
size_t operator()(const std::pair<T1, T2>& pair) const {
std::tuple<T1, T2> tuple = std::make_tuple(pair.first, pair.second);
return _hash_detail::simple_get_hash(tuple);
}
};
template <typename T>
struct hash<c10::ArrayRef<T>> {
size_t operator()(c10::ArrayRef<T> v) const {
size_t seed = 0;
for (const auto& elem : v) {
seed = hash_combine(seed, _hash_detail::simple_get_hash(elem));
}
return seed;
}
};
// Specialization for std::vector
template <typename T>
struct hash<std::vector<T>> {
size_t operator()(const std::vector<T>& v) const {
return hash<c10::ArrayRef<T>>()(v);
}
};
namespace _hash_detail {
template <typename T>
size_t simple_get_hash(const T& o) {
return c10::hash<T>()(o);
}
} // namespace _hash_detail
// Use this function to actually hash multiple things in one line.
// Dispatches to c10::hash, so it can hash containers.
// Example:
//
// static size_t hash(const MyStruct& s) {
// return get_hash(s.member1, s.member2, s.member3);
// }
template <typename... Types>
size_t get_hash(const Types&... args) {
return c10::hash<decltype(std::tie(args...))>()(std::tie(args...));
}
// Specialization for c10::complex
template <typename T>
struct hash<c10::complex<T>> {
size_t operator()(const c10::complex<T>& c) const {
return get_hash(c.real(), c.imag());
}
};
} // namespace c10

View File

@ -0,0 +1,398 @@
// This file is based on the uint128 implementation of protobuf at
// https://github.com/protocolbuffers/protobuf/blob/1e88936fce10cf773cb72b44c6a7f48b38c7578b/src/google/protobuf/stubs/int128.h
//
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <c10/macros/Export.h>
#include <cstdint>
#include <iosfwd>
namespace c10 {
struct uint128_pod;
// TODO(xiaofeng): Define GOOGLE_PROTOBUF_HAS_CONSTEXPR when constexpr is
// available.
#ifdef GOOGLE_PROTOBUF_HAS_CONSTEXPR
#define UINT128_CONSTEXPR constexpr
#else
#define UINT128_CONSTEXPR
#endif
class uint128;
inline uint128& operator<<=(uint128& self, int amount);
// An unsigned 128-bit integer type. Thread-compatible.
class C10_API uint128 {
public:
UINT128_CONSTEXPR uint128(); // Sets to 0, but don't trust on this behavior.
UINT128_CONSTEXPR uint128(uint64_t top, uint64_t bottom);
#ifndef SWIG
UINT128_CONSTEXPR uint128(int bottom);
UINT128_CONSTEXPR uint128(uint32_t bottom); // Top 96 bits = 0
#endif
UINT128_CONSTEXPR uint128(uint64_t bottom); // hi_ = 0
UINT128_CONSTEXPR uint128(const uint128_pod& val);
// Trivial copy constructor, assignment operator and destructor.
void Initialize(uint64_t top, uint64_t bottom);
// Arithmetic operators.
uint128& operator+=(const uint128& b);
uint128& operator-=(const uint128& b);
uint128& operator*=(const uint128& b);
// Long division/modulo for uint128.
uint128& operator/=(const uint128& b);
uint128& operator%=(const uint128& b);
uint128 operator++(int);
uint128 operator--(int);
// Make msvc happy with using operator<<= from DivModImpl
// which is a static function, and linker complained about missing
// static version of this overload
friend uint128& operator<<=(uint128&, int);
uint128& operator>>=(int);
uint128& operator&=(const uint128& b);
uint128& operator|=(const uint128& b);
uint128& operator^=(const uint128& b);
uint128& operator++();
uint128& operator--();
friend uint64_t Uint128Low64(const uint128& v);
friend uint64_t Uint128High64(const uint128& v);
// We add "std::" to avoid including all of port.h.
C10_API friend std::ostream& operator<<(std::ostream& o, const uint128& b);
private:
static void DivModImpl(
uint128 dividend,
uint128 divisor,
uint128* quotient_ret,
uint128* remainder_ret);
// Little-endian memory order optimizations can benefit from
// having lo_ first, hi_ last.
// See util/endian/endian.h and Load128/Store128 for storing a uint128.
uint64_t lo_;
uint64_t hi_;
// Not implemented, just declared for catching automatic type conversions.
uint128(uint8_t);
uint128(uint16_t);
uint128(float v);
uint128(double v);
};
// This is a POD form of uint128 which can be used for static variables which
// need to be operated on as uint128.
struct uint128_pod {
// Note: The ordering of fields is different than 'class uint128' but the
// same as its 2-arg constructor. This enables more obvious initialization
// of static instances, which is the primary reason for this struct in the
// first place. This does not seem to defeat any optimizations wrt
// operations involving this struct.
uint64_t hi;
uint64_t lo;
};
C10_API extern const uint128_pod kuint128max;
// allow uint128 to be logged
C10_API extern std::ostream& operator<<(std::ostream& o, const uint128& b);
// Methods to access low and high pieces of 128-bit value.
// Defined externally from uint128 to facilitate conversion
// to native 128-bit types when compilers support them.
inline uint64_t Uint128Low64(const uint128& v) {
return v.lo_;
}
inline uint64_t Uint128High64(const uint128& v) {
return v.hi_;
}
// TODO: perhaps it would be nice to have int128, a signed 128-bit type?
// --------------------------------------------------------------------------
// Implementation details follow
// --------------------------------------------------------------------------
inline bool operator==(const uint128& lhs, const uint128& rhs) {
return (
Uint128Low64(lhs) == Uint128Low64(rhs) &&
Uint128High64(lhs) == Uint128High64(rhs));
}
inline bool operator!=(const uint128& lhs, const uint128& rhs) {
return !(lhs == rhs);
}
C10_API inline UINT128_CONSTEXPR uint128::uint128() : lo_(0), hi_(0) {}
C10_API inline UINT128_CONSTEXPR uint128::uint128(uint64_t top, uint64_t bottom)
: lo_(bottom), hi_(top) {}
C10_API inline UINT128_CONSTEXPR uint128::uint128(const uint128_pod& v)
: lo_(v.lo), hi_(v.hi) {}
C10_API inline UINT128_CONSTEXPR uint128::uint128(uint64_t bottom)
: lo_(bottom), hi_(0) {}
#ifndef SWIG
C10_API inline UINT128_CONSTEXPR uint128::uint128(uint32_t bottom)
: lo_(bottom), hi_(0) {}
C10_API inline UINT128_CONSTEXPR uint128::uint128(int bottom)
: lo_(bottom), hi_(static_cast<int64_t>((bottom < 0) ? -1 : 0)) {}
#endif
#undef UINT128_CONSTEXPR
C10_API inline void uint128::Initialize(uint64_t top, uint64_t bottom) {
hi_ = top;
lo_ = bottom;
}
// Comparison operators.
#define CMP128(op) \
inline bool operator op(const uint128& lhs, const uint128& rhs) { \
return (Uint128High64(lhs) == Uint128High64(rhs)) \
? (Uint128Low64(lhs) op Uint128Low64(rhs)) \
: (Uint128High64(lhs) op Uint128High64(rhs)); \
}
CMP128(<)
CMP128(>)
CMP128(>=)
CMP128(<=)
#undef CMP128
// Unary operators
inline uint128 operator-(const uint128& val) {
const uint64_t hi_flip = ~Uint128High64(val);
const uint64_t lo_flip = ~Uint128Low64(val);
const uint64_t lo_add = lo_flip + 1;
if (lo_add < lo_flip) {
return uint128(hi_flip + 1, lo_add);
}
return uint128(hi_flip, lo_add);
}
inline bool operator!(const uint128& val) {
return !Uint128High64(val) && !Uint128Low64(val);
}
// Logical operators.
inline uint128 operator~(const uint128& val) {
return uint128(~Uint128High64(val), ~Uint128Low64(val));
}
#define LOGIC128(op) \
inline uint128 operator op(const uint128& lhs, const uint128& rhs) { \
return uint128( \
Uint128High64(lhs) op Uint128High64(rhs), \
Uint128Low64(lhs) op Uint128Low64(rhs)); \
}
LOGIC128(|)
LOGIC128(&)
LOGIC128(^)
#undef LOGIC128
#define LOGICASSIGN128(op) \
C10_API inline uint128& uint128::operator op(const uint128 & other) { \
hi_ op other.hi_; \
lo_ op other.lo_; \
return *this; \
}
LOGICASSIGN128(|=)
LOGICASSIGN128(&=)
LOGICASSIGN128(^=)
#undef LOGICASSIGN128
// Shift operators.
inline uint128 operator<<(const uint128& val, int amount) {
// uint64_t shifts of >= 64 are undefined, so we will need some
// special-casing.
if (amount < 64) {
if (amount == 0) {
return val;
}
uint64_t new_hi =
(Uint128High64(val) << amount) | (Uint128Low64(val) >> (64 - amount));
uint64_t new_lo = Uint128Low64(val) << amount;
return uint128(new_hi, new_lo);
} else if (amount < 128) {
return uint128(Uint128Low64(val) << (amount - 64), 0);
} else {
return uint128(0, 0);
}
}
inline uint128 operator>>(const uint128& val, int amount) {
// uint64_t shifts of >= 64 are undefined, so we will need some
// special-casing.
if (amount < 64) {
if (amount == 0) {
return val;
}
uint64_t new_hi = Uint128High64(val) >> amount;
uint64_t new_lo =
(Uint128Low64(val) >> amount) | (Uint128High64(val) << (64 - amount));
return uint128(new_hi, new_lo);
} else if (amount < 128) {
return uint128(0, Uint128High64(val) >> (amount - 64));
} else {
return uint128(0, 0);
}
}
inline uint128& operator<<=(uint128& self, int amount) {
// uint64_t shifts of >= 64 are undefined, so we will need some
// special-casing.
if (amount < 64) {
if (amount != 0) {
self.hi_ = (self.hi_ << amount) | (self.lo_ >> (64 - amount));
self.lo_ = self.lo_ << amount;
}
} else if (amount < 128) {
self.hi_ = self.lo_ << (amount - 64);
self.lo_ = 0;
} else {
self.hi_ = 0;
self.lo_ = 0;
}
return self;
}
C10_API inline uint128& uint128::operator>>=(int amount) {
// uint64_t shifts of >= 64 are undefined, so we will need some
// special-casing.
if (amount < 64) {
if (amount != 0) {
lo_ = (lo_ >> amount) | (hi_ << (64 - amount));
hi_ = hi_ >> amount;
}
} else if (amount < 128) {
lo_ = hi_ >> (amount - 64);
hi_ = 0;
} else {
lo_ = 0;
hi_ = 0;
}
return *this;
}
inline uint128 operator+(const uint128& lhs, const uint128& rhs) {
return uint128(lhs) += rhs;
}
inline uint128 operator-(const uint128& lhs, const uint128& rhs) {
return uint128(lhs) -= rhs;
}
inline uint128 operator*(const uint128& lhs, const uint128& rhs) {
return uint128(lhs) *= rhs;
}
inline uint128 operator/(const uint128& lhs, const uint128& rhs) {
return uint128(lhs) /= rhs;
}
inline uint128 operator%(const uint128& lhs, const uint128& rhs) {
return uint128(lhs) %= rhs;
}
C10_API inline uint128& uint128::operator+=(const uint128& b) {
hi_ += b.hi_;
uint64_t lolo = lo_ + b.lo_;
if (lolo < lo_)
++hi_;
lo_ = lolo;
return *this;
}
C10_API inline uint128& uint128::operator-=(const uint128& b) {
hi_ -= b.hi_;
if (b.lo_ > lo_)
--hi_;
lo_ -= b.lo_;
return *this;
}
C10_API inline uint128& uint128::operator*=(const uint128& b) {
uint64_t a96 = hi_ >> 32;
uint64_t a64 = hi_ & 0xffffffffu;
uint64_t a32 = lo_ >> 32;
uint64_t a00 = lo_ & 0xffffffffu;
uint64_t b96 = b.hi_ >> 32;
uint64_t b64 = b.hi_ & 0xffffffffu;
uint64_t b32 = b.lo_ >> 32;
uint64_t b00 = b.lo_ & 0xffffffffu;
// multiply [a96 .. a00] x [b96 .. b00]
// terms higher than c96 disappear off the high side
// terms c96 and c64 are safe to ignore carry bit
uint64_t c96 = a96 * b00 + a64 * b32 + a32 * b64 + a00 * b96;
uint64_t c64 = a64 * b00 + a32 * b32 + a00 * b64;
this->hi_ = (c96 << 32) + c64;
this->lo_ = 0;
// add terms after this one at a time to capture carry
*this += uint128(a32 * b00) << 32;
*this += uint128(a00 * b32) << 32;
*this += a00 * b00;
return *this;
}
C10_API inline uint128 uint128::operator++(int) {
uint128 tmp(*this);
*this += 1;
return tmp;
}
C10_API inline uint128 uint128::operator--(int) {
uint128 tmp(*this);
*this -= 1;
return tmp;
}
C10_API inline uint128& uint128::operator++() {
*this += 1;
return *this;
}
C10_API inline uint128& uint128::operator--() {
*this -= 1;
return *this;
}
} // namespace c10

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,124 @@
// Copyright 2004-present Facebook. All Rights Reserved.
#pragma once
#include <c10/util/Exception.h>
#include <c10/util/TypeSafeSignMath.h>
#include <algorithm>
#include <cstddef>
#include <iterator>
#include <type_traits>
namespace c10 {
namespace detail {
template <
typename I,
bool one_sided = false,
std::enable_if_t<std::is_integral_v<I>, int> = 0>
struct integer_iterator {
using iterator_category = std::input_iterator_tag;
using value_type = I;
using difference_type = std::ptrdiff_t;
using pointer = I*;
using reference = I&;
explicit integer_iterator(I value) : value(value) {}
I operator*() const {
return value;
}
I const* operator->() const {
return &value;
}
integer_iterator& operator++() {
++value;
return *this;
}
integer_iterator operator++(int) {
const auto copy = *this;
++*this;
return copy;
}
bool operator==(const integer_iterator& other) const {
if constexpr (one_sided) {
// Range-for loops' end test is `begin != end`, not `begin <
// end`. To handle `c10::irange(n)` where n < 0 (which should be
// empty), we just make `begin != end` fail whenever `end` is
// negative.
return is_negative(other.value) || value == other.value;
} else {
return value == other.value;
}
// Suppress "warning: missing return statement at end of non-void function"
// which Nvidia's Robert Crovella confirms is an NVCC compiler error
// here https://stackoverflow.com/a/64561686/752843 on 2020-10-27
// `__builtin_unreachable();` would be best here, but it's not
// available with all compilers. So we instead return an arbitrary
// value trusting that this line will, in fact, never be reached.
return false; // Horrible hack
}
bool operator!=(const integer_iterator& other) const {
return !(*this == other);
}
protected:
I value;
};
} // namespace detail
template <
typename I,
bool one_sided = false,
std::enable_if_t<std::is_integral_v<I>, bool> = true>
struct integer_range {
public:
integer_range(I begin, I end) : begin_(begin), end_(end) {}
using iterator = detail::integer_iterator<I, one_sided>;
iterator begin() const {
return begin_;
}
iterator end() const {
return end_;
}
private:
iterator begin_;
iterator end_;
};
/// Creates an integer range for the half-open interval [begin, end)
/// If end<=begin, then the range is empty.
/// The range has the type of the `end` integer; `begin` integer is
/// cast to this type.
template <
typename Integer1,
typename Integer2,
std::enable_if_t<std::is_integral_v<Integer1>, bool> = true,
std::enable_if_t<std::is_integral_v<Integer2>, bool> = true>
integer_range<Integer2> irange(Integer1 begin, Integer2 end) {
// If end<=begin then the range is empty; we can achieve this effect by
// choosing the larger of {begin, end} as the loop terminator
return {
static_cast<Integer2>(begin),
std::max(static_cast<Integer2>(begin), end)};
}
/// Creates an integer range for the half-open interval [0, end)
/// If end<=begin, then the range is empty
template <
typename Integer,
std::enable_if_t<std::is_integral_v<Integer>, bool> = true>
integer_range<Integer, true> irange(Integer end) {
return {Integer(), end};
}
} // namespace c10

View File

@ -0,0 +1,906 @@
//===-- llvm/Support/MathExtras.h - Useful math functions -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains some functions that are useful for math stuff.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <c10/util/bit_cast.h>
#include <algorithm>
#include <cassert>
#include <climits>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <limits>
#include <type_traits>
#ifdef __ANDROID_NDK__
#include <android/api-level.h>
#endif
#ifndef __has_builtin
#define __has_builtin(x) 0
#endif
#ifndef LLVM_GNUC_PREREQ
#if defined(__GNUC__) && defined(__GNUC_MINOR__) && defined(__GNUC_PATCHLEVEL__)
#define LLVM_GNUC_PREREQ(maj, min, patch) \
((__GNUC__ << 20) + (__GNUC_MINOR__ << 10) + __GNUC_PATCHLEVEL__ >= \
((maj) << 20) + ((min) << 10) + (patch))
#elif defined(__GNUC__) && defined(__GNUC_MINOR__)
#define LLVM_GNUC_PREREQ(maj, min, patch) \
((__GNUC__ << 20) + (__GNUC_MINOR__ << 10) >= ((maj) << 20) + ((min) << 10))
#else
#define LLVM_GNUC_PREREQ(maj, min, patch) 0
#endif
#endif
#ifdef _MSC_VER
// Declare these intrinsics manually rather including intrin.h. It's very
// expensive, and MathExtras.h is popular.
// #include <intrin.h>
extern "C" {
unsigned char _BitScanForward(unsigned long* _Index, unsigned long _Mask);
unsigned char _BitScanForward64(unsigned long* _Index, unsigned __int64 _Mask);
unsigned char _BitScanReverse(unsigned long* _Index, unsigned long _Mask);
unsigned char _BitScanReverse64(unsigned long* _Index, unsigned __int64 _Mask);
}
#endif
namespace c10::llvm {
/// The behavior an operation has on an input of 0.
enum ZeroBehavior {
/// The returned value is undefined.
ZB_Undefined,
/// The returned value is numeric_limits<T>::max()
ZB_Max,
/// The returned value is numeric_limits<T>::digits
ZB_Width
};
namespace detail {
template <typename T, std::size_t SizeOfT>
struct TrailingZerosCounter {
static std::size_t count(T Val, ZeroBehavior) {
if (!Val)
return std::numeric_limits<T>::digits;
if (Val & 0x1)
return 0;
// Bisection method.
std::size_t ZeroBits = 0;
T Shift = std::numeric_limits<T>::digits >> 1;
T Mask = std::numeric_limits<T>::max() >> Shift;
while (Shift) {
if ((Val & Mask) == 0) {
Val >>= Shift;
ZeroBits |= Shift;
}
Shift >>= 1;
Mask >>= Shift;
}
return ZeroBits;
}
};
#if (defined(__GNUC__) && __GNUC__ >= 4) || defined(_MSC_VER)
template <typename T>
struct TrailingZerosCounter<T, 4> {
static std::size_t count(T Val, ZeroBehavior ZB) {
if (ZB != ZB_Undefined && Val == 0)
return 32;
#if __has_builtin(__builtin_ctz) || LLVM_GNUC_PREREQ(4, 0, 0)
return __builtin_ctz(Val);
#elif defined(_MSC_VER)
unsigned long Index;
_BitScanForward(&Index, Val);
return Index;
#endif
}
};
#if !defined(_MSC_VER) || defined(_M_X64)
template <typename T>
struct TrailingZerosCounter<T, 8> {
static std::size_t count(T Val, ZeroBehavior ZB) {
if (ZB != ZB_Undefined && Val == 0)
return 64;
#if __has_builtin(__builtin_ctzll) || LLVM_GNUC_PREREQ(4, 0, 0)
return __builtin_ctzll(Val);
#elif defined(_MSC_VER)
unsigned long Index;
_BitScanForward64(&Index, Val);
return Index;
#endif
}
};
#endif
#endif
} // namespace detail
/// Count number of 0's from the least significant bit to the most
/// stopping at the first 1.
///
/// Only unsigned integral types are allowed.
///
/// \param ZB the behavior on an input of 0. Only ZB_Width and ZB_Undefined are
/// valid arguments.
template <typename T>
std::size_t countTrailingZeros(T Val, ZeroBehavior ZB = ZB_Width) {
static_assert(
std::numeric_limits<T>::is_integer && !std::numeric_limits<T>::is_signed,
"Only unsigned integral types are allowed.");
return llvm::detail::TrailingZerosCounter<T, sizeof(T)>::count(Val, ZB);
}
namespace detail {
template <typename T, std::size_t SizeOfT>
struct LeadingZerosCounter {
static std::size_t count(T Val, ZeroBehavior) {
if (!Val)
return std::numeric_limits<T>::digits;
// Bisection method.
std::size_t ZeroBits = 0;
for (T Shift = std::numeric_limits<T>::digits >> 1; Shift; Shift >>= 1) {
T Tmp = Val >> Shift;
if (Tmp)
Val = Tmp;
else
ZeroBits |= Shift;
}
return ZeroBits;
}
};
#if (defined(__GNUC__) && __GNUC__ >= 4) || defined(_MSC_VER)
template <typename T>
struct LeadingZerosCounter<T, 4> {
static std::size_t count(T Val, ZeroBehavior ZB) {
if (ZB != ZB_Undefined && Val == 0)
return 32;
#if __has_builtin(__builtin_clz) || LLVM_GNUC_PREREQ(4, 0, 0)
return __builtin_clz(Val);
#elif defined(_MSC_VER)
unsigned long Index;
_BitScanReverse(&Index, Val);
return Index ^ 31;
#endif
}
};
#if !defined(_MSC_VER) || defined(_M_X64)
template <typename T>
struct LeadingZerosCounter<T, 8> {
static std::size_t count(T Val, ZeroBehavior ZB) {
if (ZB != ZB_Undefined && Val == 0)
return 64;
#if __has_builtin(__builtin_clzll) || LLVM_GNUC_PREREQ(4, 0, 0)
return __builtin_clzll(Val);
#elif defined(_MSC_VER)
unsigned long Index;
_BitScanReverse64(&Index, Val);
return Index ^ 63;
#endif
}
};
#endif
#endif
} // namespace detail
/// Count number of 0's from the most significant bit to the least
/// stopping at the first 1.
///
/// Only unsigned integral types are allowed.
///
/// \param ZB the behavior on an input of 0. Only ZB_Width and ZB_Undefined are
/// valid arguments.
template <typename T>
std::size_t countLeadingZeros(T Val, ZeroBehavior ZB = ZB_Width) {
static_assert(
std::numeric_limits<T>::is_integer && !std::numeric_limits<T>::is_signed,
"Only unsigned integral types are allowed.");
return llvm::detail::LeadingZerosCounter<T, sizeof(T)>::count(Val, ZB);
}
/// Get the index of the first set bit starting from the least
/// significant bit.
///
/// Only unsigned integral types are allowed.
///
/// \param ZB the behavior on an input of 0. Only ZB_Max and ZB_Undefined are
/// valid arguments.
template <typename T>
T findFirstSet(T Val, ZeroBehavior ZB = ZB_Max) {
if (ZB == ZB_Max && Val == 0)
return std::numeric_limits<T>::max();
return countTrailingZeros(Val, ZB_Undefined);
}
/// Create a bitmask with the N right-most bits set to 1, and all other
/// bits set to 0. Only unsigned types are allowed.
template <typename T>
T maskTrailingOnes(unsigned N) {
static_assert(std::is_unsigned_v<T>, "Invalid type!");
const unsigned Bits = CHAR_BIT * sizeof(T);
assert(N <= Bits && "Invalid bit index");
return N == 0 ? 0 : (T(-1) >> (Bits - N));
}
/// Create a bitmask with the N left-most bits set to 1, and all other
/// bits set to 0. Only unsigned types are allowed.
template <typename T>
T maskLeadingOnes(unsigned N) {
return ~maskTrailingOnes<T>(CHAR_BIT * sizeof(T) - N);
}
/// Create a bitmask with the N right-most bits set to 0, and all other
/// bits set to 1. Only unsigned types are allowed.
template <typename T>
T maskTrailingZeros(unsigned N) {
return maskLeadingOnes<T>(CHAR_BIT * sizeof(T) - N);
}
/// Create a bitmask with the N left-most bits set to 0, and all other
/// bits set to 1. Only unsigned types are allowed.
template <typename T>
T maskLeadingZeros(unsigned N) {
return maskTrailingOnes<T>(CHAR_BIT * sizeof(T) - N);
}
/// Get the index of the last set bit starting from the least
/// significant bit.
///
/// Only unsigned integral types are allowed.
///
/// \param ZB the behavior on an input of 0. Only ZB_Max and ZB_Undefined are
/// valid arguments.
template <typename T>
T findLastSet(T Val, ZeroBehavior ZB = ZB_Max) {
if (ZB == ZB_Max && Val == 0)
return std::numeric_limits<T>::max();
// Use ^ instead of - because both gcc and llvm can remove the associated ^
// in the __builtin_clz intrinsic on x86.
return countLeadingZeros(Val, ZB_Undefined) ^
(std::numeric_limits<T>::digits - 1);
}
/// Macro compressed bit reversal table for 256 bits.
///
/// http://graphics.stanford.edu/~seander/bithacks.html#BitReverseTable
/// NOLINTNEXTLINE(*c-arrays*)
static constexpr unsigned char BitReverseTable256[256] = {
#define R2(n) n, n + 2 * 64, n + 1 * 64, n + 3 * 64
#define R4(n) R2(n), R2(n + 2 * 16), R2(n + 1 * 16), R2(n + 3 * 16)
#define R6(n) R4(n), R4(n + 2 * 4), R4(n + 1 * 4), R4(n + 3 * 4)
R6(0),
R6(2),
R6(1),
R6(3)
#undef R2
#undef R4
#undef R6
};
/// Reverse the bits in \p Val.
template <typename T>
T reverseBits(T Val) {
// NOLINTNEXTLINE(*c-arrays*)
unsigned char in[sizeof(Val)];
// NOLINTNEXTLINE(*c-arrays*)
unsigned char out[sizeof(Val)];
std::memcpy(in, &Val, sizeof(Val));
for (unsigned i = 0; i < sizeof(Val); ++i)
out[(sizeof(Val) - i) - 1] = BitReverseTable256[in[i]];
std::memcpy(&Val, out, sizeof(Val));
return Val;
}
// NOTE: The following support functions use the _32/_64 extensions instead of
// type overloading so that signed and unsigned integers can be used without
// ambiguity.
/// Return the high 32 bits of a 64 bit value.
constexpr inline uint32_t Hi_32(uint64_t Value) {
return static_cast<uint32_t>(Value >> 32);
}
/// Return the low 32 bits of a 64 bit value.
constexpr inline uint32_t Lo_32(uint64_t Value) {
return static_cast<uint32_t>(Value);
}
/// Make a 64-bit integer from a high / low pair of 32-bit integers.
constexpr inline uint64_t Make_64(uint32_t High, uint32_t Low) {
return ((uint64_t)High << 32) | (uint64_t)Low;
}
/// Checks if an integer fits into the given bit width.
template <unsigned N>
constexpr inline bool isInt(int64_t x) {
return N >= 64 ||
(-(INT64_C(1) << (N - 1)) <= x && x < (INT64_C(1) << (N - 1)));
}
// Template specializations to get better code for common cases.
template <>
constexpr inline bool isInt<8>(int64_t x) {
return static_cast<int8_t>(x) == x;
}
template <>
constexpr inline bool isInt<16>(int64_t x) {
return static_cast<int16_t>(x) == x;
}
template <>
constexpr inline bool isInt<32>(int64_t x) {
return static_cast<int32_t>(x) == x;
}
/// Checks if a signed integer is an N bit number shifted left by S.
template <unsigned N, unsigned S>
constexpr inline bool isShiftedInt(int64_t x) {
static_assert(
N > 0, "isShiftedInt<0> doesn't make sense (refers to a 0-bit number.");
static_assert(N + S <= 64, "isShiftedInt<N, S> with N + S > 64 is too wide.");
return isInt<N + S>(x) && (x % (UINT64_C(1) << S) == 0);
}
/// Checks if an unsigned integer fits into the given bit width.
///
/// This is written as two functions rather than as simply
///
/// return N >= 64 || X < (UINT64_C(1) << N);
///
/// to keep MSVC from (incorrectly) warning on isUInt<64> that we're shifting
/// left too many places.
template <unsigned N>
constexpr inline std::enable_if_t<(N < 64), bool> isUInt(uint64_t X) {
static_assert(N > 0, "isUInt<0> doesn't make sense");
return X < (UINT64_C(1) << (N));
}
template <unsigned N>
constexpr inline std::enable_if_t<N >= 64, bool> isUInt(uint64_t /*X*/) {
return true;
}
// Template specializations to get better code for common cases.
template <>
constexpr inline bool isUInt<8>(uint64_t x) {
return static_cast<uint8_t>(x) == x;
}
template <>
constexpr inline bool isUInt<16>(uint64_t x) {
return static_cast<uint16_t>(x) == x;
}
template <>
constexpr inline bool isUInt<32>(uint64_t x) {
return static_cast<uint32_t>(x) == x;
}
/// Checks if a unsigned integer is an N bit number shifted left by S.
template <unsigned N, unsigned S>
constexpr inline bool isShiftedUInt(uint64_t x) {
static_assert(
N > 0, "isShiftedUInt<0> doesn't make sense (refers to a 0-bit number)");
static_assert(
N + S <= 64, "isShiftedUInt<N, S> with N + S > 64 is too wide.");
// Per the two static_asserts above, S must be strictly less than 64. So
// 1 << S is not undefined behavior.
return isUInt<N + S>(x) && (x % (UINT64_C(1) << S) == 0);
}
/// Gets the maximum value for a N-bit unsigned integer.
inline uint64_t maxUIntN(uint64_t N) {
assert(N > 0 && N <= 64 && "integer width out of range");
// uint64_t(1) << 64 is undefined behavior, so we can't do
// (uint64_t(1) << N) - 1
// without checking first that N != 64. But this works and doesn't have a
// branch.
return UINT64_MAX >> (64 - N);
}
// Ignore the false warning "Arithmetic overflow" for MSVC
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4146)
#endif
/// Gets the minimum value for a N-bit signed integer.
inline int64_t minIntN(int64_t N) {
assert(N > 0 && N <= 64 && "integer width out of range");
// NOLINTNEXTLINE(*-narrowing-conversions)
return -(UINT64_C(1) << (N - 1));
}
#ifdef _MSC_VER
#pragma warning(pop)
#endif
/// Gets the maximum value for a N-bit signed integer.
inline int64_t maxIntN(int64_t N) {
assert(N > 0 && N <= 64 && "integer width out of range");
// This relies on two's complement wraparound when N == 64, so we convert to
// int64_t only at the very end to avoid UB.
// NOLINTNEXTLINE(*-narrowing-conversions)
return (UINT64_C(1) << (N - 1)) - 1;
}
/// Checks if an unsigned integer fits into the given (dynamic) bit width.
inline bool isUIntN(unsigned N, uint64_t x) {
return N >= 64 || x <= maxUIntN(N);
}
/// Checks if an signed integer fits into the given (dynamic) bit width.
inline bool isIntN(unsigned N, int64_t x) {
return N >= 64 || (minIntN(N) <= x && x <= maxIntN(N));
}
/// Return true if the argument is a non-empty sequence of ones starting at the
/// least significant bit with the remainder zero (32 bit version).
/// Ex. isMask_32(0x0000FFFFU) == true.
constexpr inline bool isMask_32(uint32_t Value) {
return Value && ((Value + 1) & Value) == 0;
}
/// Return true if the argument is a non-empty sequence of ones starting at the
/// least significant bit with the remainder zero (64 bit version).
constexpr inline bool isMask_64(uint64_t Value) {
return Value && ((Value + 1) & Value) == 0;
}
/// Return true if the argument contains a non-empty sequence of ones with the
/// remainder zero (32 bit version.) Ex. isShiftedMask_32(0x0000FF00U) == true.
constexpr inline bool isShiftedMask_32(uint32_t Value) {
return Value && isMask_32((Value - 1) | Value);
}
/// Return true if the argument contains a non-empty sequence of ones with the
/// remainder zero (64 bit version.)
constexpr inline bool isShiftedMask_64(uint64_t Value) {
return Value && isMask_64((Value - 1) | Value);
}
/// Return true if the argument is a power of two > 0.
/// Ex. isPowerOf2_32(0x00100000U) == true (32 bit edition.)
constexpr inline bool isPowerOf2_32(uint32_t Value) {
return Value && !(Value & (Value - 1));
}
/// Return true if the argument is a power of two > 0 (64 bit edition.)
constexpr inline bool isPowerOf2_64(uint64_t Value) {
return Value && !(Value & (Value - 1));
}
/// Count the number of ones from the most significant bit to the first
/// zero bit.
///
/// Ex. countLeadingOnes(0xFF0FFF00) == 8.
/// Only unsigned integral types are allowed.
///
/// \param ZB the behavior on an input of all ones. Only ZB_Width and
/// ZB_Undefined are valid arguments.
template <typename T>
std::size_t countLeadingOnes(T Value, ZeroBehavior ZB = ZB_Width) {
static_assert(
std::numeric_limits<T>::is_integer && !std::numeric_limits<T>::is_signed,
"Only unsigned integral types are allowed.");
return countLeadingZeros<T>(~Value, ZB);
}
/// Count the number of ones from the least significant bit to the first
/// zero bit.
///
/// Ex. countTrailingOnes(0x00FF00FF) == 8.
/// Only unsigned integral types are allowed.
///
/// \param ZB the behavior on an input of all ones. Only ZB_Width and
/// ZB_Undefined are valid arguments.
template <typename T>
std::size_t countTrailingOnes(T Value, ZeroBehavior ZB = ZB_Width) {
static_assert(
std::numeric_limits<T>::is_integer && !std::numeric_limits<T>::is_signed,
"Only unsigned integral types are allowed.");
return countTrailingZeros<T>(~Value, ZB);
}
namespace detail {
template <typename T, std::size_t SizeOfT>
struct PopulationCounter {
static unsigned count(T Value) {
// Generic version, forward to 32 bits.
static_assert(SizeOfT <= 4, "Not implemented!");
#if defined(__GNUC__) && __GNUC__ >= 4
return __builtin_popcount(Value);
#else
uint32_t v = Value;
v = v - ((v >> 1) & 0x55555555);
v = (v & 0x33333333) + ((v >> 2) & 0x33333333);
return ((v + (v >> 4) & 0xF0F0F0F) * 0x1010101) >> 24;
#endif
}
};
template <typename T>
struct PopulationCounter<T, 8> {
static unsigned count(T Value) {
#if defined(__GNUC__) && __GNUC__ >= 4
return __builtin_popcountll(Value);
#else
uint64_t v = Value;
v = v - ((v >> 1) & 0x5555555555555555ULL);
v = (v & 0x3333333333333333ULL) + ((v >> 2) & 0x3333333333333333ULL);
v = (v + (v >> 4)) & 0x0F0F0F0F0F0F0F0FULL;
return unsigned((uint64_t)(v * 0x0101010101010101ULL) >> 56);
#endif
}
};
} // namespace detail
/// Count the number of set bits in a value.
/// Ex. countPopulation(0xF000F000) = 8
/// Returns 0 if the word is zero.
template <typename T>
inline unsigned countPopulation(T Value) {
static_assert(
std::numeric_limits<T>::is_integer && !std::numeric_limits<T>::is_signed,
"Only unsigned integral types are allowed.");
return detail::PopulationCounter<T, sizeof(T)>::count(Value);
}
/// Return the log base 2 of the specified value.
inline double Log2(double Value) {
#if defined(__ANDROID_API__) && __ANDROID_API__ < 18
return __builtin_log(Value) / __builtin_log(2.0);
#else
return log2(Value);
#endif
}
/// Return the floor log base 2 of the specified value, -1 if the value is zero.
/// (32 bit edition.)
/// Ex. Log2_32(32) == 5, Log2_32(1) == 0, Log2_32(0) == -1, Log2_32(6) == 2
inline unsigned Log2_32(uint32_t Value) {
return static_cast<unsigned>(31 - countLeadingZeros(Value));
}
/// Return the floor log base 2 of the specified value, -1 if the value is zero.
/// (64 bit edition.)
inline unsigned Log2_64(uint64_t Value) {
return static_cast<unsigned>(63 - countLeadingZeros(Value));
}
/// Return the ceil log base 2 of the specified value, 32 if the value is zero.
/// (32 bit edition).
/// Ex. Log2_32_Ceil(32) == 5, Log2_32_Ceil(1) == 0, Log2_32_Ceil(6) == 3
inline unsigned Log2_32_Ceil(uint32_t Value) {
return static_cast<unsigned>(32 - countLeadingZeros(Value - 1));
}
/// Return the ceil log base 2 of the specified value, 64 if the value is zero.
/// (64 bit edition.)
inline unsigned Log2_64_Ceil(uint64_t Value) {
return static_cast<unsigned>(64 - countLeadingZeros(Value - 1));
}
/// Return the greatest common divisor of the values using Euclid's algorithm.
inline uint64_t GreatestCommonDivisor64(uint64_t A, uint64_t B) {
while (B) {
uint64_t T = B;
B = A % B;
A = T;
}
return A;
}
/// This function takes a 64-bit integer and returns the bit equivalent double.
inline double BitsToDouble(uint64_t Bits) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
double D;
static_assert(sizeof(uint64_t) == sizeof(double), "Unexpected type sizes");
memcpy(&D, &Bits, sizeof(Bits));
return D;
}
/// This function takes a 32-bit integer and returns the bit equivalent float.
inline float BitsToFloat(uint32_t Bits) {
// TODO: Use std::bit_cast once C++20 becomes available.
return c10::bit_cast<float>(Bits);
}
/// This function takes a double and returns the bit equivalent 64-bit integer.
/// Note that copying doubles around changes the bits of NaNs on some hosts,
/// notably x86, so this routine cannot be used if these bits are needed.
inline uint64_t DoubleToBits(double Double) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint64_t Bits;
static_assert(sizeof(uint64_t) == sizeof(double), "Unexpected type sizes");
memcpy(&Bits, &Double, sizeof(Double));
return Bits;
}
/// This function takes a float and returns the bit equivalent 32-bit integer.
/// Note that copying floats around changes the bits of NaNs on some hosts,
/// notably x86, so this routine cannot be used if these bits are needed.
inline uint32_t FloatToBits(float Float) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint32_t Bits;
static_assert(sizeof(uint32_t) == sizeof(float), "Unexpected type sizes");
memcpy(&Bits, &Float, sizeof(Float));
return Bits;
}
/// A and B are either alignments or offsets. Return the minimum alignment that
/// may be assumed after adding the two together.
constexpr inline uint64_t MinAlign(uint64_t A, uint64_t B) {
// The largest power of 2 that divides both A and B.
//
// Replace "-Value" by "1+~Value" in the following commented code to avoid
// MSVC warning C4146
// return (A | B) & -(A | B);
return (A | B) & (1 + ~(A | B));
}
/// Aligns \c Addr to \c Alignment bytes, rounding up.
///
/// Alignment should be a power of two. This method rounds up, so
/// alignAddr(7, 4) == 8 and alignAddr(8, 4) == 8.
inline uintptr_t alignAddr(const void* Addr, size_t Alignment) {
assert(
Alignment && isPowerOf2_64((uint64_t)Alignment) &&
"Alignment is not a power of two!");
assert((uintptr_t)Addr + Alignment - 1 >= (uintptr_t)Addr);
return (((uintptr_t)Addr + Alignment - 1) & ~(uintptr_t)(Alignment - 1));
}
/// Returns the necessary adjustment for aligning \c Ptr to \c Alignment
/// bytes, rounding up.
inline size_t alignmentAdjustment(const void* Ptr, size_t Alignment) {
return alignAddr(Ptr, Alignment) - (uintptr_t)Ptr;
}
/// Returns the next power of two (in 64-bits) that is strictly greater than A.
/// Returns zero on overflow.
inline uint64_t NextPowerOf2(uint64_t A) {
A |= (A >> 1);
A |= (A >> 2);
A |= (A >> 4);
A |= (A >> 8);
A |= (A >> 16);
A |= (A >> 32);
return A + 1;
}
/// Returns the power of two which is less than or equal to the given value.
/// Essentially, it is a floor operation across the domain of powers of two.
inline uint64_t PowerOf2Floor(uint64_t A) {
if (!A)
return 0;
return 1ull << (63 - countLeadingZeros(A, ZB_Undefined));
}
/// Returns the power of two which is greater than or equal to the given value.
/// Essentially, it is a ceil operation across the domain of powers of two.
inline uint64_t PowerOf2Ceil(uint64_t A) {
if (!A)
return 0;
return NextPowerOf2(A - 1);
}
/// Returns the next integer (mod 2**64) that is greater than or equal to
/// \p Value and is a multiple of \p Align. \p Align must be non-zero.
///
/// If non-zero \p Skew is specified, the return value will be a minimal
/// integer that is greater than or equal to \p Value and equal to
/// \p Align * N + \p Skew for some integer N. If \p Skew is larger than
/// \p Align, its value is adjusted to '\p Skew mod \p Align'.
///
/// Examples:
/// \code
/// alignTo(5, 8) = 8
/// alignTo(17, 8) = 24
/// alignTo(~0LL, 8) = 0
/// alignTo(321, 255) = 510
///
/// alignTo(5, 8, 7) = 7
/// alignTo(17, 8, 1) = 17
/// alignTo(~0LL, 8, 3) = 3
/// alignTo(321, 255, 42) = 552
/// \endcode
inline uint64_t alignTo(uint64_t Value, uint64_t Align, uint64_t Skew = 0) {
assert(Align != 0u && "Align can't be 0.");
Skew %= Align;
return (Value + Align - 1 - Skew) / Align * Align + Skew;
}
/// Returns the next integer (mod 2**64) that is greater than or equal to
/// \p Value and is a multiple of \c Align. \c Align must be non-zero.
template <uint64_t Align>
constexpr inline uint64_t alignTo(uint64_t Value) {
static_assert(Align != 0u, "Align must be non-zero");
return (Value + Align - 1) / Align * Align;
}
/// Returns the integer ceil(Numerator / Denominator).
inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
return alignTo(Numerator, Denominator) / Denominator;
}
/// \c alignTo for contexts where a constant expression is required.
/// \sa alignTo
///
/// \todo FIXME: remove when \c constexpr becomes really \c constexpr
template <uint64_t Align>
struct AlignTo {
static_assert(Align != 0u, "Align must be non-zero");
template <uint64_t Value>
struct from_value {
static const uint64_t value = (Value + Align - 1) / Align * Align;
};
};
/// Returns the largest uint64_t less than or equal to \p Value and is
/// \p Skew mod \p Align. \p Align must be non-zero
inline uint64_t alignDown(uint64_t Value, uint64_t Align, uint64_t Skew = 0) {
assert(Align != 0u && "Align can't be 0.");
Skew %= Align;
return (Value - Skew) / Align * Align + Skew;
}
/// Returns the offset to the next integer (mod 2**64) that is greater than
/// or equal to \p Value and is a multiple of \p Align. \p Align must be
/// non-zero.
inline uint64_t OffsetToAlignment(uint64_t Value, uint64_t Align) {
return alignTo(Value, Align) - Value;
}
/// Sign-extend the number in the bottom B bits of X to a 32-bit integer.
/// Requires 0 < B <= 32.
template <unsigned B>
constexpr inline int32_t SignExtend32(uint32_t X) {
static_assert(B > 0, "Bit width can't be 0.");
static_assert(B <= 32, "Bit width out of range.");
return int32_t(X << (32 - B)) >> (32 - B);
}
/// Sign-extend the number in the bottom B bits of X to a 32-bit integer.
/// Requires 0 < B < 32.
inline int32_t SignExtend32(uint32_t X, unsigned B) {
assert(B > 0 && "Bit width can't be 0.");
assert(B <= 32 && "Bit width out of range.");
return int32_t(X << (32 - B)) >> (32 - B);
}
/// Sign-extend the number in the bottom B bits of X to a 64-bit integer.
/// Requires 0 < B < 64.
template <unsigned B>
constexpr inline int64_t SignExtend64(uint64_t x) {
static_assert(B > 0, "Bit width can't be 0.");
static_assert(B <= 64, "Bit width out of range.");
return int64_t(x << (64 - B)) >> (64 - B);
}
/// Sign-extend the number in the bottom B bits of X to a 64-bit integer.
/// Requires 0 < B < 64.
inline int64_t SignExtend64(uint64_t X, unsigned B) {
assert(B > 0 && "Bit width can't be 0.");
assert(B <= 64 && "Bit width out of range.");
return int64_t(X << (64 - B)) >> (64 - B);
}
/// Subtract two unsigned integers, X and Y, of type T and return the absolute
/// value of the result.
template <typename T>
std::enable_if_t<std::is_unsigned_v<T>, T> AbsoluteDifference(T X, T Y) {
return std::max(X, Y) - std::min(X, Y);
}
/// Add two unsigned integers, X and Y, of type T. Clamp the result to the
/// maximum representable value of T on overflow. ResultOverflowed indicates if
/// the result is larger than the maximum representable value of type T.
template <typename T>
std::enable_if_t<std::is_unsigned_v<T>, T> SaturatingAdd(
T X,
T Y,
bool* ResultOverflowed = nullptr) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool Dummy;
bool& Overflowed = ResultOverflowed ? *ResultOverflowed : Dummy;
// Hacker's Delight, p. 29
T Z = X + Y;
Overflowed = (Z < X || Z < Y);
if (Overflowed)
return std::numeric_limits<T>::max();
else
return Z;
}
/// Multiply two unsigned integers, X and Y, of type T. Clamp the result to the
/// maximum representable value of T on overflow. ResultOverflowed indicates if
/// the result is larger than the maximum representable value of type T.
template <typename T>
std::enable_if_t<std::is_unsigned_v<T>, T> SaturatingMultiply(
T X,
T Y,
bool* ResultOverflowed = nullptr) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool Dummy;
bool& Overflowed = ResultOverflowed ? *ResultOverflowed : Dummy;
// Hacker's Delight, p. 30 has a different algorithm, but we don't use that
// because it fails for uint16_t (where multiplication can have undefined
// behavior due to promotion to int), and requires a division in addition
// to the multiplication.
Overflowed = false;
// Log2(Z) would be either Log2Z or Log2Z + 1.
// Special case: if X or Y is 0, Log2_64 gives -1, and Log2Z
// will necessarily be less than Log2Max as desired.
int Log2Z = Log2_64(X) + Log2_64(Y);
const T Max = std::numeric_limits<T>::max();
int Log2Max = Log2_64(Max);
if (Log2Z < Log2Max) {
return X * Y;
}
if (Log2Z > Log2Max) {
Overflowed = true;
return Max;
}
// We're going to use the top bit, and maybe overflow one
// bit past it. Multiply all but the bottom bit then add
// that on at the end.
T Z = (X >> 1) * Y;
if (Z & ~(Max >> 1)) {
Overflowed = true;
return Max;
}
Z <<= 1;
if (X & 1)
return SaturatingAdd(Z, Y, ResultOverflowed);
return Z;
}
/// Multiply two unsigned integers, X and Y, and add the unsigned integer, A to
/// the product. Clamp the result to the maximum representable value of T on
/// overflow. ResultOverflowed indicates if the result is larger than the
/// maximum representable value of type T.
template <typename T>
std::enable_if_t<std::is_unsigned_v<T>, T> SaturatingMultiplyAdd(
T X,
T Y,
T A,
bool* ResultOverflowed = nullptr) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool Dummy;
bool& Overflowed = ResultOverflowed ? *ResultOverflowed : Dummy;
T Product = SaturatingMultiply(X, Y, &Overflowed);
if (Overflowed)
return Product;
return SaturatingAdd(A, Product, &Overflowed);
}
/// Use this rather than HUGE_VALF; the latter causes warnings on MSVC.
extern const float huge_valf;
} // namespace c10::llvm

View File

@ -0,0 +1,109 @@
#ifndef C10_UTIL_LOGGING_IS_GOOGLE_GLOG_H_
#define C10_UTIL_LOGGING_IS_GOOGLE_GLOG_H_
#include <map>
#include <set>
#include <vector>
#include <iomanip> // because some of the caffe2 code uses e.g. std::setw
// Using google glog. For glog 0.3.2 versions, stl_logging.h needs to be before
// logging.h to actually use stl_logging. Because template magic.
// In addition, we do not do stl logging in .cu files because nvcc does not like
// it. Some mobile platforms do not like stl_logging, so we add an
// overload in that case as well.
#ifdef __CUDACC__
#include <cuda.h>
#endif
#if !defined(__CUDACC__) && !defined(C10_USE_MINIMAL_GLOG)
#include <glog/stl_logging.h>
// Old versions of glog don't declare this using declaration, so help
// them out. Fortunately, C++ won't complain if you declare the same
// using declaration multiple times.
namespace std {
using ::operator<<;
}
#else // !defined(__CUDACC__) && !defined(C10_USE_MINIMAL_GLOG)
// In the cudacc compiler scenario, we will simply ignore the container
// printout feature. Basically we need to register a fake overload for
// vector/string - here, we just ignore the entries in the logs.
namespace std {
#define INSTANTIATE_FOR_CONTAINER(container) \
template <class... Types> \
ostream& operator<<(ostream& out, const container<Types...>&) { \
return out; \
}
INSTANTIATE_FOR_CONTAINER(vector)
INSTANTIATE_FOR_CONTAINER(map)
INSTANTIATE_FOR_CONTAINER(set)
#undef INSTANTIATE_FOR_CONTAINER
} // namespace std
#endif
#include <glog/logging.h>
// Additional macros on top of glog
#define TORCH_CHECK_EQ(val1, val2) CHECK_EQ(val1, val2)
#define TORCH_CHECK_NE(val1, val2) CHECK_NE(val1, val2)
#define TORCH_CHECK_LE(val1, val2) CHECK_LE(val1, val2)
#define TORCH_CHECK_LT(val1, val2) CHECK_LT(val1, val2)
#define TORCH_CHECK_GE(val1, val2) CHECK_GE(val1, val2)
#define TORCH_CHECK_GT(val1, val2) CHECK_GT(val1, val2)
#ifndef NDEBUG
#define TORCH_DCHECK_EQ(val1, val2) DCHECK_EQ(val1, val2)
#define TORCH_DCHECK_NE(val1, val2) DCHECK_NE(val1, val2)
#define TORCH_DCHECK_LE(val1, val2) DCHECK_LE(val1, val2)
#define TORCH_DCHECK_LT(val1, val2) DCHECK_LT(val1, val2)
#define TORCH_DCHECK_GE(val1, val2) DCHECK_GE(val1, val2)
#define TORCH_DCHECK_GT(val1, val2) DCHECK_GT(val1, val2)
#else // !NDEBUG
// These versions generate no code in optimized mode.
#define TORCH_DCHECK_EQ(val1, val2) \
while (false) \
DCHECK_EQ(val1, val2)
#define TORCH_DCHECK_NE(val1, val2) \
while (false) \
DCHECK_NE(val1, val2)
#define TORCH_DCHECK_LE(val1, val2) \
while (false) \
DCHECK_LE(val1, val2)
#define TORCH_DCHECK_LT(val1, val2) \
while (false) \
DCHECK_LT(val1, val2)
#define TORCH_DCHECK_GE(val1, val2) \
while (false) \
DCHECK_GE(val1, val2)
#define TORCH_DCHECK_GT(val1, val2) \
while (false) \
DCHECK_GT(val1, val2)
#endif // NDEBUG
// Check that a pointer is not null.
#define TORCH_CHECK_NOTNULL(val) CHECK_NOTNULL(val)
#ifndef NDEBUG
// Debug only version of TORCH_CHECK_NOTNULL
#define TORCH_DCHECK_NOTNULL(val) DCHECK_NOTNULL(val)
#else // !NDEBUG
// Optimized version - generates no code.
#define TORCH_DCHECK_NOTNULL(val) \
while (false) \
DCHECK_NOTNULL(val)
#endif // NDEBUG
// Log with source location information override (to be used in generic
// warning/error handlers implemented as functions, not macros)
//
// Note, we don't respect GOOGLE_STRIP_LOG here for simplicity
#define LOG_AT_FILE_LINE(n, file, line) \
::google::LogMessage(file, line, ::google::GLOG_##n).stream()
#endif // C10_UTIL_LOGGING_IS_GOOGLE_GLOG_H_

View File

@ -0,0 +1,258 @@
#ifndef C10_UTIL_LOGGING_IS_NOT_GOOGLE_GLOG_H_
#define C10_UTIL_LOGGING_IS_NOT_GOOGLE_GLOG_H_
#include <chrono>
#include <climits>
#include <ctime>
#include <iomanip>
#include <map>
#include <ostream>
#include <set>
#include <sstream>
#include <string>
#include <vector>
#include <c10/util/Flags.h>
const char CAFFE2_SEVERITY_PREFIX[] = "FEWIV";
namespace c10 {
// Log severity level constants.
const int GLOG_FATAL = 3;
const int GLOG_ERROR = 2;
const int GLOG_WARNING = 1;
const int GLOG_INFO = 0;
class C10_API MessageLogger {
public:
MessageLogger(const char* file, int line, int severity);
~MessageLogger();
// Return the stream associated with the logger object.
std::stringstream& stream() {
return stream_;
}
private:
// When there is a fatal log, we simply abort.
void DealWithFatal() {
abort();
}
const char* tag_;
std::stringstream stream_;
int severity_;
};
// This class is used to explicitly ignore values in the conditional
// logging macros. This avoids compiler warnings like "value computed
// is not used" and "statement has no effect".
class C10_API LoggerVoidify {
public:
LoggerVoidify() = default;
// This has to be an operator with a precedence lower than << but
// higher than ?:
void operator&(const std::ostream& s [[maybe_unused]]) {}
};
// Log a message and terminate.
template <class T>
void LogMessageFatal(const char* file, int line, const T& message) {
MessageLogger(file, line, GLOG_FATAL).stream() << message;
}
// Helpers for TORCH_CHECK_NOTNULL(). Two are necessary to support both raw
// pointers and smart pointers.
template <typename T>
T& CheckNotNullCommon(const char* file, int line, const char* names, T& t) {
if (t == nullptr) {
LogMessageFatal(file, line, std::string(names));
}
return t;
}
template <typename T>
T* CheckNotNull(const char* file, int line, const char* names, T* t) {
return CheckNotNullCommon(file, line, names, t);
}
template <typename T>
T& CheckNotNull(const char* file, int line, const char* names, T& t) {
return CheckNotNullCommon(file, line, names, t);
}
} // namespace c10
// ---------------------- Logging Macro definitions --------------------------
static_assert(
CAFFE2_LOG_THRESHOLD <= ::c10::GLOG_FATAL,
"CAFFE2_LOG_THRESHOLD should at most be GLOG_FATAL.");
// If n is under the compile time caffe log threshold, The _CAFFE_LOG(n)
// should not generate anything in optimized code.
#define LOG(n) \
if (::c10::GLOG_##n >= CAFFE2_LOG_THRESHOLD) \
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream()
#define VLOG(n) \
if (-n >= CAFFE2_LOG_THRESHOLD) \
::c10::MessageLogger(__FILE__, __LINE__, -n).stream()
#define LOG_IF(n, condition) \
if (::c10::GLOG_##n >= CAFFE2_LOG_THRESHOLD && (condition)) \
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream()
#define VLOG_IF(n, condition) \
if (-n >= CAFFE2_LOG_THRESHOLD && (condition)) \
::c10::MessageLogger(__FILE__, __LINE__, -n).stream()
#define VLOG_IS_ON(verboselevel) (CAFFE2_LOG_THRESHOLD <= -(verboselevel))
// Log with source location information override (to be used in generic
// warning/error handlers implemented as functions, not macros)
#define LOG_AT_FILE_LINE(n, file, line) \
if (::c10::GLOG_##n >= CAFFE2_LOG_THRESHOLD) \
::c10::MessageLogger(file, line, ::c10::GLOG_##n).stream()
// Log only if condition is met. Otherwise evaluates to void.
#define FATAL_IF(condition) \
condition ? (void)0 \
: ::c10::LoggerVoidify() & \
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_FATAL).stream()
// Check for a given boolean condition.
#define CHECK(condition) FATAL_IF(condition) << "Check failed: " #condition " "
#ifndef NDEBUG
// Debug only version of CHECK
#define DCHECK(condition) FATAL_IF(condition) << "Check failed: " #condition " "
#define DLOG(severity) LOG(severity)
#else // NDEBUG
// Optimized version - generates no code.
#define DCHECK(condition) \
while (false) \
CHECK(condition)
#define DLOG(n) \
true ? (void)0 \
: ::c10::LoggerVoidify() & \
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream()
#endif // NDEBUG
#define TORCH_CHECK_OP(val1, val2, op) \
FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \
<< (val1) << " vs. " << (val2) << ") "
// TORCH_CHECK_OP macro definitions
#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
#define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
#define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
#define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
#define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
#ifndef NDEBUG
// Debug only versions of TORCH_CHECK_OP macros.
#define TORCH_DCHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
#define TORCH_DCHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
#define TORCH_DCHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
#define TORCH_DCHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
#define TORCH_DCHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
#define TORCH_DCHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
#else // !NDEBUG
// These versions generate no code in optimized mode.
#define TORCH_DCHECK_EQ(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, ==)
#define TORCH_DCHECK_NE(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, !=)
#define TORCH_DCHECK_LE(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, <=)
#define TORCH_DCHECK_LT(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, <)
#define TORCH_DCHECK_GE(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, >=)
#define TORCH_DCHECK_GT(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, >)
#endif // NDEBUG
// Check that a pointer is not null.
#define TORCH_CHECK_NOTNULL(val) \
::c10::CheckNotNull( \
__FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val))
#ifndef NDEBUG
// Debug only version of TORCH_CHECK_NOTNULL
#define TORCH_DCHECK_NOTNULL(val) \
::c10::CheckNotNull( \
__FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val))
#else // !NDEBUG
// Optimized version - generates no code.
#define TORCH_DCHECK_NOTNULL(val) \
while (false) \
TORCH_CHECK_NOTNULL(val)
#endif // NDEBUG
// ---------------------- Support for std objects --------------------------
// These are adapted from glog to support a limited set of logging capability
// for STL objects.
namespace std {
// Forward declare these two, and define them after all the container streams
// operators so that we can recurse from pair -> container -> container -> pair
// properly.
template <class First, class Second>
std::ostream& operator<<(std::ostream& out, const std::pair<First, Second>& p);
} // namespace std
namespace c10 {
template <class Iter>
void PrintSequence(std::ostream& ss, Iter begin, Iter end);
} // namespace c10
namespace std {
#define INSTANTIATE_FOR_CONTAINER(container) \
template <class... Types> \
std::ostream& operator<<( \
std::ostream& out, const container<Types...>& seq) { \
c10::PrintSequence(out, seq.begin(), seq.end()); \
return out; \
}
INSTANTIATE_FOR_CONTAINER(std::vector)
INSTANTIATE_FOR_CONTAINER(std::map)
INSTANTIATE_FOR_CONTAINER(std::set)
#undef INSTANTIATE_FOR_CONTAINER
template <class First, class Second>
inline std::ostream& operator<<(
std::ostream& out,
const std::pair<First, Second>& p) {
out << '(' << p.first << ", " << p.second << ')';
return out;
}
inline std::ostream& operator<<(std::ostream& out, const std::nullptr_t&) {
out << "(null)";
return out;
}
} // namespace std
namespace c10 {
template <class Iter>
inline void PrintSequence(std::ostream& out, Iter begin, Iter end) {
// Output at most 100 elements -- appropriate if used for logging.
for (int i = 0; begin != end && i < 100; ++i, ++begin) {
if (i > 0)
out << ' ';
out << *begin;
}
if (begin != end) {
out << " ...";
}
}
} // namespace c10
#endif // C10_UTIL_LOGGING_IS_NOT_GOOGLE_GLOG_H_

View File

@ -0,0 +1,41 @@
#pragma once
#include <c10/macros/Export.h>
#include <c10/util/Flags.h>
#include <cstddef>
C10_DECLARE_bool(caffe2_cpu_numa_enabled);
namespace c10 {
/**
* Check whether NUMA is enabled
*/
C10_API bool IsNUMAEnabled();
/**
* Bind to a given NUMA node
*/
C10_API void NUMABind(int numa_node_id);
/**
* Get the NUMA id for a given pointer `ptr`
*/
C10_API int GetNUMANode(const void* ptr);
/**
* Get number of NUMA nodes
*/
C10_API int GetNumNUMANodes();
/**
* Move the memory pointed to by `ptr` of a given size to another NUMA node
*/
C10_API void NUMAMove(void* ptr, size_t size, int numa_node_id);
/**
* Get the current NUMA node id
*/
C10_API int GetCurrentNUMANode();
} // namespace c10

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,31 @@
#pragma once
#include <memory>
namespace c10 {
namespace detail {
template <class... Ts>
struct overloaded_t {};
template <class T0>
struct overloaded_t<T0> : T0 {
using T0::operator();
overloaded_t(T0 t0) : T0(std::move(t0)) {}
};
template <class T0, class... Ts>
struct overloaded_t<T0, Ts...> : T0, overloaded_t<Ts...> {
using T0::operator();
using overloaded_t<Ts...>::operator();
overloaded_t(T0 t0, Ts... ts)
: T0(std::move(t0)), overloaded_t<Ts...>(std::move(ts)...) {}
};
} // namespace detail
// Construct an overloaded callable combining multiple callables, e.g. lambdas
template <class... Ts>
detail::overloaded_t<Ts...> overloaded(Ts... ts) {
return {std::move(ts)...};
}
} // namespace c10

View File

@ -0,0 +1,4 @@
#pragma once
struct _object;
using PyObject = _object;

View File

@ -0,0 +1,18 @@
#pragma once
#include <cstdint>
#include <c10/macros/Macros.h>
namespace c10 {
/**
* qint32 is for signed 32 bit quantized Tensors
*/
struct alignas(4) qint32 {
using underlying = int32_t;
int32_t val_;
qint32() = default;
C10_HOST_DEVICE explicit qint32(int32_t val) : val_(val) {}
};
} // namespace c10

View File

@ -0,0 +1,20 @@
#pragma once
#include <cstdint>
#include <c10/macros/Macros.h>
namespace c10 {
/**
* This is the data type for quantized Tensors. Right now we only have
* qint8 which is for 8 bit Tensors, and qint32 for 32 bit int Tensors,
* we might have 4 bit, 2 bit or 1 bit data types in the future.
*/
struct alignas(1) qint8 {
using underlying = int8_t;
int8_t val_;
qint8() = default;
C10_HOST_DEVICE explicit qint8(int8_t val) : val_(val) {}
};
} // namespace c10

View File

@ -0,0 +1,19 @@
#pragma once
#include <cstdint>
#include <c10/macros/Macros.h>
namespace c10 {
/**
* quint2x4 is for un-signed 2 bit quantized Tensors that are packed to byte
* boundary.
*/
struct alignas(1) quint2x4 {
using underlying = uint8_t;
uint8_t val_;
quint2x4() = default;
C10_HOST_DEVICE explicit quint2x4(uint8_t val) : val_(val) {}
};
} // namespace c10

View File

@ -0,0 +1,19 @@
#pragma once
#include <cstdint>
#include <c10/macros/Macros.h>
namespace c10 {
/**
* quint4x2 is for un-signed 4 bit quantized Tensors that are packed to byte
* boundary.
*/
struct alignas(1) quint4x2 {
using underlying = uint8_t;
uint8_t val_;
quint4x2() = default;
C10_HOST_DEVICE explicit quint4x2(uint8_t val) : val_(val) {}
};
} // namespace c10

View File

@ -0,0 +1,18 @@
#pragma once
#include <cstdint>
#include <c10/macros/Macros.h>
namespace c10 {
/**
* quint8 is for unsigned 8 bit quantized Tensors
*/
struct alignas(1) quint8 {
using underlying = uint8_t;
uint8_t val_;
quint8() = default;
C10_HOST_DEVICE explicit quint8(uint8_t val) : val_(val) {}
};
} // namespace c10

View File

@ -0,0 +1,90 @@
#pragma once
#include <c10/macros/Macros.h>
#include <cstdint>
// GCC has __builtin_mul_overflow from before it supported __has_builtin
#ifdef _MSC_VER
#define C10_HAS_BUILTIN_OVERFLOW() (0)
#include <c10/util/llvmMathExtras.h>
#include <intrin.h>
#else
#define C10_HAS_BUILTIN_OVERFLOW() (1)
#endif
namespace c10 {
C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) {
#if C10_HAS_BUILTIN_OVERFLOW()
return __builtin_add_overflow(a, b, out);
#else
unsigned long long tmp;
#if defined(_M_IX86) || defined(_M_X64)
auto carry = _addcarry_u64(0, a, b, &tmp);
#else
tmp = a + b;
unsigned long long vector = (a & b) ^ ((a ^ b) & ~tmp);
auto carry = vector >> 63;
#endif
*out = tmp;
return carry;
#endif
}
C10_ALWAYS_INLINE bool mul_overflows(uint64_t a, uint64_t b, uint64_t* out) {
#if C10_HAS_BUILTIN_OVERFLOW()
return __builtin_mul_overflow(a, b, out);
#else
*out = a * b;
// This test isnt exact, but avoids doing integer division
return (
(c10::llvm::countLeadingZeros(a) + c10::llvm::countLeadingZeros(b)) < 64);
#endif
}
C10_ALWAYS_INLINE bool mul_overflows(int64_t a, int64_t b, int64_t* out) {
#if C10_HAS_BUILTIN_OVERFLOW()
return __builtin_mul_overflow(a, b, out);
#else
volatile int64_t tmp = a * b;
*out = tmp;
if (a == 0 || b == 0) {
return false;
}
return !(a == tmp / b);
#endif
}
template <typename It>
bool safe_multiplies_u64(It first, It last, uint64_t* out) {
#if C10_HAS_BUILTIN_OVERFLOW()
uint64_t prod = 1;
bool overflow = false;
for (; first != last; ++first) {
overflow |= c10::mul_overflows(prod, *first, &prod);
}
*out = prod;
return overflow;
#else
uint64_t prod = 1;
uint64_t prod_log2 = 0;
bool is_zero = false;
for (; first != last; ++first) {
auto x = static_cast<uint64_t>(*first);
prod *= x;
// log2(0) isn't valid, so need to track it specially
is_zero |= (x == 0);
prod_log2 += c10::llvm::Log2_64_Ceil(x);
}
*out = prod;
// This test isnt exact, but avoids doing integer division
return !is_zero && (prod_log2 >= 64);
#endif
}
template <typename Container>
bool safe_multiplies_u64(const Container& c, uint64_t* out) {
return safe_multiplies_u64(c.begin(), c.end(), out);
}
} // namespace c10

View File

@ -0,0 +1,110 @@
#pragma once
#include <atomic>
#include <condition_variable>
#include <csignal>
#include <cstdint>
#include <mutex>
#include <c10/macros/Export.h>
#if defined(__APPLE__)
#define C10_SUPPORTS_SIGNAL_HANDLER
#elif defined(__linux__) && !defined(C10_DISABLE_SIGNAL_HANDLERS)
#define C10_SUPPORTS_FATAL_SIGNAL_HANDLERS
#define C10_SUPPORTS_SIGNAL_HANDLER
#endif
#if defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS)
#include <pthread.h>
#endif
namespace c10 {
class C10_API SignalHandler {
public:
enum class Action { NONE, STOP };
// Constructor. Specify what action to take when a signal is received.
SignalHandler(Action SIGINT_action, Action SIGHUP_action);
~SignalHandler();
Action CheckForSignals();
bool GotSIGINT();
bool GotSIGHUP();
Action SIGINT_action_;
Action SIGHUP_action_;
std::atomic<uint64_t> my_sigint_count_;
std::atomic<uint64_t> my_sighup_count_;
};
#if defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS)
class C10_API FatalSignalHandler {
// This works by setting up certain fatal signal handlers. Previous fatal
// signal handlers will still be called when the signal is raised. Defaults
// to being off.
public:
C10_API void setPrintStackTracesOnFatalSignal(bool print);
C10_API bool printStackTracesOnFatalSignal();
static FatalSignalHandler& getInstance();
virtual ~FatalSignalHandler();
protected:
explicit FatalSignalHandler();
private:
void installFatalSignalHandlers();
void uninstallFatalSignalHandlers();
static void fatalSignalHandlerStatic(int signum);
void fatalSignalHandler(int signum);
virtual void fatalSignalHandlerPostProcess();
struct sigaction* getPreviousSigaction(int signum);
const char* getSignalName(int signum);
void callPreviousSignalHandler(
struct sigaction* action,
int signum,
siginfo_t* info,
void* ctx);
void stacktraceSignalHandler(bool needsLock);
static void stacktraceSignalHandlerStatic(
int signum,
siginfo_t* info,
void* ctx);
void stacktraceSignalHandler(int signum, siginfo_t* info, void* ctx);
// The mutex protects the bool.
std::mutex fatalSignalHandlersInstallationMutex;
bool fatalSignalHandlersInstalled;
// We need to hold a reference to call the previous SIGUSR2 handler in case
// we didn't signal it
struct sigaction previousSigusr2 {};
// Flag dictating whether the SIGUSR2 handler falls back to previous handlers
// or is intercepted in order to print a stack trace.
std::atomic<bool> fatalSignalReceived;
// Global state set when a fatal signal is received so that backtracing
// threads know why they're printing a stacktrace.
const char* fatalSignalName;
int fatalSignum = -1;
// This wait condition is used to wait for other threads to finish writing
// their stack trace when in fatal sig handler (we can't use pthread_join
// because there's no way to convert from a tid to a pthread_t).
std::condition_variable writingCond;
std::mutex writingMutex;
// used to indicate if the other thread responded to the signal
bool signalReceived;
struct signal_handler {
const char* name;
int signum;
struct sigaction previous;
};
// NOLINTNEXTLINE(*c-arrays*)
static signal_handler kSignalHandlers[];
};
#endif // defined(C10_SUPPORTS_SIGNAL_HANDLER)
} // namespace c10

View File

@ -0,0 +1,892 @@
//===- llvm/ADT/SparseBitVector.h - Efficient Sparse BitVector --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the SparseBitVector class. See the doxygen comment for
// SparseBitVector for more details on the algorithm used.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/llvmMathExtras.h>
#include <array>
#include <cassert>
#include <climits>
#include <iterator>
#include <list>
#include <ostream>
namespace c10 {
/// SparseBitVector is an implementation of a bitvector that is sparse by only
/// storing the elements that have non-zero bits set. In order to make this
/// fast for the most common cases, SparseBitVector is implemented as a linked
/// list of SparseBitVectorElements. We maintain a pointer to the last
/// SparseBitVectorElement accessed (in the form of a list iterator), in order
/// to make multiple in-order test/set constant time after the first one is
/// executed. Note that using vectors to store SparseBitVectorElement's does
/// not work out very well because it causes insertion in the middle to take
/// enormous amounts of time with a large amount of bits. Other structures that
/// have better worst cases for insertion in the middle (various balanced trees,
/// etc) do not perform as well in practice as a linked list with this iterator
/// kept up to date. They are also significantly more memory intensive.
template <unsigned ElementSize = 128>
struct SparseBitVectorElement {
public:
using BitWord = unsigned long;
using size_type = unsigned;
enum {
BITWORD_SIZE = sizeof(BitWord) * CHAR_BIT,
BITWORDS_PER_ELEMENT = (ElementSize + BITWORD_SIZE - 1) / BITWORD_SIZE,
BITS_PER_ELEMENT = ElementSize
};
private:
// Index of Element in terms of where first bit starts.
unsigned ElementIndex;
std::array<BitWord, BITWORDS_PER_ELEMENT> Bits{};
SparseBitVectorElement() : ElementIndex(~0U) {}
public:
explicit SparseBitVectorElement(unsigned Idx) : ElementIndex(Idx) {}
// Comparison.
bool operator==(const SparseBitVectorElement& RHS) const {
if (ElementIndex != RHS.ElementIndex)
return false;
for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
if (Bits[i] != RHS.Bits[i])
return false;
return true;
}
bool operator!=(const SparseBitVectorElement& RHS) const {
return !(*this == RHS);
}
// Return the bits that make up word Idx in our element.
BitWord word(unsigned Idx) const {
assert(Idx < BITWORDS_PER_ELEMENT);
return Bits[Idx];
}
unsigned index() const {
return ElementIndex;
}
bool empty() const {
for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
if (Bits[i])
return false;
return true;
}
void set(unsigned Idx) {
Bits[Idx / BITWORD_SIZE] |= 1L << (Idx % BITWORD_SIZE);
}
bool test_and_set(unsigned Idx) {
bool old = test(Idx);
if (!old) {
set(Idx);
return true;
}
return false;
}
void reset(unsigned Idx) {
Bits[Idx / BITWORD_SIZE] &= ~(1L << (Idx % BITWORD_SIZE));
}
bool test(unsigned Idx) const {
return Bits[Idx / BITWORD_SIZE] & (1L << (Idx % BITWORD_SIZE));
}
size_type count() const {
unsigned NumBits = 0;
for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
NumBits += llvm::countPopulation(Bits[i]);
return NumBits;
}
/// find_first - Returns the index of the first set bit.
int find_first() const {
for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
if (Bits[i] != 0)
return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]);
throw std::runtime_error("Illegal empty element");
}
/// find_last - Returns the index of the last set bit.
int find_last() const {
for (unsigned I = 0; I < BITWORDS_PER_ELEMENT; ++I) {
unsigned Idx = BITWORDS_PER_ELEMENT - I - 1;
if (Bits[Idx] != 0)
return Idx * BITWORD_SIZE + BITWORD_SIZE -
llvm::countLeadingZeros(Bits[Idx]);
}
throw std::runtime_error("Illegal empty element");
}
/// find_next - Returns the index of the next set bit starting from the
/// "Curr" bit. Returns -1 if the next set bit is not found.
int find_next(unsigned Curr) const {
if (Curr >= BITS_PER_ELEMENT)
return -1;
unsigned WordPos = Curr / BITWORD_SIZE;
unsigned BitPos = Curr % BITWORD_SIZE;
BitWord Copy = Bits[WordPos];
assert(
WordPos <= BITWORDS_PER_ELEMENT && "Word Position outside of element");
// Mask off previous bits.
Copy &= ~0UL << BitPos;
if (Copy != 0)
return WordPos * BITWORD_SIZE + llvm::countTrailingZeros(Copy);
// Check subsequent words.
for (unsigned i = WordPos + 1; i < BITWORDS_PER_ELEMENT; ++i)
if (Bits[i] != 0)
return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]);
return -1;
}
// Union this element with RHS and return true if this one changed.
bool unionWith(const SparseBitVectorElement& RHS) {
bool changed = false;
for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
BitWord old = changed ? 0 : Bits[i];
Bits[i] |= RHS.Bits[i];
if (!changed && old != Bits[i])
changed = true;
}
return changed;
}
// Return true if we have any bits in common with RHS
bool intersects(const SparseBitVectorElement& RHS) const {
for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
if (RHS.Bits[i] & Bits[i])
return true;
}
return false;
}
// Intersect this Element with RHS and return true if this one changed.
// BecameZero is set to true if this element became all-zero bits.
bool intersectWith(const SparseBitVectorElement& RHS, bool& BecameZero) {
bool changed = false;
bool allzero = true;
for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
BitWord old = changed ? 0 : Bits[i];
Bits[i] &= RHS.Bits[i];
if (Bits[i] != 0)
allzero = false;
if (!changed && old != Bits[i])
changed = true;
}
BecameZero = allzero;
return changed;
}
// Intersect this Element with the complement of RHS and return true if this
// one changed. BecameZero is set to true if this element became all-zero
// bits.
bool intersectWithComplement(
const SparseBitVectorElement& RHS,
bool& BecameZero) {
bool changed = false;
bool allzero = true;
for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
BitWord old = changed ? 0 : Bits[i];
Bits[i] &= ~RHS.Bits[i];
if (Bits[i] != 0)
allzero = false;
if (!changed && old != Bits[i])
changed = true;
}
BecameZero = allzero;
return changed;
}
// Three argument version of intersectWithComplement that intersects
// RHS1 & ~RHS2 into this element
void intersectWithComplement(
const SparseBitVectorElement& RHS1,
const SparseBitVectorElement& RHS2,
bool& BecameZero) {
bool allzero = true;
for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
Bits[i] = RHS1.Bits[i] & ~RHS2.Bits[i];
if (Bits[i] != 0)
allzero = false;
}
BecameZero = allzero;
}
};
template <unsigned ElementSize = 128>
class SparseBitVector {
using ElementList = std::list<SparseBitVectorElement<ElementSize>>;
using ElementListIter = typename ElementList::iterator;
using ElementListConstIter = typename ElementList::const_iterator;
enum { BITWORD_SIZE = SparseBitVectorElement<ElementSize>::BITWORD_SIZE };
ElementList Elements;
// Pointer to our current Element. This has no visible effect on the external
// state of a SparseBitVector, it's just used to improve performance in the
// common case of testing/modifying bits with similar indices.
mutable ElementListIter CurrElementIter;
// This is like std::lower_bound, except we do linear searching from the
// current position.
ElementListIter FindLowerBoundImpl(unsigned ElementIndex) const {
// We cache a non-const iterator so we're forced to resort to const_cast to
// get the begin/end in the case where 'this' is const. To avoid duplication
// of code with the only difference being whether the const cast is present
// 'this' is always const in this particular function and we sort out the
// difference in FindLowerBound and FindLowerBoundConst.
ElementListIter Begin =
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<SparseBitVector<ElementSize>*>(this)->Elements.begin();
ElementListIter End =
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<SparseBitVector<ElementSize>*>(this)->Elements.end();
if (Elements.empty()) {
CurrElementIter = Begin;
return CurrElementIter;
}
// Make sure our current iterator is valid.
if (CurrElementIter == End)
--CurrElementIter;
// Search from our current iterator, either backwards or forwards,
// depending on what element we are looking for.
ElementListIter ElementIter = CurrElementIter;
if (CurrElementIter->index() == ElementIndex) {
return ElementIter;
} else if (CurrElementIter->index() > ElementIndex) {
while (ElementIter != Begin && ElementIter->index() > ElementIndex)
--ElementIter;
} else {
while (ElementIter != End && ElementIter->index() < ElementIndex)
++ElementIter;
}
CurrElementIter = ElementIter;
return ElementIter;
}
ElementListConstIter FindLowerBoundConst(unsigned ElementIndex) const {
return FindLowerBoundImpl(ElementIndex);
}
ElementListIter FindLowerBound(unsigned ElementIndex) {
return FindLowerBoundImpl(ElementIndex);
}
// Iterator to walk set bits in the bitmap. This iterator is a lot uglier
// than it would be, in order to be efficient.
class SparseBitVectorIterator {
private:
bool AtEnd{false};
const SparseBitVector<ElementSize>* BitVector = nullptr;
// Current element inside of bitmap.
ElementListConstIter Iter;
// Current bit number inside of our bitmap.
unsigned BitNumber{0};
// Current word number inside of our element.
unsigned WordNumber{0};
// Current bits from the element.
typename SparseBitVectorElement<ElementSize>::BitWord Bits{0};
// Move our iterator to the first non-zero bit in the bitmap.
void AdvanceToFirstNonZero() {
if (AtEnd)
return;
if (BitVector->Elements.empty()) {
AtEnd = true;
return;
}
Iter = BitVector->Elements.begin();
BitNumber = Iter->index() * ElementSize;
unsigned BitPos = Iter->find_first();
BitNumber += BitPos;
WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE;
Bits = Iter->word(WordNumber);
Bits >>= BitPos % BITWORD_SIZE;
}
// Move our iterator to the next non-zero bit.
void AdvanceToNextNonZero() {
if (AtEnd)
return;
while (Bits && !(Bits & 1)) {
Bits >>= 1;
BitNumber += 1;
}
// See if we ran out of Bits in this word.
if (!Bits) {
int NextSetBitNumber = Iter->find_next(BitNumber % ElementSize);
// If we ran out of set bits in this element, move to next element.
if (NextSetBitNumber == -1 || (BitNumber % ElementSize == 0)) {
++Iter;
WordNumber = 0;
// We may run out of elements in the bitmap.
if (Iter == BitVector->Elements.end()) {
AtEnd = true;
return;
}
// Set up for next non-zero word in bitmap.
BitNumber = Iter->index() * ElementSize;
NextSetBitNumber = Iter->find_first();
BitNumber += NextSetBitNumber;
WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE;
Bits = Iter->word(WordNumber);
Bits >>= NextSetBitNumber % BITWORD_SIZE;
} else {
WordNumber = (NextSetBitNumber % ElementSize) / BITWORD_SIZE;
Bits = Iter->word(WordNumber);
Bits >>= NextSetBitNumber % BITWORD_SIZE;
BitNumber = Iter->index() * ElementSize;
BitNumber += NextSetBitNumber;
}
}
}
public:
SparseBitVectorIterator() = default;
SparseBitVectorIterator(
const SparseBitVector<ElementSize>* RHS,
bool end = false)
: AtEnd(end),
BitVector(RHS),
Iter(BitVector->Elements.begin()),
WordNumber(~0) {
AdvanceToFirstNonZero();
}
// Preincrement.
inline SparseBitVectorIterator& operator++() {
++BitNumber;
Bits >>= 1;
AdvanceToNextNonZero();
return *this;
}
// Postincrement.
inline SparseBitVectorIterator operator++(int) {
SparseBitVectorIterator tmp = *this;
++*this;
return tmp;
}
// Return the current set bit number.
unsigned operator*() const {
return BitNumber;
}
bool operator==(const SparseBitVectorIterator& RHS) const {
// If they are both at the end, ignore the rest of the fields.
if (AtEnd && RHS.AtEnd)
return true;
// Otherwise they are the same if they have the same bit number and
// bitmap.
return AtEnd == RHS.AtEnd && RHS.BitNumber == BitNumber;
}
bool operator!=(const SparseBitVectorIterator& RHS) const {
return !(*this == RHS);
}
};
public:
using iterator = SparseBitVectorIterator;
SparseBitVector() : Elements(), CurrElementIter(Elements.begin()) {}
SparseBitVector(const SparseBitVector& RHS)
: Elements(RHS.Elements), CurrElementIter(Elements.begin()) {}
SparseBitVector(SparseBitVector&& RHS) noexcept
: Elements(std::move(RHS.Elements)), CurrElementIter(Elements.begin()) {}
// Clear.
void clear() {
Elements.clear();
}
// Assignment
SparseBitVector& operator=(const SparseBitVector& RHS) {
if (this == &RHS)
return *this;
Elements = RHS.Elements;
CurrElementIter = Elements.begin();
return *this;
}
SparseBitVector& operator=(SparseBitVector&& RHS) noexcept {
Elements = std::move(RHS.Elements);
CurrElementIter = Elements.begin();
return *this;
}
// Test, Reset, and Set a bit in the bitmap.
bool test(unsigned Idx) const {
if (Elements.empty())
return false;
unsigned ElementIndex = Idx / ElementSize;
ElementListConstIter ElementIter = FindLowerBoundConst(ElementIndex);
// If we can't find an element that is supposed to contain this bit, there
// is nothing more to do.
if (ElementIter == Elements.end() || ElementIter->index() != ElementIndex)
return false;
return ElementIter->test(Idx % ElementSize);
}
void reset(unsigned Idx) {
if (Elements.empty())
return;
unsigned ElementIndex = Idx / ElementSize;
ElementListIter ElementIter = FindLowerBound(ElementIndex);
// If we can't find an element that is supposed to contain this bit, there
// is nothing more to do.
if (ElementIter == Elements.end() || ElementIter->index() != ElementIndex)
return;
ElementIter->reset(Idx % ElementSize);
// When the element is zeroed out, delete it.
if (ElementIter->empty()) {
++CurrElementIter;
Elements.erase(ElementIter);
}
}
void set(unsigned Idx) {
unsigned ElementIndex = Idx / ElementSize;
ElementListIter ElementIter;
if (Elements.empty()) {
ElementIter = Elements.emplace(Elements.end(), ElementIndex);
} else {
ElementIter = FindLowerBound(ElementIndex);
if (ElementIter == Elements.end() ||
ElementIter->index() != ElementIndex) {
// We may have hit the beginning of our SparseBitVector, in which case,
// we may need to insert right after this element, which requires moving
// the current iterator forward one, because insert does insert before.
if (ElementIter != Elements.end() &&
ElementIter->index() < ElementIndex)
++ElementIter;
ElementIter = Elements.emplace(ElementIter, ElementIndex);
}
}
CurrElementIter = ElementIter;
ElementIter->set(Idx % ElementSize);
}
bool test_and_set(unsigned Idx) {
bool old = test(Idx);
if (!old) {
set(Idx);
return true;
}
return false;
}
bool operator!=(const SparseBitVector& RHS) const {
return !(*this == RHS);
}
bool operator==(const SparseBitVector& RHS) const {
ElementListConstIter Iter1 = Elements.begin();
ElementListConstIter Iter2 = RHS.Elements.begin();
for (; Iter1 != Elements.end() && Iter2 != RHS.Elements.end();
++Iter1, ++Iter2) {
if (*Iter1 != *Iter2)
return false;
}
return Iter1 == Elements.end() && Iter2 == RHS.Elements.end();
}
// Union our bitmap with the RHS and return true if we changed.
bool operator|=(const SparseBitVector& RHS) {
if (this == &RHS)
return false;
if (empty()) {
*this = RHS;
return true;
}
bool changed = false;
ElementListIter Iter1 = Elements.begin();
ElementListConstIter Iter2 = RHS.Elements.begin();
// If RHS is empty, we are done
if (RHS.Elements.empty())
return false;
while (Iter2 != RHS.Elements.end()) {
if (Iter1 == Elements.end() || Iter1->index() > Iter2->index()) {
Elements.insert(Iter1, *Iter2);
++Iter2;
changed = true;
} else if (Iter1->index() == Iter2->index()) {
changed |= Iter1->unionWith(*Iter2);
++Iter1;
++Iter2;
} else {
++Iter1;
}
}
CurrElementIter = Elements.begin();
return changed;
}
// Intersect our bitmap with the RHS and return true if ours changed.
bool operator-=(const SparseBitVector& RHS) {
return intersectWithComplement(RHS);
}
// Intersect our bitmap with the RHS and return true if ours changed.
bool operator&=(const SparseBitVector& RHS) {
if (this == &RHS)
return false;
bool changed = false;
ElementListIter Iter1 = Elements.begin();
ElementListConstIter Iter2 = RHS.Elements.begin();
// Check if both bitmaps are empty.
if (Elements.empty() && RHS.Elements.empty())
return false;
// Loop through, intersecting as we go, erasing elements when necessary.
while (Iter2 != RHS.Elements.end()) {
if (Iter1 == Elements.end()) {
CurrElementIter = Elements.begin();
return changed;
}
if (Iter1->index() > Iter2->index()) {
++Iter2;
} else if (Iter1->index() == Iter2->index()) {
bool BecameZero = false;
changed |= Iter1->intersectWith(*Iter2, BecameZero);
if (BecameZero) {
ElementListIter IterTmp = Iter1;
++Iter1;
Elements.erase(IterTmp);
} else {
++Iter1;
}
++Iter2;
} else {
ElementListIter IterTmp = Iter1;
++Iter1;
Elements.erase(IterTmp);
changed = true;
}
}
if (Iter1 != Elements.end()) {
Elements.erase(Iter1, Elements.end());
changed = true;
}
CurrElementIter = Elements.begin();
return changed;
}
// Intersect our bitmap with the complement of the RHS and return true
// if ours changed.
bool intersectWithComplement(const SparseBitVector& RHS) {
if (this == &RHS) {
if (!empty()) {
clear();
return true;
}
return false;
}
bool changed = false;
ElementListIter Iter1 = Elements.begin();
ElementListConstIter Iter2 = RHS.Elements.begin();
// If either our bitmap or RHS is empty, we are done
if (Elements.empty() || RHS.Elements.empty())
return false;
// Loop through, intersecting as we go, erasing elements when necessary.
while (Iter2 != RHS.Elements.end()) {
if (Iter1 == Elements.end()) {
CurrElementIter = Elements.begin();
return changed;
}
if (Iter1->index() > Iter2->index()) {
++Iter2;
} else if (Iter1->index() == Iter2->index()) {
bool BecameZero = false;
changed |= Iter1->intersectWithComplement(*Iter2, BecameZero);
if (BecameZero) {
ElementListIter IterTmp = Iter1;
++Iter1;
Elements.erase(IterTmp);
} else {
++Iter1;
}
++Iter2;
} else {
++Iter1;
}
}
CurrElementIter = Elements.begin();
return changed;
}
bool intersectWithComplement(const SparseBitVector<ElementSize>* RHS) const {
return intersectWithComplement(*RHS);
}
// Three argument version of intersectWithComplement.
// Result of RHS1 & ~RHS2 is stored into this bitmap.
void intersectWithComplement(
const SparseBitVector<ElementSize>& RHS1,
const SparseBitVector<ElementSize>& RHS2) {
if (this == &RHS1) {
intersectWithComplement(RHS2);
return;
} else if (this == &RHS2) {
SparseBitVector RHS2Copy(RHS2);
intersectWithComplement(RHS1, RHS2Copy);
return;
}
Elements.clear();
CurrElementIter = Elements.begin();
ElementListConstIter Iter1 = RHS1.Elements.begin();
ElementListConstIter Iter2 = RHS2.Elements.begin();
// If RHS1 is empty, we are done
// If RHS2 is empty, we still have to copy RHS1
if (RHS1.Elements.empty())
return;
// Loop through, intersecting as we go, erasing elements when necessary.
while (Iter2 != RHS2.Elements.end()) {
if (Iter1 == RHS1.Elements.end())
return;
if (Iter1->index() > Iter2->index()) {
++Iter2;
} else if (Iter1->index() == Iter2->index()) {
bool BecameZero = false;
Elements.emplace_back(Iter1->index());
Elements.back().intersectWithComplement(*Iter1, *Iter2, BecameZero);
if (BecameZero)
Elements.pop_back();
++Iter1;
++Iter2;
} else {
Elements.push_back(*Iter1++);
}
}
// copy the remaining elements
std::copy(Iter1, RHS1.Elements.end(), std::back_inserter(Elements));
}
void intersectWithComplement(
const SparseBitVector<ElementSize>* RHS1,
const SparseBitVector<ElementSize>* RHS2) {
intersectWithComplement(*RHS1, *RHS2);
}
bool intersects(const SparseBitVector<ElementSize>* RHS) const {
return intersects(*RHS);
}
// Return true if we share any bits in common with RHS
bool intersects(const SparseBitVector<ElementSize>& RHS) const {
ElementListConstIter Iter1 = Elements.begin();
ElementListConstIter Iter2 = RHS.Elements.begin();
// Check if both bitmaps are empty.
if (Elements.empty() && RHS.Elements.empty())
return false;
// Loop through, intersecting stopping when we hit bits in common.
while (Iter2 != RHS.Elements.end()) {
if (Iter1 == Elements.end())
return false;
if (Iter1->index() > Iter2->index()) {
++Iter2;
} else if (Iter1->index() == Iter2->index()) {
if (Iter1->intersects(*Iter2))
return true;
++Iter1;
++Iter2;
} else {
++Iter1;
}
}
return false;
}
// Return true iff all bits set in this SparseBitVector are
// also set in RHS.
bool contains(const SparseBitVector<ElementSize>& RHS) const {
SparseBitVector<ElementSize> Result(*this);
Result &= RHS;
return (Result == RHS);
}
// Return the first set bit in the bitmap. Return -1 if no bits are set.
int find_first() const {
if (Elements.empty())
return -1;
const SparseBitVectorElement<ElementSize>& First = *(Elements.begin());
return (First.index() * ElementSize) + First.find_first();
}
// Return the last set bit in the bitmap. Return -1 if no bits are set.
int find_last() const {
if (Elements.empty())
return -1;
const SparseBitVectorElement<ElementSize>& Last = *(Elements.rbegin());
return (Last.index() * ElementSize) + Last.find_last();
}
// Return true if the SparseBitVector is empty
bool empty() const {
return Elements.empty();
}
unsigned count() const {
unsigned BitCount = 0;
for (ElementListConstIter Iter = Elements.begin(); Iter != Elements.end();
++Iter)
BitCount += Iter->count();
return BitCount;
}
iterator begin() const {
return iterator(this);
}
iterator end() const {
return iterator(this, true);
}
};
// Convenience functions to allow Or and And without dereferencing in the user
// code.
template <unsigned ElementSize>
inline bool operator|=(
SparseBitVector<ElementSize>& LHS,
const SparseBitVector<ElementSize>* RHS) {
return LHS |= *RHS;
}
template <unsigned ElementSize>
inline bool operator|=(
SparseBitVector<ElementSize>* LHS,
const SparseBitVector<ElementSize>& RHS) {
return LHS->operator|=(RHS);
}
template <unsigned ElementSize>
inline bool operator&=(
SparseBitVector<ElementSize>* LHS,
const SparseBitVector<ElementSize>& RHS) {
return LHS->operator&=(RHS);
}
template <unsigned ElementSize>
inline bool operator&=(
SparseBitVector<ElementSize>& LHS,
const SparseBitVector<ElementSize>* RHS) {
return LHS &= *RHS;
}
// Convenience functions for infix union, intersection, difference operators.
template <unsigned ElementSize>
inline SparseBitVector<ElementSize> operator|(
const SparseBitVector<ElementSize>& LHS,
const SparseBitVector<ElementSize>& RHS) {
SparseBitVector<ElementSize> Result(LHS);
Result |= RHS;
return Result;
}
template <unsigned ElementSize>
inline SparseBitVector<ElementSize> operator&(
const SparseBitVector<ElementSize>& LHS,
const SparseBitVector<ElementSize>& RHS) {
SparseBitVector<ElementSize> Result(LHS);
Result &= RHS;
return Result;
}
template <unsigned ElementSize>
inline SparseBitVector<ElementSize> operator-(
const SparseBitVector<ElementSize>& LHS,
const SparseBitVector<ElementSize>& RHS) {
SparseBitVector<ElementSize> Result;
Result.intersectWithComplement(LHS, RHS);
return Result;
}
template <unsigned ElementSize>
std::ostream& operator<<(
std::ostream& stream,
const SparseBitVector<ElementSize>& vec) {
bool first = true;
stream << "{";
for (auto el : vec) {
if (first) {
first = false;
} else {
stream << ", ";
}
stream << el;
}
stream << "}";
return stream;
}
} // end namespace c10

View File

@ -0,0 +1,46 @@
#pragma once
#include <c10/util/Exception.h>
#include <c10/util/TypeSafeSignMath.h>
#include <cstddef>
#include <type_traits>
namespace c10 {
// Implementations of std::ssize() from C++ 20.
//
// This is useful in particular for avoiding -Werror=sign-compare
// issues.
//
// Use this with argument-dependent lookup, e.g.:
// use c10::ssize;
// auto size = ssize(container);
//
// As with the standard library version, containers are permitted to
// specialize this with a free function defined in the same namespace.
//
// See https://en.cppreference.com/w/cpp/iterator/size for more
// information as well as the source of our implementations.
//
// We augment the implementation by adding an assert() if an overflow
// would occur.
template <typename C>
constexpr auto ssize(const C& c) -> std::
common_type_t<std::ptrdiff_t, std::make_signed_t<decltype(c.size())>> {
using R = std::
common_type_t<std::ptrdiff_t, std::make_signed_t<decltype(c.size())>>;
// We expect this to be exceedingly rare to fire and don't wish to
// pay a performance hit in release mode.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!greater_than_max<R>(c.size()));
return static_cast<R>(c.size());
}
template <typename T, std::ptrdiff_t N>
// NOLINTNEXTLINE(*-c-arrays)
constexpr auto ssize(const T (&array)[N]) noexcept -> std::ptrdiff_t {
return N;
}
} // namespace c10

View File

@ -0,0 +1,34 @@
#pragma once
#if defined(__ELF__) && (defined(__x86_64__) || defined(__i386__)) && \
!(defined(TORCH_DISABLE_SDT) && TORCH_DISABLE_SDT)
#define TORCH_HAVE_SDT 1
#include <c10/util/static_tracepoint_elfx86.h>
#define TORCH_SDT(name, ...) \
TORCH_SDT_PROBE_N( \
pytorch, name, 0, TORCH_SDT_NARG(0, ##__VA_ARGS__), ##__VA_ARGS__)
// Use TORCH_SDT_DEFINE_SEMAPHORE(name) to define the semaphore
// as global variable before using the TORCH_SDT_WITH_SEMAPHORE macro
#define TORCH_SDT_WITH_SEMAPHORE(name, ...) \
TORCH_SDT_PROBE_N( \
pytorch, name, 1, TORCH_SDT_NARG(0, ##__VA_ARGS__), ##__VA_ARGS__)
#define TORCH_SDT_IS_ENABLED(name) (TORCH_SDT_SEMAPHORE(pytorch, name) > 0)
#else
#define TORCH_HAVE_SDT 0
#define TORCH_SDT(name, ...) \
do { \
} while (0)
#define TORCH_SDT_WITH_SEMAPHORE(name, ...) \
do { \
} while (0)
#define TORCH_SDT_IS_ENABLED(name) (false)
#define TORCH_SDT_DEFINE_SEMAPHORE(name)
#define TORCH_SDT_DECLARE_SEMAPHORE(name)
#endif

View File

@ -0,0 +1,132 @@
#pragma once
// clang-format off
// Default constraint for the probe arguments as operands.
#ifndef TORCH_SDT_ARG_CONSTRAINT
#define TORCH_SDT_ARG_CONSTRAINT "nor"
#endif
// Instruction to emit for the probe.
#define TORCH_SDT_NOP nop
// Note section properties.
#define TORCH_SDT_NOTE_NAME "stapsdt"
#define TORCH_SDT_NOTE_TYPE 3
// Semaphore variables are put in this section
#define TORCH_SDT_SEMAPHORE_SECTION ".probes"
// Size of address depending on platform.
#ifdef __LP64__
#define TORCH_SDT_ASM_ADDR .8byte
#else
#define TORCH_SDT_ASM_ADDR .4byte
#endif
// Assembler helper Macros.
#define TORCH_SDT_S(x) #x
#define TORCH_SDT_ASM_1(x) TORCH_SDT_S(x) "\n"
#define TORCH_SDT_ASM_2(a, b) TORCH_SDT_S(a) "," TORCH_SDT_S(b) "\n"
#define TORCH_SDT_ASM_3(a, b, c) TORCH_SDT_S(a) "," TORCH_SDT_S(b) "," \
TORCH_SDT_S(c) "\n"
#define TORCH_SDT_ASM_STRING(x) TORCH_SDT_ASM_1(.asciz TORCH_SDT_S(x))
// Helper to determine the size of an argument.
#define TORCH_SDT_IS_ARRAY_POINTER(x) ((__builtin_classify_type(x) == 14) || \
(__builtin_classify_type(x) == 5))
#define TORCH_SDT_ARGSIZE(x) (TORCH_SDT_IS_ARRAY_POINTER(x) \
? sizeof(void*) \
: sizeof(x))
// Format of each probe arguments as operand.
// Size of the argument tagged with TORCH_SDT_Sn, with "n" constraint.
// Value of the argument tagged with TORCH_SDT_An, with configured constraint.
#define TORCH_SDT_ARG(n, x) \
[TORCH_SDT_S##n] "n" ((size_t)TORCH_SDT_ARGSIZE(x)), \
[TORCH_SDT_A##n] TORCH_SDT_ARG_CONSTRAINT (x)
// Templates to append arguments as operands.
#define TORCH_SDT_OPERANDS_0() [__sdt_dummy] "g" (0)
#define TORCH_SDT_OPERANDS_1(_1) TORCH_SDT_ARG(1, _1)
#define TORCH_SDT_OPERANDS_2(_1, _2) \
TORCH_SDT_OPERANDS_1(_1), TORCH_SDT_ARG(2, _2)
#define TORCH_SDT_OPERANDS_3(_1, _2, _3) \
TORCH_SDT_OPERANDS_2(_1, _2), TORCH_SDT_ARG(3, _3)
#define TORCH_SDT_OPERANDS_4(_1, _2, _3, _4) \
TORCH_SDT_OPERANDS_3(_1, _2, _3), TORCH_SDT_ARG(4, _4)
#define TORCH_SDT_OPERANDS_5(_1, _2, _3, _4, _5) \
TORCH_SDT_OPERANDS_4(_1, _2, _3, _4), TORCH_SDT_ARG(5, _5)
#define TORCH_SDT_OPERANDS_6(_1, _2, _3, _4, _5, _6) \
TORCH_SDT_OPERANDS_5(_1, _2, _3, _4, _5), TORCH_SDT_ARG(6, _6)
#define TORCH_SDT_OPERANDS_7(_1, _2, _3, _4, _5, _6, _7) \
TORCH_SDT_OPERANDS_6(_1, _2, _3, _4, _5, _6), TORCH_SDT_ARG(7, _7)
#define TORCH_SDT_OPERANDS_8(_1, _2, _3, _4, _5, _6, _7, _8) \
TORCH_SDT_OPERANDS_7(_1, _2, _3, _4, _5, _6, _7), TORCH_SDT_ARG(8, _8)
#define TORCH_SDT_OPERANDS_9(_1, _2, _3, _4, _5, _6, _7, _8, _9) \
TORCH_SDT_OPERANDS_8(_1, _2, _3, _4, _5, _6, _7, _8), TORCH_SDT_ARG(9, _9)
// Templates to reference the arguments from operands in note section.
#define TORCH_SDT_ARGFMT(no) %n[TORCH_SDT_S##no]@%[TORCH_SDT_A##no]
#define TORCH_SDT_ARG_TEMPLATE_0 /*No arguments*/
#define TORCH_SDT_ARG_TEMPLATE_1 TORCH_SDT_ARGFMT(1)
#define TORCH_SDT_ARG_TEMPLATE_2 TORCH_SDT_ARG_TEMPLATE_1 TORCH_SDT_ARGFMT(2)
#define TORCH_SDT_ARG_TEMPLATE_3 TORCH_SDT_ARG_TEMPLATE_2 TORCH_SDT_ARGFMT(3)
#define TORCH_SDT_ARG_TEMPLATE_4 TORCH_SDT_ARG_TEMPLATE_3 TORCH_SDT_ARGFMT(4)
#define TORCH_SDT_ARG_TEMPLATE_5 TORCH_SDT_ARG_TEMPLATE_4 TORCH_SDT_ARGFMT(5)
#define TORCH_SDT_ARG_TEMPLATE_6 TORCH_SDT_ARG_TEMPLATE_5 TORCH_SDT_ARGFMT(6)
#define TORCH_SDT_ARG_TEMPLATE_7 TORCH_SDT_ARG_TEMPLATE_6 TORCH_SDT_ARGFMT(7)
#define TORCH_SDT_ARG_TEMPLATE_8 TORCH_SDT_ARG_TEMPLATE_7 TORCH_SDT_ARGFMT(8)
#define TORCH_SDT_ARG_TEMPLATE_9 TORCH_SDT_ARG_TEMPLATE_8 TORCH_SDT_ARGFMT(9)
// Semaphore define, declare and probe note format
#define TORCH_SDT_SEMAPHORE(provider, name) \
torch_sdt_semaphore_##provider##_##name
#define TORCH_SDT_DEFINE_SEMAPHORE(name) \
extern "C" { \
volatile unsigned short TORCH_SDT_SEMAPHORE(pytorch, name) \
__attribute__((section(TORCH_SDT_SEMAPHORE_SECTION), used)) = 0; \
}
#define TORCH_SDT_DECLARE_SEMAPHORE(name) \
extern "C" volatile unsigned short TORCH_SDT_SEMAPHORE(pytorch, name)
#define TORCH_SDT_SEMAPHORE_NOTE_0(provider, name) \
TORCH_SDT_ASM_1( TORCH_SDT_ASM_ADDR 0) /*No Semaphore*/ \
#define TORCH_SDT_SEMAPHORE_NOTE_1(provider, name) \
TORCH_SDT_ASM_1(TORCH_SDT_ASM_ADDR TORCH_SDT_SEMAPHORE(provider, name))
// Structure of note section for the probe.
#define TORCH_SDT_NOTE_CONTENT(provider, name, has_semaphore, arg_template) \
TORCH_SDT_ASM_1(990: TORCH_SDT_NOP) \
TORCH_SDT_ASM_3( .pushsection .note.stapsdt,"","note") \
TORCH_SDT_ASM_1( .balign 4) \
TORCH_SDT_ASM_3( .4byte 992f-991f, 994f-993f, TORCH_SDT_NOTE_TYPE) \
TORCH_SDT_ASM_1(991: .asciz TORCH_SDT_NOTE_NAME) \
TORCH_SDT_ASM_1(992: .balign 4) \
TORCH_SDT_ASM_1(993: TORCH_SDT_ASM_ADDR 990b) \
TORCH_SDT_ASM_1( TORCH_SDT_ASM_ADDR 0) /*Reserved for Base Address*/ \
TORCH_SDT_SEMAPHORE_NOTE_##has_semaphore(provider, name) \
TORCH_SDT_ASM_STRING(provider) \
TORCH_SDT_ASM_STRING(name) \
TORCH_SDT_ASM_STRING(arg_template) \
TORCH_SDT_ASM_1(994: .balign 4) \
TORCH_SDT_ASM_1( .popsection)
// Main probe Macro.
#define TORCH_SDT_PROBE(provider, name, has_semaphore, n, arglist) \
__asm__ __volatile__ ( \
TORCH_SDT_NOTE_CONTENT( \
provider, name, has_semaphore, TORCH_SDT_ARG_TEMPLATE_##n) \
:: TORCH_SDT_OPERANDS_##n arglist \
) \
// Helper Macros to handle variadic arguments.
#define TORCH_SDT_NARG_(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, N, ...) N
#define TORCH_SDT_NARG(...) \
TORCH_SDT_NARG_(__VA_ARGS__, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
#define TORCH_SDT_PROBE_N(provider, name, has_semaphore, N, ...) \
TORCH_SDT_PROBE(provider, name, has_semaphore, N, (__VA_ARGS__))

View File

@ -0,0 +1,24 @@
#pragma once
#include <c10/util/ArrayRef.h>
#include <c10/util/DimVector.h>
#include <algorithm>
namespace c10 {
// Computes the contiguous strides of a tensor, given its sizes.
inline DimVector contiguous_strides(const IntArrayRef sizes) {
using Int = IntArrayRef::value_type;
const Int dims = static_cast<Int>(sizes.size());
// With this initialisation we get the case dim == 0 or 1 right
DimVector strides(dims, 1);
for (auto i = dims - 2; i >= 0; --i) {
// Strides can't be 0 even if sizes are 0.
strides[i] = strides[i + 1] * std::max(sizes[i + 1], Int{1});
}
return strides;
}
} // namespace c10

View File

@ -0,0 +1,18 @@
#pragma once
#include <string>
namespace c10 {
// NOLINTNEXTLINE(misc-unused-using-decls)
using std::stod;
// NOLINTNEXTLINE(misc-unused-using-decls)
using std::stoi;
// NOLINTNEXTLINE(misc-unused-using-decls)
using std::stoll;
// NOLINTNEXTLINE(misc-unused-using-decls)
using std::stoull;
// NOLINTNEXTLINE(misc-unused-using-decls)
using std::to_string;
} // namespace c10

Some files were not shown because too many files have changed in this diff Show More