I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,319 @@
#pragma once
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <utility>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/ThreadLocalDebugInfo.h>
#include <c10/util/UniqueVoidPtr.h>
namespace c10 {
// A DataPtr is a unique pointer (with an attached deleter and some
// context for the deleter) to some memory, which also records what
// device is for its data.
//
// nullptr DataPtrs can still have a nontrivial device; this allows
// us to treat zero-size allocations uniformly with non-zero allocations.
//
class C10_API DataPtr {
private:
c10::detail::UniqueVoidPtr ptr_;
Device device_;
public:
// Choice of CPU here is arbitrary; if there's an "undefined" device
// we could use that too
DataPtr() : ptr_(), device_(DeviceType::CPU) {}
DataPtr(void* data, Device device) : ptr_(data), device_(device) {}
DataPtr(void* data, void* ctx, DeleterFnPtr ctx_deleter, Device device)
: ptr_(data, ctx, ctx_deleter), device_(device) {}
void* operator->() const {
return ptr_.get();
}
void clear() {
ptr_.clear();
}
void* get() const {
return ptr_.get();
}
void* mutable_get() {
return ptr_.get();
}
void* get_context() const {
return ptr_.get_context();
}
void* release_context() {
return ptr_.release_context();
}
std::unique_ptr<void, DeleterFnPtr>&& move_context() {
return ptr_.move_context();
}
operator bool() const {
return static_cast<bool>(ptr_);
}
template <typename T>
T* cast_context(DeleterFnPtr expected_deleter) const {
return ptr_.cast_context<T>(expected_deleter);
}
DeleterFnPtr get_deleter() const {
return ptr_.get_deleter();
}
/**
* Compare the deleter in a DataPtr to expected_deleter.
* If it matches, replace the deleter with new_deleter
* and return true; otherwise, does nothing and returns
* false.
*
* In general, it is not safe to unconditionally set the
* deleter on a DataPtr, because you don't know what
* the deleter is, and thus will have a hard time properly
* disposing of the deleter without storing the original
* deleter (this is difficult to do, because DeleterFnPtr
* is not a closure, and because the context on DataPtr is
* only a single word, you generally don't have enough
* space to store both the original deleter and its context).
* However, in some cases, you know /exactly/ what the deleter
* is, and you have a new deleter that manually wraps
* the old one. In this case, you can safely swap the deleter
* after asserting that the deleters line up.
*
* What are the requirements on new_deleter? It must still
* properly dispose of the void* pointer passed in as its argument,
* where void* is whatever the context of the original deleter
* is. So in general, you expect the new deleter to look something
* like this:
*
* [](void* ptr) {
* some_new_stuff(ptr);
* get_orig_allocator()->raw_deleter(ptr);
* }
*
* Note that it won't work to close over the original
* allocator; you don't have enough space to do that! Also,
* it's unsafe to assume that the passed in pointer in
* question is the memory pointer in question; it might not
* be; be sure to read the source code of the Allocator
* in question to confirm this.
*/
C10_NODISCARD bool compare_exchange_deleter(
DeleterFnPtr expected_deleter,
DeleterFnPtr new_deleter) {
return ptr_.compare_exchange_deleter(expected_deleter, new_deleter);
}
Device device() const {
return device_;
}
// Unsafely mutates the device on a DataPtr. Under normal use,
// you should never actually need to call this function.
// We need this for the implementation of the hack detailed
// in Note [Masquerading as CUDA]
void unsafe_set_device(Device device) {
device_ = device;
}
};
// NB: Device is NOT tested for here; a CUDA nullptr is as much a nullptr as a
// CPU nullptr
inline bool operator==(const DataPtr& dp, std::nullptr_t) noexcept {
return !dp;
}
inline bool operator==(std::nullptr_t, const DataPtr& dp) noexcept {
return !dp;
}
inline bool operator!=(const DataPtr& dp, std::nullptr_t) noexcept {
return dp;
}
inline bool operator!=(std::nullptr_t, const DataPtr& dp) noexcept {
return dp;
}
// Note [raw_allocate/raw_deallocate and Thrust]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Thrust's support for custom allocators requires us to write something
// like this:
//
// class ThrustAllocator {
// char* allocate(size_t);
// void deallocate(char*, size_t);
// };
//
// This is not good for our unique_ptr based allocator interface, as
// there is no way to get to the context when we free.
//
// However, in some cases the context is exactly the same as
// the data pointer. In this case, we can support the "raw"
// allocate and deallocate interface. This is what
// raw_deleter signifies. By default, it returns a nullptr, which means that
// the raw interface is not implemented. Be sure to implement it whenever
// possible, or the raw interface will incorrectly reported as unsupported,
// when it is actually possible.
struct C10_API Allocator {
virtual ~Allocator() = default;
virtual DataPtr allocate(size_t n) = 0;
// Clones an allocation that came from this allocator.
//
// To perform the copy, this function calls `copy_data`, which
// must be implemented by derived classes.
//
// Note that this explicitly ignores any context that may have been
// attached to the input data.
//
// Requires: input data was allocated by the same allocator.
DataPtr clone(const void* data, std::size_t n);
// Checks if DataPtr has a simple context, not wrapped with any out of the
// ordinary contexts.
virtual bool is_simple_data_ptr(const DataPtr& data_ptr) const;
// If this returns a non nullptr, it means that allocate()
// is guaranteed to return a unique_ptr with this deleter attached;
// it means the rawAllocate and rawDeallocate APIs are safe to use.
// This function MUST always return the same BoundDeleter.
virtual DeleterFnPtr raw_deleter() const {
return nullptr;
}
void* raw_allocate(size_t n) {
auto dptr = allocate(n);
AT_ASSERT(dptr.get() == dptr.get_context());
return dptr.release_context();
}
void raw_deallocate(void* ptr) {
auto d = raw_deleter();
AT_ASSERT(d);
d(ptr);
}
// Copies data from one allocation to another.
// Pure virtual, so derived classes must define behavior.
// Derived class implementation can simply call `default_copy_data`
// to use `std::memcpy`.
//
// Requires: src and dest were allocated by this allocator
// Requires: src and dest both have length >= count
virtual void copy_data(void* dest, const void* src, std::size_t count)
const = 0;
protected:
// Uses `std::memcpy` to copy data.
// Child classes can use this as `copy_data` when an alternative copy
// API is not needed.
void default_copy_data(void* dest, const void* src, std::size_t count) const;
};
// This context is used to generate DataPtr which have arbitrary
// std::function deleters associated with them. In some user facing
// functions, we give a (user-friendly) interface for constructing
// tensors from external data which take an arbitrary std::function
// deleter. Grep for InefficientStdFunctionContext to find these
// occurrences.
//
// This context is inefficient because we have to do a dynamic
// allocation InefficientStdFunctionContext, on top of the dynamic
// allocation which is implied by std::function itself.
struct C10_API InefficientStdFunctionContext {
void* ptr_;
std::function<void(void*)> deleter_;
InefficientStdFunctionContext(void* ptr, std::function<void(void*)> deleter)
: ptr_(ptr), deleter_(std::move(deleter)) {}
~InefficientStdFunctionContext() {
if (deleter_) {
deleter_(ptr_);
}
}
static DataPtr makeDataPtr(
void* ptr,
std::function<void(void*)> deleter,
Device device);
};
/** Set the allocator for DeviceType `t`. The passed in allocator pointer is
* expected to have static lifetime; this function does NOT take ownership
* of the raw pointer. (The reason for this is to prevent existing pointers
* to an allocator of a particular device from being invalidated when
* SetAllocator is called.)
*
* Also note that this is not thread-safe, and we assume this function will
* only be called during initialization.
*
* The 'priority' flag is introduced when we want to overwrite the default
* allocator, since the allocators are set statically. The default priority
* is 0, which means the lowest. Only higher or equal priority can overwrite
* existing ones.
*/
C10_API void SetAllocator(DeviceType t, Allocator* alloc, uint8_t priority = 0);
C10_API Allocator* GetAllocator(const DeviceType& t);
template <DeviceType t>
struct AllocatorRegisterer {
explicit AllocatorRegisterer(Allocator* alloc) {
SetAllocator(t, alloc);
}
};
#define REGISTER_ALLOCATOR(t, f) \
namespace { \
static c10::AllocatorRegisterer<t> g_allocator_d(f); \
}
// An interface for reporting thread local memory usage
// per device
struct C10_API MemoryReportingInfoBase : public c10::DebugInfoBase {
MemoryReportingInfoBase();
~MemoryReportingInfoBase() override = default;
/**
* alloc_size corresponds to the size of the ptr.
*
* total_allocated corresponds to total allocated memory.
*
* total_reserved corresponds to total size of memory pool, both used and
* unused, if applicable.
*/
virtual void reportMemoryUsage(
void* ptr,
int64_t alloc_size,
size_t total_allocated,
size_t total_reserved,
Device device) = 0;
virtual void reportOutOfMemory(
int64_t alloc_size,
size_t total_allocated,
size_t total_reserved,
Device device);
virtual bool memoryProfilingEnabled() const = 0;
};
C10_API bool memoryProfilingEnabled();
C10_API void reportMemoryUsageToProfiler(
void* ptr,
int64_t alloc_size,
size_t total_allocated,
size_t total_reserved,
Device device);
C10_API void reportOutOfMemoryToProfiler(
int64_t alloc_size,
size_t total_allocated,
size_t total_reserved,
Device device);
// used to hold traceback information in allocators
struct GatheredContext {
virtual ~GatheredContext() = default;
};
} // namespace c10

View File

@ -0,0 +1,72 @@
#pragma once
#include <c10/macros/Export.h>
namespace c10 {
// Structure used to pack all the thread local boolean
// flags used by autograd
struct C10_API AutogradState {
static AutogradState& get_tls_state();
static void set_tls_state(AutogradState state);
AutogradState(
bool grad_mode,
bool inference_mode,
bool fw_grad_mode,
bool multithreading_enabled)
: grad_mode_(grad_mode),
inference_mode_(inference_mode),
fw_grad_mode_(fw_grad_mode),
multithreading_enabled_(multithreading_enabled),
view_replay_enabled_(false) {}
void set_grad_mode(bool enabled) {
grad_mode_ = enabled;
}
void set_fw_grad_mode(bool enabled) {
fw_grad_mode_ = enabled;
}
void set_inference_mode(bool enabled) {
inference_mode_ = enabled;
}
void set_multithreading_enabled(bool multithreading_enabled) {
multithreading_enabled_ = multithreading_enabled;
}
void set_view_replay_enabled(bool view_replay_enabled) {
view_replay_enabled_ = view_replay_enabled;
}
bool get_grad_mode() const {
return grad_mode_;
}
bool get_fw_grad_mode() const {
return fw_grad_mode_;
}
bool get_inference_mode() const {
return inference_mode_;
}
bool get_multithreading_enabled() const {
return multithreading_enabled_;
}
bool get_view_replay_enabled() const {
return view_replay_enabled_;
}
private:
bool grad_mode_ : 1;
bool inference_mode_ : 1;
bool fw_grad_mode_ : 1;
bool multithreading_enabled_ : 1;
bool view_replay_enabled_ : 1;
};
} // namespace c10

View File

@ -0,0 +1,387 @@
#pragma once
#include <c10/core/DeviceType.h>
#include <c10/core/DispatchKey.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/util/Exception.h>
#include <stdexcept>
namespace c10 {
/**
* This legacy enum class defines the set of backends supported by old school,
* code generated Type-based ATen. A "backend" in this sense roughly
* corresponds to the cartesian product of (device type, layout), but restricted
* only to combinations which we actually have kernels for. Backend does NOT
* include dtype.
*
* The reason we are sunsetting this enum class is because it doesn't allow for
* open registration; e.g., if you want to add SparseXLA, you'd have to
* edit this enum; you wouldn't be able to do it out of tree. DispatchKey is
* the replacement for Backend which supports open registration.
*
* NB: The concept of 'Backend' here disagrees with the notion of backend
* exposed to users in torch.backends. Backend here is something like "CPU"
* or "SparseCUDA"; backend in torch.backends is something like "MKL" or
* "CUDNN".
*/
enum class Backend {
CPU,
CUDA,
HIP,
VE,
FPGA,
IPU,
XPU,
SparseCPU,
SparseCUDA,
SparseCsrCPU,
SparseCsrCUDA,
SparseHIP,
SparseVE,
SparseXPU,
SparsePrivateUse1,
SparseCsrHIP,
SparseCsrVE,
SparseCsrXPU,
SparseCsrPrivateUse1,
MAIA,
XLA,
Vulkan,
Metal,
Meta,
QuantizedCPU,
QuantizedCUDA,
QuantizedXPU,
QuantizedPrivateUse1,
Undefined,
MkldnnCPU,
MPS,
HPU,
Lazy,
MTIA,
PrivateUse1,
NumOptions
};
inline Backend dispatchKeyToBackend(DispatchKey t) {
if (t == DispatchKey::CPU || t == DispatchKey::AutogradCPU) {
return Backend::CPU;
} else if (t == DispatchKey::CUDA || t == DispatchKey::AutogradCUDA) {
return Backend::CUDA;
} else if (t == DispatchKey::HIP) {
return Backend::HIP;
} else if (t == DispatchKey::VE) {
return Backend::VE;
} else if (t == DispatchKey::FPGA) {
return Backend::FPGA;
} else if (t == DispatchKey::MAIA) {
return Backend::MAIA;
} else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
return Backend::XLA;
} else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) {
return Backend::Lazy;
} else if (t == DispatchKey::MPS || t == DispatchKey::AutogradMPS) {
return Backend::MPS;
} else if (t == DispatchKey::Vulkan) {
return Backend::Vulkan;
} else if (t == DispatchKey::Metal) {
return Backend::Metal;
} else if (t == DispatchKey::Meta) {
return Backend::Meta;
} else if (t == DispatchKey::SparseCPU) {
return Backend::SparseCPU;
} else if (t == DispatchKey::SparseCUDA) {
return Backend::SparseCUDA;
} else if (t == DispatchKey::SparseHIP) {
return Backend::SparseHIP;
} else if (t == DispatchKey::SparseVE) {
return Backend::SparseVE;
} else if (t == DispatchKey::SparsePrivateUse1) {
return Backend::SparsePrivateUse1;
} else if (t == DispatchKey::SparseCsrCPU) {
return Backend::SparseCsrCPU;
} else if (t == DispatchKey::SparseCsrCUDA) {
return Backend::SparseCsrCUDA;
} else if (t == DispatchKey::SparseCsrHIP) {
return Backend::SparseCsrHIP;
} else if (t == DispatchKey::SparseCsrVE) {
return Backend::SparseCsrVE;
} else if (t == DispatchKey::SparseCsrPrivateUse1) {
return Backend::SparseCsrPrivateUse1;
} else if (t == DispatchKey::MkldnnCPU) {
return Backend::MkldnnCPU;
} else if (t == DispatchKey::QuantizedCPU) {
return Backend::QuantizedCPU;
} else if (t == DispatchKey::QuantizedCUDA) {
return Backend::QuantizedCUDA;
} else if (t == DispatchKey::IPU || t == DispatchKey::AutogradIPU) {
return Backend::IPU;
} else if (t == DispatchKey::XPU || t == DispatchKey::AutogradXPU) {
return Backend::XPU;
} else if (t == DispatchKey::SparseXPU) {
return Backend::SparseXPU;
} else if (t == DispatchKey::SparseCsrXPU) {
return Backend::SparseCsrXPU;
} else if (t == DispatchKey::QuantizedXPU) {
return Backend::QuantizedXPU;
} else if (t == DispatchKey::QuantizedPrivateUse1) {
return Backend::QuantizedPrivateUse1;
} else if (t == DispatchKey::HPU || t == DispatchKey::AutogradHPU) {
return Backend::HPU;
} else if (t == DispatchKey::MTIA || t == DispatchKey::AutogradMTIA) {
return Backend::MTIA;
} else if (
t == DispatchKey::PrivateUse1 || t == DispatchKey::AutogradPrivateUse1) {
return Backend::PrivateUse1;
} else if (t == DispatchKey::Undefined) {
return Backend::Undefined;
} else {
TORCH_CHECK(false, "Unrecognized tensor type ID: ", t);
}
}
inline DispatchKey backendToDispatchKey(Backend b) {
switch (b) {
case Backend::CPU:
return DispatchKey::CPU;
case Backend::CUDA:
return DispatchKey::CUDA;
case Backend::HIP:
return DispatchKey::HIP;
case Backend::VE:
return DispatchKey::VE;
case Backend::FPGA:
return DispatchKey::FPGA;
case Backend::MAIA:
return DispatchKey::MAIA;
case Backend::XLA:
return DispatchKey::XLA;
case Backend::Lazy:
return DispatchKey::Lazy;
case Backend::IPU:
return DispatchKey::IPU;
case Backend::XPU:
return DispatchKey::XPU;
case Backend::SparseXPU:
return DispatchKey::SparseXPU;
case Backend::SparseCsrXPU:
return DispatchKey::SparseCsrXPU;
case Backend::SparseCPU:
return DispatchKey::SparseCPU;
case Backend::SparseCUDA:
return DispatchKey::SparseCUDA;
case Backend::SparseHIP:
return DispatchKey::SparseHIP;
case Backend::SparseVE:
return DispatchKey::SparseVE;
case Backend::SparsePrivateUse1:
return DispatchKey::SparsePrivateUse1;
case Backend::SparseCsrCPU:
return DispatchKey::SparseCsrCPU;
case Backend::SparseCsrCUDA:
return DispatchKey::SparseCsrCUDA;
case Backend::SparseCsrHIP:
return DispatchKey::SparseCsrHIP;
case Backend::SparseCsrVE:
return DispatchKey::SparseCsrVE;
case Backend::SparseCsrPrivateUse1:
return DispatchKey::SparseCsrPrivateUse1;
case Backend::MkldnnCPU:
return DispatchKey::MkldnnCPU;
case Backend::Vulkan:
return DispatchKey::Vulkan;
case Backend::Metal:
return DispatchKey::Metal;
case Backend::Meta:
return DispatchKey::Meta;
case Backend::QuantizedCPU:
return DispatchKey::QuantizedCPU;
case Backend::QuantizedCUDA:
return DispatchKey::QuantizedCUDA;
case Backend::QuantizedPrivateUse1:
return DispatchKey::QuantizedPrivateUse1;
case Backend::Undefined:
return DispatchKey::Undefined;
case Backend::MPS:
return DispatchKey::MPS;
case Backend::HPU:
return DispatchKey::HPU;
case Backend::MTIA:
return DispatchKey::MTIA;
case Backend::PrivateUse1:
return DispatchKey::PrivateUse1;
default:
throw std::runtime_error("Unknown backend");
}
}
inline DeviceType backendToDeviceType(Backend b) {
switch (b) {
case Backend::CPU:
case Backend::MkldnnCPU:
case Backend::SparseCPU:
case Backend::SparseCsrCPU:
case Backend::QuantizedCPU:
return DeviceType::CPU;
case Backend::CUDA:
case Backend::SparseCUDA:
case Backend::QuantizedCUDA:
case Backend::SparseCsrCUDA:
return DeviceType::CUDA;
case Backend::HIP:
return DeviceType::HIP;
case Backend::VE:
return DeviceType::VE;
case Backend::FPGA:
return DeviceType::FPGA;
case Backend::MAIA:
return DeviceType::MAIA;
case Backend::XLA:
return DeviceType::XLA;
case Backend::Lazy:
return DeviceType::Lazy;
case Backend::SparseHIP:
return DeviceType::HIP;
case Backend::SparseVE:
return DeviceType::VE;
case Backend::SparseCsrHIP:
return DeviceType::HIP;
case Backend::SparseCsrVE:
return DeviceType::VE;
case Backend::IPU:
return DeviceType::IPU;
case Backend::XPU:
case Backend::SparseXPU:
case Backend::SparseCsrXPU:
case Backend::QuantizedXPU:
return DeviceType::XPU;
case Backend::Vulkan:
return DeviceType::Vulkan;
case Backend::Metal:
return DeviceType::Metal;
case Backend::Meta:
return DeviceType::Meta;
case Backend::MPS:
return DeviceType::MPS;
case Backend::HPU:
return DeviceType::HPU;
case Backend::MTIA:
return DeviceType::MTIA;
case Backend::PrivateUse1:
case Backend::SparsePrivateUse1:
case Backend::SparseCsrPrivateUse1:
case Backend::QuantizedPrivateUse1:
return DeviceType::PrivateUse1;
case Backend::Undefined:
TORCH_CHECK(false, "Undefined backend is not a valid device type");
default:
TORCH_CHECK(false, "Unknown backend");
}
}
inline const char* toString(Backend b) {
switch (b) {
case Backend::CPU:
return "CPU";
case Backend::CUDA:
return "CUDA";
case Backend::HIP:
return "HIP";
case Backend::VE:
return "VE";
case Backend::FPGA:
return "FPGA";
case Backend::XPU:
return "XPU";
case Backend::IPU:
return "IPU";
case Backend::MAIA:
return "MAIA";
case Backend::XLA:
return "XLA";
case Backend::Lazy:
return "Lazy";
case Backend::MPS:
return "MPS";
case Backend::SparseCPU:
return "SparseCPU";
case Backend::SparseCUDA:
return "SparseCUDA";
case Backend::SparseHIP:
return "SparseHIP";
case Backend::SparseVE:
return "SparseVE";
case Backend::SparseXPU:
return "SparseXPU";
case Backend::SparsePrivateUse1:
return "SparsePrivateUse1";
case Backend::SparseCsrCPU:
return "SparseCsrCPU";
case Backend::SparseCsrCUDA:
return "SparseCsrCUDA";
case Backend::SparseCsrHIP:
return "SparseCsrHIP";
case Backend::SparseCsrVE:
return "SparseCsrVE";
case Backend::SparseCsrXPU:
return "SparseCsrXPU";
case Backend::SparseCsrPrivateUse1:
return "SparseCsrPrivateUse1";
case Backend::MkldnnCPU:
return "MkldnnCPU";
case Backend::Vulkan:
return "Vulkan";
case Backend::Metal:
return "Metal";
case Backend::Meta:
return "Meta";
case Backend::QuantizedCPU:
return "QuantizedCPU";
case Backend::QuantizedCUDA:
return "QuantizedCUDA";
case Backend::QuantizedXPU:
return "QuantizedXPU";
case Backend::QuantizedPrivateUse1:
return "QuantizedPrivateUse1";
case Backend::HPU:
return "HPU";
case Backend::MTIA:
return "MTIA";
case Backend::PrivateUse1:
return "PrivateUseOne";
default:
return "UNKNOWN_BACKEND";
}
}
inline bool isSparse(Backend b) {
switch (b) {
case Backend::SparseXPU:
case Backend::SparseCPU:
case Backend::SparseCUDA:
case Backend::SparseHIP:
case Backend::SparseVE:
case Backend::SparsePrivateUse1:
return true;
default:
return false;
}
}
inline bool isSparseCsr(Backend b) {
switch (b) {
case Backend::SparseCsrXPU:
case Backend::SparseCsrCPU:
case Backend::SparseCsrCUDA:
case Backend::SparseCsrHIP:
case Backend::SparseCsrVE:
case Backend::SparseCsrPrivateUse1:
return true;
default:
return false;
}
}
} // namespace c10

View File

@ -0,0 +1,59 @@
#pragma once
#include <cstdint>
#include <cstring>
#include <mutex>
#include <unordered_map>
#include <c10/core/Allocator.h>
#include <c10/macros/Export.h>
#include <c10/util/Flags.h>
// TODO: rename to c10
C10_DECLARE_bool(caffe2_report_cpu_memory_usage);
namespace c10 {
using MemoryDeleter = void (*)(void*);
// A helper function that is basically doing nothing.
C10_API void NoDelete(void*);
// A simple struct that is used to report C10's memory allocation,
// deallocation status and out-of-memory events to the profiler
class C10_API ProfiledCPUMemoryReporter {
public:
ProfiledCPUMemoryReporter() = default;
void New(void* ptr, size_t nbytes);
void OutOfMemory(size_t nbytes);
void Delete(void* ptr);
private:
std::mutex mutex_;
std::unordered_map<void*, size_t> size_table_;
size_t allocated_ = 0;
size_t log_cnt_ = 0;
};
C10_API ProfiledCPUMemoryReporter& profiledCPUMemoryReporter();
// Get the CPU Allocator.
C10_API at::Allocator* GetCPUAllocator();
// Sets the CPU allocator to the given allocator: the caller gives away the
// ownership of the pointer.
C10_API void SetCPUAllocator(at::Allocator* alloc, uint8_t priority = 0);
// Get the Default CPU Allocator
C10_API at::Allocator* GetDefaultCPUAllocator();
// Get the Default Mobile CPU Allocator
C10_API at::Allocator* GetDefaultMobileCPUAllocator();
// The CPUCachingAllocator is experimental and might disappear in the future.
// The only place that uses it is in StaticRuntime.
// Set the CPU Caching Allocator
C10_API void SetCPUCachingAllocator(Allocator* alloc, uint8_t priority = 0);
// Get the CPU Caching Allocator
C10_API Allocator* GetCPUCachingAllocator();
} // namespace c10

View File

@ -0,0 +1,131 @@
#pragma once
#include <c10/core/Allocator.h>
#include <c10/util/irange.h>
#include <array>
namespace c10::CachingDeviceAllocator {
struct Stat {
void increase(size_t amount) {
current += static_cast<int64_t>(amount);
peak = std::max(current, peak);
allocated += static_cast<int64_t>(amount);
}
void decrease(size_t amount) {
current -= static_cast<int64_t>(amount);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
current >= 0,
"Negative tracked stat in device allocator (likely logic error).");
freed += static_cast<int64_t>(amount);
}
void reset_accumulated() {
allocated = 0;
freed = 0;
}
void reset_peak() {
peak = current;
}
int64_t current = 0;
int64_t peak = 0;
int64_t allocated = 0;
int64_t freed = 0;
};
enum struct StatType : uint64_t {
AGGREGATE = 0,
SMALL_POOL = 1,
LARGE_POOL = 2,
NUM_TYPES = 3 // remember to update this whenever a new stat type is added
};
using StatArray = std::array<Stat, static_cast<size_t>(StatType::NUM_TYPES)>;
using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>;
template <typename Func>
void for_each_selected_stat_type(const StatTypes& stat_types, Func f) {
for (const auto stat_type : c10::irange(stat_types.size())) {
if (stat_types[stat_type]) {
f(stat_type);
}
}
}
// Struct containing memory allocator summary statistics for a device.
struct DeviceStats {
// COUNT: allocations requested by client code
StatArray allocation;
// COUNT: number of allocated segments from device memory allocation.
StatArray segment;
// COUNT: number of active memory blocks (allocated or used by stream)
StatArray active;
// COUNT: number of inactive, split memory blocks (unallocated but can't be
// released via device memory deallocation)
StatArray inactive_split;
// SUM: bytes allocated by this memory alocator
StatArray allocated_bytes;
// SUM: bytes reserved by this memory allocator (both free and used)
StatArray reserved_bytes;
// SUM: bytes within active memory blocks
StatArray active_bytes;
// SUM: bytes within inactive, split memory blocks
StatArray inactive_split_bytes;
// SUM: bytes requested by client code
StatArray requested_bytes;
// COUNT: total number of failed calls to device malloc necessitating cache
// flushes.
int64_t num_alloc_retries = 0;
// COUNT: total number of OOMs (i.e. failed calls to device memory allocation
// after cache flush)
int64_t num_ooms = 0;
// COUNT: total number of oversize blocks allocated from pool
Stat oversize_allocations;
// COUNT: total number of oversize blocks requiring malloc
Stat oversize_segments;
// COUNT: total number of synchronize_and_free_events() calls
int64_t num_sync_all_streams = 0;
// COUNT: total number of device memory allocation calls. This includes both
// mapped and malloced memory.
int64_t num_device_alloc = 0;
// COUNT: total number of device memory deallocation calls. This includes both
// un-mapped and free memory.
int64_t num_device_free = 0;
// SIZE: maximum block size that is allowed to be split.
int64_t max_split_size = 0;
};
// Size pretty-printer
inline std::string format_size(uint64_t size) {
std::ostringstream os;
os.precision(2);
os << std::fixed;
if (size <= 1024) {
os << size << " bytes";
} else if (size <= 1048576) {
os << (static_cast<double>(size) / 1024.0);
os << " KiB";
} else if (size <= 1073741824ULL) {
os << static_cast<double>(size) / 1048576.0;
os << " MiB";
} else {
os << static_cast<double>(size) / 1073741824.0;
os << " GiB";
}
return os.str();
}
} // namespace c10::CachingDeviceAllocator

View File

@ -0,0 +1,57 @@
#pragma once
#include <c10/util/TypeTraits.h>
#include <type_traits>
namespace c10 {
/**
* Represent a function pointer as a C++ type.
* This allows using the function pointer as a type
* in a template and calling it from inside the template
* allows the compiler to inline the call because it
* knows the function pointer at compile time.
*
* Example 1:
* int add(int a, int b) {return a + b;}
* using Add = TORCH_FN_TYPE(add);
* template<class Func> struct Executor {
* int execute(int a, int b) {
* return Func::func_ptr()(a, b);
* }
* };
* Executor<Add> executor;
* EXPECT_EQ(3, executor.execute(1, 2));
*
* Example 2:
* int add(int a, int b) {return a + b;}
* template<class Func> int execute(Func, int a, int b) {
* return Func::func_ptr()(a, b);
* }
* EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2));
*/
template <class FuncType_, FuncType_* func_ptr_>
struct CompileTimeFunctionPointer final {
static_assert(
guts::is_function_type<FuncType_>::value,
"TORCH_FN can only wrap function types.");
using FuncType = FuncType_;
static constexpr FuncType* func_ptr() {
return func_ptr_;
}
};
template <class T>
struct is_compile_time_function_pointer : std::false_type {};
template <class FuncType, FuncType* func_ptr>
struct is_compile_time_function_pointer<
CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};
} // namespace c10
#define TORCH_FN_TYPE(func) \
::c10::CompileTimeFunctionPointer< \
std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \
func>
#define TORCH_FN(func) TORCH_FN_TYPE(func)()

View File

@ -0,0 +1,110 @@
#pragma once
#include <c10/core/SymNodeImpl.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
#include <cstdint>
#include <optional>
#include <string>
#include <variant>
namespace c10 {
// Unlike other SymNodeImpl, this cannot be "dispatched" conventionally,
// as it typically needs to defer to another SymNodeImpl
//
// Can either represent a bool, int (don't support float yet) this is useful
// for representing otherwise unrepresentable large negative integer constant.
template <typename T>
class C10_API ConstantSymNodeImpl : public SymNodeImpl {
static_assert(
::std::is_same_v<T, int64_t> || ::std::is_same_v<T, bool>,
"ConstantSymNodeImpl can only accept int64_t or bool types");
public:
ConstantSymNodeImpl(T val) : value_(val) {}
bool is_int() override {
return is_int_();
}
bool is_bool() override {
return is_bool_();
}
bool is_float() override {
return false;
}
int64_t guard_int(
const char* file [[maybe_unused]],
int64_t line [[maybe_unused]]) override {
TORCH_CHECK(is_int(), "not an int");
return int_();
}
bool guard_bool(
const char* file [[maybe_unused]],
int64_t line [[maybe_unused]]) override {
TORCH_CHECK(is_bool(), "not a bool");
return bool_();
}
double guard_float(
const char* file [[maybe_unused]],
int64_t line [[maybe_unused]]) override {
TORCH_CHECK(false, "not a float");
}
int64_t int_() override {
TORCH_CHECK(is_int(), "not an int");
return ::std::get<int64_t>(value_);
}
bool bool_() override {
TORCH_CHECK(is_bool(), "not a bool");
return ::std::get<bool>(value_);
}
bool has_hint() override {
return true;
}
c10::SymNode eq(const c10::SymNode& other) override;
c10::SymNode ne(const c10::SymNode& other) override;
c10::SymNode ge(const c10::SymNode& other) override;
c10::SymNode le(const c10::SymNode& other) override;
c10::SymNode lt(const c10::SymNode& other) override;
c10::SymNode gt(const c10::SymNode& other) override;
c10::SymNode mul(const c10::SymNode& other) override;
::std::string str() override {
if constexpr (is_int_()) {
return ::std::to_string(::std::get<int64_t>(value_));
} else {
return ::std::get<bool>(value_) ? "true" : "false";
}
}
std::optional<int64_t> constant_int() override {
if constexpr (is_int_()) {
return ::std::get<int64_t>(value_);
} else {
return std::nullopt;
}
}
std::optional<bool> constant_bool() override {
if constexpr (is_bool_()) {
return ::std::get<bool>(value_);
} else {
return std::nullopt;
}
}
bool is_constant() override {
return true;
}
bool is_symbolic() override {
return false;
}
private:
::std::variant<int64_t, bool> value_;
static constexpr bool is_int_() {
return ::std::is_same_v<T, int64_t>;
}
static constexpr bool is_bool_() {
return ::std::is_same_v<T, bool>;
}
};
} // namespace c10

View File

@ -0,0 +1,129 @@
#pragma once
#include <c10/core/SymBool.h>
#include <c10/core/SymInt.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/SmallVector.h>
#include <c10/util/irange.h>
#include <algorithm>
#include <cstdint>
namespace c10 {
template <typename T>
bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
bool is_contiguous = true;
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) {
return is_contiguous;
}
T z = 1;
// NB: make sure we do signed arithmetic
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
const auto& size_d = sizes[d];
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(strides[d], z))) {
z *= size_d;
} else {
is_contiguous = false;
break;
}
}
}
return is_contiguous;
}
template <typename T>
bool _compute_channels_last_contiguous_2d(
ArrayRef<T> sizes,
ArrayRef<T> strides) {
// Please don't combine these code, constant array is used here to let
// compiler fully unroll the loop to get better performance
switch (sizes.size()) {
case 4: {
T expected = 1;
for (auto& d : {1, 3, 2, 0}) {
const auto& size_d = sizes[d];
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) {
return false;
}
expected *= size_d;
}
}
return true;
}
// NOLINTNEXTLINE(bugprone-branch-clone)
case 3:
// TODO dim == 3 case will be enabled once it is fully tested
return false;
default:
return false;
}
}
template <typename T>
bool _compute_channels_last_contiguous_3d(
ArrayRef<T> sizes,
ArrayRef<T> strides) {
// Please don't combine these code, constant array is used here to let
// compiler fully unroll the loop to get better performance
switch (sizes.size()) {
case 5: {
T expected = 1;
for (auto& d : {1, 4, 3, 2, 0}) {
const auto& size_d = sizes[d];
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) {
return false;
}
expected *= size_d;
}
}
return true;
}
// NOLINTNEXTLINE(bugprone-branch-clone)
case 4:
// TODO dim == 4 case will be enabled once it is fully tested
return false;
default:
return false;
}
}
template <typename T>
bool _compute_non_overlapping_and_dense(
ArrayRef<T> sizes,
ArrayRef<T> strides) {
auto dim = sizes.size();
if (dim == 1) {
return sizes[0] < 2 || strides[0] == 1;
}
SmallVector<int64_t, 5> perm;
perm.resize(dim);
for (const auto i : c10::irange(dim)) {
perm[i] = i;
}
// Sort by strides, leaving 0 and 1 sized dims at the end of the array
std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
if (sizes[a] < 2) {
return false;
} else if (sizes[b] < 2) {
return true;
}
return strides[a] < strides[b];
});
T require_stride = 1;
for (const auto i : c10::irange(dim)) {
const auto& size_perm_i = sizes[perm[i]];
if (size_perm_i < 2) {
return true;
}
if (strides[perm[i]] != require_stride) {
return false;
}
require_stride *= size_perm_i;
}
return true;
}
} // namespace c10

View File

@ -0,0 +1,48 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <cstddef>
namespace c10 {
using CopyBytesFunction = void (*)(
size_t nbytes,
const void* src,
Device src_device,
void* dst,
Device dst_device);
struct C10_API _CopyBytesFunctionRegisterer {
_CopyBytesFunctionRegisterer(
DeviceType from,
DeviceType to,
CopyBytesFunction func_sync,
CopyBytesFunction func_async = nullptr);
};
#define REGISTER_COPY_BYTES_FUNCTION(from, to, ...) \
namespace { \
static _CopyBytesFunctionRegisterer C10_ANONYMOUS_VARIABLE( \
g_copy_function)(from, to, __VA_ARGS__); \
}
/*
* WARNING: Implementations for this function are currently registered from
* ATen and caffe2, not yet from c10. Don't use this if not either ATen
* or caffe2 is present as well.
* We can't move them yet, because the CUDA implementations aren't unified yet
* between ATen and caffe2.
* We're planning to move the implementations into c10/backend/xxx
* to make c10 self contained again.
*/
C10_API void CopyBytes(
size_t nbytes,
const void* src,
Device src_device,
void* dst,
Device dst_device,
bool async);
} // namespace c10

View File

@ -0,0 +1,15 @@
#pragma once
#include <c10/core/ScalarType.h>
#include <c10/macros/Export.h>
namespace caffe2 {
class TypeMeta;
} // namespace caffe2
namespace c10 {
C10_API void set_default_dtype(caffe2::TypeMeta dtype);
C10_API const caffe2::TypeMeta get_default_dtype();
C10_API ScalarType get_default_dtype_as_scalartype();
C10_API const caffe2::TypeMeta get_default_complex_dtype();
} // namespace c10

View File

@ -0,0 +1,45 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Layout.h>
#include <c10/core/ScalarType.h>
#include <c10/util/typeid.h>
namespace c10 {
struct TensorOptions;
/// Like TensorOptions, but all fields are guaranteed to be filled.
struct DefaultTensorOptions {
DefaultTensorOptions() = default;
caffe2::TypeMeta dtype() const noexcept {
return dtype_;
}
Device device() const noexcept {
return device_;
}
Layout layout() const noexcept {
return layout_;
}
bool requires_grad() const noexcept {
return requires_grad_;
}
// Defined in TensorOptions.h
inline DefaultTensorOptions& merge(const TensorOptions& options);
private:
caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make<float>(); // 64-bit
Device device_ = at::kCPU; // 32-bit
Layout layout_ = at::kStrided; // 8-bit
bool requires_grad_ = false; // 8-bit
};
inline const DefaultTensorOptions& getDefaultTensorOptions() {
static const auto options = DefaultTensorOptions();
return options;
}
} // namespace c10

View File

@ -0,0 +1,216 @@
#pragma once
#include <c10/core/DeviceType.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iosfwd>
#include <string>
namespace c10 {
/// An index representing a specific device; e.g., the 1 in GPU 1.
/// A DeviceIndex is not independently meaningful without knowing
/// the DeviceType it is associated; try to use Device rather than
/// DeviceIndex directly.
using DeviceIndex = int8_t;
/// Represents a compute device on which a tensor is located. A device is
/// uniquely identified by a type, which specifies the type of machine it is
/// (e.g. CPU or CUDA GPU), and a device index or ordinal, which identifies the
/// specific compute device when there is more than one of a certain type. The
/// device index is optional, and in its defaulted state represents (abstractly)
/// "the current device". Further, there are two constraints on the value of the
/// device index, if one is explicitly stored:
/// 1. A negative index represents the current device, a non-negative index
/// represents a specific, concrete device,
/// 2. When the device type is CPU, the device index must be zero.
struct C10_API Device final {
using Type = DeviceType;
/// Constructs a new `Device` from a `DeviceType` and an optional device
/// index.
/* implicit */ Device(DeviceType type, DeviceIndex index = -1)
: type_(type), index_(index) {
validate();
}
/// Constructs a `Device` from a string description, for convenience.
/// The string supplied must follow the following schema:
/// `(cpu|cuda)[:<device-index>]`
/// where `cpu` or `cuda` specifies the device type, and
/// `:<device-index>` optionally specifies a device index.
/* implicit */ Device(const std::string& device_string);
/// Returns true if the type and index of this `Device` matches that of
/// `other`.
bool operator==(const Device& other) const noexcept {
return this->type_ == other.type_ && this->index_ == other.index_;
}
/// Returns true if the type or index of this `Device` differs from that of
/// `other`.
bool operator!=(const Device& other) const noexcept {
return !(*this == other);
}
/// Sets the device index.
void set_index(DeviceIndex index) {
index_ = index;
}
/// Returns the type of device this is.
DeviceType type() const noexcept {
return type_;
}
/// Returns the optional index.
DeviceIndex index() const noexcept {
return index_;
}
/// Returns true if the device has a non-default index.
bool has_index() const noexcept {
return index_ != -1;
}
/// Return true if the device is of CUDA type.
bool is_cuda() const noexcept {
return type_ == DeviceType::CUDA;
}
/// Return true if the device is of PrivateUse1 type.
bool is_privateuseone() const noexcept {
return type_ == DeviceType::PrivateUse1;
}
/// Return true if the device is of MPS type.
bool is_mps() const noexcept {
return type_ == DeviceType::MPS;
}
/// Return true if the device is of HIP type.
bool is_hip() const noexcept {
return type_ == DeviceType::HIP;
}
/// Return true if the device is of VE type.
bool is_ve() const noexcept {
return type_ == DeviceType::VE;
}
/// Return true if the device is of XPU type.
bool is_xpu() const noexcept {
return type_ == DeviceType::XPU;
}
/// Return true if the device is of IPU type.
bool is_ipu() const noexcept {
return type_ == DeviceType::IPU;
}
/// Return true if the device is of XLA type.
bool is_xla() const noexcept {
return type_ == DeviceType::XLA;
}
/// Return true if the device is of MTIA type.
bool is_mtia() const noexcept {
return type_ == DeviceType::MTIA;
}
/// Return true if the device is of HPU type.
bool is_hpu() const noexcept {
return type_ == DeviceType::HPU;
}
/// Return true if the device is of Lazy type.
bool is_lazy() const noexcept {
return type_ == DeviceType::Lazy;
}
/// Return true if the device is of Vulkan type.
bool is_vulkan() const noexcept {
return type_ == DeviceType::Vulkan;
}
/// Return true if the device is of Metal type.
bool is_metal() const noexcept {
return type_ == DeviceType::Metal;
}
/// Return true if the device is of MAIA type.
bool is_maia() const noexcept {
return type_ == DeviceType::MAIA;
}
/// Return true if the device is of META type.
bool is_meta() const noexcept {
return type_ == DeviceType::Meta;
}
/// Return true if the device is of CPU type.
bool is_cpu() const noexcept {
return type_ == DeviceType::CPU;
}
/// Return true if the device supports arbitrary strides.
bool supports_as_strided() const noexcept {
return type_ != DeviceType::IPU && type_ != DeviceType::XLA &&
type_ != DeviceType::Lazy && type_ != DeviceType::MTIA;
}
/// Same string as returned from operator<<.
std::string str() const;
private:
DeviceType type_;
DeviceIndex index_ = -1;
void validate() {
// Removing these checks in release builds noticeably improves
// performance in micro-benchmarks.
// This is safe to do, because backends that use the DeviceIndex
// have a later check when we actually try to switch to that device.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
index_ >= -1,
"Device index must be -1 or non-negative, got ",
static_cast<int>(index_));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!is_cpu() || index_ <= 0,
"CPU device index must be -1 or zero, got ",
static_cast<int>(index_));
}
};
C10_API std::ostream& operator<<(std::ostream& stream, const Device& device);
} // namespace c10
namespace std {
template <>
struct hash<c10::Device> {
size_t operator()(c10::Device d) const noexcept {
// Are you here because this static assert failed? Make sure you ensure
// that the bitmasking code below is updated accordingly!
static_assert(sizeof(c10::DeviceType) == 1, "DeviceType is not 8-bit");
static_assert(sizeof(c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit");
// Note [Hazard when concatenating signed integers]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// We must first convert to a same-sized unsigned type, before promoting to
// the result type, to prevent sign extension when any of the values is -1.
// If sign extension occurs, you'll clobber all of the values in the MSB
// half of the resulting integer.
//
// Technically, by C/C++ integer promotion rules, we only need one of the
// uint32_t casts to the result type, but we put in both for explicitness's
// sake.
uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type()))
<< 16 |
static_cast<uint32_t>(static_cast<uint8_t>(d.index()));
return std::hash<uint32_t>{}(bits);
}
};
} // namespace std

View File

@ -0,0 +1,28 @@
#include <c10/core/Allocator.h>
#include <c10/util/Exception.h>
#include <cstddef>
#include <cstdint>
#include <type_traits>
namespace c10 {
template <typename T>
class DeviceArray {
public:
DeviceArray(c10::Allocator& allocator, size_t size)
: data_ptr_(allocator.allocate(size * sizeof(T))) {
static_assert(std::is_trivial<T>::value, "T must be a trivial type");
TORCH_INTERNAL_ASSERT(
0 == (reinterpret_cast<intptr_t>(data_ptr_.get()) % alignof(T)),
"c10::DeviceArray: Allocated memory is not aligned for this data type");
}
T* get() {
return static_cast<T*>(data_ptr_.get());
}
private:
c10::DataPtr data_ptr_;
};
} // namespace c10

View File

@ -0,0 +1,199 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/InlineDeviceGuard.h>
#include <c10/core/impl/VirtualGuardImpl.h>
#include <c10/util/Optional.h>
namespace c10 {
/// RAII guard that sets a certain default device in its constructor, and
/// changes it back to the device that was originally active upon destruction.
///
/// The device is always reset to the one that was active at the time of
/// construction of the guard. Even if you `set_device` after construction, the
/// destructor will still reset the device to the one that was active at
/// construction time.
///
/// This device guard does NOT have an uninitialized state; it is guaranteed
/// to reset a device on exit. If you are in a situation where you *might*
/// want to setup a guard (i.e., are looking for the moral equivalent
/// of std::optional<DeviceGuard>), see OptionalDeviceGuard.
class DeviceGuard {
public:
/// No default constructor; see Note [Omitted default constructor from RAII]
explicit DeviceGuard() = delete;
/// Set the current device to the passed Device.
explicit DeviceGuard(Device device) : guard_(device) {}
/// This constructor is for testing only.
explicit DeviceGuard(
Device device,
const impl::DeviceGuardImplInterface* impl)
: guard_(device, impl) {}
/// Copy is disallowed
DeviceGuard(const DeviceGuard&) = delete;
DeviceGuard& operator=(const DeviceGuard&) = delete;
/// Move is disallowed, as DeviceGuard does not have an uninitialized state,
/// which is required for moves on types with nontrivial destructors.
DeviceGuard(DeviceGuard&& other) = delete;
DeviceGuard& operator=(DeviceGuard&& other) = delete;
/// Sets the device to the given one. The specified device must be consistent
/// with the device type originally specified during guard construction.
///
/// TODO: The consistency check here is inconsistent with StreamGuard's
/// behavior with set_stream, where a stream on a different device than
/// the original one isn't an error; we just reset the stream and then
/// switch devices.
void reset_device(at::Device device) {
guard_.reset_device(device);
}
/// This method is for testing only.
void reset_device(
at::Device device,
const impl::DeviceGuardImplInterface* impl) {
guard_.reset_device(device, impl);
}
/// Sets the device index to the given one. The device type is inferred
/// from the original device type the guard was constructed with.
void set_index(DeviceIndex index) {
guard_.set_index(index);
}
/// Returns the device that was set at the time the guard was constructed.
Device original_device() const {
return guard_.original_device();
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device.
Device current_device() const {
return guard_.current_device();
}
private:
impl::InlineDeviceGuard<impl::VirtualGuardImpl> guard_;
};
/**
* A OptionalDeviceGuard is an RAII class that sets a device to some value on
* initialization, and resets the device to its original value on destruction.
* Morally, a OptionalDeviceGuard is equivalent to std::optional<DeviceGuard>,
* but with extra constructors and methods as appropriate.
*
* Besides its obvious use (optionally applying a DeviceGuard),
* OptionalDeviceGuard is often also used for the following idiom:
*
* OptionalDeviceGuard g;
* for (const auto& t : tensors) {
* g.set_device(t.device());
* do_something_with(t);
* }
*
* This usage is marginally more efficient than constructing a DeviceGuard every
* iteration of the for loop, as it avoids an unnecessary device reset.
*
* Unlike DeviceGuard, a OptionalDeviceGuard may be uninitialized. This occurs
* when you use the nullary constructor, or pass a nullopt to the constructor.
* Uninitialized OptionalDeviceGuards do *nothing*; they do not know what the
* original device was and they do not reset on destruction. This is why
* original_device() and current_device() return std::optional<Device> rather
* than Device (as they do in DeviceGuard), and also is why we didn't just
* provide OptionalDeviceGuard by default and hide DeviceGuard from users.
*
* The semantics of an OptionalDeviceGuard are exactly explained by thinking
* of it as an std::optional<DeviceGuard>. In particular, an initialized
* OptionalDeviceGuard doesn't restore device to its value at construction; it
* restores device to its value *at initialization*. So if you have the
* program:
*
* setDevice(1);
* OptionalDeviceGuard g;
* setDevice(2);
* g.reset_device(Device(DeviceType::CUDA, 3)); // initializes!
*
* On destruction, g will reset device to 2, rather than 1.
*
* An uninitialized OptionalDeviceGuard is distinct from a (initialized)
* DeviceGuard whose original_device_ and current_device_ match, since the
* DeviceGuard will still reset the device to original_device_.
*/
class OptionalDeviceGuard {
public:
/// Create an uninitialized guard. Set the guard later using reset_device.
explicit OptionalDeviceGuard() = default;
/// Initialize the guard, setting the current device to the passed Device.
explicit OptionalDeviceGuard(Device device) : guard_(device) {}
/// Initialize the guard if a Device is passed; otherwise leave the
/// guard uninitialized.
explicit OptionalDeviceGuard(std::optional<Device> device) : guard_(device) {}
/// Constructor for testing only.
explicit OptionalDeviceGuard(
Device device,
const impl::DeviceGuardImplInterface* impl)
: guard_(device, impl) {}
/// Copy is disallowed
OptionalDeviceGuard(const OptionalDeviceGuard&) = delete;
OptionalDeviceGuard& operator=(const OptionalDeviceGuard&) = delete;
/// Move is disallowed
/// See Note [Explicit initialization of optional fields]
/// and // Note [Move construction for RAII guards is tricky]
/// for rationale.
OptionalDeviceGuard(OptionalDeviceGuard&& other) = delete;
OptionalDeviceGuard& operator=(OptionalDeviceGuard&& other) = delete;
/// Sets the device to the given one. The specified device must be consistent
/// with the device type originally specified during guard construction.
void reset_device(at::Device device) {
guard_.reset_device(device);
}
/// For testing only
void reset_device(
at::Device device,
const impl::DeviceGuardImplInterface* impl) {
guard_.reset_device(device, impl);
}
/// Returns the device that was set at the time the guard was constructed.
std::optional<Device> original_device() const {
return guard_.original_device();
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via reset_device.
std::optional<Device> current_device() const {
return guard_.current_device();
}
private:
impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl> guard_{};
};
// Note [Whither the DeviceGuard boilerplate]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Design note: in principle, we could avoid these wrappers using:
//
// using DeviceGuard = impl::InlineDeviceGuard<impl::VirtualGuardImpl>;
// using OptionalDeviceGuard =
// impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl>;
//
// But the error messages are worse, and our users can't just look at the
// header file to find out what's going on. Furthermore, for specializations
// like CUDAStreamGuard, it can be profitable to replace some interfaces with
// refined types (e.g., return CUDAStream instead of Stream). So, we eat
// the boilerplate and write out the API explicitly.
} // namespace c10

View File

@ -0,0 +1,123 @@
#pragma once
// This is directly synchronized with caffe2/proto/caffe2.proto, but
// doesn't require me to figure out how to get Protobuf headers into
// ATen/core (which would require a lot more build system hacking.)
// If you modify me, keep me synchronized with that file.
#include <c10/macros/Export.h>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <ostream>
#include <string>
namespace c10 {
// These contains all device types that also have a BackendComponent
// and therefore participate in per-backend functionality dispatch keys.
// This is most backends except PrivateUse2 and PrivateUse3
#define C10_FORALL_BACKEND_DEVICE_TYPES(_, extra) \
_(CPU, extra) \
_(CUDA, extra) \
_(HIP, extra) \
_(XLA, extra) \
_(MPS, extra) \
_(IPU, extra) \
_(XPU, extra) \
_(HPU, extra) \
_(VE, extra) \
_(Lazy, extra) \
_(Meta, extra) \
_(MTIA, extra) \
_(PrivateUse1, extra)
enum class DeviceType : int8_t {
CPU = 0,
CUDA = 1, // CUDA.
MKLDNN = 2, // Reserved for explicit MKLDNN
OPENGL = 3, // OpenGL
OPENCL = 4, // OpenCL
IDEEP = 5, // IDEEP.
HIP = 6, // AMD HIP
FPGA = 7, // FPGA
MAIA = 8, // ONNX Runtime / Microsoft
XLA = 9, // XLA / TPU
Vulkan = 10, // Vulkan
Metal = 11, // Metal
XPU = 12, // XPU
MPS = 13, // MPS
Meta = 14, // Meta (tensors with no data)
HPU = 15, // HPU / HABANA
VE = 16, // SX-Aurora / NEC
Lazy = 17, // Lazy Tensors
IPU = 18, // Graphcore IPU
MTIA = 19, // Meta training and inference devices
PrivateUse1 = 20, // PrivateUse1 device
// NB: If you add more devices:
// - Change the implementations of DeviceTypeName and isValidDeviceType
// in DeviceType.cpp
// - Change the number below
COMPILE_TIME_MAX_DEVICE_TYPES = 21,
};
constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kHIP = DeviceType::HIP;
constexpr DeviceType kFPGA = DeviceType::FPGA;
constexpr DeviceType kMAIA = DeviceType::MAIA;
constexpr DeviceType kXLA = DeviceType::XLA;
constexpr DeviceType kMPS = DeviceType::MPS;
constexpr DeviceType kMeta = DeviceType::Meta;
constexpr DeviceType kVulkan = DeviceType::Vulkan;
constexpr DeviceType kMetal = DeviceType::Metal;
constexpr DeviceType kXPU = DeviceType::XPU;
constexpr DeviceType kHPU = DeviceType::HPU;
constexpr DeviceType kVE = DeviceType::VE;
constexpr DeviceType kLazy = DeviceType::Lazy;
constexpr DeviceType kIPU = DeviceType::IPU;
constexpr DeviceType kMTIA = DeviceType::MTIA;
constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1;
// define explicit int constant
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
static_assert(
COMPILE_TIME_MAX_DEVICE_TYPES <= 21,
"Hey! You seem to be adding a lot of new DeviceTypes. The intent was "
"for this constant to reflect the actual number of DeviceTypes we support "
"in PyTorch; it's important that this number is not too large as we "
"use this to allocate stack arrays in some places in our code. If you "
"are indeed just adding the 20th device type, feel free to change "
"the check to 32; but if you are adding some sort of extensible device "
"types registration, please be aware that you are affecting code that "
"this number is small. Try auditing uses of this constant.");
C10_API std::string DeviceTypeName(DeviceType d, bool lower_case = false);
C10_API bool isValidDeviceType(DeviceType d);
C10_API std::ostream& operator<<(std::ostream& stream, DeviceType type);
C10_API void register_privateuse1_backend(const std::string& backend_name);
C10_API std::string get_privateuse1_backend(bool lower_case = true);
C10_API bool is_privateuse1_backend_registered();
} // namespace c10
namespace std {
template <>
struct hash<c10::DeviceType> {
std::size_t operator()(c10::DeviceType k) const {
return std::hash<int>()(static_cast<int>(k));
}
};
} // namespace std
namespace torch {
// NOLINTNEXTLINE(misc-unused-using-decls)
using c10::DeviceType;
} // namespace torch

View File

@ -0,0 +1,747 @@
#pragma once
#include <c10/core/DeviceType.h>
#include <c10/macros/Export.h>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <ostream>
#include <string>
namespace c10 {
// Semantically, each value of BackendComponent identifies a "backend" for our
// dispatch. Some functionalities that we may dispatch to are allowed to
// register different handlers for each backend. The BackendComponent is then
// used to figure out which backend implementation to dispatch to.
// In implementation terms, the backend component identifies a specific "bit" in
// a DispatchKeySet. The bits in the DispatchKeySet are split between the bottom
// ~12 "BackendComponent" bits, while the remaining upper bits are assigned to
// functionalities. When we encounter a functionality bit that is known to be
// customizable per-backend, then we also look at the lower BackendComponent
// bits and take the highest bit to determine which backend's implementation to
// use.
// WARNING! If you add a new backend component to the end of this list,
// make sure you register it before Meta.
// Meta must be at the end so that meta key in tls triggers meta kernels.
// (But you shouldn't: private use keys should have higher precedence than all
// built-in keys)
// If you add a new (non-privateuse) backend here,
// make sure to add an Autograd<Backend> fallthrough kernel
// in aten/src/ATen/core/VariableFallbackKernel.cpp
#define C10_FORALL_BACKEND_COMPONENTS(_, extra) \
_(CPU, extra) \
_(CUDA, extra) \
_(HIP, extra) \
_(XLA, extra) \
_(MPS, extra) \
_(IPU, extra) \
_(XPU, extra) \
_(HPU, extra) \
_(VE, extra) \
_(Lazy, extra) \
_(MTIA, extra) \
_(PrivateUse1, extra) \
_(PrivateUse2, extra) \
_(PrivateUse3, extra) \
_(Meta, extra)
// WARNING! If we add a new per-backend functionality key that has higher
// priority than Autograd, then make sure you update EndOfRuntimeBackendKeys
#define C10_FORALL_FUNCTIONALITY_KEYS(_) \
_(Dense, ) \
_(Quantized, Quantized) \
_(Sparse, Sparse) \
_(SparseCsr, SparseCsr) \
_(NestedTensor, NestedTensor) \
_(AutogradFunctionality, Autograd)
enum class BackendComponent : uint8_t {
// A "backend" is colloquially used to refer to handlers for dispatch
// which actually implement the numerics of an operation in question.
//
// Due to the nature of the enum, these backends are specified in
// an ordered way, but for most backends this order is not semantically
// meaningful (e.g., it's valid to reorder these backends without changing
// semantics). The only situation when backend ordering is meaningful
// is when the backend participates in multiple dispatch with another
// backend; e.g., CPU and CUDA (cuda must have higher priority).
// These keys don't correspond to individual kernels.
// Instead, they represent the backends that are allowed to override specific
// pieces of functionality:
// - dense kernels (e.g. DispatchKey::CPU)
// - sparse kernels (e.g. DispatchKey::SparseCPU)
// - quantized kernels (e.g. DispatchKey::QuantizedCPU)
// - autograd kernels (e.g. DispatchKey::AutogradCPU)
// We reserve space in the runtime operator table for this full cross product
// of
// [backends in this enum] x [keys below that are explicitly marked as having
// per-backend functionality]
//
// A meta tensor is a tensor without any data associated with it. (They
// have also colloquially been referred to as tensors on the "null" device).
// A meta tensor can be used to dry run operators without actually doing any
// computation, e.g., add on two meta tensors would give you another meta
// tensor with the output shape and dtype, but wouldn't actually add anything.
InvalidBit = 0,
#define DEFINE_BACKEND_COMPONENT(n, _) n##Bit,
C10_FORALL_BACKEND_COMPONENTS(DEFINE_BACKEND_COMPONENT, unused)
#undef DEFINE_BACKEND_COMPONENT
// Define an alias to represent end of backend dispatch keys.
// If you add new backend keys after PrivateUse3, please also update it here.
EndOfBackendKeys = MetaBit,
};
// Semantically, a dispatch key identifies a possible "level" in our
// dispatch, for which a handler may be registered. Each handler corresponds
// to a type of functionality.
//
// In implementation terms, the dispatch key identifies a specific "bit" in a
// DispatchKeySet. Higher bit indexes get handled by dispatching first (because
// we "count leading zeros" when we extract the highest priority dispatch
// key.)
//
// Note [DispatchKey Classification]
// This enum actually contains several types of keys, which are explained
// in more detail further down:
// (1) non-customizable backends (e.g. FPGA)
// (2) non-customizable functionalities (e.g. Functionalize)
// (3) functionalized that are customizable per backend (e.g. Dense, Sparse,
// AutogradFunctionality) (4) per-backend instances of customizable
// functionalities (e.g. CPU, SparseCPU, AutogradCPU) (5) alias keys (e.g.
// CompositeImplicitAutograd)
//
// Of the categories above, it's important to note:
// (a) which keys are assigned individual bits in a DispatchKeySet
// (b) which keys are assigned individual slots in the runtime operator table
// ("Runtime keys")
//
// (1), (2) and (3) all get their own dedicated bits in the DispatchKeySet.
// (1), (2) and (4) all get their own dedicated slots in the runtime operator
// table.
// See Note [DispatchKeySet Internal Representation] for more details.
//
// NOTE: Keep the list in sync with `DispatchKey` in torchgen/model.py
enum class DispatchKey : uint16_t {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~ UNDEFINED ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
// This is not a "real" functionality, but it exists to give us a "nullopt"
// element we can return for cases when a DispatchKeySet contains no elements.
// You can think a more semantically accurate definition of DispatchKey is:
//
// using DispatchKey = std::optional<RealDispatchKey>
//
// and Undefined == nullopt. We didn't actually represent
// it this way because std::optional<RealDispatchKey> would take two
// words, when DispatchKey fits in eight bits.
Undefined = 0,
// Define an alias for Undefined to represent CatchAll (long term
// this will get eliminated, but for now it's convenient)
CatchAll = Undefined,
// ~~~~~~~~~~~~~~~~~~~~~~~~~~ Functionality Keys ~~~~~~~~~~~~~~~~~~~~~~ //
// Every value in the enum (up to EndOfFunctionalityKeys)
// corresponds to an individual "functionality" that can be dispatched to.
// This is represented in the DispatchKeySet by assigning each of these enum
// values
// to each of the remaining (64 - len(BackendComponent)) bits.
//
// Most of these functionalities have a single handler assigned to them,
// making them "runtime keys".
// That map to a single slot in the runtime operator table.
//
// A few functionalities are allowed to be customizable per backend.
// See [Note: Per-Backend Functionality Dispatch Keys] for details.
// See [Note: Per-Backend Functionality Dispatch Keys]
Dense,
// Below are non-extensible backends.
// These are backends that currently don't have their own overrides for
// Autograd/Sparse/Quantized kernels,
// and we therefore don't waste space in the runtime operator table allocating
// space for them.
// If any of these backends ever need to customize, e.g., Autograd, then we'll
// need to add a DispatchKey::*Bit for them.
// TODO: put this in BackendComponents
FPGA, // Xilinx support lives out of tree at
// https://gitlab.com/pytorch-complex/vitis_kernels
// TODO: put this in BackendComponents
// MAIA backend lives out of tree
// - test/cpp_extensions/maia_extension.cpp
// - test/test_torch.py
// - aten/src/ATen/test/extension_backend_test.cpp
MAIA,
Vulkan, // TODO: put this in BackendComponents
Metal, // TODO: put this in BackendComponents
// See [Note: Per-Backend Functionality Dispatch Keys]
Quantized,
// This backend is to support custom RNGs; it lets you go
// to a different kernel if you pass in a generator that is not a
// traditional CPUGeneratorImpl/CUDAGeneratorImpl. To make use of this
// key:
// 1) set it as a second parameter of at::Generator constructor call in
// the user-defined PRNG class.
// 2) use it as a dispatch key while registering custom kernels
// (templatized kernels specialized for user-defined PRNG class)
// intended for out of tree use; tested by aten/src/ATen/test/rng_test.cpp
CustomRNGKeyId,
// TODO: Make Mkldnn a functionality key, so we can give it Meta
// support
// Here are backends which specify more specialized operators
// based on the layout of the tensor. Note that the sparse backends
// are one case where ordering matters: sparse multi-dispatches with
// the corresponding dense tensors, and must be handled before them.
MkldnnCPU, // registered at build/aten/src/ATen/RegisterMkldnnCPU.cpp
// NB: not to be confused with MKLDNN, which is Caffe2 only
// See [Note: Per-Backend Functionality Dispatch Keys]
Sparse,
SparseCsr,
NestedTensor,
// In some situations, it is not immediately obvious what the correct
// backend for function is, because the function in question doesn't
// have any "tensor" arguments. In this case, a BackendSelect function
// can be registered to implement the custom determination of the
// correct backend.
BackendSelect,
Python,
// Out-of-core key for Fake Tensor in torchdistx.
// See https://pytorch.org/torchdistx/latest/fake_tensor.html
// TODO: delete this in favor of Python-implemented fake tensor
Fake,
// See Note [Out-of-tree vmap+grad prototype]. The purpose of this key
// is to insert code after the "autograd subsystem" runs, so this key should
// be directly after ADInplaceOrView and all of the autograd keys.
FuncTorchDynamicLayerBackMode,
// Alias and mutation removal.
// If some backends want to opt into only alias removal or only mutation
// removal,
// we can consider adding separate keys dedicated to those individual passes.
// See Note [Functionalization Pass In Core] for details.
Functionalize,
// The named dispatch key is set for any tensors with named dimensions.
// Although we have a dispatch key for named tensors, for historical reasons,
// this dispatch key doesn't do any of the substantive functionality for named
// tensor (though, hypothetically, it could!) At the moment, it's just
// responsible for letting us give good error messages when operations
// don't support named tensors.
//
// NB: If you ever consider moving named tensor functionality into
// this dispatch key, note that it might be necessary add another dispatch
// key that triggers before composite operators, in case a composite operator
// has named dimension propagation that doesn't match that of its
// constituent parts.
// TODO: delete this once torchdim lands in functorch
Named,
// The Conjugate dispatch key is set for any tensors that need to perform
// conjugation
// This is implemented at a dispatch level right before any backends run
Conjugate,
// The Negative dispatch key is set for any tensors that need to perform
// negation
// This is implemented at a dispatch level right before any backends run
Negative,
ZeroTensor, // registered at build/aten/src/ATen/RegisterZeroTensor.cpp
// Note [ADInplaceOrView key]
// ADInplaceOrView key is used by inplace or view ops to register a kernel
// that does additional setup for future autograd computation.
//
// 1. For inplace ops this kernel does version bump
// 2. For view ops this kernel does `as_view` setup where we properly setup
// DifferentiableViewMeta on the view tensors.
//
// For other ops it's fallthrough kernel since there's no extra
// work to do.
//
// Note [Dream: skip VariableType kernel when requires_grad=false]
//
// In an ideal world where we can skip VariableType kernel for inputs
// with requires_grad=false, instead of a fallthrough kernel, we'll
// register a kernel shown below to all functional ops as well:
// torch::Tensor my_functional_op(...) {
// {
// // Note for every op in VariableType, you need to go through
// // `AutoDispatchBelowADInplaceOrView` guard exactly once to add the
// // key to TLS excluded set. If you don't go through it at all,
// // inplace/view ops called through `at::` inside your backend
// // kernel will dispatch to ADInplaceOrView kernels and do a lot
// // of extra work.
// at::AutoDispatchBelowADInplaceOrView guard;
// at::redispatch::my_functional_op(...);
// }
// }
// But this work is currently blocked since it adds an extra dispatch
// for all ops and it's non-trivial overhead at model level(a few percents).
// Thus our current approach takes advantage of the fact every kernel go
// through VariableType kernel first and pulls the
// `at::AutoDispatchBelowADInplaceOrView` guard of functional ops
// up to the `VariableType` kernel. Thus we only add the extra dispatch
// to view/inplace ops to minimize its perf impact to real models.
ADInplaceOrView,
// Note [Alias Dispatch Key : Autograd]
// All backends are oblivious to autograd; autograd is handled as a
// layer which happens on top of all backends. It inspects the autograd
// metadata of all inputs, determines what autograd metadata should be
// constructed by the output, and otherwise defers to the backend to
// actually do the numeric computation. Autograd contains
// the bulk of this logic.
// Autograd is now an alias dispatch key which by default maps to all
// backend-specific autograd keys.
// Backend-specific allow backends to override the default kernel registered
// to Autograd key as needed.
// For example, XLA wants to define autograd for einsum directly.
// Registering a custom autograd implementation at the XLA key won't work
// because we process Autograd before XLA. This key has higher priority and
// gets processed first. You generally should NOT redispatch after handling
// autograd here (since that would result in execution of the Autograd
// operator, which you're trying to skip). In AutogradXLA implementations,
// you are responsible for handling autograd yourself, or deferring to other
// operators which support autograd.
// Currently we only have backend-specific autograd keys for CPU/CUDA/XLA and
// reserved user-defined backends. All other in-tree backends share the
// AutogradOther key. We can add specific autograd key for those backends
// upon request.
AutogradOther,
// See [Note: Per-Backend Functionality Dispatch Keys]
AutogradFunctionality,
// NestedTensor is an example of something that isn't a "real backend"
// (because it mostly consists of redispatching kernels)
// but it would like to override autograd functionality in C++.
// We can handle cases like this by adding an extra functionality key
// exclusively for handling autograd for NestedTensor.
// lives out of tree at
// https://github.com/pytorch/nestedtensor
AutogradNestedTensor,
Tracer,
// TODO: make Autocast a functionality key
// Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed
// and inputs are saved for backward in the post-autocast type.
AutocastCPU,
AutocastXPU,
AutocastIPU,
AutocastHPU,
AutocastXLA,
// AutocastXLA is only being used for TPUs. XLA GPUs continue to use
// AutocastCUDA.
AutocastMPS,
AutocastCUDA,
AutocastPrivateUse1,
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
// There are a number of alternative modes which may want to handle before
// autograd; for example, error checking, tracing, profiling or vmap. They
// go here.
FuncTorchBatched, // See Note [Out-of-tree vmap+grad prototype]
// Dispatch key for BatchedTensorImpl wrapping a nested tensor.
BatchedNestedTensor,
FuncTorchVmapMode, // See Note [Out-of-tree vmap+grad prototype]
// This is the dispatch key for BatchedTensorImpl, which is used to implement
// batching rules for vmap.
Batched,
// When we are inside a vmap, all tensors dispatch on this key.
// See Note: [DispatchKey::VmapMode usage] for more details.
VmapMode,
FuncTorchGradWrapper, // See Note [Out-of-tree vmap+grad prototype]
// Out-of-core key for Deferred Module Initialization in torchdistx.
// See https://pytorch.org/torchdistx/latest/deferred_init.html
DeferredInit,
// Used by Python key logic to know the set of tls on entry to the dispatcher
// This kernel assumes it is the top-most non-functorch-related DispatchKey.
// If you add a key above, make sure to update the fallback implementation for
// this.
PythonTLSSnapshot,
// This key should be at the very top of the dispatcher
FuncTorchDynamicLayerFrontMode, // See Note [Out-of-tree vmap+grad prototype]
// TESTING: This is intended to be a generic testing tensor type id.
// Don't use it for anything real; its only acceptable use is within a single
// process test. Use it by creating a TensorImpl with this DispatchKey, and
// then registering operators to operate on this type id. See
// aten/src/ATen/core/dispatch/backend_fallback_test.cpp for a usage example.
TESTING_ONLY_GenericWrapper,
// TESTING: This is intended to be a generic testing tensor type id.
// Don't use it for anything real; its only acceptable use is within a ingle
// process test. Use it by toggling the mode on and off via
// TESTING_ONLY_tls_generic_mode_set_enabled and then registering operators
// to operate on this type id. See
// aten/src/ATen/core/dispatch/backend_fallback_test.cpp
// for a usage example
TESTING_ONLY_GenericMode,
// This key is used for pre-dispatch tracing in make_fx.
// It has lower priority than the PythonDispatcher key
// because we use the PythonDispatcher to intercept the key from python,
// and avoid having to implement it in C++.
PreDispatch,
// This is a bypass that allows you to skip running the C++ dispatcher
// entirely
PythonDispatcher,
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
EndOfFunctionalityKeys, // End of functionality keys.
// ~~~~~~~~~~~~~~ "Dense" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~~ //
// Here are backends which you think of as traditionally specifying
// how to implement operations on some device.
#define DEFINE_PER_BACKEND_KEYS_FOR_BACKEND(n, prefix) prefix##n,
#define DEFINE_PER_BACKEND_KEYS(fullname, prefix) \
StartOf##fullname##Backends, \
C10_FORALL_BACKEND_COMPONENTS( \
DEFINE_PER_BACKEND_KEYS_FOR_BACKEND, prefix) \
EndOf##fullname##Backends = prefix##Meta,
C10_FORALL_FUNCTIONALITY_KEYS(DEFINE_PER_BACKEND_KEYS)
#undef DEFINE_PER_BACKEND_KEYS
#undef DEFINE_PER_BACKEND_KEYS_FOR_BACKEND
EndOfRuntimeBackendKeys = EndOfAutogradFunctionalityBackends,
// ~~~~~~~~~~~~~~~~~~~~~~ Alias Dispatch Keys ~~~~~~~~~~~~~~~~~~~~~~~~~~ //
// Note [Alias Dispatch Keys]
// Alias dispatch keys are synthetic dispatch keys which map to multiple
// runtime dispatch keys. Alisa keys have precedence, but they are always
// lower precedence than runtime keys. You can register a kernel to an
// alias key, the kernel might be populated to the mapped runtime keys
// during dispatch table computation.
// If a runtime dispatch key has multiple kernels from alias keys, which
// kernel wins is done based on the precedence of alias keys (but runtime
// keys always have precedence over alias keys).
// Alias keys won't be directly called during runtime.
// See Note [Alias Dispatch Key : Autograd]
Autograd,
CompositeImplicitAutograd, // registered at
// build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp
// Note: The alias keyset for FuncTorchBatchedDecomposition is disjoint from
// all
// other alias keysets
// and so precedence order doesn't matter
FuncTorchBatchedDecomposition, // registered at
// build/aten/src/ATen/RegisterFuncTorchBatchedDecomposition.cpp
// Note: The alias keyset for CompositeImplicitAutogradNestedTensor is
// disjoint from all other alias keysets
CompositeImplicitAutogradNestedTensor, // registered at
// build/aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp
CompositeExplicitAutograd, // registered at
// build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp
// See Note [CompositeExplicitAutogradNonFunctional Key]
CompositeExplicitAutogradNonFunctional, // registered at
// build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp
// Define an alias key to represent end of alias dispatch keys.
// If you add new alias keys after Autograd, please also update it here.
StartOfAliasKeys = Autograd,
EndOfAliasKeys = CompositeExplicitAutogradNonFunctional, //
// ~~~~~~~~~~~~~~~~~~~~~~~~~ BC ALIASES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
// The aliases exist for backwards compatibility reasons, they shouldn't
// be used
CPUTensorId = CPU,
CUDATensorId = CUDA,
DefaultBackend = CompositeExplicitAutograd,
PrivateUse1_PreAutograd = AutogradPrivateUse1,
PrivateUse2_PreAutograd = AutogradPrivateUse2,
PrivateUse3_PreAutograd = AutogradPrivateUse3,
Autocast = AutocastCUDA,
};
// Note [Private use DispatchKey]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Private use tensor IDs are preallocated tensor type IDs for use in user
// applications. Similar to private use fields in HTTP, they can be used
// by end users for experimental or private applications, without needing
// to "standardize" the tensor ID (which would be done by submitting a PR
// to PyTorch to add your type ID).
//
// Private use tensor IDs are appropriate to use if you want to experiment
// with adding a new tensor type (without having to patch PyTorch first) or
// have a private, non-distributed application that needs to make use of a
// new tensor type. Private use tensor IDs are NOT appropriate to use for
// libraries intended to be distributed to further users: please contact
// the PyTorch developers to get a type ID registered in this case.
//
// We provide two classes of private user tensor id: regular DispatchKeys
// and Autograd DispatchKeys. DispatchKeys serve the role of ordinary "backend"
// DispatchKeys; if you were adding support for a new type of accelerator, you
// would use a backend DispatchKey, and ideally automatically reuse
// AutogradOther definitions already defined in PyTorch. AutogradPrivateUse
// DispatchKeys serve as "wrapper" DispatchKeys: they are only necessary for
// tensors that compose multiple internal tensors, and for cases when the
// built-in autograd formulas for operators are not appropriate.
static_assert(
(static_cast<uint8_t>(BackendComponent::EndOfBackendKeys) +
static_cast<uint8_t>(DispatchKey::EndOfFunctionalityKeys)) <= 64,
"The BackendComponent and DispatchKey enums (below EndOfFunctionalityKeys)"
" both map to backend and functionality bits"
" into a 64-bit bitmask; you must have less than 64 total entries between them");
// Check if a DispatchKey is an alias mapping to other runtime keys.
constexpr bool isAliasDispatchKey(DispatchKey k) {
return k >= DispatchKey::StartOfAliasKeys && k <= DispatchKey::EndOfAliasKeys;
}
// [Note: Per-Backend Functionality Dispatch Keys]
// Check if a DispatchKey is a per-backend functionality key
// Any functionalities that can be customized per-backend should be added here.
// These keys correspond to functionalities that can be customized individually
// per backend. While they only take up one bit in the `DispatchKeySet` bitset,
// they map to (# backends) slots in the operator table.
// Each of these keys also has a separate set of "runtime keys" in the dispatch
// key enum, per backend, which *do* map to the individual operator table slots.
// For example, the "Sparse" key maps to an individual bit in the
// DispatchKeySet, while `SparseCPU`, `SparseCUDA`, etc all map to individual
// slots in the runtime operator table.
constexpr bool isPerBackendFunctionalityKey(DispatchKey k) {
if (k == DispatchKey::Dense || k == DispatchKey::Quantized ||
k == DispatchKey::Sparse || k == DispatchKey::SparseCsr ||
k == DispatchKey::AutogradFunctionality ||
k == DispatchKey::NestedTensor) {
return true;
} else {
return false;
}
}
// Note that this includes Undefined in the total count.
// BUT EndOfFunctionalityKeys is its own (placeholder) key.
// e.g. Undefined=0, Dense=1, Sparse=2, EndOfFunctionalityKeys=3.
// In the above example, there are 3 total functionality keys.
constexpr uint8_t num_functionality_keys =
static_cast<uint8_t>(DispatchKey::EndOfFunctionalityKeys);
constexpr uint8_t num_backends =
static_cast<uint8_t>(BackendComponent::EndOfBackendKeys);
// Note [No More Than 16 Backends]
// Search for this note to find places in the code where the "no more than 16
// backends" invariant is baked in.
static_assert(
static_cast<uint8_t>(BackendComponent::EndOfBackendKeys) <= 16,
"BackendComponent currently only supports <= 16 backends. If we really need to extend this, \
there are a few places where this invariant is baked in");
constexpr uint8_t numPerBackendFunctionalityKeys() {
uint8_t count = 0;
for (uint8_t k = 0; k <= num_functionality_keys; ++k) {
if (isPerBackendFunctionalityKey(static_cast<DispatchKey>(k)))
++count;
}
return count;
}
#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
// See [Note: Trimmed Mobile Dispatch Keys]
constexpr uint16_t num_runtime_entries = 8;
#else
constexpr uint16_t num_runtime_entries = num_functionality_keys +
(numPerBackendFunctionalityKeys() * (num_backends - 1));
#endif
// See Note [No More Than 16 Backends]
constexpr uint16_t full_backend_mask =
(static_cast<uint16_t>(1) << num_backends) - 1;
C10_API const char* toString(DispatchKey);
C10_API const char* toString(BackendComponent);
C10_API std::ostream& operator<<(std::ostream&, DispatchKey);
C10_API std::ostream& operator<<(std::ostream&, BackendComponent);
C10_API DispatchKey getAutogradKeyFromBackend(BackendComponent k);
// Parses a string into a dispatch key.
// If the string cannot be correctly parsed, throws an exception.
C10_API c10::DispatchKey parseDispatchKey(const std::string& k);
// These are some convenience identifiers for dispatch keys which are
// shorter to type than their long counterparts. Note that some of these
// dispatch keys directly correspond to DeviceType; and most APIs that
// accept DispatchKey also accept DeviceType; e.g.,
// torch::dispatch(torch::kCPU, ...) is also valid.
constexpr DispatchKey kAutograd = DispatchKey::Autograd;
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
// This function relies on the invariant that the dispatch keys between
// StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend
// in the same order as `BackendComponent`.
constexpr BackendComponent toBackendComponent(DispatchKey k) {
if (k >= DispatchKey::StartOfDenseBackends &&
k <= DispatchKey::EndOfDenseBackends) {
return static_cast<BackendComponent>(
static_cast<uint8_t>(k) -
static_cast<uint8_t>(DispatchKey::StartOfDenseBackends));
} else if (
k >= DispatchKey::StartOfQuantizedBackends &&
k <= DispatchKey::EndOfQuantizedBackends) {
return static_cast<BackendComponent>(
static_cast<uint8_t>(k) -
static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends));
} else if (
k >= DispatchKey::StartOfSparseBackends &&
k <= DispatchKey::EndOfSparseBackends) {
return static_cast<BackendComponent>(
static_cast<uint8_t>(k) -
static_cast<uint8_t>(DispatchKey::StartOfSparseBackends));
} else if (
k >= DispatchKey::StartOfSparseCsrBackends &&
k <= DispatchKey::EndOfSparseCsrBackends) {
return static_cast<BackendComponent>(
static_cast<uint8_t>(k) -
static_cast<uint8_t>(DispatchKey::StartOfSparseCsrBackends));
} else if (
k >= DispatchKey::StartOfNestedTensorBackends &&
k <= DispatchKey::EndOfNestedTensorBackends) {
return static_cast<BackendComponent>(
static_cast<uint8_t>(k) -
static_cast<uint8_t>(DispatchKey::StartOfNestedTensorBackends));
} else if (
k >= DispatchKey::StartOfAutogradFunctionalityBackends &&
k <= DispatchKey::EndOfAutogradFunctionalityBackends) {
return static_cast<BackendComponent>(
static_cast<uint8_t>(k) -
static_cast<uint8_t>(
DispatchKey::StartOfAutogradFunctionalityBackends));
} else {
return BackendComponent::InvalidBit;
}
}
constexpr DispatchKey toFunctionalityKey(DispatchKey k) {
if (k <= DispatchKey::EndOfFunctionalityKeys) {
return k;
} else if (k <= DispatchKey::EndOfDenseBackends) {
return DispatchKey::Dense;
} else if (k <= DispatchKey::EndOfQuantizedBackends) {
return DispatchKey::Quantized;
} else if (k <= DispatchKey::EndOfSparseBackends) {
return DispatchKey::Sparse;
} else if (k <= DispatchKey::EndOfSparseCsrBackends) {
return DispatchKey::SparseCsr;
} else if (k <= DispatchKey::EndOfNestedTensorBackends) {
return DispatchKey::NestedTensor;
} else if (k <= DispatchKey::EndOfAutogradFunctionalityBackends) {
return DispatchKey::AutogradFunctionality;
} else {
return DispatchKey::Undefined;
}
}
BackendComponent toBackendComponent(DeviceType device_type);
// Given (DispatchKey::Dense, BackendComponent::CUDABit), returns
// DispatchKey::CUDA.
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
// This function relies on the invariant that the dispatch keys between
// StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend
// in the same order as `BackendComponent`.
constexpr DispatchKey toRuntimePerBackendFunctionalityKey(
DispatchKey functionality_k,
BackendComponent backend_k) {
if (functionality_k == DispatchKey::Dense) {
return static_cast<DispatchKey>(
static_cast<uint8_t>(DispatchKey::StartOfDenseBackends) +
static_cast<uint8_t>(backend_k));
}
if (functionality_k == DispatchKey::Sparse) {
return static_cast<DispatchKey>(
static_cast<uint8_t>(DispatchKey::StartOfSparseBackends) +
static_cast<uint8_t>(backend_k));
}
if (functionality_k == DispatchKey::SparseCsr) {
return static_cast<DispatchKey>(
static_cast<uint8_t>(DispatchKey::StartOfSparseCsrBackends) +
static_cast<uint8_t>(backend_k));
}
if (functionality_k == DispatchKey::Quantized) {
return static_cast<DispatchKey>(
static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends) +
static_cast<uint8_t>(backend_k));
}
if (functionality_k == DispatchKey::NestedTensor) {
return static_cast<DispatchKey>(
static_cast<uint8_t>(DispatchKey::StartOfNestedTensorBackends) +
static_cast<uint8_t>(backend_k));
}
if (functionality_k == DispatchKey::AutogradFunctionality) {
return static_cast<DispatchKey>(
static_cast<uint8_t>(
DispatchKey::StartOfAutogradFunctionalityBackends) +
static_cast<uint8_t>(backend_k));
}
return DispatchKey::Undefined;
}
} // namespace c10
namespace torch {
// Expose the constant, but not the TYPE (DispatchKey is an implementation
// detail!)
// NOLINTNEXTLINE(misc-unused-using-decls)
using c10::kAutograd;
} // namespace torch
// NB: You really shouldn't use this instance; this enum is guaranteed
// to be pretty small so a regular array should be acceptable.
namespace std {
template <>
struct hash<c10::DispatchKey> {
typedef size_t result_type;
typedef c10::DispatchKey argument_type;
size_t operator()(c10::DispatchKey x) const {
return static_cast<size_t>(x);
}
};
} // namespace std

View File

@ -0,0 +1,949 @@
#pragma once
#include <c10/core/DispatchKey.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/TypeList.h>
#include <c10/util/llvmMathExtras.h>
#include <array>
#include <cstddef>
#include <cstdint>
#include <initializer_list>
#include <iterator>
#include <ostream>
#include <string>
#include <type_traits>
namespace c10 {
struct FunctionalityOffsetAndMask {
// empty constructor shouldn't be used; only needed to initialize
// the array before populating it.
FunctionalityOffsetAndMask() = default;
FunctionalityOffsetAndMask(uint16_t offset, uint16_t mask)
: offset(offset), mask(mask) {}
// This needs to big enough to cover the size of the operator table.
uint16_t offset{};
// See Note [No More Than 16 Backends]
// This mask needs to be big enough to mask all of the backend bits.
// We probably don't ever want to have more than 16 backend bits, so uint16_t
// should be enough.
uint16_t mask{};
};
static_assert(
c10::num_runtime_entries < 65536,
"The dispatcher currently only supports up to 2^16 runtime entries");
C10_API std::array<FunctionalityOffsetAndMask, num_functionality_keys>
initializeFunctionalityOffsetsAndMasks();
C10_ALWAYS_INLINE static const std::
array<FunctionalityOffsetAndMask, num_functionality_keys>&
offsetsAndMasks() {
static auto offsets_and_masks_ = initializeFunctionalityOffsetsAndMasks();
return offsets_and_masks_;
}
// A representation of a set of DispatchKeys. A DispatchKeySet contains both
// "functionality" bits and "backend bits", and every tensor holds its own
// DispatchKeySet. The Dispatcher implements multiple dispatch by grabbing the
// keyset on every input tensor, oring them together, and dispatching to a
// specific piece of functionality. The functionality bits are *ordered*. When
// multiple functionality bits are set, we use the highest priority
// functionality. Similarly, multiple backend bits can theoretically be set if
// you call an operator with multiple tensors from difference devices (e.g. CPU
// and CUDA), although support for mixed device dispatch is limited (the only
// kernels that gracefully handle mixed device inputs for now are cuda kernels
// that take in a scalar cpu tensor).
// A representation of a set of DispatchKeys. A tensor may have multiple
// tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the
// DispatchKeySet specifies what type ids apply. The internal representation is
// as a 64-bit bit set (this means only 64 tensor type ids are supported).
//
// As mentioned above, DispatchKeys are ordered; thus, we can ask questions like
// "what is the highest priority DispatchKey in the set"? (The set itself is
// not ordered; two sets with the same ids will always have the ids ordered in
// the same way.)
//
// Note [DispatchKeySet Internal Representation]
// Internally, dispatch keys are packed into 64-bit DispatchKeySet objects
// that get passed around at runtime.
// However, there isn't necessarily a 1-to-1 mapping between bits in the keyset
// and individual dispatch keys.
//
// First: why do we have this distinction, and why not map every dispatch key
// directly to a bit? This is mostly because we have several types of
// functionalities that different backends would like to customize. For example,
// we have:
// - "Dense": CPU, CUDA, XLA, ... (~12 keys)
// - "Sparse": SparseCPU, SparseCUDA, ...
// - "SparseCsr": SparseCsrCPU, SparseCsrCUDA, ...
// - "Quantized": QuantizedCPU, QuantizedCUDA, QuantizedXLA, ...
// - "Autograd": AutogradCPU, AutogradCUDA, Autograd XLA, ...
// The problem is that total number of keys grows quadratically with [#
// backends] x [# functionalities], making it very difficult to map each key
// directly to a bit in a bitset without dramatically increasing the size of the
// bitset over time.
//
// The two enums (BackendComponent and DispatchKey) can be divided roughly into
// 5 categories.
//
// (1) "Building block" keys
// (a) backends: Everything in the BackendComponent enum (e.g. CPUBit,
// CUDABit) (b) functionalities: (per-backend) functionality-bit DispatchKeys
// (e.g. AutogradFunctionality, SparseCsr, Sparse, Dense)
// (2) "Runtime" keys
// (a) "non-customizable backends" (e.g. FPGA)
// (b) "non-customizable functionalities" (e.g. Functionalize)
// (c) "per-backend instances of customizable functionalities" (e.g. CPU,
// SparseCPU, AutogradCPU)
// (3) "Alias" DispatchKeys (see Note [Alias Dispatch Keys])
//
// (1) Building block keys always correspond to individual bits in a
// DispatchKeySet. They can also be combined in a DispatchKeySet to form actual
// runtime keys. e.g.
// auto dense_cpu_ks = DispatchKeySet({DispatchKey::CPUBit,
// DispatchKey::Dense});
// // The keyset has the runtime dense-cpu key.
// dense_cpu_ks.has(DispatchKey::CPU);
// // And it contains the building block keys too.
// dense_cpu_ks.has(DispatchKey::CPUBit);
// dense_cpu_ks.has(DispatchKey::Dense);
//
// Not every backend and not every functionality counts as a "building block
// key". This is mostly to give us more levers to pull in the design space.
// Backend keys and functionality keys that count as "building blocks" will
// contribute to a full cross product of functionality that can be overriden.
//
// For example, right now we have at least 12 "backend" building
// blocks (CPU, CUDA, XLA, ...) and at least 5 "functionality"
// building blocks (Dense, Sparse, SparseCsr, Quantized,
// AutogradFunctionality, ...). These keys together allow every
// dispatcher operator to be customized in up to 12*4 different
// ways. Each of those requires a slot in the operator table of every
// dispatcher operator. Not every piece of functionality necessarily
// needs to be customizable per-backend, and not every backend
// necessarily needs to be able to customize every type of
// functionality.
//
//
// (2) Every runtime key corresponds directly to a slot in an operator's runtime
// dispatch table, and you can directly register kernels to a runtime dispatch
// key.
//
// For per-backend functionalities like "Dense" or "AutogradFunctionality",
// you can think of the corresponding runtime dispatch keys as "instances" of
// that functionality, per backend. E.g. "CPU", "CUDA", "XLA", etc. are all
// runtime instances of the "Dense" building block key.
// (2a) and (2b) are represented identically in the DispatchKeySet logic:
// - backend-agnostic functionalities (e.g. FuncTorchBatched) are NOT
// customizable per backend.
// In order to do so, we'd need to promote it to a per-backend functionality
// "building block" key.
// - non-customizable backends (e.g. FPGA) can NOT customize existing
// functionality like Sparse, Autograd, etc.
// In order to do so, we'd need to promote it to a backend "building block"
// key.
//
// In both cases, these keys directly correspond to runtime slots in the
// operator table.
//
//
// (3) "Alias" keys
// See Note [Alias Dispatch Keys]
//
// Final note: for anyone making future changes to the Dispatcher +
// DispatchKeySet internals, there's a closed PR with a basic
// python-implementation of the Dispatcher that might be useful in quickly
// testing out and validating changes. See it at
// https://github.com/pytorch/pytorch/pull/68743
// An undefined tensor is one with an empty tensor type set.
class DispatchKeySet final {
public:
enum Full { FULL };
enum FullAfter { FULL_AFTER };
enum Raw { RAW };
// NB: default constructor representation as zero is MANDATORY as
// use of DispatchKeySet in TLS requires this.
constexpr DispatchKeySet() = default;
constexpr DispatchKeySet(Full)
: repr_((1ULL << (num_backends + num_functionality_keys - 1)) - 1) {}
constexpr DispatchKeySet(FullAfter, DispatchKey t)
// LSB after t are OK, but not t itself.
// "functionalities" have a notion of ordering (e.g. Autograd > Sparse >
// Quantized > Dense). But backends don't really have an ordering.
// Therefore, we're enforcing that FullAfter can only be used on
// "functionality" keys.
: repr_(
(1ULL
<< (num_backends + static_cast<uint8_t>(toFunctionalityKey(t)) -
1)) -
1) {
*this = add(DispatchKey::PythonDispatcher);
}
// Public version of DispatchKeySet(uint64_t) API; external users
// must be explicit when they do this!
constexpr DispatchKeySet(Raw, uint64_t x) : repr_(x) {}
constexpr explicit DispatchKeySet(BackendComponent k) {
if (k == BackendComponent::InvalidBit) {
repr_ = 0;
} else {
repr_ = 1ULL << (static_cast<uint8_t>(k) - 1);
}
}
constexpr explicit DispatchKeySet(DispatchKey k) {
// NOLINTNEXTLINE(bugprone-branch-clone)
if (k == DispatchKey::Undefined) {
// Case 1: handle Undefined specifically
repr_ = 0;
} else if (k <= DispatchKey::EndOfFunctionalityKeys) {
// Case 2: handle "functionality-only" keys
// These keys have a functionality bit set, but no backend bits
// These can technically be either:
// - valid runtime keys (e.g. DispatchKey::AutogradOther,
// DispatchKey::FuncTorchBatched, etc)
// - "building block" keys that aren't actual runtime keys (e.g.
// DispatchKey::Dense or Sparse)
uint64_t functionality_val = 1ULL
<< (num_backends + static_cast<uint8_t>(k) - 1);
repr_ = functionality_val;
} else if (k <= DispatchKey::EndOfRuntimeBackendKeys) {
// Case 3: "runtime" keys that have a functionality bit AND a backend bit.
// First compute which bit to flip for the functionality.
auto functionality_k = toFunctionalityKey(k);
// The - 1 is because Undefined is technically a "functionality" that
// doesn't show up in the bitset. So e.g. Dense is technically the second
// functionality, but the lowest functionality bit.
uint64_t functionality_val = 1ULL
<< (num_backends + static_cast<uint8_t>(functionality_k) - 1);
// then compute which bit to flip for the backend
// Case 4a: handle the runtime instances of "per-backend functionality"
// keys For example, given DispatchKey::CPU, we should set:
// - the Dense functionality bit
// - the CPUBit backend bit
// first compute which bit to flip for the backend
auto backend_k = toBackendComponent(k);
uint64_t backend_val = backend_k == BackendComponent::InvalidBit
? 0
: 1ULL << (static_cast<uint8_t>(backend_k) - 1);
repr_ = functionality_val + backend_val;
} else {
// At this point, we should have covered every case except for alias keys.
// Technically it would be possible to add alias dispatch keys to a
// DispatchKeySet, but the semantics are a little confusing and this
// currently isn't needed anywhere.
repr_ = 0;
}
}
constexpr uint64_t keys_to_repr(std::initializer_list<DispatchKey> ks) {
uint64_t repr = 0;
for (auto k : ks) {
repr |= DispatchKeySet(k).repr_;
}
return repr;
}
constexpr uint64_t backend_bits_to_repr(
std::initializer_list<BackendComponent> ks) {
uint64_t repr = 0;
for (auto k : ks) {
repr |= DispatchKeySet(k).repr_;
}
return repr;
}
explicit constexpr DispatchKeySet(std::initializer_list<DispatchKey> ks)
: repr_(keys_to_repr(ks)) {}
explicit constexpr DispatchKeySet(std::initializer_list<BackendComponent> ks)
// Note: for some reason, putting this logic directly in the constructor
// appears to fail to compile on CUDA 10.1.
// See an example internal failure at
// https://www.internalfb.com/intern/skycastle/run/76561193669136035/artifact/actionlog.76561193742069401.stderr
: repr_(backend_bits_to_repr(ks)) {}
// Test if a DispatchKey is in the set
inline bool has(DispatchKey t) const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined);
return has_all(DispatchKeySet(t));
}
constexpr bool has_backend(BackendComponent t) const {
return has_all(DispatchKeySet(t));
}
// Test if a DispatchKey is in the set
// Given a DispatchKeySet of functionality keys and (potentially) backend
// keys, tests if all of them are in the current set.
constexpr bool has_all(DispatchKeySet ks) const {
return static_cast<bool>((repr_ & ks.repr_) == ks.repr_);
}
// Given a DispatchKeySet of functionality keys and (potentially) backend
// keys, tests if any of them are in the current set. This could technically
// be pretty easily implemented using has(). It is strictly a perf
// optimization though. There are many places in the code base where we want
// to test for multiple functionality keys together. HOWEVER, runtime
// per-backend functionality keys aren't allowed to be used with this
// function, because you can end up with weird results. e.g.
// DispatchKeySet(DispatchKey::AutogradCPU).has_any(DispatchKeySet(DispatchKey::CPU))
// would return true.
inline bool has_any(DispatchKeySet ks) const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
// Either there are no backend bits in the input keyset
((ks.repr_ & full_backend_mask) == 0) ||
// or there are no per-backend-functionality bits
// See [Note: Per-Backend Functionality Dispatch Keys]
((ks &
DispatchKeySet({
DispatchKey::Dense,
DispatchKey::Quantized,
DispatchKey::Sparse,
DispatchKey::SparseCsr,
DispatchKey::AutogradFunctionality,
})
.repr_) == 0));
return static_cast<bool>((repr_ & ks.repr_) != 0);
}
// Test if DispatchKeySet is a superset of ks.
bool isSupersetOf(DispatchKeySet ks) const {
return (repr_ & ks.repr_) == ks.repr_;
}
// Perform set union
constexpr DispatchKeySet operator|(DispatchKeySet other) const {
return DispatchKeySet(repr_ | other.repr_);
}
// Perform set intersection
constexpr DispatchKeySet operator&(DispatchKeySet other) const {
return DispatchKeySet(repr_ & other.repr_);
}
// Compute the set difference self - other,
// but ONLY for the functionality keys.
// Any backend bits set on self will remain unchanged.
// See Note [Removing keys from DispatchKeySet Only Affects Functionality
// Keys]
constexpr DispatchKeySet operator-(DispatchKeySet other) const {
return DispatchKeySet(repr_ & (full_backend_mask | ~other.repr_));
}
// Compute self ^ other
constexpr DispatchKeySet operator^(DispatchKeySet other) const {
return DispatchKeySet(repr_ ^ other.repr_);
}
bool operator==(DispatchKeySet other) const {
return repr_ == other.repr_;
}
bool operator!=(DispatchKeySet other) const {
return repr_ != other.repr_;
}
// Add a DispatchKey to the DispatchKey set. Does NOT mutate,
// returns the extended DispatchKeySet!
C10_NODISCARD constexpr DispatchKeySet add(DispatchKey t) const {
return *this | DispatchKeySet(t);
}
C10_NODISCARD constexpr DispatchKeySet add(DispatchKeySet ks) const {
return *this | ks;
}
// Remove a DispatchKey from the DispatchKey set.
// This is generally not an operation you should be doing
// (it's used to implement the printing overload, operator<<)
//
// Note [Removing keys from DispatchKeySet Only Affects Functionality Keys]
// Only functionality bits are allowed to be removed from a keyset.
// For now, we're only allowing removal of "functionality bits" from the
// keyset, which is specifically needed by the fallthrough key calculation
// logic. Why is removing backend bits problematic? Consider this example:
//
// DispatchKeySet([DispatchKey.CPU, DispatchKey.AutogradCUDA,
// DispatchKey.CUDA]).remove(DispatchKey.AutogradCUDA)
// DispatchKeySet([DispatchKey.CPU,
// DispatchKey.AutogradCUDA]).remove(DispatchKey.AutogradCUDA)
//
// What do we want to happen?
// Technically, we'd like it to be true that after removal,
// the first keyset still has the CUDA dispatch key while the second doesn't.
// Unfortunately there's no way to represent that, because the two keysets are
// represented the same way internally: functionality bits: Autograd, Dense
// backend bits: CPU, CUDA
//
// Instead, remove(DispatchKey.AutogradCPU) will only remove the "Autograd"
// bit from the bitset.
C10_NODISCARD constexpr DispatchKeySet remove(DispatchKey t) const {
return DispatchKeySet(
repr_ & ~(DispatchKeySet(t).repr_ & ~full_backend_mask));
}
// You're allowed to remove a backend bit from a DispatchKeySet,
// but you have to be explicit about it (remove_backend() instead of
// remove()).
constexpr DispatchKeySet remove_backend(BackendComponent b) const {
return DispatchKeySet(repr_ & ~(DispatchKeySet(b).repr_));
}
// Is the set empty? (AKA undefined tensor)
bool empty() const {
return repr_ == 0;
}
uint64_t raw_repr() {
return repr_;
}
DispatchKey highestFunctionalityKey() const {
auto functionality_idx = indexOfHighestBit();
// This means that none of the functionality bits were set.
if (functionality_idx < num_backends)
return DispatchKey::Undefined;
// The first num_backend bits in the keyset don't correspond to real
// dispatch keys.
return static_cast<DispatchKey>(functionality_idx - num_backends);
}
// This is similar like toBackendComponent(DispatchKey), but less restrictive.
// toBackendComponent() errors out if the key that it was passed has no
// backend bits, which is useful for error checking. We need a version of that
// here that can also handle "fake" backends like FPGA, because they need to
// map to the AutogradOther key. For those backends, we return
// BackendComponent::InvalidBit.
BackendComponent highestBackendKey() const {
// mask to mask out functionality bits
auto backend_idx =
DispatchKeySet(repr_ & full_backend_mask).indexOfHighestBit();
// all zeros across the backend bits means that no backend bits are set.
if (backend_idx == 0)
return BackendComponent::InvalidBit;
return static_cast<BackendComponent>(backend_idx);
}
// returns the DispatchKey of highest priority in the set.
DispatchKey highestPriorityTypeId() const {
auto functionality_k = highestFunctionalityKey();
if (isPerBackendFunctionalityKey(functionality_k)) {
return toRuntimePerBackendFunctionalityKey(
functionality_k, highestBackendKey());
}
return functionality_k;
}
// Returns the index of the most-significant bit in the keyset.
// This is used to as part of the calculation into the operator table to get:
// - the highest "functionality" bit in the keyset.
// - the highest "backend" bit in the keyset.
uint8_t indexOfHighestBit() const {
return 64 - llvm::countLeadingZeros(repr_);
}
#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
// [Note: Trimmed Mobile Dispatch Keys]
/**
* The method below maps the dispatch key in the enum DispatchKey to an
* integer index in the dispatchTable_ array in OperatorEntry. The array
* is trimmed for mobile to reduce peak memory usage since it's
* unnecessary to reserve additional space for dispatch keys that will
* never be used on mobile.
*/
int getDispatchTableIndexForDispatchKeySet() const {
auto dk = highestPriorityTypeId();
switch (dk) {
case DispatchKey::Undefined:
return 0;
case DispatchKey::CPU:
return 1;
case DispatchKey::QuantizedCPU:
return 2;
case DispatchKey::SparseCPU:
return 3;
case DispatchKey::BackendSelect:
return 4;
case DispatchKey::ADInplaceOrView:
return 5;
case DispatchKey::AutogradOther:
return 6;
case DispatchKey::AutogradCPU:
return 7;
default:
return -1;
}
}
#else
// returns the index in the operator table of highest priority key in the the
// keyset Note that we could in theory implement this using
// highestPriorityTypeId(), but this code is very hotpath and we can do it
// faster without it.
int getDispatchTableIndexForDispatchKeySet() const {
auto functionality_idx =
DispatchKeySet(repr_ >> num_backends).indexOfHighestBit();
auto offset_and_mask = offsetsAndMasks()[functionality_idx];
// Mask the functionality bits out first, then right-shift by 1.
// right-shifting by 1 because everything is zero-indexed.
// E.g. 000001 (CPU) should give us an offset of 0, 000010 (CUDA) should
// give us an offset of 1, etc.
auto backend_idx =
DispatchKeySet((repr_ & offset_and_mask.mask) >> 1).indexOfHighestBit();
return offset_and_mask.offset + backend_idx;
}
#endif
// returns the "index" of the highest priority backend in the keyset.
// This is pretty similar to getBackendKey(), but:
// - It's hotpath code (part of the runtime bitset calculation)
// - I's returns an integer index, not an enum value
// - Everything is shifted to the right by 1.
// BackendComponent::InvalidBit is technically the lowest enum value,
// but it isn't included in the runtime table. So CPUBit = 1, CUDABit = 2,
// etc.
uint64_t getBackendIndex() const {
return DispatchKeySet((repr_ & full_backend_mask) >> 1).indexOfHighestBit();
}
private:
constexpr DispatchKeySet(uint64_t repr) : repr_(repr) {}
uint64_t repr_ = 0;
public:
// STL iterator for DispatchKeySet. Iterates through all runtime DispatchKeys
// in the set. The iterator is only invalidated by the destruction of the
// underlying DispatchKeySet as the iterator stores a pointer to the raw
// representation of the DispatchKeySet. Note: When we encounter a per-backend
// functionality (e.g. Dense or Sparse), we will iterate through EVERY backend
// in the keyset, for that functionality. For example, if the next
// functionality key to iterate over is Autograd, and the backend bits in the
// keyset correspond to [BackendComponent::CPUBit, BackendComponent::CUDABit],
// then the next two keys we return will be DispatchKey::AutogradCPU,
// DispatchKey::AutogradCUDA (CPU first because it has lower precedence than
// CUDA in DispatchKey.h).
class iterator {
public:
using self_type = iterator;
using iterator_category = std::input_iterator_tag;
using value_type = DispatchKey;
using difference_type = ptrdiff_t;
using reference = value_type&;
using pointer = value_type*;
// final mask value should mask out the entire keyset
static const uint8_t end_iter_mask_val =
num_backends + num_functionality_keys;
// final key value should be the last DispatchKey
static const uint8_t end_iter_key_val = num_functionality_keys;
// current_dispatchkey_idx_ will iterate through all functionality bits.
// current_backendcomponent_idx_ will iterate through all backend bits.
explicit iterator(
const uint64_t* data_ptr,
uint8_t next_functionality = num_backends,
uint8_t next_backend = 0)
: data_ptr_(data_ptr),
next_functionality_(next_functionality),
next_backend_(next_backend),
// These are in an invalid state at construction time, and set by the
// first increment call
current_dispatchkey_idx_(end_iter_key_val),
current_backendcomponent_idx_(end_iter_key_val) {
// Go to the first key in the set
TORCH_INTERNAL_ASSERT(
next_functionality_ >= num_backends,
"num_backends=",
static_cast<uint32_t>(num_backends),
"next_functionality_=",
static_cast<uint32_t>(next_functionality_));
++(*this);
}
C10_API self_type& operator++();
self_type operator++(int) {
self_type previous_iterator = *this;
++(*this);
return previous_iterator;
}
bool operator==(const self_type& rhs) const {
return next_functionality_ == rhs.next_functionality_ &&
current_dispatchkey_idx_ == rhs.current_dispatchkey_idx_ &&
next_backend_ == rhs.next_backend_ &&
current_backendcomponent_idx_ == rhs.current_backendcomponent_idx_;
}
bool operator!=(const self_type& rhs) const {
return next_functionality_ != rhs.next_functionality_ ||
current_dispatchkey_idx_ != rhs.current_dispatchkey_idx_ ||
next_backend_ != rhs.next_backend_ ||
current_backendcomponent_idx_ != rhs.current_backendcomponent_idx_;
}
DispatchKey operator*() const {
auto functionality_key =
static_cast<DispatchKey>(current_dispatchkey_idx_);
if (isPerBackendFunctionalityKey(functionality_key)) {
auto next_key = toRuntimePerBackendFunctionalityKey(
functionality_key,
static_cast<BackendComponent>(current_backendcomponent_idx_));
// We expect all of the Dense, Sparse, Quantized, and Autograd keys to
// be ordered the same way with respect to their backends
TORCH_INTERNAL_ASSERT(
toBackendComponent(next_key) ==
static_cast<BackendComponent>(current_backendcomponent_idx_),
"Tried to map functionality key ",
toString(functionality_key),
" and backend bit ",
toString(
static_cast<BackendComponent>(current_backendcomponent_idx_)),
" to a runtime key, but ended up with ",
toString(next_key),
". This can happen if the order of the backend dispatch keys in DispatchKey.h isn't consistent.",
" Please double check that enum for inconsistencies.");
return next_key;
} else {
return functionality_key;
}
}
private:
const uint64_t* data_ptr_;
uint8_t next_functionality_;
uint8_t next_backend_;
uint8_t current_dispatchkey_idx_;
uint8_t current_backendcomponent_idx_;
};
public:
// Returns iterator to the first key in the set. If no keys are in the
// set, then will return the end iterator.
iterator begin() const {
return iterator(&repr_);
}
// We do not need to iterate beyond EndOfFunctionalityKeys so we will treat
// this as the end iterator.
iterator end() const {
return iterator(&repr_, iterator::end_iter_mask_val);
}
};
C10_API std::string toString(DispatchKeySet);
C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet);
C10_API inline int getDispatchTableIndexForDispatchKey(DispatchKey k) {
return DispatchKeySet(k).getDispatchTableIndexForDispatchKeySet();
}
// Alias key DispatchKey::Autograd maps to
// (autograd_dispatch_keyset x full_backend_mask)
// NB: keys in this set also get associated with CompositeImplicitAutograd
//
// Note [autograd_dispatch_keyset Does Not Include Backend Bits]
// We don't want to include any backend bits (BackendComponent::CPUBit, etc)
// directly in autograd_dispatch_keyset.
// Why? keysets like autograd_dispatch_keyset are commonly used to remove
// autograd keys from a DispatchKeySet throughout the code base. However, you
// are only allowed to remove functionality bits from a keyset, not backend
// bits. See Note [Removing keys from DispatchKeySet Only Affects Functionality
// Keys] for details. To be consistent and avoid confusion, we're explicitly
// setting up autograd_dispatch_keyset to not have any backend bits.
constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
DispatchKey::AutogradFunctionality,
DispatchKey::AutogradOther,
DispatchKey::AutogradNestedTensor,
});
constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
DispatchKey::AutocastCPU,
DispatchKey::AutocastMPS,
DispatchKey::AutocastCUDA,
DispatchKey::AutocastXPU,
DispatchKey::AutocastIPU,
DispatchKey::AutocastHPU,
DispatchKey::AutocastXLA,
DispatchKey::AutocastPrivateUse1,
});
// See Note [TLS Initialization]
constexpr DispatchKeySet default_included_set = DispatchKeySet({
DispatchKey::BackendSelect,
DispatchKey::ADInplaceOrView,
});
constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
DispatchKey::AutocastCPU,
DispatchKey::AutocastMPS,
DispatchKey::AutocastCUDA,
DispatchKey::AutocastXPU,
DispatchKey::AutocastIPU,
DispatchKey::AutocastHPU,
DispatchKey::AutocastXLA,
DispatchKey::AutocastPrivateUse1,
});
constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::ADInplaceOrView);
constexpr DispatchKeySet python_ks = DispatchKeySet({
DispatchKey::Python,
DispatchKey::PythonTLSSnapshot,
});
constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse);
constexpr DispatchKeySet sparse_csr_ks = DispatchKeySet(DispatchKey::SparseCsr);
constexpr DispatchKeySet mkldnn_ks = DispatchKeySet(DispatchKey::MkldnnCPU);
// backend dispatch keys that map to DispatchKey::AutogradOther
// NB: keys in this set also get associated with CompositeImplicitAutograd
constexpr DispatchKeySet autogradother_backends =
DispatchKeySet(
// HIP and VE aren't in this list: they now have their own backend bits
// which means that they can now have their own Autograd keys.
// Technically, HIP will now redispatch to its own custom AutogradHIP
// slot in the runtime table.
{DispatchKey::FPGA,
DispatchKey::MAIA,
DispatchKey::Vulkan,
DispatchKey::Metal,
DispatchKey::CustomRNGKeyId,
DispatchKey::MkldnnCPU,
// Sparse and Quantized backends also live here.
DispatchKey::Sparse,
DispatchKey::SparseCsr,
DispatchKey::Quantized})
// Including the backend bits because this keyset is used during op
// registration, which requires looping over all runtime autogradother
// backend keys.
| DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
// The set of dispatch keys that come after autograd
// n.b. this relies on the fact that AutogradOther is currently the lowest
// Autograd key
constexpr DispatchKeySet after_autograd_keyset =
DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::AutogradOther);
// The set of dispatch keys that come after ADInplaceOrView
constexpr DispatchKeySet after_ADInplaceOrView_keyset = DispatchKeySet(
DispatchKeySet::FULL_AFTER,
c10::DispatchKey::ADInplaceOrView);
// The set of dispatch keys that come after Functionalize
constexpr DispatchKeySet after_func_keyset =
DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::Functionalize)
.remove(
// NOTE: we also need to remove ADInplaceOrView from the keyset when
// redispatching after the func kernels. This is because we're not
// calling the same op; we originally called an inplace op, and now
// we aren't. The original key calculation figured out which keys
// were Fallthrough based on the inplace op. That means that it did
// not include the ADInPlaceOrView kernel as a fallthrough key.
// However, we WANT the ADInPlaceOrView kernel to be ignored now
// that we're calling an out-of-place op. Re-invoking
// Dispatcher::call would re-run the Fallthrough key calculation and
// get us that, But at::redispatch is more performant. We can get
// away with it by explicitly removing the key here.
c10::DispatchKey::ADInplaceOrView);
constexpr DispatchKeySet backend_bitset_mask =
DispatchKeySet(DispatchKeySet::RAW, (1ULL << num_backends) - 1);
constexpr auto inplace_or_view_ks =
DispatchKeySet(DispatchKey::ADInplaceOrView);
constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU);
constexpr auto autograd_ipu_ks = DispatchKeySet(DispatchKey::AutogradIPU);
constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU);
constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA);
constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA);
constexpr auto autograd_lazy_ks = DispatchKeySet(DispatchKey::AutogradLazy);
constexpr auto autograd_meta_ks = DispatchKeySet(DispatchKey::AutogradMeta);
constexpr auto autograd_mps_ks = DispatchKeySet(DispatchKey::AutogradMPS);
constexpr auto autograd_hpu_ks = DispatchKeySet(DispatchKey::AutogradHPU);
constexpr auto autograd_privateuse1_ks =
DispatchKeySet(DispatchKey::AutogradPrivateUse1);
constexpr auto autograd_privateuse2_ks =
DispatchKeySet(DispatchKey::AutogradPrivateUse2);
constexpr auto autograd_privateuse3_ks =
DispatchKeySet(DispatchKey::AutogradPrivateUse3);
constexpr auto autograd_other_ks = DispatchKeySet(DispatchKey::AutogradOther);
constexpr auto autograd_nested =
DispatchKeySet(DispatchKey::AutogradNestedTensor);
// keyset corresponding to functorch keys that have their own dedicated
// TensorImpl subclass.
constexpr auto functorch_transforms_ks = DispatchKeySet(
{DispatchKey::FuncTorchBatched,
DispatchKey::FuncTorchVmapMode,
DispatchKey::Batched,
DispatchKey::VmapMode,
DispatchKey::FuncTorchGradWrapper});
constexpr auto functorch_batched_ks =
DispatchKeySet({DispatchKey::FuncTorchBatched});
// This keyset has:
// (1) the functionality bits corresponding to backends (dense, sparse,
// quantized) (2) all of the backend bits set
constexpr DispatchKeySet backend_functionality_keys =
DispatchKeySet({
DispatchKey::Dense,
DispatchKey::Quantized,
DispatchKey::Sparse,
DispatchKey::SparseCsr,
}) |
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
struct OpTableOffsetAndMask {
uint16_t offset;
uint16_t backend_mask;
};
static_assert(
num_backends <= 16,
"Right now we expect the number of backends not to exceed 16. In the (unlikely) event"
" that this changes, the size of OpTableOffsetAndMask::backend_mask needs to be increased too.");
// true if t is a backend dispatch key
C10_API bool isBackendDispatchKey(DispatchKey t);
// Resolve alias dispatch key to DispatchKeySet if applicable
C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t);
// Resolve alias dispatch key to DispatchKeySet if applicable,
// and check if k is a part of that set
C10_API bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k);
// Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key
// t, DispatchKeySet is empty if t is not alias of DispatchKey::Autograd.
C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t);
// Returns a DispatchKeySet of autograd related keys mapped to backend.
// for a given backend key, use the associated autograd key.
// for non-backend keys, use AutogradOther as a default.
// Note: it's convenient and fast to return a default here rather than (say)
// returning an std::optional<DispatchKey>, or throwing. But it makes callers
// responsible for either a) enforcing the invariant that only backend keys
// be passed as arguments, or b) interpreting our return value carefully.
inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
switch (t) {
case BackendComponent::CPUBit:
return inplace_or_view_ks | autograd_cpu_ks;
case BackendComponent::IPUBit:
return inplace_or_view_ks | autograd_ipu_ks;
case BackendComponent::XPUBit:
return inplace_or_view_ks | autograd_xpu_ks;
case BackendComponent::CUDABit:
return inplace_or_view_ks | autograd_cuda_ks;
case BackendComponent::XLABit:
return inplace_or_view_ks | autograd_xla_ks;
case BackendComponent::LazyBit:
return inplace_or_view_ks | autograd_lazy_ks;
case BackendComponent::MetaBit:
return inplace_or_view_ks | autograd_meta_ks;
case BackendComponent::MPSBit:
return inplace_or_view_ks | autograd_mps_ks;
case BackendComponent::HPUBit:
return inplace_or_view_ks | autograd_hpu_ks;
case BackendComponent::PrivateUse1Bit:
return inplace_or_view_ks | autograd_privateuse1_ks;
case BackendComponent::PrivateUse2Bit:
return inplace_or_view_ks | autograd_privateuse2_ks;
case BackendComponent::PrivateUse3Bit:
return inplace_or_view_ks | autograd_privateuse3_ks;
default:
return inplace_or_view_ks | autograd_other_ks;
}
}
// Returns a DispatchKeySet of autocast related keys mapped to backend.
inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU);
constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU);
constexpr auto autocast_ipu_ks = DispatchKeySet(DispatchKey::AutocastIPU);
constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU);
constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA);
constexpr auto autocast_xla_ks = DispatchKeySet(DispatchKey::AutocastXLA);
constexpr auto autocast_privateuse1_ks =
DispatchKeySet(DispatchKey::AutocastPrivateUse1);
constexpr auto autocast_mps_ks = DispatchKeySet(DispatchKey::AutocastMPS);
switch (t) {
case BackendComponent::CPUBit:
return autocast_cpu_ks;
case BackendComponent::XPUBit:
return autocast_xpu_ks;
case BackendComponent::IPUBit:
return autocast_ipu_ks;
case BackendComponent::HPUBit:
return autocast_hpu_ks;
case BackendComponent::CUDABit:
return autocast_cuda_ks;
case BackendComponent::XLABit:
return autocast_xla_ks;
case BackendComponent::PrivateUse1Bit:
return autocast_privateuse1_ks;
case BackendComponent::MPSBit:
return autocast_mps_ks;
default:
return DispatchKeySet();
}
}
// returns the "backend" DispatchKey of highest priority in the set.
// This is basically like highestBackendKey(), except that we have some
// "functionality" bits that correspond to backends (Sparse, Quantized)
inline DispatchKey highestPriorityBackendTypeId(DispatchKeySet ks) {
return (ks & backend_functionality_keys).highestPriorityTypeId();
}
// This API exists because we have a use case for checking
// getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined)
// in OperatorEntry.cpp but we disallow it in has() API.
C10_API bool isIncludedInAlias(DispatchKey k, DispatchKey alias);
// Historically, every tensor only had a single DispatchKey, and it was always
// something like CPU, and there wasn't any of this business where TLS
// could cause the DispatchKey of a tensor to change. But we still have some
// legacy code that is still using DispatchKey for things like instanceof
// checks; if at all possible, refactor the code to stop using DispatchKey in
// those cases.
inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) {
// NB: If you add any extra keys that can be stored in TensorImpl on
// top of existing "backend" keys like CPU/CUDA, you need to add it
// here. At the moment, autograd keys and ADInplaceOrView key need this
// treatment;
return (s - autograd_dispatch_keyset_with_ADInplaceOrView -
autocast_dispatch_keyset -
DispatchKeySet(
{DispatchKey::Functionalize,
DispatchKey::PythonTLSSnapshot,
DispatchKey::FuncTorchGradWrapper,
DispatchKey::FuncTorchVmapMode,
DispatchKey::FuncTorchBatched,
DispatchKey::Python}))
.highestPriorityTypeId();
}
template <class T>
using is_not_DispatchKeySet = std::negation<std::is_same<DispatchKeySet, T>>;
// Given a function type, constructs a function_traits type that drops the first
// parameter type if the first parameter is of type DispatchKeySet. NB:
// DispatchKeySet is currently explicitly hidden from JIT (mainly to avoid
// pushing unnecessary arguments on the stack - see Note [ Plumbing Keys Through
// the Dispatcher] for details). If at any point in the future we need to expose
// this type to JIT, revisit the usage of this type alias.
template <class FuncType>
using remove_DispatchKeySet_arg_from_func = guts::make_function_traits_t<
typename guts::infer_function_traits_t<FuncType>::return_type,
typename std::conditional_t<
std::is_same_v<
DispatchKeySet,
typename guts::typelist::head_with_default_t<
void,
typename guts::infer_function_traits_t<
FuncType>::parameter_types>>,
guts::typelist::drop_if_nonempty_t<
typename guts::infer_function_traits_t<FuncType>::parameter_types,
1>,
typename guts::infer_function_traits_t<FuncType>::parameter_types>>;
} // namespace c10

View File

@ -0,0 +1,125 @@
#pragma once
#include <c10/core/ScalarType.h>
#include <c10/macros/Macros.h>
#include <c10/util/Load.h>
#include <c10/util/TypeCast.h>
namespace c10 {
// Dynamic type casting utils:
// - fetch_and_cast
// - cast_and_store
//
// fetch_and_cast fetch a value with dynamic type specified by a ScalarType
// from a void pointer and cast it to a static type.
//
// cast_and_store casts a static typed value into dynamic type specified
// by a ScalarType, and store it into a void pointer.
//
// NOTE:
//
// Dynamic casting allows us to support type promotion without blowing up
// the combination space: For example, without dynamic cast, in order to
// implement `add_` with type promotion, we would need something like
//
// AT_DISPATCH_ALL_TYPES(output.dtype(),
// AT_DISPATCH_ALL_TYPES(input1.dtype(),
// AT_DISPATCH_ALL_TYPES(input2.dtype(),
// [](arg0_t a, arg1_t b) -> out_t { return a + b; }
// )
// )
// )
//
// If we support N dtypes, the above code would generate the a+b kernel for
// all the N * N * N different supported types, the compilation time and
// binary size would become horrible.
//
// Dynamic casting might sounds like a bad idea in terms of performance.
// Especially if you ever do it in a loop, you are going to do a billion tests.
// But in practice it is not as bad as it might look:
//
// - on CPU, this is a branch that always has the same outcome, therefore
// hopefully the branch predictor could do the job pretty well
// - on GPU, these branches will not diverge, so we could still have the same
// warp executing the same line of code
// - Most kernels, like `add`, are bandwidth bound, adding a few clock cycles to
// check an integer does not hurt the performance much because the ALUs would
// wait for load instructions anyway.
//
// For the discussion and benchmark, refer to:
// - https://github.com/pytorch/pytorch/pull/28343
// - https://github.com/pytorch/pytorch/pull/28344
// - https://github.com/pytorch/pytorch/pull/28345
//
#ifdef C10_HOST_DEVICE
#define ERROR_UNSUPPORTED_CAST CUDA_KERNEL_ASSERT(false);
#else
#define ERROR_UNSUPPORTED_CAST TORCH_CHECK(false, "Unexpected scalar type");
#endif
// Fetch a value with dynamic type src_type from ptr, and cast it to static type
// dest_t.
#define FETCH_AND_CAST_CASE(type, scalartype) \
case ScalarType::scalartype: \
return c10::convert<dest_t>(c10::load<type>(ptr));
template <typename dest_t>
C10_HOST_DEVICE inline dest_t fetch_and_cast(
const ScalarType src_type,
const void* ptr) {
switch (src_type) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(FETCH_AND_CAST_CASE)
FETCH_AND_CAST_CASE(uint16_t, UInt16)
FETCH_AND_CAST_CASE(uint32_t, UInt32)
FETCH_AND_CAST_CASE(uint64_t, UInt64)
default:
ERROR_UNSUPPORTED_CAST
}
return dest_t(0); // just to avoid compiler warning
}
// Cast a value with static type src_t into dynamic dest_type, and store it to
// ptr.
#define CAST_AND_STORE_CASE(type, scalartype) \
case ScalarType::scalartype: \
*(type*)ptr = c10::convert<type>(value); \
return;
template <typename src_t>
C10_HOST_DEVICE inline void cast_and_store(
const ScalarType dest_type,
void* ptr,
src_t value) {
switch (dest_type) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CAST_AND_STORE_CASE)
CAST_AND_STORE_CASE(uint16_t, UInt16)
CAST_AND_STORE_CASE(uint32_t, UInt32)
CAST_AND_STORE_CASE(uint64_t, UInt64)
default:;
}
ERROR_UNSUPPORTED_CAST
}
#define DEFINE_UNCASTABLE(T, scalartype_) \
template <> \
C10_HOST_DEVICE inline T fetch_and_cast<T>( \
const ScalarType src_type, const void* ptr) { \
CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == src_type); \
return c10::load<T>(ptr); \
} \
template <> \
C10_HOST_DEVICE inline void cast_and_store<T>( \
const ScalarType dest_type, void* ptr, T value) { \
CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == dest_type); \
*(T*)ptr = value; \
}
AT_FORALL_QINT_TYPES(DEFINE_UNCASTABLE)
#undef FETCH_AND_CAST_CASE
#undef CAST_AND_STORE_CASE
#undef DEFINE_UNCASTABLE
#undef ERROR_UNSUPPORTED_CAST
} // namespace c10

View File

@ -0,0 +1,137 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/InlineEvent.h>
#include <c10/core/impl/VirtualGuardImpl.h>
namespace c10 {
/**
* A backend-generic movable, not copyable, not thread-safe event.
*
* The design of this event follows that of CUDA and HIP events. These events
* are recorded and waited on by streams and can be rerecorded to,
* each rerecording essentially creating a new version of the event.
* For example, if (in CPU time), stream X is asked to record E,
* stream Y waits on E, and stream X is asked to record E again, then Y will
* wait for X to finish the first call to record and not the second, because
* it's waiting on the first version of event E, not the second.
* Querying an event only returns the status of its most recent version.
*
* Backend-generic events are implemented by this class and
* impl::InlineEvent. In addition to these events there are also
* some backend-specific events, like ATen's CUDAEvent. Each of these
* classes has its own use.
*
* impl::InlineEvent<...> or a backend-specific event should be
* preferred when the backend is known at compile time and known to
* be compiled. Backend-specific events may have additional functionality.
*
* This Event should be used if a particular backend may not be available,
* or the backend required is not known at compile time.
*
* These generic events are built on top of DeviceGuardImpls, analogous
* to DeviceGuard and InlineDeviceGuard. The name "DeviceGuardImpls,"
* is no longer entirely accurate, as these classes implement the
* backend-specific logic for a generic backend interface.
*
* See DeviceGuardImplInterface.h for a list of all supported flags.
*/
struct Event final {
// Constructors
Event() = delete;
Event(
const DeviceType _device_type,
const EventFlag _flag = EventFlag::PYTORCH_DEFAULT)
: impl_{_device_type, _flag} {}
// Copy constructor and copy assignment operator (deleted)
Event(const Event&) = delete;
Event& operator=(const Event&) = delete;
// Move constructor and move assignment operator
Event(Event&&) noexcept = default;
Event& operator=(Event&&) noexcept = default;
// Destructor
~Event() = default;
// Getters
Device device() const noexcept {
return Device(device_type(), device_index());
}
DeviceType device_type() const noexcept {
return impl_.device_type();
}
DeviceIndex device_index() const noexcept {
return impl_.device_index();
}
EventFlag flag() const noexcept {
return impl_.flag();
}
bool was_marked_for_recording() const noexcept {
return impl_.was_marked_for_recording();
}
/**
* Calls record() if and only if record() has never been called for this
* event. Note: because Event is not thread-safe recordOnce() may call
* record() multiple times if called from multiple threads.
*/
void recordOnce(const Stream& stream) {
impl_.recordOnce(stream);
}
/**
* Increments the event's version and enqueues a job with this version
* in the stream's work queue. When the stream process that job
* it notifies all streams waiting on / blocked by that version of the
* event to continue and marks that version as recorded.
* */
void record(const Stream& stream) {
impl_.record(stream);
}
/**
* Does nothing if the event has not been scheduled to be recorded.
* If the event was previously enqueued to be recorded, a command
* to wait for the version of the event that exists at the time of this call
* is inserted in the stream's work queue.
* When the stream reaches this command it will stop processing
* additional commands until that version of the event is marked as recorded.
*/
void block(const Stream& stream) const {
impl_.block(stream);
}
/**
* Returns true if (and only if)
* (1) the event has never been scheduled to be recorded
* (2) the current version is marked as recorded.
* Returns false otherwise.
*/
bool query() const {
return impl_.query();
}
double elapsedTime(const Event& event) const {
return impl_.elapsedTime(event.impl_);
}
void* eventId() const {
return impl_.eventId();
}
void synchronize() const {
return impl_.synchronize();
}
private:
impl::InlineEvent<impl::VirtualGuardImpl> impl_;
};
} // namespace c10

View File

@ -0,0 +1,110 @@
#pragma once
#include <cstdint>
#include <mutex>
#include <c10/core/Device.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/core/TensorImpl.h>
#include <c10/macros/Export.h>
#include <c10/util/intrusive_ptr.h>
#include <c10/util/python_stub.h>
/**
* Note [Generator]
* ~~~~~~~~~~~~~~~~
* A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm
* to generate a seemingly random sequence of numbers, that may be later be used
* in creating a random distribution. Such an engine almost always maintains a
* state and requires a seed to start off the creation of random numbers. Often
* times, users have found it beneficial to be able to explicitly create,
* retain, and destroy PRNG states and also be able to have control over the
* seed value.
*
* A Generator in ATen gives users the ability to read, write and modify a PRNG
* engine. For instance, it does so by letting users seed a PRNG engine, fork
* the state of the engine, etc.
*
* By default, there is one generator per device, and a device's generator is
* lazily created. A user can use the torch.Generator() api to create their own
* generator. Currently torch.Generator() can only create a CPUGeneratorImpl.
*/
/**
* Note [Acquire lock when using random generators]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* Generator and its derived classes are NOT thread-safe. Please note that most
* of the places where we have inserted locking for generators are historically
* based, and we haven't actually checked that everything is truly thread safe
* (and it probably isn't). Please use the public mutex_ when using any methods
* from these classes, except for the read-only methods. You can learn about the
* usage by looking into the unittests (aten/src/ATen/cpu_generator_test.cpp)
* and other places where we have used lock_guard.
*
* TODO: Look into changing the threading semantics of Generators in ATen (e.g.,
* making them non-thread safe and instead making the generator state
* splittable, to accommodate forks into other threads).
*/
namespace c10 {
// The default seed is selected to be a large number
// with good distribution of 0s and 1s in bit representation
constexpr uint64_t default_rng_seed_val = 67280421310721;
struct C10_API GeneratorImpl : public c10::intrusive_ptr_target {
// Constructors
GeneratorImpl(Device device_in, DispatchKeySet key_set);
// Delete all copy and move assignment in favor of clone()
// method
GeneratorImpl(const GeneratorImpl& other) = delete;
GeneratorImpl(GeneratorImpl&& other) = delete;
GeneratorImpl& operator=(const GeneratorImpl& other) = delete;
~GeneratorImpl() override = default;
c10::intrusive_ptr<GeneratorImpl> clone() const;
// Common methods for all generators
virtual void set_current_seed(uint64_t seed) = 0;
virtual void set_offset(uint64_t offset) = 0;
virtual uint64_t get_offset() const = 0;
virtual uint64_t current_seed() const = 0;
virtual uint64_t seed() = 0;
virtual void set_state(const c10::TensorImpl& new_state) = 0;
virtual c10::intrusive_ptr<c10::TensorImpl> get_state() const = 0;
virtual void graphsafe_set_state(
const c10::intrusive_ptr<c10::GeneratorImpl>& new_state);
virtual c10::intrusive_ptr<c10::GeneratorImpl> graphsafe_get_state() const;
Device device() const;
// See Note [Acquire lock when using random generators]
std::mutex mutex_;
DispatchKeySet key_set() const {
return key_set_;
}
inline void set_pyobj(PyObject* pyobj) noexcept {
pyobj_ = pyobj;
}
inline PyObject* pyobj() const noexcept {
return pyobj_;
}
protected:
Device device_;
DispatchKeySet key_set_;
PyObject* pyobj_ = nullptr;
virtual GeneratorImpl* clone_impl() const = 0;
};
namespace detail {
C10_API uint64_t getNonDeterministicRandom(bool is_cuda = false);
} // namespace detail
} // namespace c10

View File

@ -0,0 +1,44 @@
#pragma once
#include <c10/core/AutogradState.h>
#include <c10/macros/Export.h>
namespace c10 {
struct C10_API GradMode {
static bool is_enabled();
static void set_enabled(bool enabled);
};
// A RAII, thread local (!) guard that enables or disables grad mode upon
// construction, and sets it back to the original value upon destruction.
struct C10_API AutoGradMode {
AutoGradMode(bool enabled) : prev_mode(GradMode::is_enabled()) {
GradMode::set_enabled(enabled);
}
~AutoGradMode() {
GradMode::set_enabled(prev_mode);
}
bool prev_mode;
};
// A RAII, thread local (!) guard that stops future operations from building
// gradients.
struct C10_API NoGradGuard : public AutoGradMode {
NoGradGuard() : AutoGradMode(/*enabled=*/false) {}
};
// A RAII, thread local (!) guard that enables or disables forward grad mode
// upon construction, and sets it back to the original value upon destruction.
struct C10_API AutoFwGradMode {
AutoFwGradMode(bool enabled)
: prev_mode(AutogradState::get_tls_state().get_fw_grad_mode()) {
AutogradState::get_tls_state().set_fw_grad_mode(enabled);
}
~AutoFwGradMode() {
AutogradState::get_tls_state().set_fw_grad_mode(prev_mode);
}
bool prev_mode;
};
} // namespace c10

View File

@ -0,0 +1,86 @@
#pragma once
#include <c10/core/AutogradState.h>
#include <c10/core/DispatchKey.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/macros/Export.h>
namespace c10 {
// A RAII, thread local (!) guard that enables or disables inference mode upon
// construction, and sets it back to the original value upon destruction.
struct C10_API InferenceMode {
// Note [Expected TLS state in InferenceMode]:
// InferenceMode: ADInplaceOrView not in
// raw_local_dispatch_key_set.included(),
// Autograd in raw_local_dispatch_key_set.excluded()
// GradMode is disabled.
// NormalMode: ADInplaceOrView in raw_local_dispatch_key_set.included(),
// Autograd not in raw_local_dispatch_key_set.excluded()
// GradMode is enabled by default unless toggled manually
// through other APIs, e.g. NoGradGuard.
//
// Invariant:
// - ADInplaceOrView is never in the excluded set
// - Autograd is never in the included set
// - Setting InferenceMode will set GradMode accordingly, but not vice versa.
//
// 1. Why do we put ADInplaceOrView in included set outside InferenceMode?
//
// Inplace update to inference tensor outside InferenceMode is not
// allowed. See Note [Inplace update inference tensor] for more details.
// Without going through ADInplaceOrView kernel, we cannot throw error
// for `inference_tensor.add_(1)` case.
//
// 2. Why not put ADInplaceOrView in the excluded set inside InferenceMode?
//
// For example:
// torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true);
// torch::Tensor k = a + 2;
// {
// c10::InferenceMode guard(true);
// k.add_(2);
// }
// `k.add_(2)` still need to go through ADInplaceOrView kernel so that it's
// prepared for future autograd.
//
// 3. Why does setting InferenceMode also set GradMode?
//
// This is required since InferenceMode is a faster and more restrictive
// version of NoGradGuard. All runtime checks using GradMode::is_enabled()
// are applicable to InferenceMode as well, e.g.
// `tensorTypeInCurrentExecutionContext` in interpreter.cpp.
InferenceMode(bool enabled = true)
: prev_mode(AutogradState::get_tls_state()),
prev_keyset(c10::impl::tls_local_dispatch_key_set()) {
// Enabling inference mode means disabling grad modes
// And disabling inference mode means enabling grad modes
AutogradState::set_tls_state(AutogradState(
/* grad_mode */ !enabled,
/* inference_mode */ enabled,
/* fw_grad_mode */ !enabled,
/* multithreading_enabled*/ !enabled));
DispatchKeySet included = enabled
? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView)
: prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView);
DispatchKeySet excluded = enabled
? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset)
: (prev_keyset.excluded_ - c10::autograd_dispatch_keyset);
c10::impl::PODLocalDispatchKeySet cur_keyset{};
cur_keyset.set_included(included);
cur_keyset.set_excluded(excluded);
c10::impl::_force_tls_local_dispatch_key_set(cur_keyset);
}
~InferenceMode() {
AutogradState::set_tls_state(prev_mode);
c10::impl::_force_tls_local_dispatch_key_set(prev_keyset);
}
static bool is_enabled();
private:
AutogradState prev_mode;
c10::impl::LocalDispatchKeySet prev_keyset;
};
} // namespace c10

View File

@ -0,0 +1,78 @@
#pragma once
#include <c10/core/Backend.h>
#include <c10/util/Exception.h>
#include <cstdint>
#include <ostream>
namespace c10 {
enum class Layout : int8_t {
Strided,
Sparse,
SparseCsr,
Mkldnn,
SparseCsc,
SparseBsr,
SparseBsc,
Jagged,
NumOptions
};
constexpr auto kStrided = Layout::Strided;
constexpr auto kSparse = Layout::Sparse;
constexpr auto kSparseCsr = Layout::SparseCsr;
constexpr auto kMkldnn = Layout::Mkldnn;
constexpr auto kSparseCsc = Layout::SparseCsc;
constexpr auto kSparseBsr = Layout::SparseBsr;
constexpr auto kSparseBsc = Layout::SparseBsc;
constexpr auto kJagged = Layout::Jagged;
inline Layout layout_from_backend(Backend backend) {
switch (backend) {
case Backend::SparseCPU:
case Backend::SparseCUDA:
case Backend::SparseHIP:
case Backend::SparseVE:
case Backend::SparseXPU:
case Backend::SparsePrivateUse1:
return Layout::Sparse;
case Backend::MkldnnCPU:
return Layout::Mkldnn;
case Backend::SparseCsrCPU:
case Backend::SparseCsrCUDA:
case Backend::SparseCsrHIP:
case Backend::SparseCsrVE:
case Backend::SparseCsrXPU:
TORCH_CHECK(
false,
"Cannot map Backend SparseCsr(CPU|CUDA|HIP|VE|XPU) to a unique layout.");
default:
return Layout::Strided;
}
}
inline std::ostream& operator<<(std::ostream& stream, at::Layout layout) {
switch (layout) {
case at::kStrided:
return stream << "Strided";
case at::kSparse:
return stream << "Sparse";
case at::kSparseCsr:
return stream << "SparseCsr";
case at::kSparseCsc:
return stream << "SparseCsc";
case at::kSparseBsr:
return stream << "SparseBsr";
case at::kSparseBsc:
return stream << "SparseBsc";
case at::kMkldnn:
return stream << "Mkldnn";
case at::kJagged:
return stream << "Jagged";
default:
TORCH_CHECK(false, "Unknown layout");
}
}
} // namespace c10

View File

@ -0,0 +1,290 @@
#pragma once
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <cstdint>
#include <ostream>
#include <vector>
// Memory format is not the property of a Tensor. It is the way to tell an
// operator how the result should be organized in memory and nothing more. That
// means memory format should never be used as return value for any tensor state
// interrogation functions (internally and externally).
//
// Possible options are:
// Preserve:
// If any of the input tensors is in channels_last format, operator output
// should be in channels_last format
//
// Contiguous:
// Regardless of input tensors format, the output should be contiguous
// Tensor.
//
// ChannelsLast:
// Regardless of input tensors format, the output should be in channels_last
// format.
namespace c10 {
enum class MemoryFormat : int8_t {
Contiguous,
Preserve,
ChannelsLast,
ChannelsLast3d,
NumOptions
};
// If you are seeing this, it means that this call site was not checked if
// the memory format could be preserved, and it was switched to old default
// behaviour of contiguous
#define LEGACY_CONTIGUOUS_MEMORY_FORMAT c10::get_contiguous_memory_format()
inline MemoryFormat get_contiguous_memory_format() {
return MemoryFormat::Contiguous;
}
inline std::ostream& operator<<(
std::ostream& stream,
at::MemoryFormat memory_format) {
switch (memory_format) {
case MemoryFormat::Preserve:
return stream << "Preserve";
case MemoryFormat::Contiguous:
return stream << "Contiguous";
case MemoryFormat::ChannelsLast:
return stream << "ChannelsLast";
case MemoryFormat::ChannelsLast3d:
return stream << "ChannelsLast3d";
default:
TORCH_CHECK(false, "Unknown memory format ", memory_format);
}
}
// Note: Hardcoded the channel last stride indices here to get better
// performance
template <typename T>
inline std::vector<T> get_channels_last_strides_2d(ArrayRef<T> sizes) {
std::vector<T> strides(sizes.size());
switch (sizes.size()) {
case 4:
strides[1] = 1;
strides[3] = sizes[1];
strides[2] = strides[3] * sizes[3];
strides[0] = strides[2] * sizes[2];
return strides;
case 3:
strides[0] = 1;
strides[2] = sizes[0];
strides[1] = strides[2] * sizes[2];
return strides;
default:
TORCH_INTERNAL_ASSERT(
false, "ChannelsLast2d doesn't support size ", sizes.size());
}
}
inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) {
return get_channels_last_strides_2d<int64_t>(sizes);
}
template <typename T>
std::vector<T> get_channels_last_strides_3d(ArrayRef<T> sizes) {
std::vector<T> strides(sizes.size());
switch (sizes.size()) {
case 5:
strides[1] = 1;
strides[4] = sizes[1];
strides[3] = strides[4] * sizes[4];
strides[2] = strides[3] * sizes[3];
strides[0] = strides[2] * sizes[2];
return strides;
case 4:
strides[0] = 1;
strides[3] = sizes[0];
strides[2] = strides[3] * sizes[3];
strides[1] = strides[2] * sizes[2];
return strides;
default:
TORCH_INTERNAL_ASSERT(
false, "ChannelsLast3d doesn't support size ", sizes.size());
}
}
inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
return get_channels_last_strides_3d<int64_t>(sizes);
}
// NOTE:
// Below are Helper functions for is_channels_last_strides_xd.
// 1. Please do not combine these helper functions, each helper function handles
// exactly one case of sizes + memory_format, by doing this, the strides indices
// will be a constant array and we can access it using constant index number,
// the compiler will fully unroll the loop on strides indices to gain a better
// performance.
// 2. No error check in helper function, caller ensures the correctness of the
// input
// 3. All helper functions have similar comments, only 1st helper function is
// commented here.
template <typename T>
inline bool is_channels_last_strides_2d_s4(
const ArrayRef<T> sizes,
const ArrayRef<T> strides) {
T min = 0;
// special case for trivial C dimension. default to NCHW
if (strides[1] == 0) {
return false;
}
// loop strides indices
for (auto& d : {1, 3, 2, 0}) {
if (sizes[d] == 0) {
return false;
}
if (strides[d] < min) {
return false;
}
// Fallback to NCHW as default layout for ambiguous cases
// This is the flaw of implicit memory_format from strides.
// N111 tensor with identical strides for size 1 dimension;
// Two cases could lead us here:
// a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
// b. N11W contiguous Tensor sliced on the W-dimension.
// ([N,1,1,1]@[W,W,W,W])
if (d == 0 && min == strides[1]) {
return false;
}
// This is necessary to:
// 1. distinguish the memory_format of N1H1;
// [H, 1, 1, 1] channels_last stride
// [H, H, 1, 1] contiguous stride
// 2. permutation of 1C1W:
// [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
// [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as channels_last
min = strides[d];
if (sizes[d] > 1) {
min *= sizes[d];
}
}
return true;
}
template <typename T>
inline bool is_channels_last_strides_3d_s5(
const ArrayRef<T> sizes,
const ArrayRef<T> strides) {
T min = 0;
if (strides[1] == 0) {
return false;
}
for (auto& d : {1, 4, 3, 2, 0}) {
if (sizes[d] == 0) {
return false;
}
if (strides[d] < min) {
return false;
}
if (d == 0 && min == strides[1]) {
return false;
}
min = strides[d];
if (sizes[d] > 1) {
min *= sizes[d];
}
}
return true;
}
// Note [Ambiguous is_channels_last_strides_xd]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// The flaw of carrying memory_format implicitly through strides is very hard
// to WAR properly. issue #24090
// Without the history of permutation, we can't infer the memory_format of a
// tensor from the snapshot of its size & stride
// e.g.
//
// 1. We can NOT specify the memory_format of N111 tensor through strides in a
// meaningful way;
//
// 2. Two path that ended up with identical size/stride
// N11W contiguous tensor sliced at w-dimension becomes [N,1,1,1]@[W,W,W,W]
// NC11 channels_last tensor sliced at c-dimension becomes [N,1,1,1]@[C,C,C,C]
// So if we see a tensor [N,1,1,1]@[X,X,X,X], there's no way for us to infer
// the memory_format of the original tensor.
//
// Due to the limitations, our temporary WAR `is_channels_last_strides` does the
// best effort to infer whether the original memory_format of a tensor is
// at::MemoryFormat::ChannelsLast. The two objectives of this function (ordered
// by their importance):
// 1. Ensure that normal shape manipulation does not accidentally change the
// MemoryFormat of an existing tensor.
// 2. Allows user to mark MemoryFormat::ChannelsLast to tensors;
//
// The function does so via checking strides of the tensor, including strides of
// size-1 dimensions. Although conventionally PyTorch implies no restriction on
// trivial stride (stride for size-1 dimension).
//
// Note that this approach is a compromise. We did not solve the problem
// completely. Many cases we will not be able to infer the correct memory
// format.
// The implementation of `is_channels_last_strides` is to serve the objectives:
// MemoryFormat::ChannelsLast has to be explicitly opted-in (no accidental
// conversion); Best effort to maintain the ChannelsLast flag.
//
// Due to the fact that this is not a bulletproof solution, through testing
// (aten/src/ATen/test/memory_format_test.cpp)
// a. we ensure that the common tasks are supported;
// a. we identify corner cases where the implementation compromises on.
//
// By the time accumulated permutation is enabled to replace implicit
// memory_format through strides, we should be updating our tests and fix the
// issues in our tests.
//
// We use Channels Last 2d as an example above.
// This is a general problem for all the is_channels_last_strides_xd
// implementation. Please check the helper functions
// (is_channels_last_strides_*d_s*) for more details.
template <typename T>
inline bool is_channels_last_strides_2d(
const ArrayRef<T> sizes,
const ArrayRef<T> strides) {
switch (sizes.size()) {
case 4:
return is_channels_last_strides_2d_s4(sizes, strides);
// NOLINTNEXTLINE(bugprone-branch-clone)
case 3:
// TODO dim == 3 case will be enabled once it is fully tested
return false;
default:
return false;
}
}
template <typename T>
inline bool is_channels_last_strides_3d(
const ArrayRef<T> sizes,
const ArrayRef<T> strides) {
switch (sizes.size()) {
case 5:
return is_channels_last_strides_3d_s5(sizes, strides);
// NOLINTNEXTLINE(bugprone-branch-clone)
case 4:
// TODO dim == 4 case will be enabled once it is fully tested
return false;
default:
return false;
}
}
inline bool is_channels_last_strides_2d(
const IntArrayRef sizes,
const IntArrayRef strides) {
return is_channels_last_strides_2d<int64_t>(sizes, strides);
}
inline bool is_channels_last_strides_3d(
const IntArrayRef sizes,
const IntArrayRef strides) {
return is_channels_last_strides_3d<int64_t>(sizes, strides);
}
} // namespace c10

View File

@ -0,0 +1,31 @@
#pragma once
namespace c10 {
template <typename T>
class OptionalRef {
public:
OptionalRef() : data_(nullptr) {}
OptionalRef(const T* data) : data_(data) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(data_);
}
OptionalRef(const T& data) : data_(&data) {}
bool has_value() const {
return data_ != nullptr;
}
const T& get() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(data_);
return *data_;
}
operator bool() const {
return has_value();
}
private:
const T* data_;
};
} // namespace c10

View File

@ -0,0 +1,76 @@
#pragma once
#include <c10/core/impl/PyInterpreter.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/python_stub.h>
#include <atomic>
namespace c10 {
// A PyHandleCache represents a cached pointer from a C++ object to
// a Python object that represents that object analogously in Python.
// Upon a cache hit, the relevant object can be retrieved after a test
// and then a memory load. Two conditions must hold to be able to use this
// class:
//
// - This must truly be a cache; e.g., the caller must be able to produce
// the object some other way if the cache hit misses.
//
// - This must truly be a handle; e.g., the Python object referenced by
// this class must have static lifetime. This means we don't have to
// maintain strong ownership or deallocate the object when the C++ object
// dies. Static lifetime is a good idea in conjunction with the cache,
// since if you are producing a fresh object on miss you won't be
// maintaining object identity. If you need bidirectional ownership,
// you will want to factor out the pattern in TensorImpl with
// resurrection.
//
// This cache is expected to not improve perf under torchdeploy, as one
// interpreter will fill up the cache, and all the interpreters will be
// unable to use the slot. A potential improvement is to have multiple
// slots (one per interpreter), which will work in deployment scenarios
// where there a stable, fixed number of interpreters. You can also store
// the relevant state in the Python library, rather than in the non-Python
// library (although in many cases, this is not convenient, as there may
// not be a way to conveniently index based on the object.)
class PyHandleCache {
public:
PyHandleCache() : pyinterpreter_(nullptr) {}
// Attempt to fetch the pointer from the cache, if the PyInterpreter
// matches. If it doesn't exist, or the cache entry is not valid,
// use slow_accessor to get the real pointer value and return that
// (possibly writing it to the cache, if the cache entry is
// available.)
template <typename F>
PyObject* ptr_or(impl::PyInterpreter* self_interpreter, F slow_accessor)
const {
// Note [Memory ordering on Python interpreter tag]
impl::PyInterpreter* interpreter =
pyinterpreter_.load(std::memory_order_acquire);
if (C10_LIKELY(interpreter == self_interpreter)) {
return data_;
} else if (interpreter == nullptr) {
auto* r = slow_accessor();
impl::PyInterpreter* expected = nullptr;
// attempt to claim this cache entry with the specified interpreter tag
if (pyinterpreter_.compare_exchange_strong(
expected, self_interpreter, std::memory_order_acq_rel)) {
data_ = r;
}
// This shouldn't be possible, as you should be GIL protected
TORCH_INTERNAL_ASSERT(expected != self_interpreter);
return r;
} else {
return slow_accessor();
}
}
private:
mutable std::atomic<impl::PyInterpreter*> pyinterpreter_;
mutable PyObject* data_{nullptr};
};
} // namespace c10

View File

@ -0,0 +1,46 @@
#pragma once
#include <c10/util/Exception.h>
#include <cstdint>
#include <string>
namespace c10 {
/**
* QEngine is an enum that is used to select the engine to run quantized ops.
* Keep this enum in sync with get_qengine_id() in
* torch/backends/quantized/__init__.py
*/
enum class QEngine : uint8_t {
NoQEngine = 0,
FBGEMM = 1,
QNNPACK = 2,
ONEDNN = 3,
X86 = 4,
};
constexpr auto kNoQEngine = QEngine::NoQEngine;
constexpr auto kFBGEMM = QEngine::FBGEMM;
constexpr auto kQNNPACK = QEngine::QNNPACK;
constexpr auto kONEDNN = QEngine::ONEDNN;
constexpr auto kX86 = QEngine::X86;
inline std::string toString(QEngine qengine) {
switch (qengine) {
case kNoQEngine:
return "NoQEngine";
case kFBGEMM:
return "FBGEMM";
case kQNNPACK:
return "QNNPACK";
case kONEDNN:
return "ONEDNN";
case kX86:
return "X86";
default:
TORCH_CHECK(
false, "Unrecognized Quantized Engine: ", static_cast<int>(qengine));
}
}
} // namespace c10

View File

@ -0,0 +1,50 @@
#pragma once
#include <c10/util/Exception.h>
#include <cstdint>
#include <string>
namespace c10 {
/**
* QScheme is an enum that specifies the type of quantization. This has a one
* to one correspondence with Quantizer
* Please refer to ATen/quantized/Quantizer.h to see the Quantizers classes.
* Keep this file in sync with torch/nn/_qscheme.py
*/
enum class QScheme : uint8_t {
PER_TENSOR_AFFINE = 0,
PER_CHANNEL_AFFINE = 1,
PER_TENSOR_SYMMETRIC = 2,
PER_CHANNEL_SYMMETRIC = 3,
PER_CHANNEL_AFFINE_FLOAT_QPARAMS = 4,
COMPILE_TIME_NUM_QSCHEMES = 5,
};
constexpr auto kPerTensorAffine = QScheme::PER_TENSOR_AFFINE;
constexpr auto kPerChannelAffine = QScheme::PER_CHANNEL_AFFINE;
constexpr auto kPerTensorSymmetric = QScheme::PER_TENSOR_SYMMETRIC;
constexpr auto kPerChannelSymmetric = QScheme::PER_CHANNEL_SYMMETRIC;
constexpr auto kPerChannelAffineFloatQParams =
QScheme::PER_CHANNEL_AFFINE_FLOAT_QPARAMS;
constexpr int COMPILE_TIME_NUM_QSCHEMES =
static_cast<int>(QScheme::COMPILE_TIME_NUM_QSCHEMES);
inline std::string toString(QScheme qscheme) {
switch (qscheme) {
case kPerTensorAffine:
return "per_tensor_affine";
case kPerChannelAffine:
return "per_channel_affine";
case kPerTensorSymmetric:
return "per_tensor_symmetric";
case kPerChannelSymmetric:
return "per_channel_symmetric";
case kPerChannelAffineFloatQParams:
return "per_channel_affine_float_qparams";
default:
TORCH_CHECK(false, "Unrecognized qscheme: ", static_cast<int>(qscheme));
}
}
} // namespace c10

View File

@ -0,0 +1,52 @@
#pragma once
#include <c10/core/Storage.h>
#include <c10/macros/Export.h>
#include <c10/util/UniqueVoidPtr.h>
#include <atomic>
#include <memory>
namespace c10 {
// A RefcountedDeleterContext object is used as the `ctx` argument for DataPtr
// to implement a shared DataPtr. Normally, a DataPtr is unique, but we use
// this custom context and the `refcounted_deleter` function below to make the
// DataPtr act like a non-unique DataPtr. This context object holds onto an
// inner context and deleter function which handle the actual deletion of the
// data when the refcount reaches 0.
//
// This shared DataPtr feature is only used when storages are shared between
// multiple Python interpreters in MultiPy. Before storages had PyObject
// preservation, interpreters could just share the same StorageImpl instance.
// But now a StorageImpl can only be associated with one interpreter in order
// to properly manage a zombie PyObject. So we share storages across Python
// interpreters by creating a different StorageImpl instance for each one, but
// they all point to the same data.
struct C10_API RefcountedDeleterContext {
RefcountedDeleterContext(void* other_ctx, c10::DeleterFnPtr other_deleter)
: other_ctx(other_ctx, other_deleter), refcount(1) {}
std::unique_ptr<void, c10::DeleterFnPtr> other_ctx;
std::atomic_int refcount;
};
// `refcounted_deleter` is used as the `ctx_deleter` for DataPtr to implement
// a shared DataPtr.
//
// Warning: This should only be called on a pointer to
// a RefcountedDeleterContext that was allocated on the heap with `new`,
// because when the refcount reaches 0, the context is deleted with `delete`
C10_API void refcounted_deleter(void* ctx_);
// If the storage's DataPtr does not use `refcounted_deleter`, replace it with
// a DataPtr that does, so it can be shared between multiple StorageImpls
C10_API void maybeApplyRefcountedDeleter(const c10::Storage& storage);
// Create a new StorageImpl that points to the same data. If the original
// StorageImpl's DataPtr does not use `refcounted_deleter`, it will be replaced
// with one that does
C10_API c10::Storage newStorageImplFromRefcountedDataPtr(
const c10::Storage& storage);
} // namespace c10

View File

@ -0,0 +1,118 @@
#pragma once
#include <c10/core/impl/PyInterpreter.h>
#include <c10/macros/Export.h>
#include <c10/util/python_stub.h>
#include <utility>
namespace c10 {
// This is an safe owning holder for a PyObject, akin to pybind11's
// py::object, with two major differences:
//
// - It is in c10/core; i.e., you can use this type in contexts where
// you do not have a libpython dependency
//
// - It is multi-interpreter safe (ala torchdeploy); when you fetch
// the underlying PyObject* you are required to specify what the current
// interpreter context is and we will check that you match it.
//
// It is INVALID to store a reference to a Tensor object in this way;
// you should just use TensorImpl directly in that case!
struct C10_API SafePyObject {
// Steals a reference to data
SafePyObject(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
: data_(data), pyinterpreter_(pyinterpreter) {}
SafePyObject(SafePyObject&& other) noexcept
: data_(std::exchange(other.data_, nullptr)),
pyinterpreter_(other.pyinterpreter_) {}
// For now it's not used, so we just disallow it.
SafePyObject& operator=(SafePyObject&&) = delete;
SafePyObject(SafePyObject const& other)
: data_(other.data_), pyinterpreter_(other.pyinterpreter_) {
if (data_ != nullptr) {
(*pyinterpreter_)->incref(data_);
}
}
SafePyObject& operator=(SafePyObject const& other) {
if (this == &other) {
return *this; // Handle self-assignment
}
if (other.data_ != nullptr) {
(*other.pyinterpreter_)->incref(other.data_);
}
if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
}
data_ = other.data_;
pyinterpreter_ = other.pyinterpreter_;
return *this;
}
~SafePyObject() {
if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
}
}
c10::impl::PyInterpreter& pyinterpreter() const {
return *pyinterpreter_;
}
PyObject* ptr(const c10::impl::PyInterpreter*) const;
// stop tracking the current object, and return it
PyObject* release() {
auto rv = data_;
data_ = nullptr;
return rv;
}
private:
PyObject* data_;
c10::impl::PyInterpreter* pyinterpreter_;
};
// A newtype wrapper around SafePyObject for type safety when a python object
// represents a specific type. Note that `T` is only used as a tag and isn't
// actually used for any true purpose.
template <typename T>
struct SafePyObjectT : private SafePyObject {
SafePyObjectT(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
: SafePyObject(data, pyinterpreter) {}
SafePyObjectT(SafePyObjectT&& other) noexcept : SafePyObject(other) {}
SafePyObjectT(SafePyObjectT const&) = delete;
SafePyObjectT& operator=(SafePyObjectT const&) = delete;
using SafePyObject::ptr;
using SafePyObject::pyinterpreter;
using SafePyObject::release;
};
// Like SafePyObject, but non-owning. Good for references to global PyObjects
// that will be leaked on interpreter exit. You get a copy constructor/assign
// this way.
struct C10_API SafePyHandle {
SafePyHandle() : data_(nullptr), pyinterpreter_(nullptr) {}
SafePyHandle(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
: data_(data), pyinterpreter_(pyinterpreter) {}
c10::impl::PyInterpreter& pyinterpreter() const {
return *pyinterpreter_;
}
PyObject* ptr(const c10::impl::PyInterpreter*) const;
void reset() {
data_ = nullptr;
pyinterpreter_ = nullptr;
}
operator bool() {
return data_;
}
private:
PyObject* data_;
c10::impl::PyInterpreter* pyinterpreter_;
};
} // namespace c10

View File

@ -0,0 +1,467 @@
#pragma once
#include <cstdint>
#include <stdexcept>
#include <type_traits>
#include <utility>
#include <c10/core/OptionalRef.h>
#include <c10/core/ScalarType.h>
#include <c10/core/SymBool.h>
#include <c10/core/SymFloat.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymNodeImpl.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/Deprecated.h>
#include <c10/util/Exception.h>
#include <c10/util/Half.h>
#include <c10/util/TypeCast.h>
#include <c10/util/complex.h>
#include <c10/util/intrusive_ptr.h>
namespace c10 {
/**
* Scalar represents a 0-dimensional tensor which contains a single element.
* Unlike a tensor, numeric literals (in C++) are implicitly convertible to
* Scalar (which is why, for example, we provide both add(Tensor) and
* add(Scalar) overloads for many operations). It may also be used in
* circumstances where you statically know a tensor is 0-dim and single size,
* but don't know its type.
*/
class C10_API Scalar {
public:
Scalar() : Scalar(int64_t(0)) {}
void destroy() {
if (Tag::HAS_si == tag || Tag::HAS_sd == tag || Tag::HAS_sb == tag) {
raw::intrusive_ptr::decref(v.p);
v.p = nullptr;
}
}
~Scalar() {
destroy();
}
#define DEFINE_IMPLICIT_CTOR(type, name) \
Scalar(type vv) : Scalar(vv, true) {}
AT_FORALL_SCALAR_TYPES_AND7(
Half,
BFloat16,
Float8_e5m2,
Float8_e4m3fn,
Float8_e5m2fnuz,
Float8_e4m3fnuz,
ComplexHalf,
DEFINE_IMPLICIT_CTOR)
AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR)
// Helper constructors to allow Scalar creation from long and long long types
// As std::is_same_v<long, long long> is false(except Android), one needs to
// provide a constructor from either long or long long in addition to one from
// int64_t
#if defined(__APPLE__) || defined(__MACOSX)
static_assert(
std::is_same_v<long long, int64_t>,
"int64_t is the same as long long on MacOS");
Scalar(long vv) : Scalar(vv, true) {}
#endif
#if defined(_MSC_VER)
static_assert(
std::is_same_v<long long, int64_t>,
"int64_t is the same as long long on Windows");
Scalar(long vv) : Scalar(vv, true) {}
#endif
#if defined(__linux__) && !defined(__ANDROID__)
static_assert(
std::is_same_v<long, int64_t>,
"int64_t is the same as long on Linux");
Scalar(long long vv) : Scalar(vv, true) {}
#endif
Scalar(uint16_t vv) : Scalar(vv, true) {}
Scalar(uint32_t vv) : Scalar(vv, true) {}
Scalar(uint64_t vv) {
if (vv > static_cast<uint64_t>(INT64_MAX)) {
tag = Tag::HAS_u;
v.u = vv;
} else {
tag = Tag::HAS_i;
// NB: no need to use convert, we've already tested convertibility
v.i = static_cast<int64_t>(vv);
}
}
#undef DEFINE_IMPLICIT_CTOR
// Value* is both implicitly convertible to SymbolicVariable and bool which
// causes ambiguity error. Specialized constructor for bool resolves this
// problem.
template <
typename T,
typename std::enable_if_t<std::is_same_v<T, bool>, bool>* = nullptr>
Scalar(T vv) : tag(Tag::HAS_b) {
v.i = convert<int64_t, bool>(vv);
}
template <
typename T,
typename std::enable_if_t<std::is_same_v<T, c10::SymBool>, bool>* =
nullptr>
Scalar(T vv) : tag(Tag::HAS_sb) {
v.i = convert<int64_t, c10::SymBool>(vv);
}
#define DEFINE_ACCESSOR(type, name) \
type to##name() const { \
if (Tag::HAS_d == tag) { \
return checked_convert<type, double>(v.d, #type); \
} else if (Tag::HAS_z == tag) { \
return checked_convert<type, c10::complex<double>>(v.z, #type); \
} \
if (Tag::HAS_b == tag) { \
return checked_convert<type, bool>(v.i, #type); \
} else if (Tag::HAS_i == tag) { \
return checked_convert<type, int64_t>(v.i, #type); \
} else if (Tag::HAS_u == tag) { \
return checked_convert<type, uint64_t>(v.u, #type); \
} else if (Tag::HAS_si == tag) { \
return checked_convert<type, int64_t>( \
toSymInt().guard_int(__FILE__, __LINE__), #type); \
} else if (Tag::HAS_sd == tag) { \
return checked_convert<type, int64_t>( \
toSymFloat().guard_float(__FILE__, __LINE__), #type); \
} else if (Tag::HAS_sb == tag) { \
return checked_convert<type, int64_t>( \
toSymBool().guard_bool(__FILE__, __LINE__), #type); \
} \
TORCH_CHECK(false) \
}
// TODO: Support ComplexHalf accessor
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ACCESSOR)
DEFINE_ACCESSOR(uint16_t, UInt16)
DEFINE_ACCESSOR(uint32_t, UInt32)
DEFINE_ACCESSOR(uint64_t, UInt64)
#undef DEFINE_ACCESSOR
SymInt toSymInt() const {
if (Tag::HAS_si == tag) {
return c10::SymInt(intrusive_ptr<SymNodeImpl>::reclaim_copy(
static_cast<SymNodeImpl*>(v.p)));
} else {
return toLong();
}
}
SymFloat toSymFloat() const {
if (Tag::HAS_sd == tag) {
return c10::SymFloat(intrusive_ptr<SymNodeImpl>::reclaim_copy(
static_cast<SymNodeImpl*>(v.p)));
} else {
return toDouble();
}
}
SymBool toSymBool() const {
if (Tag::HAS_sb == tag) {
return c10::SymBool(intrusive_ptr<SymNodeImpl>::reclaim_copy(
static_cast<SymNodeImpl*>(v.p)));
} else {
return toBool();
}
}
// also support scalar.to<int64_t>();
// Deleted for unsupported types, but specialized below for supported types
template <typename T>
T to() const = delete;
// audit uses of data_ptr
const void* data_ptr() const {
TORCH_INTERNAL_ASSERT(!isSymbolic());
return static_cast<const void*>(&v);
}
bool isFloatingPoint() const {
return Tag::HAS_d == tag || Tag::HAS_sd == tag;
}
C10_DEPRECATED_MESSAGE(
"isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.")
bool isIntegral() const {
return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag;
}
bool isIntegral(bool includeBool) const {
return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag ||
(includeBool && isBoolean());
}
bool isComplex() const {
return Tag::HAS_z == tag;
}
bool isBoolean() const {
return Tag::HAS_b == tag || Tag::HAS_sb == tag;
}
// you probably don't actually want these; they're mostly for testing
bool isSymInt() const {
return Tag::HAS_si == tag;
}
bool isSymFloat() const {
return Tag::HAS_sd == tag;
}
bool isSymBool() const {
return Tag::HAS_sb == tag;
}
bool isSymbolic() const {
return Tag::HAS_si == tag || Tag::HAS_sd == tag || Tag::HAS_sb == tag;
}
C10_ALWAYS_INLINE Scalar& operator=(Scalar&& other) noexcept {
if (&other == this) {
return *this;
}
destroy();
moveFrom(std::move(other));
return *this;
}
C10_ALWAYS_INLINE Scalar& operator=(const Scalar& other) {
if (&other == this) {
return *this;
}
*this = Scalar(other);
return *this;
}
Scalar operator-() const;
Scalar conj() const;
Scalar log() const;
template <
typename T,
typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
bool equal(T num) const {
if (isComplex()) {
TORCH_INTERNAL_ASSERT(!isSymbolic());
auto val = v.z;
return (val.real() == num) && (val.imag() == T());
} else if (isFloatingPoint()) {
TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality");
return v.d == num;
} else if (tag == Tag::HAS_i) {
if (overflows<T>(v.i, /* strict_unsigned */ true)) {
return false;
} else {
return static_cast<T>(v.i) == num;
}
} else if (tag == Tag::HAS_u) {
if (overflows<T>(v.u, /* strict_unsigned */ true)) {
return false;
} else {
return static_cast<T>(v.u) == num;
}
} else if (tag == Tag::HAS_si) {
TORCH_INTERNAL_ASSERT(false, "NYI SymInt equality");
} else if (isBoolean()) {
// boolean scalar does not equal to a non boolean value
TORCH_INTERNAL_ASSERT(!isSymbolic());
return false;
} else {
TORCH_INTERNAL_ASSERT(false);
}
}
template <
typename T,
typename std::enable_if_t<c10::is_complex<T>::value, int> = 0>
bool equal(T num) const {
if (isComplex()) {
TORCH_INTERNAL_ASSERT(!isSymbolic());
return v.z == num;
} else if (isFloatingPoint()) {
TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality");
return (v.d == num.real()) && (num.imag() == T());
} else if (tag == Tag::HAS_i) {
if (overflows<T>(v.i, /* strict_unsigned */ true)) {
return false;
} else {
return static_cast<T>(v.i) == num.real() && num.imag() == T();
}
} else if (tag == Tag::HAS_u) {
if (overflows<T>(v.u, /* strict_unsigned */ true)) {
return false;
} else {
return static_cast<T>(v.u) == num.real() && num.imag() == T();
}
} else if (tag == Tag::HAS_si) {
TORCH_INTERNAL_ASSERT(false, "NYI SymInt equality");
} else if (isBoolean()) {
// boolean scalar does not equal to a non boolean value
TORCH_INTERNAL_ASSERT(!isSymbolic());
return false;
} else {
TORCH_INTERNAL_ASSERT(false);
}
}
bool equal(bool num) const {
if (isBoolean()) {
TORCH_INTERNAL_ASSERT(!isSymbolic());
return static_cast<bool>(v.i) == num;
} else {
return false;
}
}
ScalarType type() const {
if (isComplex()) {
return ScalarType::ComplexDouble;
} else if (isFloatingPoint()) {
return ScalarType::Double;
} else if (isIntegral(/*includeBool=*/false)) {
// Represent all integers as long, UNLESS it is unsigned and therefore
// unrepresentable as long
if (Tag::HAS_u == tag) {
return ScalarType::UInt64;
}
return ScalarType::Long;
} else if (isBoolean()) {
return ScalarType::Bool;
} else {
throw std::runtime_error("Unknown scalar type.");
}
}
Scalar(Scalar&& rhs) noexcept : tag(rhs.tag) {
moveFrom(std::move(rhs));
}
Scalar(const Scalar& rhs) : tag(rhs.tag), v(rhs.v) {
if (isSymbolic()) {
c10::raw::intrusive_ptr::incref(v.p);
}
}
Scalar(c10::SymInt si) {
if (auto m = si.maybe_as_int()) {
tag = Tag::HAS_i;
v.i = *m;
} else {
tag = Tag::HAS_si;
v.p = std::move(si).release();
}
}
Scalar(c10::SymFloat sd) {
if (sd.is_symbolic()) {
tag = Tag::HAS_sd;
v.p = std::move(sd).release();
} else {
tag = Tag::HAS_d;
v.d = sd.as_float_unchecked();
}
}
Scalar(c10::SymBool sb) {
if (auto m = sb.maybe_as_bool()) {
tag = Tag::HAS_b;
v.i = *m;
} else {
tag = Tag::HAS_sb;
v.p = std::move(sb).release();
}
}
// We can't set v in the initializer list using the
// syntax v{ .member = ... } because it doesn't work on MSVC
private:
enum class Tag { HAS_d, HAS_i, HAS_u, HAS_z, HAS_b, HAS_sd, HAS_si, HAS_sb };
// Note [Meaning of HAS_u]
// ~~~~~~~~~~~~~~~~~~~~~~~
// HAS_u is a bit special. On its face, it just means that we
// are holding an unsigned integer. However, we generally don't
// distinguish between different bit sizes in Scalar (e.g., we represent
// float as double), instead, it represents a mathematical notion
// of some quantity (integral versus floating point). So actually,
// HAS_u is used solely to represent unsigned integers that could
// not be represented as a signed integer. That means only uint64_t
// potentially can get this tag; smaller types like uint8_t fits into a
// regular int and so for BC reasons we keep as an int.
// NB: assumes that self has already been cleared
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
C10_ALWAYS_INLINE void moveFrom(Scalar&& rhs) noexcept {
v = rhs.v;
tag = rhs.tag;
if (rhs.tag == Tag::HAS_si || rhs.tag == Tag::HAS_sd ||
rhs.tag == Tag::HAS_sb) {
// Move out of scalar
rhs.tag = Tag::HAS_i;
rhs.v.i = 0;
}
}
Tag tag;
union v_t {
double d{};
int64_t i;
// See Note [Meaning of HAS_u]
uint64_t u;
c10::complex<double> z;
c10::intrusive_ptr_target* p;
// NOLINTNEXTLINE(modernize-use-equals-default)
v_t() {} // default constructor
} v;
template <
typename T,
typename std::enable_if_t<
std::is_integral_v<T> && !std::is_same_v<T, bool>,
bool>* = nullptr>
Scalar(T vv, bool) : tag(Tag::HAS_i) {
v.i = convert<decltype(v.i), T>(vv);
}
template <
typename T,
typename std::enable_if_t<
!std::is_integral_v<T> && !c10::is_complex<T>::value,
bool>* = nullptr>
Scalar(T vv, bool) : tag(Tag::HAS_d) {
v.d = convert<decltype(v.d), T>(vv);
}
template <
typename T,
typename std::enable_if_t<c10::is_complex<T>::value, bool>* = nullptr>
Scalar(T vv, bool) : tag(Tag::HAS_z) {
v.z = convert<decltype(v.z), T>(vv);
}
};
using OptionalScalarRef = c10::OptionalRef<Scalar>;
// define the scalar.to<int64_t>() specializations
#define DEFINE_TO(T, name) \
template <> \
inline T Scalar::to<T>() const { \
return to##name(); \
}
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_TO)
DEFINE_TO(uint16_t, UInt16)
DEFINE_TO(uint32_t, UInt32)
DEFINE_TO(uint64_t, UInt64)
#undef DEFINE_TO
} // namespace c10

View File

@ -0,0 +1,573 @@
#pragma once
#include <c10/util/BFloat16.h>
#include <c10/util/Deprecated.h>
#include <c10/util/Exception.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Half.h>
#include <c10/util/bits.h>
#include <c10/util/complex.h>
#include <c10/util/qint32.h>
#include <c10/util/qint8.h>
#include <c10/util/quint2x4.h>
#include <c10/util/quint4x2.h>
#include <c10/util/quint8.h>
#include <array>
#include <cstddef>
#include <cstdint>
#include <limits>
#include <ostream>
#include <type_traits>
#include <unordered_map>
namespace c10 {
// dummy struct for uint1 to uint7, actual functionality
// of these dtypes will be implemented in python with Tensor subclass
template <unsigned int N>
struct dummy_uint1_7_t {};
// For the macros below:
//
// For users: If you want to macro some code for all non-QInt scalar types
// (i.e. types with complete information, you probably want one of the
// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND macros below, which are
// designed to behave similarly to the Dispatch macros with the same name.
//
// For adding a new dtype: In the beginning, we had an idea that there was a
// list of all scalar types, and you could use AT_FORALL_SCALAR_TYPES to
// iterate over them. But over the years we added weird types which couldn't
// be handled uniformly everywhere and so in the end we ended up with some
// mish-mosh of some helper macros, but mostly use sites making a call about
// what dtypes they can or can't support. So if you want to add a new dtype,
// the preferred resolution is to find a dtype similar to what you want,
// grep for it and edit all the sites you find this way. If you need to add
// a completely new kind of dtype, you're going to have to laboriously audit
// all of the sites everywhere to figure out how it should work. Consulting
// some old PRs where we added new dtypes (check history of this file) can
// help give you an idea where to start.
// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
_(uint8_t, Byte) /* 0 */ \
_(int8_t, Char) /* 1 */ \
_(int16_t, Short) /* 2 */ \
_(int, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \
_(at::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
_(c10::complex<float>, ComplexFloat) /* 9 */ \
_(c10::complex<double>, ComplexDouble) /* 10 */ \
_(bool, Bool) /* 11 */ \
_(c10::qint8, QInt8) /* 12 */ \
_(c10::quint8, QUInt8) /* 13 */ \
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */ \
_(c10::bits1x8, Bits1x8) /* 18 */ \
_(c10::bits2x4, Bits2x4) /* 19 */ \
_(c10::bits4x2, Bits4x2) /* 20 */ \
_(c10::bits8, Bits8) /* 21 */ \
_(c10::bits16, Bits16) /* 22 */ \
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
_(uint16_t, UInt16) /* 27 */ \
_(uint32_t, UInt32) /* 28 */ \
_(uint64_t, UInt64) /* 29 */ \
_(c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \
_(c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \
_(c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \
_(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \
_(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \
_(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \
_(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */
// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert()
// doesn't work for all the conversions you need...
//
// TODO: To add unsigned int types here, we must define accumulate type.
// But uint8 currently accumulates into int64, so we would have to make
// an inconsistent choice for the larger types. Difficult.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(at::Half, Half) \
_(float, Float) \
_(double, Double) \
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble) \
_(bool, Bool) \
_(at::BFloat16, BFloat16) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn)
// This macro controls many of our C++ APIs, including constructors
// for Scalar as well as the data() and item() accessors on Tensor
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(at::Half, Half) \
_(float, Float) \
_(double, Double) \
_(c10::complex<c10::Half>, ComplexHalf) \
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble) \
_(bool, Bool) \
_(at::BFloat16, BFloat16) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn) \
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz)
enum class ScalarType : int8_t {
#define DEFINE_ST_ENUM_VAL_(_1, n) n,
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_)
#undef DEFINE_ENUM_ST_ENUM_VAL_
Undefined,
NumOptions
};
constexpr uint16_t NumScalarTypes =
static_cast<uint16_t>(ScalarType::NumOptions);
namespace impl {
// These are used to map ScalarTypes to C++ types.
template <c10::ScalarType N>
struct ScalarTypeToCPPType;
#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \
template <> \
struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \
using type = cpp_type; \
\
/* This is a workaround for the CUDA bug which prevents */ \
/* ::detail::ScalarTypeToCType<T>::type being used directly due to */ \
/* ambiguous reference which can't to be resolved. For some reason it */ \
/* can't pick between at::detail and at::cuda::detail. */ \
/* For repro example, please see: */ \
/* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
/* TODO: remove once the bug is fixed. */ \
static type t; \
};
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
#undef SPECIALIZE_ScalarTypeToCPPType
template <c10::ScalarType N>
using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
} // namespace impl
template <typename T>
struct CppTypeToScalarType;
#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
template <> \
struct CppTypeToScalarType<cpp_type> \
: std:: \
integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
};
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
#undef SPECIALIZE_CppTypeToScalarType
// NB: despite its generic sounding name, the macros that don't take _AND
// are mostly only used by tensorexpr
#define AT_FORALL_INT_TYPES(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long)
#define AT_FORALL_SCALAR_TYPES(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double)
// These macros are often controlling how many template instantiations we
// create for kernels. It is typically inappropriate to add new dtypes here,
// instead, new types should be added to use sites on a case-by-case basis.
// We generally are not accepting new dtypes due to binary size concerns.
#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE>::t), \
SCALARTYPE)
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2)
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE3>::t), \
SCALARTYPE3)
#define AT_FORALL_SCALAR_TYPES_AND7( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
SCALARTYPE7, \
_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE3>::t), \
SCALARTYPE3) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE4>::t), \
SCALARTYPE4) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE5>::t), \
SCALARTYPE5) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE6>::t), \
SCALARTYPE6) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE7>::t), \
SCALARTYPE7)
#define AT_FORALL_QINT_TYPES(_) \
_(c10::qint8, QInt8) \
_(c10::quint8, QUInt8) \
_(c10::qint32, QInt32) \
_(c10::quint4x2, QUInt4x2) \
_(c10::quint2x4, QUInt2x4)
#define AT_FORALL_COMPLEX_TYPES(_) \
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble)
#define DEFINE_CONSTANT(_, name) \
constexpr ScalarType k##name = ScalarType::name;
// NOLINTNEXTLINE(clang-diagnostic-unused-const-variable)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
#undef DEFINE_CONSTANT
inline const char* toString(ScalarType t) {
#define DEFINE_CASE(_, name) \
case ScalarType::name: \
return #name;
switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
default:
return "UNKNOWN_SCALAR";
}
#undef DEFINE_CASE
}
inline size_t elementSize(ScalarType t) {
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
case ScalarType::name: \
return sizeof(ctype);
switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE)
default:
TORCH_CHECK(false, "Unknown ScalarType");
}
#undef CASE_ELEMENTSIZE_CASE
}
inline bool isIntegralType(ScalarType t, bool includeBool) {
bool isIntegral =
(t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
t == ScalarType::Long || t == ScalarType::Short ||
t == ScalarType::UInt16 || t == ScalarType::UInt32 ||
t == ScalarType::UInt64);
return isIntegral || (includeBool && t == ScalarType::Bool);
}
C10_DEPRECATED_MESSAGE(
"isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.")
inline bool isIntegralType(ScalarType t) {
return isIntegralType(t, /*includeBool=*/false);
}
inline bool isFloat8Type(ScalarType t) {
return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e5m2fnuz ||
t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz;
}
inline bool isReducedFloatingType(ScalarType t) {
return t == ScalarType::Half || t == ScalarType::BFloat16 || isFloat8Type(t);
}
inline bool isFloatingType(ScalarType t) {
return t == ScalarType::Double || t == ScalarType::Float ||
isReducedFloatingType(t);
}
inline bool isComplexType(ScalarType t) {
return (
t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat ||
t == ScalarType::ComplexDouble);
}
inline bool isQIntType(ScalarType t) {
// Don't forget to extend this when adding new QInt types
return t == ScalarType::QInt8 || t == ScalarType::QUInt8 ||
t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 ||
t == ScalarType::QUInt2x4;
}
inline bool isBitsType(ScalarType t) {
return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 ||
t == ScalarType::Bits4x2 || t == ScalarType::Bits8 ||
t == ScalarType::Bits16;
}
inline bool isBarebonesUnsignedType(ScalarType t) {
return t == ScalarType::UInt1 || t == ScalarType::UInt2 ||
t == ScalarType::UInt3 || t == ScalarType::UInt4 ||
t == ScalarType::UInt5 || t == ScalarType::UInt6 ||
t == ScalarType::UInt7 || t == ScalarType::UInt16 ||
t == ScalarType::UInt32 || t == ScalarType::UInt64;
}
inline ScalarType toQIntType(ScalarType t) {
switch (t) {
case ScalarType::Byte:
return ScalarType::QUInt8;
case ScalarType::Char:
return ScalarType::QInt8;
case ScalarType::Int:
return ScalarType::QInt32;
default:
return t;
}
}
inline ScalarType toUnderlying(ScalarType t) {
switch (t) {
case ScalarType::QUInt8:
case ScalarType::QUInt4x2:
[[fallthrough]];
case ScalarType::QUInt2x4:
return ScalarType::Byte;
case ScalarType::QInt8:
return ScalarType::Char;
case ScalarType::QInt32:
return ScalarType::Int;
default:
return t;
}
}
inline bool isSignedType(ScalarType t) {
#define CASE_ISSIGNED(name) \
case ScalarType::name: \
return std::numeric_limits< \
::c10::impl::ScalarTypeToCPPTypeT<ScalarType::name>>::is_signed;
switch (t) {
case ScalarType::QInt8:
case ScalarType::QUInt8:
case ScalarType::QInt32:
case ScalarType::QUInt4x2:
case ScalarType::QUInt2x4:
TORCH_CHECK(false, "isSignedType not supported for quantized types");
case ScalarType::Bits1x8:
case ScalarType::Bits2x4:
case ScalarType::Bits4x2:
case ScalarType::Bits8:
case ScalarType::Bits16:
TORCH_CHECK(false, "Bits types are undefined");
CASE_ISSIGNED(UInt16);
CASE_ISSIGNED(UInt32);
CASE_ISSIGNED(UInt64);
CASE_ISSIGNED(BFloat16);
CASE_ISSIGNED(Float8_e5m2);
CASE_ISSIGNED(Float8_e5m2fnuz);
CASE_ISSIGNED(Float8_e4m3fn);
CASE_ISSIGNED(Float8_e4m3fnuz);
CASE_ISSIGNED(Byte);
CASE_ISSIGNED(Char);
CASE_ISSIGNED(Short);
CASE_ISSIGNED(Int);
CASE_ISSIGNED(Long);
CASE_ISSIGNED(Half);
CASE_ISSIGNED(Float);
CASE_ISSIGNED(Double);
CASE_ISSIGNED(ComplexHalf);
CASE_ISSIGNED(ComplexFloat);
CASE_ISSIGNED(ComplexDouble);
CASE_ISSIGNED(Bool);
case ScalarType::UInt1:
case ScalarType::UInt2:
case ScalarType::UInt3:
case ScalarType::UInt4:
case ScalarType::UInt5:
case ScalarType::UInt6:
case ScalarType::UInt7:
return true;
case ScalarType::Undefined:
case ScalarType::NumOptions:
break;
// Do not add default here, but rather define behavior of every new entry
// here. `-Wswitch-enum` would raise a warning in those cases.
}
TORCH_CHECK(false, "Unknown ScalarType ", t);
#undef CASE_ISSIGNED
}
inline bool isUnderlying(ScalarType type, ScalarType qtype) {
return type == toUnderlying(qtype);
}
inline ScalarType toRealValueType(ScalarType t) {
switch (t) {
case ScalarType::ComplexHalf:
return ScalarType::Half;
case ScalarType::ComplexFloat:
return ScalarType::Float;
case ScalarType::ComplexDouble:
return ScalarType::Double;
default:
return t;
}
}
inline ScalarType toComplexType(ScalarType t) {
switch (t) {
case ScalarType::BFloat16:
// BFloat16 has range equivalent to Float,
// so we map it to ComplexFloat.
return ScalarType::ComplexFloat;
case ScalarType::Half:
return ScalarType::ComplexHalf;
case ScalarType::Float:
return ScalarType::ComplexFloat;
case ScalarType::Double:
return ScalarType::ComplexDouble;
case ScalarType::ComplexHalf:
return ScalarType::ComplexHalf;
case ScalarType::ComplexFloat:
return ScalarType::ComplexFloat;
case ScalarType::ComplexDouble:
return ScalarType::ComplexDouble;
default:
TORCH_CHECK(false, "Unknown Complex ScalarType for ", t);
}
}
// see tensor_attributes.rst for detailed explanation and examples
// of casting rules.
inline bool canCast(const ScalarType from, const ScalarType to) {
// We disallow complex -> non complex, e.g., float_tensor *= complex is
// disallowed.
if (isComplexType(from) && !isComplexType(to)) {
return false;
}
// We disallow float -> integral, e.g., int_tensor *= float is disallowed.
if (isFloatingType(from) && isIntegralType(to, false)) {
return false;
}
// Treat bool as a distinct "category," to be consistent with type promotion
// rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same
// category as `bool_tensor`, we would not promote. Differing categories
// implies `bool_tensor += 5` is disallowed.
//
// NB: numpy distinguishes "unsigned" as a category to get the desired
// `bool_tensor + 5 -> int64_tensor` behavior. We don't, because:
// * We don't want the performance hit of checking the runtime sign of
// Scalars.
// * `uint8_tensor + 5 -> int64_tensor` would be undesirable.
if (from != ScalarType::Bool && to == ScalarType::Bool) {
return false;
}
return true;
}
C10_API ScalarType promoteTypes(ScalarType a, ScalarType b);
inline std::ostream& operator<<(
std::ostream& stream,
at::ScalarType scalar_type) {
return stream << toString(scalar_type);
}
// Returns a pair of strings representing the names for each dtype.
// The returned pair is (name, legacy_name_if_applicable)
C10_API std::pair<std::string, std::string> getDtypeNames(
c10::ScalarType scalarType);
// Returns a map of string name to dtype.
C10_API const std::unordered_map<std::string, ScalarType>& getStringToDtypeMap();
} // namespace c10

View File

@ -0,0 +1,57 @@
#pragma once
#include <c10/core/ScalarType.h>
#include <c10/util/Optional.h>
#include <c10/util/typeid.h>
// these just expose TypeMeta/ScalarType bridge functions in c10
// TODO move to typeid.h (or codemod away) when TypeMeta et al
// are moved from caffe2 to c10 (see note at top of typeid.h)
namespace c10 {
/**
* convert ScalarType enum values to TypeMeta handles
*/
inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {
return caffe2::TypeMeta::fromScalarType(scalar_type);
}
/**
* convert TypeMeta handles to ScalarType enum values
*/
inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) {
return dtype.toScalarType();
}
/**
* typeMetaToScalarType(), lifted to optional
*/
inline std::optional<at::ScalarType> optTypeMetaToScalarType(
std::optional<caffe2::TypeMeta> type_meta) {
if (!type_meta.has_value()) {
return std::nullopt;
}
return type_meta->toScalarType();
}
/**
* convenience: equality across TypeMeta/ScalarType conversion
*/
inline bool operator==(ScalarType t, caffe2::TypeMeta m) {
return m.isScalarType(t);
}
inline bool operator==(caffe2::TypeMeta m, ScalarType t) {
return t == m;
}
inline bool operator!=(ScalarType t, caffe2::TypeMeta m) {
return !(t == m);
}
inline bool operator!=(caffe2::TypeMeta m, ScalarType t) {
return !(t == m);
}
} // namespace c10

View File

@ -0,0 +1,272 @@
#pragma once
#include <c10/core/Allocator.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/StorageImpl.h>
#include <c10/core/SymInt.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
#include <c10/util/ExclusivelyOwned.h>
#include <c10/util/MaybeOwned.h>
#include <c10/util/UniqueVoidPtr.h>
#include <c10/util/intrusive_ptr.h>
#include <cstddef>
#include <utility>
namespace c10 {
struct Storage;
C10_API bool isSharedStorageAlias(
const Storage& storage0,
const Storage& storage1);
struct C10_API Storage {
public:
struct use_byte_size_t {};
struct unsafe_borrow_t {
explicit unsafe_borrow_t() = default;
};
Storage() = default;
Storage(c10::intrusive_ptr<StorageImpl> ptr)
: storage_impl_(std::move(ptr)) {}
// Allocates memory buffer using given allocator and creates a storage with it
Storage(
use_byte_size_t /*use_byte_size*/,
const SymInt& size_bytes,
Allocator* allocator = nullptr,
bool resizable = false)
: storage_impl_(c10::make_intrusive<StorageImpl>(
StorageImpl::use_byte_size_t(),
size_bytes,
allocator,
resizable)) {}
// Creates storage with pre-allocated memory buffer. Allocator is given for
// potential future reallocations, however it can be nullptr if the storage
// is non-resizable
Storage(
use_byte_size_t /*use_byte_size*/,
size_t size_bytes,
at::DataPtr data_ptr,
at::Allocator* allocator = nullptr,
bool resizable = false)
: storage_impl_(c10::make_intrusive<StorageImpl>(
StorageImpl::use_byte_size_t(),
size_bytes,
std::move(data_ptr),
allocator,
resizable)) {}
protected:
explicit Storage(unsafe_borrow_t, const Storage& rhs)
: storage_impl_(c10::intrusive_ptr<c10::StorageImpl>::reclaim(
rhs.storage_impl_.get())) {}
friend MaybeOwnedTraits<Storage>;
public:
// Legacy constructor for partially initialized (dtype or memory) storages
// that can be temporarily created with Caffe2 APIs. See the note on top of
// TensorImpl.h for details.
static Storage create_legacy(at::Device device) {
auto allocator = GetAllocator(device.type());
return Storage(c10::make_intrusive<StorageImpl>(
StorageImpl::use_byte_size_t(),
0,
allocator->allocate(0), // materialize a non-default Device.
allocator,
true));
}
// Mimic create_legacy, but without requiring a newly-created StorageImpl.
void reset_legacy() {
TORCH_CHECK(resizable() && allocator());
set_nbytes(0);
set_data_ptr_noswap(allocator()->allocate(0));
}
// TODO: remove later
void set_nbytes(size_t size_bytes) const {
storage_impl_->set_nbytes(size_bytes);
}
void set_nbytes(c10::SymInt size_bytes) const {
storage_impl_->set_nbytes(std::move(size_bytes));
}
bool resizable() const {
return storage_impl_->resizable();
}
size_t nbytes() const {
return storage_impl_->nbytes();
}
SymInt sym_nbytes() const {
return storage_impl_->sym_nbytes();
}
// get() use here is to get const-correctness
const void* data() const {
return storage_impl_->data();
}
void* mutable_data() const {
return storage_impl_->mutable_data();
}
at::DataPtr& mutable_data_ptr() const {
return storage_impl_->mutable_data_ptr();
}
const at::DataPtr& data_ptr() const {
return storage_impl_->data_ptr();
}
// Returns the previous data_ptr
at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) const {
return storage_impl_->set_data_ptr(std::move(data_ptr));
}
void set_data_ptr_noswap(at::DataPtr&& data_ptr) const {
return storage_impl_->set_data_ptr_noswap(std::move(data_ptr));
}
DeviceType device_type() const {
return storage_impl_->device_type();
}
at::Allocator* allocator() const {
return storage_impl_->allocator();
}
at::Device device() const {
return storage_impl_->device();
}
StorageImpl* unsafeReleaseStorageImpl() {
return storage_impl_.release();
}
StorageImpl* unsafeGetStorageImpl() const noexcept {
return storage_impl_.get();
}
c10::weak_intrusive_ptr<StorageImpl> getWeakStorageImpl() const {
return c10::weak_intrusive_ptr<StorageImpl>(storage_impl_);
}
operator bool() const {
return storage_impl_;
}
size_t use_count() const {
return storage_impl_.use_count();
}
inline bool unique() const {
return storage_impl_.unique();
}
bool is_alias_of(const Storage& other) const {
return (
storage_impl_ == other.storage_impl_ ||
isSharedStorageAlias(*this, other));
}
void UniqueStorageShareExternalPointer(
void* src,
size_t capacity,
DeleterFnPtr d = nullptr) {
if (!storage_impl_.unique()) {
TORCH_CHECK(
false,
"UniqueStorageShareExternalPointer can only be called when use_count == 1");
}
storage_impl_->UniqueStorageShareExternalPointer(src, capacity, d);
}
void UniqueStorageShareExternalPointer(
at::DataPtr&& data_ptr,
size_t capacity) {
if (!storage_impl_.unique()) {
TORCH_CHECK(
false,
"UniqueStorageShareExternalPointer can only be called when use_count == 1");
}
storage_impl_->UniqueStorageShareExternalPointer(
std::move(data_ptr), capacity);
}
protected:
c10::intrusive_ptr<StorageImpl> storage_impl_;
};
template <>
struct MaybeOwnedTraits<c10::Storage> {
using owned_type = c10::Storage;
using borrow_type = c10::Storage;
static borrow_type createBorrow(const owned_type& from) {
return borrow_type(borrow_type::unsafe_borrow_t{}, from);
}
static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
lhs.unsafeReleaseStorageImpl();
lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
}
static void destroyBorrow(borrow_type& toDestroy) {
toDestroy.unsafeReleaseStorageImpl(); // "leak" it, but it was already +0.
}
static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
return borrow;
}
static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
return &borrow;
}
static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
return true;
}
};
template <>
struct ExclusivelyOwnedTraits<c10::Storage> {
using repr_type = c10::Storage;
using pointer_type = c10::Storage*;
using const_pointer_type = const c10::Storage*;
static repr_type nullRepr() {
return c10::Storage();
}
template <class... Args>
static repr_type createInPlace(Args&&... args) {
return c10::Storage(std::forward<Args>(args)...);
}
static repr_type moveToRepr(c10::Storage&& x) {
return std::move(x);
}
static c10::Storage take(c10::Storage& x) {
return std::move(x);
}
static pointer_type getImpl(repr_type& x) {
return &x;
}
static const_pointer_type getImpl(const repr_type& x) {
return &x;
}
};
} // namespace c10

View File

@ -0,0 +1,330 @@
#pragma once
#include <c10/core/Allocator.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/SymInt.h>
#include <c10/core/impl/COW.h>
#include <c10/core/impl/COWDeleter.h>
#include <c10/core/impl/PyObjectSlot.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
#include <c10/util/UniqueVoidPtr.h>
#include <c10/util/intrusive_ptr.h>
#include <cstddef>
#include <utility>
namespace c10 {
C10_API void throwNullDataPtrError();
C10_API void warnDeprecatedDataPtr();
// A storage represents the underlying backing data buffer for a
// tensor. This concept was inherited from the original Torch7
// codebase; we'd kind of like to get rid of the concept
// (see https://github.com/pytorch/pytorch/issues/14797) but
// it's hard work and no one has gotten around to doing it.
//
// NB: storage is supposed to uniquely own a data pointer; e.g.,
// two non-null data pointers alias if and only if they are from
// the same storage. Technically you can violate this invariant
// (e.g., you can create a non-owning StorageImpl with at::from_blob)
// but a lot of things won't work correctly, including:
//
// - An ordinary deleter on such a storage is wrong, because normal deleters
// assume unique ownership, but if you have two storages at the same data,
// that implies there is some sort of shared ownership. So your deleter would
// have to actually be internally doing some sort of refcount thing
// - Deepcopy in Python side relies on storage equality and not data pointer
// equality; so if there are two separate storages pointing to the same data,
// the data will actually get duplicated in that case (one data ptr before,
// two data ptrs after)
// - Version counts won't work correctly, because we do all VC tracking at the
// level of storages (unless you explicitly disconnect the VC with detach);
// mutation because data pointers are the same are totally untracked
struct C10_API StorageImpl : public c10::intrusive_ptr_target {
public:
struct use_byte_size_t {};
StorageImpl(
use_byte_size_t /*use_byte_size*/,
SymInt size_bytes,
at::DataPtr data_ptr,
at::Allocator* allocator,
bool resizable)
: data_ptr_(std::move(data_ptr)),
size_bytes_(std::move(size_bytes)),
size_bytes_is_heap_allocated_(size_bytes_.is_heap_allocated()),
resizable_(resizable),
received_cuda_(false),
allocator_(allocator) {
if (resizable) {
TORCH_INTERNAL_ASSERT(
allocator_, "For resizable storage, allocator must be provided");
}
refresh_has_data_ptr_check();
}
StorageImpl(
use_byte_size_t /*use_byte_size*/,
const SymInt& size_bytes,
at::Allocator* allocator,
bool resizable)
: StorageImpl(
use_byte_size_t(),
size_bytes,
size_bytes.is_heap_allocated()
? allocator->allocate(0)
: allocator->allocate(size_bytes.as_int_unchecked()),
allocator,
resizable) {}
StorageImpl& operator=(StorageImpl&& other) = delete;
StorageImpl& operator=(const StorageImpl&) = delete;
StorageImpl() = delete;
StorageImpl(StorageImpl&& other) = delete;
StorageImpl(const StorageImpl&) = delete;
~StorageImpl() override = default;
void reset() {
data_ptr_.clear();
size_bytes_ = 0;
size_bytes_is_heap_allocated_ = false;
}
// Destructor doesn't call release_resources because it's
// unnecessary; don't forget to change that if needed!
void release_resources() override {
data_ptr_.clear();
}
size_t nbytes() const {
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive
TORCH_CHECK(!size_bytes_is_heap_allocated_);
return size_bytes_.as_int_unchecked();
}
SymInt sym_nbytes() const {
return size_bytes_;
}
// TODO: remove later
void set_nbytes(size_t size_bytes) {
size_bytes_ = static_cast<int64_t>(size_bytes);
size_bytes_is_heap_allocated_ = false;
}
void set_nbytes(c10::SymInt size_bytes) {
size_bytes_ = std::move(size_bytes);
}
bool resizable() const {
return resizable_;
}
const at::DataPtr& data_ptr() const {
return data_ptr_;
}
at::DataPtr& mutable_data_ptr() {
if (C10_UNLIKELY(has_data_ptr_check_)) {
if (throw_on_mutable_data_ptr_) {
throwNullDataPtrError();
}
if (warn_deprecated_on_mutable_data_ptr_) {
warnDeprecatedDataPtr();
}
maybe_materialize_cow();
}
return data_ptr_;
}
// Returns the data_ptr. Bypasses all checks.
at::DataPtr& _mutable_data_ptr_no_checks() {
return data_ptr_;
}
// Returns the previous data_ptr
at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) {
// We need to materialize the old COW DataPtr because it is
// being returned as mutable.
maybe_materialize_cow();
return set_data_ptr_no_materialize_cow(std::move(data_ptr));
}
void set_data_ptr_noswap(at::DataPtr&& data_ptr) {
data_ptr_ = std::move(data_ptr);
refresh_has_data_ptr_check();
}
const void* data() const {
return data_ptr_.get();
}
void* mutable_data() {
if (C10_UNLIKELY(has_data_ptr_check_)) {
if (throw_on_mutable_data_ptr_) {
throwNullDataPtrError();
}
if (warn_deprecated_on_mutable_data_ptr_) {
warnDeprecatedDataPtr();
}
maybe_materialize_cow();
}
return data_ptr_.mutable_get();
}
at::DeviceType device_type() const {
return data_ptr_.device().type();
}
at::Allocator* allocator() {
return allocator_;
}
const at::Allocator* allocator() const {
return allocator_;
}
// You generally shouldn't use this method, but it is occasionally
// useful if you want to override how a tensor will be reallocated,
// after it was already allocated (and its initial allocator was
// set)
void set_allocator(at::Allocator* allocator) {
allocator_ = allocator;
}
Device device() const {
return data_ptr_.device();
}
void set_resizable(bool resizable) {
if (resizable) {
// We need an allocator to be resizable
AT_ASSERT(allocator_);
}
resizable_ = resizable;
}
/**
* Can only be called when use_count is 1
*/
void UniqueStorageShareExternalPointer(
void* src,
size_t size_bytes,
DeleterFnPtr d = nullptr) {
UniqueStorageShareExternalPointer(
at::DataPtr(src, src, d, data_ptr_.device()), size_bytes);
}
/**
* Can only be called when use_count is 1
*/
void UniqueStorageShareExternalPointer(
at::DataPtr&& data_ptr,
size_t size_bytes) {
data_ptr_ = std::move(data_ptr);
size_bytes_ = static_cast<int64_t>(size_bytes);
size_bytes_is_heap_allocated_ = false;
allocator_ = nullptr;
resizable_ = false;
}
// This method can be used only after storage construction and cannot be used
// to modify storage status
void set_received_cuda(bool received_cuda) {
received_cuda_ = received_cuda;
}
bool received_cuda() {
return received_cuda_;
}
impl::PyObjectSlot* pyobj_slot() {
return &pyobj_slot_;
}
const impl::PyObjectSlot* pyobj_slot() const {
return &pyobj_slot_;
}
void set_throw_on_mutable_data_ptr() {
throw_on_mutable_data_ptr_ = true;
refresh_has_data_ptr_check();
}
void set_warn_deprecated_on_mutable_data_ptr() {
warn_deprecated_on_mutable_data_ptr_ = true;
refresh_has_data_ptr_check();
}
protected:
// materialize_cow_storage needs to call set_data_ptr_no_materlize_cow
friend void c10::impl::cow::materialize_cow_storage(StorageImpl& storage);
// Returns the previous data_ptr. If the old data_ptr was COW,
// this avoids materializing it
at::DataPtr set_data_ptr_no_materialize_cow(at::DataPtr&& data_ptr) {
at::DataPtr old_data_ptr(std::move(data_ptr_));
data_ptr_ = std::move(data_ptr);
refresh_has_data_ptr_check();
return old_data_ptr;
}
private:
void refresh_has_data_ptr_check() {
has_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_ ||
warn_deprecated_on_mutable_data_ptr_;
}
inline bool is_cow() const {
return c10::impl::cow::is_cow_data_ptr(data_ptr_);
}
// Triggers a copy if this is a copy-on-write tensor.
void maybe_materialize_cow() {
if (is_cow()) {
impl::cow::materialize_cow_storage(*this);
}
}
DataPtr data_ptr_;
SymInt size_bytes_;
bool size_bytes_is_heap_allocated_;
bool resizable_;
// Identifies that Storage was received from another process and doesn't have
// local to process cuda memory allocation
bool received_cuda_;
// All special checks in data/data_ptr calls are guarded behind this single
// boolean. This is for performance: .data/.data_ptr calls are commonly in the
// hot-path.
bool has_data_ptr_check_ = false;
// If we should throw when mutable_data_ptr() or mutable_data() is called.
bool throw_on_mutable_data_ptr_ = false;
// If we warn when mutable_data_ptr() or mutable_data() is called.
bool warn_deprecated_on_mutable_data_ptr_ = false;
Allocator* allocator_;
impl::PyObjectSlot pyobj_slot_;
};
// Declare StorageImpl create function pointer types.
using StorageImplCreateHelper = intrusive_ptr<StorageImpl> (*)(
StorageImpl::use_byte_size_t,
SymInt size_bytes,
DataPtr data_ptr,
Allocator* allocator,
bool resizable);
C10_API void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr);
C10_API StorageImplCreateHelper GetStorageImplCreate(DeviceType t);
C10_API c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
c10::StorageImpl::use_byte_size_t use_byte_size,
c10::SymInt size_bytes,
c10::DataPtr data_ptr,
c10::Allocator* allocator,
bool resizable,
std::optional<at::Device> device_opt);
} // namespace c10

View File

@ -0,0 +1,176 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <ostream>
namespace c10 {
/// An index representing a specific stream. A StreamId is not independently
/// meaningful without knowing the Device it is associated with; try to
/// use Stream rather than StreamId directly.
///
/// StreamIds are opaque; they are assigned by some DeviceType-specific
/// numbering system which is not visible to the user. HOWEVER, we
/// guarantee that StreamId 0 is always a valid stream, and corresponds
/// to some sort of "default" stream.
using StreamId = int64_t;
struct C10_API StreamData3 {
StreamId stream_id;
DeviceIndex device_index;
DeviceType device_type;
};
// NB: I decided not to call the above StreamIndex to avoid confusion with
// DeviceIndex. This way, you access device index with index(), and stream id
// with id()
/**
* A stream is a software mechanism used to synchronize launched kernels
* without requiring explicit synchronizations between kernels. The basic
* model is that every kernel launch is associated with a stream: every
* kernel on the same stream is implicitly synchronized so that if I launch
* kernels A and B on the same stream, A is guaranteed to finish before B
* launches. If I want B to run concurrently with A, I must schedule
* it on a different stream.
*
* The Stream class is a backend agnostic value class representing a stream
* which I may schedule a kernel on. Every stream is associated with a device,
* which is recorded in stream, which is used to avoid confusion about which
* device a stream refers to.
*
* Streams are explicitly thread-safe, in the sense that it is OK to pass
* a Stream from one thread to another, and kernels queued from two different
* threads will still get serialized appropriately. (Of course, the
* time when the kernels get queued is undetermined unless you synchronize
* host side ;)
*
* Stream does NOT have a default constructor. Streams are for expert
* users; if you want to use Streams, we're going to assume you know
* how to deal with C++ template error messages if you try to
* resize() a vector of Streams.
*
* Known instances of streams in backends:
*
* - cudaStream_t (CUDA)
* - hipStream_t (HIP)
* - cl_command_queue (OpenCL) (NB: Caffe2's existing OpenCL integration
* does NOT support command queues.)
*
* Because this class is device agnostic, it cannot provide backend-specific
* functionality (e.g., get the cudaStream_t of a CUDA stream.) There are
* wrapper classes which provide this functionality, e.g., CUDAStream.
*/
class C10_API Stream final {
private:
Device device_;
StreamId id_;
public:
enum Unsafe { UNSAFE };
enum Default { DEFAULT };
/// Unsafely construct a stream from a Device and a StreamId. In
/// general, only specific implementations of streams for a
/// backend should manufacture Stream directly in this way; other users
/// should use the provided APIs to get a stream. In particular,
/// we don't require backends to give any guarantees about non-zero
/// StreamIds; they are welcome to allocate in whatever way they like.
explicit Stream(Unsafe, Device device, StreamId id)
: device_(device), id_(id) {}
/// Construct the default stream of a Device. The default stream is
/// NOT the same as the current stream; default stream is a fixed stream
/// that never changes, whereas the current stream may be changed by
/// StreamGuard.
explicit Stream(Default, Device device) : device_(device), id_(0) {}
bool operator==(const Stream& other) const noexcept {
return this->device_ == other.device_ && this->id_ == other.id_;
}
bool operator!=(const Stream& other) const noexcept {
return !(*this == other);
}
Device device() const noexcept {
return device_;
}
DeviceType device_type() const noexcept {
return device_.type();
}
DeviceIndex device_index() const noexcept {
return device_.index();
}
StreamId id() const noexcept {
return id_;
}
// Enqueues a wait instruction in the stream's work queue.
// This instruction is a no-op unless the event is marked
// for recording. In that case the stream stops processing
// until the event is recorded.
template <typename T>
void wait(const T& event) const {
event.block(*this);
}
// Return whether all asynchronous work previously enqueued on this stream
// has completed running on the device.
bool query() const;
// Wait (by blocking the calling thread) until all asynchronous work enqueued
// on this stream has completed running on the device.
void synchronize() const;
// The purpose of this function is to more conveniently permit binding
// of Stream to and from Python. Without packing, I have to setup a whole
// class with two fields (device and stream id); with packing I can just
// store a single uint64_t.
//
// The particular way we pack streams into a uint64_t is considered an
// implementation detail and should not be relied upon.
uint64_t hash() const noexcept {
// Concat these together into a 64-bit integer
uint64_t bits = static_cast<uint64_t>(device_type()) << 56 |
static_cast<uint64_t>(device_index()) << 48 |
// Remove the sign extension part of the 64-bit address because
// the id might be used to hold a pointer.
(static_cast<uint64_t>(id()) & ((1ull << 48) - 1));
return bits;
}
struct StreamData3 pack3() const {
return {id(), device_index(), device_type()};
}
static Stream unpack3(
StreamId stream_id,
DeviceIndex device_index,
DeviceType device_type) {
TORCH_CHECK(isValidDeviceType(device_type));
return Stream(UNSAFE, Device(device_type, device_index), stream_id);
}
// I decided NOT to provide setters on this class, because really,
// why would you change the device of a stream? Just construct
// it correctly from the beginning dude.
};
C10_API std::ostream& operator<<(std::ostream& stream, const Stream& s);
} // namespace c10
namespace std {
template <>
struct hash<c10::Stream> {
size_t operator()(c10::Stream s) const noexcept {
return std::hash<uint64_t>{}(s.hash());
}
};
} // namespace std

View File

@ -0,0 +1,170 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/core/Stream.h>
#include <c10/core/impl/InlineStreamGuard.h>
#include <c10/core/impl/VirtualGuardImpl.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Optional.h>
namespace c10 {
/**
* A StreamGuard is an RAII class that changes the current device
* to the device corresponding to some stream, and changes the
* default stream on that device to be this stream.
*
* Use of StreamGuard is HIGHLY discouraged in operator definitions. In
* a single operator, you probably don't know enough about the global
* state of the world to profitably decide how to set streams. Let
* the caller handle this appropriately, and just use the current stream
* in your operator code.
*
* This StreamGuard does NOT have an uninitialized state; it is guaranteed
* to reset the stream and device on exit. If you are in a situation
* where you *might* want to setup a stream guard, see OptionalStreamGuard.
*/
struct StreamGuard {
/// No default constructor, see Note [Omitted default constructor from RAII]
explicit StreamGuard() = delete;
/// Set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream.
explicit StreamGuard(Stream stream) : guard_(stream) {}
/// Copy is disallowed
StreamGuard(const StreamGuard&) = delete;
StreamGuard& operator=(const StreamGuard&) = delete;
/// Move is disallowed, as StreamGuard does not have an uninitialized state,
/// which is required for moves on types with nontrivial destructors.
StreamGuard(StreamGuard&& other) = delete;
StreamGuard& operator=(StreamGuard&& other) = delete;
/// Resets the currently set stream to the original stream and
/// the currently set device to the original device. Then,
/// set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream.
///
/// NOTE: this implementation may skip some stream/device setting if
/// it can prove that it is unnecessary.
///
/// WARNING: reset_stream does NOT preserve previously set streams on
/// different devices. If you need to set streams on multiple devices
/// on , use MultiStreamGuard instead.
void reset_stream(Stream stream) {
guard_.reset_stream(stream);
}
/// Returns the stream that was set at the time the guard was constructed.
Stream original_stream() const {
return guard_.original_stream();
}
/// Returns the most recent stream that was set using this device guard,
/// either from construction, or via set_stream.
Stream current_stream() const {
return guard_.current_stream();
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device/reset_device/set_index.
Device current_device() const {
return guard_.current_device();
}
/// Returns the device that was set at the most recent reset_stream(),
/// or otherwise the device at construction time.
Device original_device() const {
return guard_.original_device();
}
private:
c10::impl::InlineStreamGuard<impl::VirtualGuardImpl> guard_;
};
/**
* An OptionalStreamGuard is an RAII class that sets a device to some value on
* initialization, and resets the device to its original value on destruction.
* See OptionalDeviceGuard for more guidance on how to use this class.
*/
struct OptionalStreamGuard {
/// Create an uninitialized guard.
explicit OptionalStreamGuard() = default;
/// Set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream.
explicit OptionalStreamGuard(Stream stream) : guard_(stream) {}
/// Set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream,
/// if the passed stream is not nullopt.
explicit OptionalStreamGuard(std::optional<Stream> stream_opt)
: guard_(stream_opt) {}
/// Copy is disallowed
OptionalStreamGuard(const OptionalStreamGuard&) = delete;
OptionalStreamGuard& operator=(const OptionalStreamGuard&) = delete;
// See Note [Move construction for RAII guards is tricky]
OptionalStreamGuard(OptionalStreamGuard&& other) = delete;
// See Note [Move assignment for RAII guards is tricky]
OptionalStreamGuard& operator=(OptionalStreamGuard&& other) = delete;
/// Resets the currently set stream to the original stream and
/// the currently set device to the original device. Then,
/// set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream.
/// Initializes the guard if it was not previously initialized.
void reset_stream(Stream stream) {
guard_.reset_stream(stream);
}
/// Returns the stream that was set at the time the guard was most recently
/// initialized, or nullopt if the guard is uninitialized.
std::optional<Stream> original_stream() const {
return guard_.original_stream();
}
/// Returns the most recent stream that was set using this stream guard,
/// either from construction, or via reset_stream, if the guard is
/// initialized, or nullopt if the guard is uninitialized.
std::optional<Stream> current_stream() const {
return guard_.current_stream();
}
/// Restore the original device and stream, resetting this guard to
/// uninitialized state.
void reset() {
guard_.reset();
}
private:
c10::impl::InlineOptionalStreamGuard<impl::VirtualGuardImpl> guard_{};
};
/**
* A MultiStreamGuard is an RAII class that sets the current streams of a set of
* devices all at once, and resets them to their original values on destruction.
*/
struct MultiStreamGuard {
/// Set the current streams to the passed streams on each of their respective
/// devices.
explicit MultiStreamGuard(ArrayRef<Stream> streams) : guard_(streams) {}
/// Copy is disallowed
MultiStreamGuard(const MultiStreamGuard&) = delete;
MultiStreamGuard& operator=(const MultiStreamGuard&) = delete;
// See Note [Move construction for RAII guards is tricky]
MultiStreamGuard(MultiStreamGuard&& other) = delete;
// See Note [Move assignment for RAII guards is tricky]
MultiStreamGuard& operator=(MultiStreamGuard&& other) = delete;
private:
c10::impl::InlineMultiStreamGuard<impl::VirtualGuardImpl> guard_;
};
} // namespace c10

View File

@ -0,0 +1,110 @@
#pragma once
#include <c10/core/SymNodeImpl.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h>
#include <cstdint>
#include <optional>
#include <ostream>
#include <utility>
namespace c10 {
class C10_API SymBool {
public:
/*implicit*/ SymBool(bool b) : data_(b){};
SymBool(SymNode ptr) : data_(false), ptr_(std::move(ptr)) {
TORCH_CHECK(ptr_->is_bool());
};
SymBool() : data_(false) {}
SymNodeImpl* toSymNodeImplUnowned() const {
return ptr_.get();
}
SymNodeImpl* release() && {
return std::move(ptr_).release();
}
// Only valid if is_heap_allocated()
SymNode toSymNodeImpl() const;
// Guaranteed to return a SymNode, wrapping using base if necessary
SymNode wrap_node(const SymNode& base) const;
bool expect_bool() const {
std::optional<bool> c = maybe_as_bool();
TORCH_CHECK(c.has_value());
return *c;
}
SymBool sym_and(const SymBool&) const;
SymBool sym_or(const SymBool&) const;
SymBool sym_not() const;
SymBool operator&(const SymBool& other) const {
return sym_and(other);
}
SymBool operator|(const SymBool& other) const {
return sym_or(other);
}
SymBool operator~() const {
return sym_not();
}
// Insert a guard for the bool to be its concrete value, and then return
// that value. Note that C++ comparison operations default to returning
// bool, so it's not so common to have to call this
bool guard_bool(const char* file, int64_t line) const;
bool expect_true(const char* file, int64_t line) const;
bool guard_size_oblivious(const char* file, int64_t line) const;
bool has_hint() const;
bool as_bool_unchecked() const {
return data_;
}
std::optional<bool> maybe_as_bool() const {
if (!is_heap_allocated()) {
return std::make_optional(data_);
}
return toSymNodeImplUnowned()->constant_bool();
}
bool is_heap_allocated() const {
return ptr_;
}
private:
// TODO: optimize to union
bool data_;
SymNode ptr_;
};
C10_API std::ostream& operator<<(std::ostream& os, const SymBool& s);
#define TORCH_SYM_CHECK(cond, ...) \
TORCH_CHECK((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
#define TORCH_SYM_INTERNAL_ASSERT(cond, ...) \
TORCH_INTERNAL_ASSERT((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
inline bool guard_size_oblivious(
bool b,
const char* file [[maybe_unused]],
int64_t line [[maybe_unused]]) {
return b;
}
inline bool guard_size_oblivious(
const c10::SymBool& b,
const char* file,
int64_t line) {
return b.guard_size_oblivious(file, line);
}
#define TORCH_GUARD_SIZE_OBLIVIOUS(cond) \
c10::guard_size_oblivious((cond), __FILE__, __LINE__)
} // namespace c10

View File

@ -0,0 +1,113 @@
#pragma once
#include <c10/core/SymBool.h>
#include <c10/core/SymNodeImpl.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h>
#include <cstdint>
#include <limits>
#include <ostream>
#include <utility>
namespace c10 {
// NB: this is actually double precision; we're using the Python naming here
class C10_API SymFloat {
public:
/*implicit*/ SymFloat(double d) : data_(d){};
SymFloat(SymNode ptr)
: data_(std::numeric_limits<double>::quiet_NaN()), ptr_(std::move(ptr)) {
TORCH_CHECK(ptr_->is_float());
};
SymFloat() : data_(0.0) {}
SymNodeImpl* toSymNodeImplUnowned() const {
return ptr_.get();
}
SymNodeImpl* release() && {
return std::move(ptr_).release();
}
// Only valid if is_symbolic()
SymNode toSymNodeImpl() const;
// Guaranteed to return a SymNode, wrapping using base if necessary
SymNode wrap_node(const SymNode& base) const;
double expect_float() const {
TORCH_CHECK(!is_symbolic());
return data_;
}
SymFloat operator+(const SymFloat&) const;
SymFloat operator-(const SymFloat&) const;
SymFloat operator*(const SymFloat&) const;
SymFloat operator/(const SymFloat&) const;
SymBool sym_eq(const SymFloat&) const;
SymBool sym_ne(const SymFloat&) const;
SymBool sym_lt(const SymFloat&) const;
SymBool sym_le(const SymFloat&) const;
SymBool sym_gt(const SymFloat&) const;
SymBool sym_ge(const SymFloat&) const;
bool operator==(const SymFloat& o) const {
return sym_eq(o).guard_bool(__FILE__, __LINE__);
}
bool operator!=(const SymFloat& o) const {
return sym_ne(o).guard_bool(__FILE__, __LINE__);
}
bool operator<(const SymFloat& o) const {
return sym_lt(o).guard_bool(__FILE__, __LINE__);
}
bool operator<=(const SymFloat& o) const {
return sym_le(o).guard_bool(__FILE__, __LINE__);
}
bool operator>(const SymFloat& o) const {
return sym_gt(o).guard_bool(__FILE__, __LINE__);
}
bool operator>=(const SymFloat& o) const {
return sym_ge(o).guard_bool(__FILE__, __LINE__);
}
SymFloat min(const SymFloat& sci) const;
SymFloat max(const SymFloat& sci) const;
// Need guidance on where to put this code
SymFloat sqrt() const;
// Insert a guard for the float to be its concrete value, and then return
// that value. This operation always works, even if the float is symbolic,
// so long as we know what the underlying value is. Don't blindly put this
// everywhere; you can cause overspecialization of PyTorch programs with
// this method.
//
// It should be called as guard_float(__FILE__, __LINE__). The file and line
// number can be used to diagnose overspecialization.
double guard_float(const char* file, int64_t line) const;
bool has_hint() const;
// N.B. It's important to keep this definition in the header
// as we expect if checks to be folded for mobile builds
// where `is_symbolic` is always false
C10_ALWAYS_INLINE bool is_symbolic() const {
return ptr_;
}
double as_float_unchecked() const {
return data_;
}
private:
// TODO: optimize to union
double data_;
SymNode ptr_;
};
C10_API std::ostream& operator<<(std::ostream& os, const SymFloat& s);
} // namespace c10

View File

@ -0,0 +1,424 @@
#pragma once
#include <c10/core/SymBool.h>
#include <c10/core/SymNodeImpl.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <cstdint>
#include <iterator>
#include <numeric>
#include <optional>
#include <ostream>
#include <type_traits>
namespace c10 {
class SymFloat;
// SymInt represents either a regular int64_t, or a symbolic integer
// (represented in a type erased way as SymNode). The intention is for SymInt
// to represent symbolic sizes that arise when doing shape computation in
// operator kernels. This allows for tracing through programs without baking in
// concrete sizes into kernel calls.
//
// SymInt has an API equivalent to int64_t. In particular, it is a value type.
// Internally, SymInt is represented in a clever packed way, so that it only
// occupies one word of space; but morally, it is a union between an int64_t
// and an intrusive pointer to SymNodeImpl.
//
// Invariant: the referenced SymNodeImpl is guaranteed to be a SymNode where
// is_int() returns true
class C10_API SymInt {
public:
enum Unchecked {
UNCHECKED,
};
/*implicit*/ SymInt(int64_t d) : data_(d) {
if (is_heap_allocated()) {
// Large negative number, heap allocate it
promote_to_negative();
}
};
SymInt() : data_(0) {}
SymInt(SymNode n);
// unchecked c-tor accepting raw `data_`
// One appropriate use for this is when you are constructing a symint
// in a situation where you know it is non-negative (or, if it is negative,
// the negative value is -1; i.e., not user controlled)
SymInt(Unchecked, int64_t d) : data_(d) {}
// TODO: these implementations are not optimal because they allocate a
// temporary and then use the move constructor/assignment
SymInt(const SymInt& s) : data_(0) {
if (s.is_heap_allocated()) {
*this = SymInt(s.toSymNode());
} else {
data_ = s.data_;
}
}
SymInt(SymInt&& s) noexcept : data_(s.data_) {
s.data_ = 0;
}
SymInt& operator=(const SymInt& s) {
if (this != &s) {
if (s.is_heap_allocated()) {
*this = SymInt(s.toSymNode());
} else {
data_ = s.data_;
}
}
return *this;
}
SymInt& operator=(SymInt&& s) noexcept {
if (this != &s) {
release_(); // release the current SymNode if any
data_ = s.data_;
if (s.is_heap_allocated())
s.data_ = 0;
};
return *this;
}
SymNodeImpl* toSymNodeImplUnowned() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_heap_allocated());
uint64_t unextended_bits = static_cast<uint64_t>(data_) & ~MASK;
uint64_t sign_bit_mask = 1ULL << (62 - 1);
// https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c
uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask;
return static_cast<SymNodeImpl*>(
// NOLINTNEXTLINE(performance-no-int-to-ptr)
reinterpret_cast<void*>(static_cast<uintptr_t>(extended_bits)));
}
void release_() {
if (is_heap_allocated()) {
SymNode::reclaim(toSymNodeImplUnowned()); // steal
}
}
SymNodeImpl* release() && {
#ifndef C10_MOBILE
TORCH_INTERNAL_ASSERT(is_heap_allocated());
auto* r = toSymNodeImplUnowned();
data_ = 0; // transfer ownership
return r;
#else
TORCH_INTERNAL_ASSERT(false);
#endif
}
// Only valid if is_heap_allocated()
SymNode toSymNode() const;
// Guaranteed to return a SymNode, wrapping using base if necessary
SymNode wrap_node(const SymNode& base) const;
~SymInt() {
release_();
}
// Require the int to be non-symbolic, and if it is symbolic raise an
// error. This is safe to use for C++ code that doesn't work for symbolic
// shapes, and you don't have time to fix it immediately, as if we
// try to trigger the path in C++ you'll appropriately get an error
int64_t expect_int() const {
if (auto r = maybe_as_int()) {
return *r;
}
TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(
false, "when unpacking SymInt, expected int but got ", *this);
}
// Test if we have a hint for this int (e.g., guard_int would work).
// Most of the time this is true; it is only false when you have
// an unbacked SymInt.
bool has_hint() const;
// Insert a guard for the int to be its concrete value, and then return
// that value. This operation always works, even if the int is symbolic,
// so long as we know what the underlying value is (e.g., this won't work
// if you call it on the size of nonzero output). Don't blindly put this
// everywhere; you can cause overspecialization of PyTorch programs with
// this method.
//
// It should be called as guard_int(__FILE__, __LINE__). The file and line
// number can be used to diagnose overspecialization.
int64_t guard_int(const char* file, int64_t line) const;
// Insert a guard that this SymInt must be size-like, returning true if
// the integer actually is >= 0. Unlike manually performing a >= 0 test,
// if the SymInt in question is an unbacked SymInt (or, potentially in the
// future, if it contains unbacked SymInts), we will also treat the
// unbacked SymInt as statically testing >= 2 (which will prevent us from
// choking on, e.g., contiguity checks.)
bool expect_size(const char* file, int64_t line) const;
// Distinguish actual symbolic values from constants stored on the heap
bool is_symbolic() const {
return is_heap_allocated() &&
!toSymNodeImplUnowned()->constant_int().has_value();
}
// N.B. It's important to keep this definition in the header
// as we expect if checks to be folded for mobile builds
// where `is_heap_allocated` is always false and optimize dead code paths
C10_ALWAYS_INLINE bool is_heap_allocated() const {
#ifdef C10_MOBILE
return false;
#else
return !check_range(data_);
#endif
}
SymInt operator+(const SymInt& sci) const;
SymInt operator-(const SymInt& sci) const;
SymInt operator*(const SymInt& sci) const;
SymInt operator/(const SymInt& sci) const;
SymInt operator%(const SymInt& sci) const;
void operator*=(const SymInt& sci);
void operator+=(const SymInt& sci);
void operator/=(const SymInt& sci);
SymInt clone() const;
SymBool sym_eq(const SymInt&) const;
SymBool sym_ne(const SymInt&) const;
SymBool sym_lt(const SymInt&) const;
SymBool sym_le(const SymInt&) const;
SymBool sym_gt(const SymInt&) const;
SymBool sym_ge(const SymInt&) const;
bool operator==(const SymInt& o) const {
return sym_eq(o).guard_bool(__FILE__, __LINE__);
}
bool operator!=(const SymInt& o) const {
return sym_ne(o).guard_bool(__FILE__, __LINE__);
}
bool operator<(const SymInt& o) const {
return sym_lt(o).guard_bool(__FILE__, __LINE__);
}
bool operator<=(const SymInt& o) const {
return sym_le(o).guard_bool(__FILE__, __LINE__);
}
bool operator>(const SymInt& o) const {
return sym_gt(o).guard_bool(__FILE__, __LINE__);
}
bool operator>=(const SymInt& o) const {
return sym_ge(o).guard_bool(__FILE__, __LINE__);
}
SymInt min(const SymInt& sci) const;
SymInt max(const SymInt& sci) const;
// If both are symbolic, this checks if
// they share the same node.
// If both are not symbolic this just checks normal equality.
bool is_same(const SymInt& other) const;
operator SymFloat() const;
// Don't use this. Prefer maybe_as_int instead
int64_t as_int_unchecked() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_heap_allocated());
return data_;
}
std::optional<int64_t> maybe_as_int() const {
if (!is_heap_allocated()) {
return std::make_optional(data_);
}
auto* node = toSymNodeImplUnowned();
if (auto c = node->constant_int()) {
return c;
}
return node->maybe_as_int();
}
// Return whether the integer is directly coercible to a SymInt
// without requiring heap allocation. You don't need to use this
// to check if you can pass an integer to SymInt; this is guaranteed
// to work (it just might heap allocate!)
static bool check_range(int64_t i) {
return i > MAX_UNREPRESENTABLE_INT;
}
// Return the min representable integer as a SymInt without
// heap allocation. For quantities that count bytes (or larger),
// this is still much larger than you need, so you may consider
// using this as a more efficient version of MIN_INT
static constexpr int64_t min_representable_int() {
return MAX_UNREPRESENTABLE_INT + 1;
}
private:
void promote_to_negative();
// Constraints on the internal representation:
//
// - Should represent positive and small negative ints
// - No conversion necessary for operations on ints
// - Must represent valid 64-bit pointers
// - Is symbolic test should be FAST (two arithmetic instructions is too
// much).
// This code being a hotpath is based on Strobelight profiles of
// is_heap_allocated(). FB only: https://fburl.com/strobelight/5l50ncxd
// (you will need to change the time window).
//
// So, the scheme is to reserve large negative numbers (assuming
// two's complement):
//
// - 0b0.... means we are a positive int
// - 0b11... means we are a small negative int
// - 0b10... means we are are a pointer. This means that
// [-2^63, -2^62-1] are not representable as ints.
// We don't actually need all of this space as on x86_64
// as the top 16bits aren't used for anything
static constexpr uint64_t MASK = 1ULL << 63 | 1ULL << 62 | 1ULL << 61;
static constexpr uint64_t IS_SYM = 1ULL << 63 | 1ULL << 61;
// We must manually translate the bit pattern test into a greater
// than test because compiler doesn't figure it out:
// https://godbolt.org/z/356aferaW
static constexpr int64_t MAX_UNREPRESENTABLE_INT =
-1LL & static_cast<int64_t>(~(1ULL << 62));
int64_t data_;
};
/// Sum of a list of SymInt; accumulates into the c10::SymInt expression
template <
typename C,
typename std::enable_if_t<
std::is_same_v<typename C::value_type, c10::SymInt>,
int> = 0>
inline c10::SymInt multiply_integers(const C& container) {
return std::accumulate(
container.begin(),
container.end(),
c10::SymInt(1),
[](const c10::SymInt& a, const c10::SymInt& b) { return a * b; });
}
template <
typename Iter,
typename = std::enable_if_t<std::is_same_v<
typename std::iterator_traits<Iter>::value_type,
c10::SymInt>>>
inline c10::SymInt multiply_integers(Iter begin, Iter end) {
return std::accumulate(
begin,
end,
c10::SymInt(1),
[](const c10::SymInt& a, const c10::SymInt& b) { return a * b; });
}
#define DECLARE_SYMINT_OP_INTONLY(scalar_t, RetTy) \
C10_API RetTy operator%(const SymInt& a, scalar_t b); \
C10_API RetTy operator%(scalar_t a, const SymInt& b);
#define DECLARE_SYMINT_OP(scalar_t, RetTy) \
C10_API RetTy operator+(const SymInt& a, scalar_t b); \
C10_API RetTy operator-(const SymInt& a, scalar_t b); \
C10_API RetTy operator*(const SymInt& a, scalar_t b); \
C10_API RetTy operator/(const SymInt& a, scalar_t b); \
C10_API RetTy operator+(scalar_t a, const SymInt& b); \
C10_API RetTy operator-(scalar_t a, const SymInt& b); \
C10_API RetTy operator*(scalar_t a, const SymInt& b); \
C10_API RetTy operator/(scalar_t a, const SymInt& b); \
C10_API bool operator==(const SymInt& a, scalar_t b); \
C10_API bool operator!=(const SymInt& a, scalar_t b); \
C10_API bool operator<(const SymInt& a, scalar_t b); \
C10_API bool operator<=(const SymInt& a, scalar_t b); \
C10_API bool operator>(const SymInt& a, scalar_t b); \
C10_API bool operator>=(const SymInt& a, scalar_t b); \
C10_API bool operator==(scalar_t a, const SymInt& b); \
C10_API bool operator!=(scalar_t a, const SymInt& b); \
C10_API bool operator<(scalar_t a, const SymInt& b); \
C10_API bool operator<=(scalar_t a, const SymInt& b); \
C10_API bool operator>(scalar_t a, const SymInt& b); \
C10_API bool operator>=(scalar_t a, const SymInt& b);
DECLARE_SYMINT_OP_INTONLY(int64_t, SymInt)
DECLARE_SYMINT_OP_INTONLY(int32_t, SymInt)
DECLARE_SYMINT_OP_INTONLY(uint64_t, SymInt)
DECLARE_SYMINT_OP_INTONLY(uint32_t, SymInt)
DECLARE_SYMINT_OP(int64_t, SymInt)
DECLARE_SYMINT_OP(int32_t, SymInt) // make sure constants work
DECLARE_SYMINT_OP(uint64_t, SymInt)
DECLARE_SYMINT_OP(uint32_t, SymInt)
DECLARE_SYMINT_OP(double, SymFloat)
DECLARE_SYMINT_OP(float, SymFloat) // just for completeness
// On OSX size_t is different than uint64_t so we have to
// define it separately
#if defined(__APPLE__)
DECLARE_SYMINT_OP_INTONLY(size_t, SymInt)
DECLARE_SYMINT_OP(size_t, SymInt)
#endif
#undef DECLARE_SYMINT_OP
C10_API std::ostream& operator<<(std::ostream& os, const SymInt& s);
C10_API SymInt operator-(const SymInt& s);
inline bool sym_eq(int64_t a, int64_t b) {
return a == b;
}
inline SymBool sym_eq(const SymInt& a, const SymInt& b) {
return a.sym_eq(b);
}
inline bool sym_ne(int64_t a, int64_t b) {
return a != b;
}
inline SymBool sym_ne(const SymInt& a, const SymInt& b) {
return a.sym_ne(b);
}
inline bool sym_lt(int64_t a, int64_t b) {
return a < b;
}
inline SymBool sym_lt(const SymInt& a, const SymInt& b) {
return a.sym_lt(b);
}
inline bool sym_le(int64_t a, int64_t b) {
return a <= b;
}
inline SymBool sym_le(const SymInt& a, const SymInt& b) {
return a.sym_le(b);
}
inline bool sym_gt(int64_t a, int64_t b) {
return a > b;
}
inline SymBool sym_gt(const SymInt& a, const SymInt& b) {
return a.sym_gt(b);
}
inline bool sym_ge(int64_t a, int64_t b) {
return a >= b;
}
inline SymBool sym_ge(const SymInt& a, const SymInt& b) {
return a.sym_ge(b);
}
inline bool definitely_true(
const c10::SymBool& b,
const char* file,
int64_t line) {
return b.has_hint() && b.guard_bool(file, line);
}
} // namespace c10

View File

@ -0,0 +1,89 @@
#pragma once
#include <c10/core/SymInt.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/DimVector.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <cstdint>
#include <optional>
namespace c10 {
using SymIntArrayRef = ArrayRef<SymInt>;
inline at::IntArrayRef asIntArrayRefUnchecked(c10::SymIntArrayRef ar) {
return IntArrayRef(reinterpret_cast<const int64_t*>(ar.data()), ar.size());
}
// TODO: a SymIntArrayRef containing a heap allocated large negative integer
// can actually technically be converted to an IntArrayRef... but not with
// the non-owning API we have here. We can't reinterpet cast; we have to
// allocate another buffer and write the integers into it. If you need it,
// we can do it. But I don't think you need it.
inline std::optional<at::IntArrayRef> asIntArrayRefSlowOpt(
c10::SymIntArrayRef ar) {
for (const c10::SymInt& sci : ar) {
if (sci.is_heap_allocated()) {
return std::nullopt;
}
}
return {asIntArrayRefUnchecked(ar)};
}
inline at::IntArrayRef asIntArrayRefSlow(
c10::SymIntArrayRef ar,
const char* file,
int64_t line) {
for (const c10::SymInt& sci : ar) {
TORCH_CHECK(
!sci.is_heap_allocated(),
file,
":",
line,
": SymIntArrayRef expected to contain only concrete integers");
}
return asIntArrayRefUnchecked(ar);
}
// Even slower than asIntArrayRefSlow, as it forces an allocation for a
// destination int, BUT it is able to force specialization (it never errors)
inline c10::DimVector asIntArrayRefSlowAlloc(
c10::SymIntArrayRef ar,
const char* file,
int64_t line) {
c10::DimVector res(ar.size(), 0);
for (const auto i : c10::irange(ar.size())) {
res[i] = ar[i].guard_int(file, line);
}
return res;
}
#define C10_AS_INTARRAYREF_SLOW(a) c10::asIntArrayRefSlow(a, __FILE__, __LINE__)
#define C10_AS_INTARRAYREF_SLOW_ALLOC(a) \
c10::asIntArrayRefSlowAlloc(a, __FILE__, __LINE__)
// Prefer using a more semantic constructor, like
// fromIntArrayRefKnownNonNegative
inline SymIntArrayRef fromIntArrayRefUnchecked(IntArrayRef array_ref) {
return SymIntArrayRef(
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
}
inline SymIntArrayRef fromIntArrayRefKnownNonNegative(IntArrayRef array_ref) {
return fromIntArrayRefUnchecked(array_ref);
}
inline SymIntArrayRef fromIntArrayRefSlow(IntArrayRef array_ref) {
for (long i : array_ref) {
TORCH_CHECK(
SymInt::check_range(i),
"IntArrayRef contains an int that cannot be represented as a SymInt: ",
i);
}
return SymIntArrayRef(
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
}
} // namespace c10

View File

@ -0,0 +1,242 @@
#pragma once
#include <c10/macros/Export.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h>
#include <cstdint>
#include <optional>
#include <ostream>
#include <string>
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
namespace c10 {
class SymNodeImpl;
using SymNode = c10::intrusive_ptr<SymNodeImpl>;
// When you add a method, you also need to edit
// torch/csrc/jit/python/init.cpp
// torch/csrc/utils/python_symnode.h
// c10/core/ConstantSymNodeImpl.h
class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
public:
~SymNodeImpl() override = default;
template <typename T>
c10::intrusive_ptr<T> dyn_cast() const {
return c10::intrusive_ptr<T>::reclaim_copy(dynamic_cast<T*>(this));
}
// these could be pure virtual when we implement LTC versions
virtual bool is_int() {
TORCH_CHECK(false, "NYI");
}
virtual bool is_bool() {
TORCH_CHECK(false, "NYI");
}
virtual bool is_float() {
TORCH_CHECK(false, "NYI");
}
virtual bool is_nested_int() const {
return false;
}
virtual SymNode add(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode sub(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode mul(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
// NB: legacy, prefer float_truediv or int_truediv
virtual SymNode truediv(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode float_truediv(const SymNode& other) {
return truediv(other);
}
virtual SymNode int_truediv(const SymNode& other) {
return truediv(other);
}
// NB: legacy, prefer float_pow or pow_by_natural
virtual SymNode pow(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode float_pow(const SymNode& other) {
return pow(other);
}
virtual SymNode pow_by_natural(const SymNode& other) {
return pow(other);
}
// NB: legacy, prefer int_floordiv
virtual SymNode floordiv(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode int_floordiv(const SymNode& other) {
return floordiv(other);
}
virtual SymNode mod(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode eq(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode ne(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode gt(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode lt(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode le(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode ge(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode ceil() {
TORCH_CHECK(false, "NYI");
}
virtual SymNode floor() {
TORCH_CHECK(false, "NYI");
}
virtual SymNode neg() {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_min(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_max(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_or(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_and(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_not() {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_ite(const SymNode& then_val, const SymNode& else_val) {
TORCH_CHECK(false, "NYI");
};
// NB: self is ignored here, only the arguments are used
virtual SymNode is_contiguous(
ArrayRef<SymNode> sizes,
ArrayRef<SymNode> strides) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode is_channels_last_contiguous_2d(
ArrayRef<SymNode> sizes,
ArrayRef<SymNode> strides) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode is_channels_last_contiguous_3d(
ArrayRef<SymNode> sizes,
ArrayRef<SymNode> strides) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode is_channels_last_strides_2d(
ArrayRef<SymNode> sizes,
ArrayRef<SymNode> strides) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode is_channels_last_strides_3d(
ArrayRef<SymNode> sizes,
ArrayRef<SymNode> strides) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode is_non_overlapping_and_dense(
ArrayRef<SymNode> sizes,
ArrayRef<SymNode> strides) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode clone() {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_float() {
TORCH_CHECK(false, "NYI");
}
virtual SymNode wrap_int(int64_t num) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode wrap_float(double num) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode wrap_bool(bool num) {
TORCH_CHECK(false, "NYI");
};
virtual int64_t guard_int(const char* file, int64_t line) {
TORCH_CHECK(false, "NYI");
};
virtual bool guard_bool(const char* file, int64_t line) {
TORCH_CHECK(false, "NYI");
};
virtual double guard_float(const char* file, int64_t line) {
TORCH_CHECK(false, "NYI");
};
virtual bool guard_size_oblivious(const char* file, int64_t line) {
// No improvement for unbacked SymBools by default, replace this
// with a better implementation!
return guard_bool(file, line);
}
virtual bool expect_true(const char* file, int64_t line) {
// No improvement for unbacked SymBools by default, replace this
// with a better implementation!
return guard_bool(file, line);
};
virtual bool expect_size(const char* file, int64_t line) {
// No improvement for unbacked SymInts by default, replace this
// with a better implementation!
return ge(wrap_int(0))->guard_bool(file, line);
};
virtual int64_t int_() {
TORCH_CHECK(false, "NYI");
};
virtual bool bool_() {
TORCH_CHECK(false, "NYI");
};
virtual bool has_hint() {
TORCH_CHECK(false, "NYI");
};
virtual std::string str() {
TORCH_CHECK(false, "NYI");
};
virtual std::string _graph_repr() {
return str();
};
virtual std::optional<int64_t> nested_int() {
return std::nullopt;
}
virtual std::optional<int64_t> nested_int_coeff() {
return std::nullopt;
}
virtual std::optional<int64_t> constant_int() {
return std::nullopt;
}
virtual std::optional<bool> constant_bool() {
return std::nullopt;
}
virtual std::optional<int64_t> maybe_as_int() {
return std::nullopt;
}
virtual bool is_constant() {
return false;
}
virtual bool is_symbolic() {
return true;
}
std::ostream& operator<<(std::ostream& os) {
os << str();
return os;
}
};
} // namespace c10
C10_DIAGNOSTIC_POP()

View File

@ -0,0 +1,216 @@
#pragma once
#include <c10/core/SymBool.h>
#include <c10/core/SymInt.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/DimVector.h>
#include <atomic>
#include <cstdint>
#include <mutex>
#include <utility>
namespace c10 {
class C10_API SymbolicShapeMeta {
public:
// Basic metadata from which other quantities are derived
SymDimVector sizes_ = {0};
SymDimVector strides_ = {1};
SymInt storage_offset_ = 0;
bool strides_valid_ = true; // e.g. for sparse where there are no strides
SymbolicShapeMeta() = default;
SymbolicShapeMeta(const SymbolicShapeMeta& other);
SymbolicShapeMeta& operator=(const SymbolicShapeMeta& other) = delete;
SymbolicShapeMeta& operator=(SymbolicShapeMeta&& other) = delete;
void refresh_numel() {
// Non-const, don't need to hold mutables_ lock
available_.fetch_and(~numel_avail);
numel_ = 1;
}
void refresh_contiguous() {
// Non-const, don't need to hold mutables_ lock
available_.fetch_and(numel_avail);
is_contiguous_ = false;
is_channels_last_contiguous_ = false;
is_channels_last_3d_contiguous_ = false;
is_channels_last_ = false;
is_channels_last_3d_ = false;
is_non_overlapping_and_dense_ = false;
}
int64_t dim() const {
return static_cast<int64_t>(sizes_.size());
}
// Accessors for derived quantities, computed lazily on first access
bool has_numel() const {
return available_.load() & numel_avail;
}
bool has_is_contiguous() const {
return available_.load() & is_contiguous_avail;
}
bool has_is_channels_last_contiguous() const {
return available_.load() & is_channels_last_contiguous_avail;
}
bool has_is_channels_last_3d_contiguous() const {
return available_.load() & is_channels_last_3d_contiguous_avail;
}
bool has_is_channels_last() const {
return available_.load() & is_channels_last_avail;
}
bool has_is_channels_last_3d() const {
return available_.load() & is_channels_last_3d_avail;
}
bool has_is_non_overlapping_and_dense() const {
return available_.load() & is_non_overlapping_and_dense_avail;
}
// Accessors to cached derived properties
// DO NOT call with mutables_ lock held
const SymInt& numel() const {
if (C10_UNLIKELY(!has_numel())) {
init_numel();
}
return numel_;
}
const SymBool& is_contiguous() const {
if (C10_UNLIKELY(!has_is_contiguous())) {
init_is_contiguous();
}
return is_contiguous_;
}
const SymBool& is_channels_last_contiguous() const {
if (C10_UNLIKELY(!has_is_channels_last_contiguous())) {
init_is_channels_last_contiguous();
}
return is_channels_last_contiguous_;
}
const SymBool& is_channels_last_3d_contiguous() const {
if (C10_UNLIKELY(!has_is_channels_last_3d_contiguous())) {
init_is_channels_last_3d_contiguous();
}
return is_channels_last_3d_contiguous_;
}
const SymBool& is_channels_last() const {
if (C10_UNLIKELY(!has_is_channels_last())) {
init_is_channels_last();
}
return is_channels_last_;
}
const SymBool& is_channels_last_3d() const {
if (C10_UNLIKELY(!has_is_channels_last_3d())) {
init_is_channels_last_3d();
}
return is_channels_last_3d_;
}
const SymBool& is_non_overlapping_and_dense() const {
if (C10_UNLIKELY(!has_is_non_overlapping_and_dense())) {
init_is_non_overlapping_and_dense();
}
return is_non_overlapping_and_dense_;
}
// Assumptions so we can short-circuit computation
// NOTE: Don't need to lock mutables_ since these aren't const
void assume_contiguous(SymBool val = true) {
is_contiguous_ = std::move(val);
available_.fetch_or(is_contiguous_avail);
}
void assume_channels_last_contiguous(SymBool val = true) {
is_contiguous_ = std::move(val);
available_.fetch_or(is_channels_last_contiguous_avail);
}
void assume_channels_last_3d_contiguous(SymBool val = true) {
is_channels_last_3d_contiguous_ = std::move(val);
available_.fetch_or(is_channels_last_3d_contiguous_avail);
}
void assume_channels_last(SymBool val = true) {
is_channels_last_ = std::move(val);
available_.fetch_or(is_channels_last_avail);
}
void assume_channels_last_3d(SymBool val = true) {
is_channels_last_3d_ = std::move(val);
available_.fetch_or(is_channels_last_3d_avail);
}
void assume_non_overlapping_and_dense(SymBool val = true) {
is_non_overlapping_and_dense_ = std::move(val);
available_.fetch_or(is_non_overlapping_and_dense_avail);
}
private:
SymBool compute_contiguous() const;
SymBool compute_channels_last_contiguous_2d() const;
SymBool compute_channels_last_contiguous_3d() const;
SymBool compute_strides_like_channels_last_2d() const;
SymBool compute_strides_like_channels_last_3d() const;
SymBool compute_non_overlapping_and_dense() const;
// These are little wrappers over the real compute_ functions that
// can make use of other contiguity fields to short circuit.
// They need to be implemented separately for SymBool, as SymBool does
// not short circuit.
// TODO: should the SymBool cases avoid the short circuit? Need to reason
// if its correct, and reason if the simpler expressions are better for
// analysis (maybe not!)
SymBool compute_channels_last_contiguous_3d_dim5() const;
SymBool compute_channels_last_2d_dim5() const;
SymBool compute_channels_last_3d_dim5() const;
SymBool compute_is_non_overlapping_and_dense_dim4() const;
SymBool compute_is_non_overlapping_and_dense_dim5() const;
SymBool compute_is_non_overlapping_and_dense_anydim() const;
void init_numel() const;
void init_is_contiguous() const;
void init_is_channels_last_contiguous() const;
void init_is_channels_last_3d_contiguous() const;
void init_is_channels_last() const;
void init_is_channels_last_3d() const;
void init_is_non_overlapping_and_dense() const;
// NOTE: These only set if !has_foo()
void set_numel(SymInt val) const;
void set_is_contiguous(SymBool val) const;
void set_is_channels_last_contiguous(SymBool val) const;
void set_is_channels_last_3d_contiguous(SymBool val) const;
void set_is_channels_last(SymBool val) const;
void set_is_channels_last_3d(SymBool val) const;
void set_is_non_overlapping_and_dense(SymBool val) const;
// Lazily initialized variables, with the corresponding available_ flag
// indicating whether the value has been initialized
mutable std::atomic<int> available_{0};
enum avail {
numel_avail = 1 << 0,
is_contiguous_avail = 1 << 1,
is_channels_last_contiguous_avail = 1 << 2,
is_channels_last_3d_contiguous_avail = 1 << 3,
is_channels_last_avail = 1 << 4,
is_channels_last_3d_avail = 1 << 5,
is_non_overlapping_and_dense_avail = 1 << 6,
};
// Mutex to prevent races when initializing the variable from const accessors
mutable std::mutex mutables_;
mutable SymInt numel_ = 1;
mutable SymBool is_contiguous_{true};
mutable SymBool is_channels_last_contiguous_{false};
mutable SymBool is_channels_last_3d_contiguous_{false};
mutable SymBool is_channels_last_{false};
mutable SymBool is_channels_last_3d_{false};
mutable SymBool is_non_overlapping_and_dense_{true};
};
} // namespace c10

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,787 @@
#pragma once
#include <c10/core/Backend.h>
#include <c10/core/DefaultDtype.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/DispatchKey.h>
#include <c10/core/Layout.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/ScalarType.h>
#include <c10/core/ScalarTypeToTypeMeta.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <optional>
#include <cstdint>
#include <iosfwd>
#include <string>
#include <type_traits>
#include <utility>
namespace c10 {
DispatchKey computeDispatchKey(
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device);
inline ScalarType dtype_or_default(std::optional<ScalarType> dtype) {
return value_or_else(dtype, [] { return get_default_dtype_as_scalartype(); });
}
inline caffe2::TypeMeta dtype_or_default(
std::optional<caffe2::TypeMeta> dtype) {
return value_or_else(dtype, [] { return get_default_dtype(); });
}
inline Layout layout_or_default(std::optional<Layout> layout) {
return layout.value_or(kStrided);
}
inline Device device_or_default(std::optional<Device> device) {
return value_or_else(device, [] { return Device(kCPU); });
}
inline bool pinned_memory_or_default(std::optional<bool> pinned_memory) {
return pinned_memory.value_or(false);
}
/// A class to encapsulate construction axes of an Tensor. TensorOptions was
/// designed to support the Python style API for specifying construction options
/// on factory functions, e.g.,
///
/// torch.zeros(2, 3, dtype=torch.int32)
///
/// Because C++ doesn't natively support keyword arguments, there must be
/// another way of specifying keyword-like arguments. TensorOptions is a
/// builder class which can be used to construct this "dictionary" of keyword
/// arguments: functions which support TensorOptions conventionally take this
/// argument optionally as their last argument.
///
/// WARNING: In PyTorch, there are `torch::` variants of factory functions,
/// e.g., torch::zeros for at::zeros. These return Variables (while the
/// stock ATen functions return plain Tensors). If you mix these functions
/// up, you WILL BE SAD.
///
/// Rather than use the constructor of this class directly, you should prefer to
/// use the constructor functions, and then chain setter methods on top of them.
///
/// at::device(at::kCUDA).dtype(kInt)
/// at::dtype(at::kInt)
///
/// Additionally, anywhere a TensorOptions is expected, you can directly
/// pass at::kCUDA / at::kInt, and it will implicitly convert to a
/// TensorOptions.
///
/// Here are some recommended ways to create a 2x2 tensor of zeros
/// with certain properties. These all *implicitly* make use of
/// TensorOptions, even if they don't mention the class explicitly:
///
/// at::zeros({2,2}, at::kCUDA);
/// at::zeros({2,2}, at::kLong);
/// at::zeros({2,2}, at::device(at::kCUDA).dtype(at::kLong()));
/// at::zeros({2,2}, at::device({at::kCUDA, 1})); // place on device 1
/// at::zeros({2,2}, at::requires_grad());
///
/// NOTE [ TensorOptions Constructors ]
///
/// TensorOptions is like a dictionary with entries from the set:
/// {requires_grad, device, dtype, layout}, where each entry may be
/// unspecified (i.e., is optional). It is used to specify the properties of
/// tensors in many places both in C++ internal and API, e.g., tensor factory
/// methods like `at::empty({10}, options)`, tensor conversions like
/// `tensor.to(...)`, etc.
///
/// To provide a simple API that is consistent with Python, where one can do
/// `torch.empty(sizes, X)` with `X` being a `torch.device`, `torch.dtype`, or a
/// `torch.layout`, we want TensorOptions to be implicitly convertible from
/// `ScalarType dtype`, `Layout layout` and `Device device`. Therefore, we have
/// three implicit constructors from each of these three types.
///
/// This is sufficient for `ScalarType` and `Layout` as they are simple Enum
/// classes. However, `Device` is an ordinary class with implicit constructors
/// `Device(DeviceType, DeviceIndex = -1)` and `Device(std::string)` to be
/// consistent with Python API, where strings are treated as equivalent with a
/// `torch.device` object (e.g., "cuda:1" can be passed to everywhere a
/// `torch.device("cuda:1")` is accepted). To support the syntax
/// `at::empty({10}, {kCUDA, 1})` and `tensor.to(kCUDA)`, we need to make sure
/// that `TensorOptions` is implicitly constructible with any arguments that a
/// `Device` can constructed from. So we have,
///
/// /* implicit */ TensorOptions(T&& device) : TensorOptions() {
/// this->set_device(device);
/// }
///
/// template <typename... Args,
/// typename = std::enable_if_t<std::is_constructible<Device,
/// Args&&...>::value>>
/// /* implicit */ TensorOptions(Args&&... args)
/// : TensorOptions(Device(std::forward<Args>(args)...)) {}
///
///
/// But this will be problematic. Consider this: `TensorOptions({kCUDA, 1})`.
/// Compiler will complain about ambiguity between the copy constructor and the
/// `Device` constructor because `{kCUDA, 1}` can be converted to both a
/// `TensorOption` and a `Device`.
///
/// To get around this, we templatize the `Device` constructor. Since overload
/// resolution is done before template resolution, our problem is solved.
DispatchKey computeDispatchKey(
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device);
struct C10_API TensorOptions {
TensorOptions()
: requires_grad_(false),
pinned_memory_(false),
has_device_(false),
has_dtype_(false),
has_layout_(false),
has_requires_grad_(false),
has_pinned_memory_(false),
has_memory_format_(false) {}
/// Constructs a `TensorOptions` object with the given layout.
/* implicit */ TensorOptions(Layout layout) : TensorOptions() {
this->set_layout(layout);
}
/// Constructs a `TensorOptions` object with the given device.
/// See NOTE [ TensorOptions Constructors ] on why this is templatized.
template <
typename T,
typename = std::enable_if_t<std::is_same_v<std::decay_t<T>, Device>>>
/* implicit */ TensorOptions(T&& device) : TensorOptions() {
this->set_device(std::forward<T>(device));
}
/// Constructs a `TensorOptions` object from arguments allowed in `Device`
/// constructors.
///
/// See NOTE [ TensorOptions Constructors ].
///
/// NB: Ideally we only allow implicit constructors here. But there is no easy
/// way to detect them. So we have this one that allows explicit
/// constructors too.
template <
typename... Args,
typename = std::enable_if_t<std::is_constructible_v<Device, Args&&...>>>
/* implicit */ TensorOptions(Args&&... args)
: TensorOptions(Device(std::forward<Args>(args)...)) {}
/// Constructs a `TensorOptions` object with the given dtype.
/* implicit */ TensorOptions(caffe2::TypeMeta dtype) : TensorOptions() {
this->set_dtype(dtype);
}
/// legacy constructor to support ScalarType
/* implicit */ TensorOptions(ScalarType dtype) : TensorOptions() {
this->set_dtype(dtype);
}
/// Constructs a `TensorOptions` object with the given memory format.
/* implicit */ TensorOptions(MemoryFormat memory_format) : TensorOptions() {
set_memory_format(memory_format);
}
/// Return a copy of `TensorOptions` with `device` set to the given one, or
/// cleared if `device` is `nullopt`.
C10_NODISCARD TensorOptions
device(std::optional<Device> device) const noexcept {
TensorOptions r = *this;
r.set_device(device);
return r;
}
/// Return a copy of `TensorOptions` with `device` set to the given one.
/// (This overload ensures that variadic template std::optional constructor
/// for Device work correctly.)
template <typename... Args>
C10_NODISCARD TensorOptions device(Args&&... args) const noexcept {
return device(
std::optional<Device>(std::in_place, std::forward<Args>(args)...));
}
/// Return a copy of `TensorOptions`, but with device set to CUDA, and the
/// device index set to the given one.
///
/// TODO: This function encourages bad behavior (assuming CUDA is
/// the only device that matters). Get rid of it / rename it.
C10_NODISCARD TensorOptions
device_index(c10::DeviceIndex device_index) const noexcept {
return device(Device::Type::CUDA, device_index);
}
/// Return a copy of `TensorOptions` with `dtype` set to the given one.
C10_NODISCARD TensorOptions
dtype(std::optional<caffe2::TypeMeta> dtype) const noexcept {
TensorOptions r = *this;
r.set_dtype(dtype);
return r;
}
// legacy function to support ScalarType
C10_NODISCARD TensorOptions
dtype(std::optional<ScalarType> dtype) const noexcept {
TensorOptions r = *this;
r.set_dtype(dtype);
return r;
}
// Since dtype is taken...
template <typename T>
TensorOptions& dtype() {
dtype_ = caffe2::TypeMeta::Make<T>();
has_dtype_ = true;
return *this;
}
/// Sets the layout of the `TensorOptions`.
C10_NODISCARD TensorOptions
layout(std::optional<Layout> layout) const noexcept {
TensorOptions r = *this;
r.set_layout(layout);
return r;
}
/// Sets the `requires_grad` property of the `TensorOptions`.
C10_NODISCARD TensorOptions
requires_grad(std::optional<bool> requires_grad) const noexcept {
TensorOptions r = *this;
r.set_requires_grad(requires_grad);
return r;
}
/// Sets the `pinned_memory` property on the `TensorOptions`.
C10_NODISCARD TensorOptions
pinned_memory(std::optional<bool> pinned_memory) const noexcept {
TensorOptions r = *this;
r.set_pinned_memory(pinned_memory);
return r;
}
/// Sets the `memory_format` property on `TensorOptions`.
C10_NODISCARD TensorOptions
memory_format(std::optional<MemoryFormat> memory_format) const noexcept {
TensorOptions r = *this;
r.set_memory_format(memory_format);
return r;
}
/// Returns the device of the `TensorOptions`.
Device device() const noexcept {
return device_or_default(device_opt());
}
/// Returns whether the device is specified.
bool has_device() const noexcept {
return has_device_;
}
/// Returns the device of the `TensorOptions`, or `std::nullopt` if
/// device is not specified.
std::optional<Device> device_opt() const noexcept {
return has_device_ ? std::make_optional(device_) : std::nullopt;
}
/// Returns the device index of the `TensorOptions`.
c10::DeviceIndex device_index() const noexcept {
return device().index();
}
/// Returns the dtype of the `TensorOptions`.
caffe2::TypeMeta dtype() const noexcept {
return dtype_or_default(dtype_opt());
}
/// Returns whether the dtype is specified.
bool has_dtype() const noexcept {
return has_dtype_;
}
/// Returns the dtype of the `TensorOptions`, or `std::nullopt` if
/// device is not specified.
std::optional<caffe2::TypeMeta> dtype_opt() const noexcept {
return has_dtype_ ? std::make_optional(dtype_) : std::nullopt;
}
/// Returns the layout of the `TensorOptions`.
Layout layout() const noexcept {
return layout_or_default(layout_opt());
}
/// Returns whether the layout is specified.
bool has_layout() const noexcept {
return has_layout_;
}
/// Returns the layout of the `TensorOptions`, or `std::nullopt` if
/// layout is not specified.
std::optional<Layout> layout_opt() const noexcept {
return has_layout_ ? std::make_optional(layout_) : std::nullopt;
}
/// Returns the `requires_grad` property of the `TensorOptions`.
bool requires_grad() const noexcept {
return has_requires_grad_ ? requires_grad_ : false;
}
/// Returns whether the `requires_grad` is specified.
bool has_requires_grad() const noexcept {
return has_requires_grad_;
}
/// Returns the `requires_grad` property of the `TensorOptions`, or
/// `std::nullopt` if `requires_grad` is not specified.
std::optional<bool> requires_grad_opt() const noexcept {
return has_requires_grad_ ? std::make_optional(requires_grad_)
: std::nullopt;
}
/// Returns the `pinned_memory` property of the `TensorOptions`.
bool pinned_memory() const noexcept {
return pinned_memory_or_default(pinned_memory_opt());
}
/// Returns whether the `pinned_memory` is specified.
bool has_pinned_memory() const noexcept {
return has_pinned_memory_;
}
/// Returns if the layout is sparse
bool is_sparse() const {
return layout_ == c10::Layout::Sparse;
}
/// Returns if the layout is sparse CSR, deprecated, use
/// is_sparse_compressed() instead
bool is_sparse_csr() const {
return layout_ == c10::Layout::SparseCsr;
}
bool is_sparse_compressed() const {
return layout_ == c10::Layout::SparseCsr ||
layout_ == c10::Layout::SparseCsc ||
layout_ == c10::Layout::SparseBsr || layout_ == c10::Layout::SparseBsc;
}
// For compatibility with legacy tensor.type() comparisons
bool type_equal(const TensorOptions& other) const {
return computeDispatchKey() == other.computeDispatchKey() &&
typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype());
}
/// Returns the `pinned_memory` property of the `TensorOptions`, or
/// `std::nullopt` if `pinned_memory` is not specified.
std::optional<bool> pinned_memory_opt() const noexcept {
return has_pinned_memory_ ? std::make_optional(pinned_memory_)
: std::nullopt;
}
/// Returns whether the `memory_layout` is specified
bool has_memory_format() const noexcept {
return has_memory_format_;
}
// NB: memory_format() getter is PURPOSELY not defined, as the default
// behavior of memory_format varies from function to function.
/// Returns the `memory_layout` property of `TensorOptions, or
/// `std::nullopt` if `memory_format` is not specified.
std::optional<MemoryFormat> memory_format_opt() const noexcept {
return has_memory_format_ ? std::make_optional(memory_format_)
: std::nullopt;
}
// Resolves the ATen backend specified by the current construction axes.
// TODO: Deprecate this
Backend backend() const {
return at::dispatchKeyToBackend(computeDispatchKey());
}
/// Return the right-biased merge of two TensorOptions. This has the
/// effect of overwriting settings from self with specified options
/// of options.
///
/// NB: This merging operation does NOT respect device merges.
/// For example, if you device({kCUDA, 1}).merge_in(kCUDA)
/// you will get kCUDA in the end! Functions like Tensor.new_empty
/// ensure the right device is selected anyway by way of a
/// device guard.
///
TensorOptions merge_in(TensorOptions options) const noexcept {
TensorOptions merged = *this;
if (options.has_device())
merged.set_device(options.device_opt());
if (options.has_dtype())
merged.set_dtype(options.dtype_opt());
if (options.has_layout())
merged.set_layout(options.layout_opt());
// NB: requires grad is right biased; not a logical AND/OR!
if (options.has_requires_grad())
merged.set_requires_grad(options.requires_grad_opt());
if (options.has_pinned_memory())
merged.set_pinned_memory(options.pinned_memory_opt());
if (options.has_memory_format())
merged.set_memory_format(options.memory_format_opt());
return merged;
}
// TODO remove after TensorOptions rationalization
TensorOptions merge_memory_format(
std::optional<MemoryFormat> optional_memory_format) const noexcept {
TensorOptions merged = *this;
if (optional_memory_format.has_value()) {
merged.set_memory_format(*optional_memory_format);
}
return merged;
}
// INVARIANT: computeDispatchKey returns only the subset of dispatch keys for
// which dispatchKeyToBackend is injective, if it is defined at all (for
// the most part, this just means that this function never returns an
// Autograd key)
DispatchKey computeDispatchKey() const {
return c10::computeDispatchKey(
optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt());
}
private:
// These methods are currently private because I'm not sure if it's wise
// to actually publish them. They are methods because I need them in
// the constructor and the functional API implementation.
//
// If you really, really need it, you can make these public, but check if you
// couldn't just do what you need with the functional API. Similarly, these
// methods are not chainable, because if you wanted chaining, you probably
// want to use the functional API instead. (It's probably OK to make
// these chainable, because these functions are all explicitly annotated
// with a ref-qualifier, the trailing &, that makes them illegal to call
// on temporaries.)
/// Mutably set the device of `TensorOptions`.
void set_device(std::optional<Device> device) & noexcept {
if (device) {
device_ = *device;
has_device_ = true;
} else {
has_device_ = false;
}
}
/// Mutably set the dtype of `TensorOptions`.
void set_dtype(std::optional<caffe2::TypeMeta> dtype) & noexcept {
if (dtype) {
dtype_ = *dtype;
has_dtype_ = true;
} else {
has_dtype_ = false;
}
}
// legacy function to support ScalarType
void set_dtype(std::optional<ScalarType> dtype) & noexcept {
if (dtype) {
dtype_ = scalarTypeToTypeMeta(*dtype);
has_dtype_ = true;
} else {
has_dtype_ = false;
}
}
/// Mutably set the layout of `TensorOptions`.
void set_layout(std::optional<Layout> layout) & noexcept {
if (layout) {
layout_ = *layout;
has_layout_ = true;
} else {
has_layout_ = false;
}
}
/// Mutably set the `requires_grad` property of `TensorOptions`.
void set_requires_grad(std::optional<bool> requires_grad) & noexcept {
if (requires_grad) {
requires_grad_ = *requires_grad;
has_requires_grad_ = true;
} else {
has_requires_grad_ = false;
}
}
/// Mutably set the `pinned_memory` property of `TensorOptions`.
void set_pinned_memory(std::optional<bool> pinned_memory) & noexcept {
if (pinned_memory) {
pinned_memory_ = *pinned_memory;
has_pinned_memory_ = true;
} else {
has_pinned_memory_ = false;
}
}
/// Mutably set the `memory_Format` property of `TensorOptions`.
void set_memory_format(std::optional<MemoryFormat> memory_format) & noexcept {
if (memory_format) {
memory_format_ = *memory_format;
has_memory_format_ = true;
} else {
has_memory_format_ = false;
}
}
// WARNING: If you edit TensorOptions to add more options, you
// may need to adjust the implementation of Tensor::options.
// The criteria for whether or not Tensor::options must be adjusted
// is whether or not the new option you added should preserved
// by functions such as empty_like(); if it should be preserved,
// you must adjust options().
//
// TODO: MemoryFormat is not implemented in this way
// NB: We didn't use std::optional here, because then we can't pack
// the has_***_ boolean fields.
Device device_ = at::kCPU; // 16-bit
caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make<float>(); // 16-bit
Layout layout_ = at::kStrided; // 8-bit
MemoryFormat memory_format_ = MemoryFormat::Contiguous; // 8-bit
// Bitmask required here to get this to fit inside 32 bits (or even 64 bits,
// for that matter)
bool requires_grad_ : 1;
bool pinned_memory_ : 1;
bool has_device_ : 1;
bool has_dtype_ : 1;
bool has_layout_ : 1;
bool has_requires_grad_ : 1;
bool has_pinned_memory_ : 1;
bool has_memory_format_ : 1;
};
// We should aspire to fit in one machine-size word; but a size greater than two
// words is too much. (We are doing terribly on 32-bit archs, where we require
// three machine size words to store tensor options. Eek!)
static_assert(
sizeof(TensorOptions) <= sizeof(int64_t) * 2,
"TensorOptions must fit in 128-bits");
/// Convenience function that returns a `TensorOptions` object with the `dtype`
/// set to the given one.
inline TensorOptions dtype(caffe2::TypeMeta dtype) {
return TensorOptions().dtype(dtype);
}
// legacy function to support ScalarType
inline TensorOptions dtype(ScalarType dtype) {
return TensorOptions().dtype(scalarTypeToTypeMeta(dtype));
}
/// Convenience function that returns a `TensorOptions` object with the `layout`
/// set to the given one.
inline TensorOptions layout(Layout layout) {
return TensorOptions().layout(layout);
}
/// Convenience function that returns a `TensorOptions` object with the `device`
/// set to the given one.
inline TensorOptions device(Device device) {
return TensorOptions().device(device);
}
/// Convenience function that returns a `TensorOptions` object with the
/// `device` set to CUDA and the `device_index` set to the given one.
inline TensorOptions device_index(c10::DeviceIndex device_index) {
return TensorOptions().device_index(device_index);
}
/// Convenience function that returns a `TensorOptions` object with the
/// `requires_grad` set to the given one.
inline TensorOptions requires_grad(bool requires_grad = true) {
return TensorOptions().requires_grad(requires_grad);
}
/// Convenience function that returns a `TensorOptions` object with the
/// `memory_format` set to the given one.
inline TensorOptions memory_format(MemoryFormat memory_format) {
return TensorOptions().memory_format(memory_format);
}
C10_API std::ostream& operator<<(
std::ostream& stream,
const TensorOptions& options);
template <typename T>
inline TensorOptions dtype() {
return dtype(caffe2::TypeMeta::Make<T>());
}
inline std::string toString(const TensorOptions& options) {
std::ostringstream stream;
stream << options;
return stream.str();
}
// This is intended to be a centralized location by which we can determine
// what an appropriate DispatchKey for a tensor is.
inline DispatchKey computeDispatchKey(
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device) {
const auto layout_ = layout_or_default(layout);
const auto device_ = device_or_default(device);
switch (layout_) {
case Layout::Jagged:
case Layout::Strided: {
const auto dtype_ = dtype_or_default(dtype);
switch (device_.type()) {
#define DO_CASE(device, _) \
case c10::DeviceType::device: { \
if (isQIntType(dtype_)) { \
return DispatchKey::Quantized##device; \
} \
return DispatchKey::device; \
}
C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused)
#undef DO_CASE
case c10::DeviceType::FPGA:
return DispatchKey::FPGA;
case c10::DeviceType::MAIA:
return DispatchKey::MAIA;
case c10::DeviceType::Vulkan:
return DispatchKey::Vulkan;
case c10::DeviceType::Metal:
return DispatchKey::Metal;
case c10::DeviceType::MKLDNN:
case c10::DeviceType::OPENGL:
case c10::DeviceType::OPENCL:
case c10::DeviceType::IDEEP:
TORCH_INTERNAL_ASSERT(
0,
"This is a grandfathered Caffe2 device type ",
device_.type(),
", it shouldn't ever convert to a DispatchKey. File a bug describing what you were doing if you think this is in error.");
default:
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"Unsupported device type for dense layout: ",
device_.type());
}
}
case Layout::Sparse:
switch (device_.type()) {
#define DO_CASE(device, _) \
case c10::DeviceType::device: { \
return DispatchKey::Sparse##device; \
}
C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused)
#undef DO_CASE
default:
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"Unsupported device type for sparse layout: ",
device_.type());
}
case Layout::Mkldnn:
switch (device_.type()) {
case c10::DeviceType::CPU:
return DispatchKey::MkldnnCPU;
default:
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"Unsupported device type for mkldnn layout: ",
device_.type());
}
case Layout::SparseCsr:
case Layout::SparseCsc:
case Layout::SparseBsr:
case Layout::SparseBsc:
switch (device_.type()) {
#define DO_CASE(device, _) \
case c10::DeviceType::device: { \
return DispatchKey::SparseCsr##device; \
}
C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused)
#undef DO_CASE
default:
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"Unsupported device type for ",
layout_,
" layout: ",
device_.type());
}
default:
TORCH_CHECK(false, "Unsupported layout: ", layout_);
}
}
inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) {
switch (dispatch_key) {
#define DO_CASE(bc, _) case DispatchKey::Sparse##bc:
C10_FORALL_BACKEND_COMPONENTS(DO_CASE, unused)
#undef DO_CASE
return Layout::Sparse;
#define DO_CASE(bc, _) case DispatchKey::SparseCsr##bc:
C10_FORALL_BACKEND_COMPONENTS(DO_CASE, unused)
#undef DO_CASE
TORCH_CHECK(
false, "Cannot map DispatchKey ", dispatch_key, " to a unique layout.");
case DispatchKey::MkldnnCPU:
return Layout::Mkldnn;
default:
return Layout::Strided;
}
}
inline c10::DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) {
switch (dispatch_key) {
// stuff that's real
#define DO_CASE(suffix, prefix) \
case DispatchKey::prefix##suffix: \
return c10::DeviceType::suffix;
#define DO_CASES(_, prefix) C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, prefix)
C10_FORALL_FUNCTIONALITY_KEYS(DO_CASES)
#undef DO_CASES
#undef DO_CASE
case DispatchKey::MkldnnCPU:
return c10::DeviceType::CPU;
case DispatchKey::Vulkan:
return c10::DeviceType::Vulkan;
case DispatchKey::MAIA:
return c10::DeviceType::MAIA;
default:
TORCH_CHECK(
false,
"DispatchKey ",
dispatch_key,
" doesn't correspond to a device");
}
}
inline TensorOptions dispatchKeyToTensorOptions(DispatchKey dispatch_key) {
return TensorOptions()
.layout(dispatchKeyToLayout(dispatch_key))
.device(dispatchKeyToDeviceType(dispatch_key));
}
namespace detail {
inline bool backend_supports_empty_operator(const TensorOptions& options) {
// Quantized backends don't support at::empty().
// They have separate operators like at::empty_quantized() that take in
// extra information about how to quantize the tensor.
return !isQIntType(typeMetaToScalarType(options.dtype()));
}
} // namespace detail
} // namespace c10

View File

@ -0,0 +1,49 @@
#pragma once
#include <c10/core/MemoryFormat.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/TensorImpl.h>
#include <c10/macros/Export.h>
#include <c10/util/ArrayRef.h>
#include <cstdint>
namespace c10 {
struct C10_API UndefinedTensorImpl final : public TensorImpl {
public:
// Without this, we get:
// error: identifier "at::UndefinedTensorImpl::_singleton" is undefined in
// device code
// (ostensibly because the constexpr tricks MSVC into trying to compile this
// function for device as well).
#ifdef _WIN32
static inline TensorImpl* singleton() {
return &getInstance();
}
#else
static constexpr inline TensorImpl* singleton() {
return &_singleton;
}
#endif
#ifdef DEBUG
bool has_storage() const override;
#endif
void set_storage_offset(int64_t offset) override;
protected:
bool is_contiguous_custom(MemoryFormat format) const override;
IntArrayRef strides_custom() const override;
SymIntArrayRef sym_strides_custom() const override;
private:
UndefinedTensorImpl();
#ifdef _WIN32
static UndefinedTensorImpl& getInstance();
#else
static UndefinedTensorImpl _singleton;
#endif
const char* tensorimpl_type_name() const override;
};
} // namespace c10

View File

@ -0,0 +1,48 @@
#pragma once
#include <c10/core/SymInt.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <cstdint>
#include <utility>
namespace c10 {
namespace detail {
// This template can only be specialized at int64_t and c10::SymInt;
// you'll get linker errors otherwise
template <typename T>
C10_API T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar);
} // namespace detail
template <typename T>
T _maybe_wrap_dim(T dim, T dim_post_expr, bool wrap_scalar = true) {
// Inline the fast paths
if (C10_LIKELY(dim_post_expr * -1 <= dim && dim < dim_post_expr)) {
// For SymInts, we want an explicit control flow to trigger a guard, so we
// may as well branch too.
if (dim < 0) {
return dim + dim_post_expr;
}
return dim;
}
// Check edge-cases out-of-line (wrapping scalars and out-of-bounds errors)
return c10::detail::maybe_wrap_dim_slow<T>(
std::move(dim), std::move(dim_post_expr), wrap_scalar);
}
inline int64_t maybe_wrap_dim(
int64_t dim,
int64_t dim_post_expr,
bool wrap_scalar = true) {
return _maybe_wrap_dim(dim, dim_post_expr, wrap_scalar);
}
inline c10::SymInt maybe_wrap_dim(
c10::SymInt dim,
c10::SymInt dim_post_expr,
bool wrap_scalar = true) {
return _maybe_wrap_dim(std::move(dim), std::move(dim_post_expr), wrap_scalar);
}
} // namespace c10

View File

@ -0,0 +1,21 @@
#pragma once
#include <cstddef>
namespace c10 {
#ifdef C10_MOBILE
// Use 16-byte alignment on mobile
// - ARM NEON AArch32 and AArch64
// - x86[-64] < AVX
constexpr size_t gAlignment = 16;
#else
// Use 64-byte alignment should be enough for computation up to AVX512.
constexpr size_t gAlignment = 64;
#endif
constexpr size_t gPagesize = 4096;
// since the default thp pagesize is 2MB, enable thp only
// for buffers of size 2MB or larger to avoid memory bloating
constexpr size_t gAlloc_threshold_thp = static_cast<size_t>(2) * 1024 * 1024;
} // namespace c10

View File

@ -0,0 +1,32 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/intrusive_ptr.h>
namespace c10 {
struct StorageImpl;
class DataPtr;
}; // namespace c10
namespace c10::impl::cow {
// Creates a Copy-on-write (COW) clone of the given storage. This will also
// convert the given storage into a COW storage if it is not COW already.
//
// Converting the storage into a COW storage will not be successful if the
// storage's DataPtr has some context (`DataPtr::get_context()`) which is not
// equal to the data pointer (`DataPtr::get()`). In this case, a nullptr is
// returned.
C10_API c10::intrusive_ptr<StorageImpl> lazy_clone_storage(
StorageImpl& storage);
// Check if a storage has a simple DataPtr with no abnormal context
C10_API bool has_simple_data_ptr(const c10::StorageImpl& storage);
// Check if a DataPtr is COW
C10_API bool is_cow_data_ptr(const c10::DataPtr& data_ptr);
// Eagerly copies a COW storage's data, turning it into a non-COW storage.
C10_API void materialize_cow_storage(StorageImpl& storage);
} // namespace c10::impl::cow

View File

@ -0,0 +1,66 @@
#pragma once
#include <c10/macros/Export.h>
#include <c10/util/UniqueVoidPtr.h>
#include <atomic>
#include <cstdint>
#include <memory>
#include <shared_mutex>
#include <variant>
namespace c10::impl::cow {
// A COWDeleterContext object is used as the `ctx` argument for DataPtr
// to implement a Copy-on-write (COW) DataPtr.
class C10_API COWDeleterContext {
public:
// Creates an instance, holding the pair of data and original
// deleter.
//
// Note that the deleter will only be called in our destructor if
// the last reference to this goes away without getting
// materialized.
explicit COWDeleterContext(std::unique_ptr<void, DeleterFnPtr> data);
// Increments the current refcount.
void increment_refcount();
// See README.md in this directory to understand the locking
// strategy.
// Represents a reference to the context.
//
// This is returned by decrement_refcount to allow the caller to
// copy the data under the shared lock.
using NotLastReference = std::shared_lock<std::shared_mutex>;
// Represents the last reference to the context.
//
// This will be returned by decrement_refcount when it is the last
// reference remaining and after any pending copies have completed.
using LastReference = std::unique_ptr<void, DeleterFnPtr>;
// Decrements the refcount, returning a handle indicating what to
// do with it.
std::variant<NotLastReference, LastReference> decrement_refcount();
private:
// The destructor is hidden, this should only ever be used within
// UniqueVoidPtr using cow::delete_context as the deleter.
~COWDeleterContext();
std::shared_mutex mutex_;
std::unique_ptr<void, DeleterFnPtr> data_;
std::atomic<std::int64_t> refcount_ = 1;
};
// `cow_deleter` is used as the `ctx_deleter` for DataPtr to implement a COW
// DataPtr.
//
// Warning: This should only be called on a pointer to a COWDeleterContext that
// was allocated on the heap with `new`, because when the refcount reaches 0,
// the context is deleted with `delete`.
C10_API void cow_deleter(void* ctx);
} // namespace c10::impl::cow

View File

@ -0,0 +1,365 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <c10/util/Exception.h>
// Just for C10_ANONYMOUS_VARIABLE
#include <c10/util/Registry.h>
#include <atomic>
namespace c10 {
// Forward declaration
class DataPtr;
/**
* Note [Flags defining the behavior of events]
*
* PYTORCH_DEFAULT and BACKEND_DEFAULT are valid for all backends. The
* BACKEND_DEFAULT is what a particular backend would select if no
* flags were given. PYTORCH_DEFAULT is the PyTorch's framework default
* choice for events on that backend, which may not be the same.
*
* The mapping of PYTORCH_DEFAULT and BACKEND_DEFAULT is done by each
* backend implementation.
*/
enum class EventFlag {
// Disable timing
PYTORCH_DEFAULT,
// Enable timing
BACKEND_DEFAULT,
// FOR TESTING ONLY
INVALID
};
namespace impl {
/**
* DeviceGuardImplInterface represents the virtual interface which provides
* functionality to provide an RAII class for device and stream switching,
* via DeviceGuard. Every distinct device type, e.g., CUDA and HIP, is
* expected to implement and register an implementation of this interface.
* All classes which inherit from DeviceGuardImplInterface should be declared
* 'final'.
*
* This class exists because we provide a unified interface for performing
* device guards via DeviceGuard, but we cannot assume that we have actually
* compiled against the, e.g., CUDA library, which actually implements
* this guard functionality. In this case, a dynamic dispatch is required
* to cross the library boundary.
*
* If possible, you should directly use implementations of this interface;
* those uses will be devirtualized.
*/
struct C10_API DeviceGuardImplInterface {
DeviceGuardImplInterface() = default;
DeviceGuardImplInterface(const DeviceGuardImplInterface&) = default;
DeviceGuardImplInterface& operator=(const DeviceGuardImplInterface&) =
default;
DeviceGuardImplInterface(DeviceGuardImplInterface&&) noexcept = default;
DeviceGuardImplInterface& operator=(DeviceGuardImplInterface&&) noexcept =
default;
/**
* Return the type of device managed by this guard implementation.
*/
virtual DeviceType type() const = 0;
/**
* Set the current device to Device, and return the previous Device.
*/
virtual Device exchangeDevice(Device) const = 0;
// NB: Implementations of exchangeDevice can be a bit boilerplatey. You might
// consider replacing exchangeDevice with a non-virtual function with a baked
// in implementation; however, note that this will triple the number of
// virtual calls (when you implement exchangeDevice in a final subclass,
// the compiler gets to devirtualize everything; it won't do that if you don't
// define it in the subclass!) A common way to solve this problem is to use
// some sort of CRTP; however, we can template DeviceGuardImplInterface since
// we really *do* need it to be virtual. A little boilerplate seems easiest
// to explain. (Another way around this problem is to provide inline
// functions that provide the default implementations, but this seems a little
// hard to explain. In any case, we're only going to have on order of ten
// implementations of this anyway.)
/**
* Get the current device.
*/
virtual Device getDevice() const = 0;
/**
* Set the current device to Device.
*/
virtual void setDevice(Device) const = 0;
/**
* Set the current device to Device, without checking for errors
* (so, e.g., this can be called from a destructor).
*/
virtual void uncheckedSetDevice(Device) const noexcept = 0;
/**
* Get the current stream for a given device.
*/
virtual Stream getStream(Device) const noexcept = 0;
/**
* Get the default stream for a given device.
*/
virtual Stream getDefaultStream(Device) const {
TORCH_CHECK(false, "Backend doesn't support acquiring a default stream.")
}
/**
* Get a stream from the global pool for a given device.
*/
virtual Stream getStreamFromGlobalPool(Device, bool isHighPriority = false)
const {
(void)isHighPriority; // Suppress unused variable warning
TORCH_CHECK(false, "Backend doesn't support acquiring a stream from pool.")
}
/**
* Return a new stream for a given device and priority. The stream will be
* copied and shared around, device backend should be able to correctly handle
* the lifetime of the stream.
*/
virtual Stream getNewStream(Device, int priority = 0) const {
(void)priority;
TORCH_CHECK(false, "Backend doesn't support create a new Stream.")
}
/**
* Set a stream to be the thread local current stream for its device.
* Return the previous stream for that device. You are NOT required
* to set the current device to match the device of this stream.
*/
virtual Stream exchangeStream(Stream) const noexcept = 0;
/**
* Destroys the given event.
*/
virtual void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/)
const noexcept {}
/**
* Increments the event's version and enqueues a job with this version
* in the stream's work queue. When the stream process that job
* it notifies all streams waiting on / blocked by that version of the
* event to continue and marks that version as recorded.
* */
virtual void record(
void** /*event*/,
const Stream& /*stream*/,
const DeviceIndex /*device_index*/,
const c10::EventFlag /*flag*/) const {
TORCH_CHECK(false, "Backend doesn't support events.");
}
/**
* Does nothing if the event has not been scheduled to be recorded.
* If the event was previously enqueued to be recorded, a command
* to wait for the version of the event that exists at the time of this call
* is inserted in the stream's work queue.
* When the stream reaches this command it will stop processing
* additional commands until that version of the event is marked as recorded.
*/
virtual void block(void* /*event*/, const Stream& /*stream*/) const {
TORCH_CHECK(false, "Backend doesn't support events.");
}
/**
* Returns true if (and only if)
* (1) the event has never been scheduled to be recorded
* (2) the current version is marked as recorded.
* Returns false otherwise.
*/
virtual bool queryEvent(void* /*event*/) const {
TORCH_CHECK(false, "Backend doesn't support events.");
}
/**
* Get the number of devices. WARNING: This is REQUIRED to not raise
* an exception. If there is some sort of problem, e.g., driver error,
* you should report that there are zero available devices.
*/
virtual DeviceIndex deviceCount() const noexcept = 0;
/**
* Return true if all the work previously enqueued on the stream for
* asynchronous execution has completed running on the device.
*/
virtual bool queryStream(const Stream& /*stream*/) const {
TORCH_CHECK(false, "Backend doesn't support querying streams.");
}
/**
* Wait (by blocking the calling thread) until all the work previously
* enqueued on the stream has completed running on the device.
*/
virtual void synchronizeStream(const Stream& /*stream*/) const {
TORCH_CHECK(false, "Backend doesn't support synchronizing streams.");
}
/**
* Wait (by blocking the calling thread) until all the work previously
* recorded on the event has completed running on the device.
*/
virtual void synchronizeEvent(void* /*event*/) const {
TORCH_CHECK(false, "Backend doesn't support synchronizing events.");
}
/**
* Ensure the caching allocator (if any) is aware that the given DataPtr is
* being used on the given stream, and that it should thus avoid recycling the
* DataPtr until all work on that stream is done.
*/
virtual void recordDataPtrOnStream(const c10::DataPtr&, const Stream&) const {
}
/**
* Fetch the elapsed time between two recorded events.
*/
virtual double elapsedTime(
void* /*event1*/,
void* /*event2*/,
const DeviceIndex /*device_index*/) const {
TORCH_CHECK(false, "Backend doesn't support elapsedTime.");
}
/**
* Intended use of this class is to leak the DeviceGuardImpl at program end.
* So you better not call the destructor, buster!
*/
virtual ~DeviceGuardImplInterface() = default;
};
// A no-op device guard impl that doesn't do anything interesting. Useful
// for devices that don't actually have a concept of device index. Prominent
// examples are CPU and Meta.
template <DeviceType D>
struct NoOpDeviceGuardImpl final : public DeviceGuardImplInterface {
NoOpDeviceGuardImpl() = default;
DeviceType type() const override {
return D;
}
Device exchangeDevice(Device) const override {
return Device(D, -1); // no-op
}
Device getDevice() const override {
return Device(D, -1);
}
void setDevice(Device) const override {
// no-op
}
void uncheckedSetDevice(Device) const noexcept override {
// no-op
}
Stream getStream(Device) const noexcept override {
// no-op
return Stream(Stream::DEFAULT, Device(D, -1));
}
Stream getNewStream(Device, int priority = 0) const override {
// no-op
(void)priority;
return Stream(Stream::DEFAULT, Device(D, -1));
}
// NB: These do NOT set the current device
Stream exchangeStream(Stream) const noexcept override {
// no-op
return Stream(Stream::DEFAULT, Device(D, -1));
}
DeviceIndex deviceCount() const noexcept override {
return 1;
}
// Event-related functions
void record(
void** /*event*/,
const Stream& /*stream*/,
const DeviceIndex /*device_index*/,
const EventFlag /*flag*/) const override {
TORCH_CHECK(false, D, " backend doesn't support events.");
}
void block(void* /*event*/, const Stream& /*stream*/) const override {
TORCH_CHECK(false, D, " backend doesn't support events.")
}
bool queryEvent(void* /*event*/) const override {
TORCH_CHECK(false, D, " backend doesn't support events.")
}
void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/)
const noexcept override {}
// Stream-related functions
bool queryStream(const Stream& /*stream*/) const override {
return true;
}
void synchronizeStream(const Stream& /*stream*/) const override {
// Don't wait for anything.
}
};
// The registry is NON-owning. Each stored pointer is std::atomic so
// that under all interleavings of registry calls the structure is
// race-free. This doesn't cost us anything on reads in X86. (An
// unsynchronized implementation probably is OK too, but I didn't want
// to prove that we never read from device_guard_impl_registry at the
// same time some registration is occurring. Shiver.)
//
// I'd like this registry to be valid even at program destruction time
// (in case someone uses a DeviceGuard in a destructor to do some cleanup
// in the CUDA API.) Since there are no direct accesses of the underlying
// owning objects which I can use to enforce initialization order (unlike
// in a Meyer singleton), it implies that you must *leak* objects when
// putting them in the registry. This is done by deleting the destructor
// on DeviceGuardImplInterface.
// NOLINTNEXTLINE(*c-arrays*)
extern C10_API std::atomic<const DeviceGuardImplInterface*>
device_guard_impl_registry[static_cast<size_t>(
DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
// I can't conveniently use c10/util/Registry.h for the following reason:
// c10/util/Registry.h gives me a slow way of Create'ing a object of some
// interface from the registry, but no way of quickly accessing an already
// created object. I'll be banging on getDeviceGuardImpl every time we do a
// DeviceGuard, so I really don't want to be doing an unordered_map lookup.
// Better if the registration mechanism directly drops its implementation
// into device_guard_impl_registry.
class C10_API DeviceGuardImplRegistrar {
public:
DeviceGuardImplRegistrar(DeviceType, const DeviceGuardImplInterface*);
};
#define C10_REGISTER_GUARD_IMPL(DevType, DeviceGuardImpl) \
static ::c10::impl::DeviceGuardImplRegistrar C10_ANONYMOUS_VARIABLE( \
g_##DeviceType)(::c10::DeviceType::DevType, new DeviceGuardImpl());
inline const DeviceGuardImplInterface* getDeviceGuardImpl(DeviceType type) {
// Two adjacent int16_t fields DeviceType and DeviceIndex has field access
// miscompiled on NVCC. To workaround this issue, we apply a mask to the
// DeviceType. First check if the DeviceType is 16-bit.
// FB employees can see
// https://fb.workplace.com/groups/llvm.gcc/permalink/4053565044692080/
// for more details
static_assert(sizeof(DeviceType) == 1, "DeviceType is not 8-bit");
auto p = device_guard_impl_registry[static_cast<size_t>(type) & 0xFF].load();
// This seems to be the first place where you make use of a device
// when you pass devices to factory functions. Give a nicer error
// message in this case.
TORCH_CHECK(p, "PyTorch is not linked with support for ", type, " devices");
return p;
}
inline bool hasDeviceGuardImpl(DeviceType type) {
return device_guard_impl_registry[static_cast<size_t>(type)].load();
}
} // namespace impl
} // namespace c10

View File

@ -0,0 +1,102 @@
#pragma once
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <array>
namespace c10::impl {
// FakeGuardImpl is hardcoded to have eight devices. Not for
// any good reason, just to simplify code.
constexpr DeviceIndex kFakeGuardImplMaxDevices = 8;
/**
* A fake implementation of DeviceGuardImplInterface suitable for testing.
* The current device is modeled as a mutable field in the guard implementation
* class. See DeviceGuard_test.cpp for an example use.
*/
template <DeviceType T>
struct FakeGuardImpl final : public DeviceGuardImplInterface {
static constexpr DeviceType static_type = T;
// Runtime device type is not used
FakeGuardImpl(DeviceType) {}
FakeGuardImpl() = default;
DeviceType type() const override {
return T;
}
Device exchangeDevice(Device d) const override {
AT_ASSERT(d.type() == type());
AT_ASSERT(d.index() < kFakeGuardImplMaxDevices);
Device old_device = getDevice();
if (old_device.index() != d.index()) {
current_device_ = d.index();
}
return old_device;
}
Device getDevice() const override {
return Device(type(), current_device_);
}
void setDevice(Device d) const override {
AT_ASSERT(d.type() == type());
AT_ASSERT(d.index() >= 0);
AT_ASSERT(d.index() < kFakeGuardImplMaxDevices);
current_device_ = d.index();
}
void uncheckedSetDevice(Device d) const noexcept override {
current_device_ = d.index();
}
Stream getStream(Device d) const noexcept override {
return Stream(Stream::UNSAFE, d, current_streams_[d.index()]);
}
Stream exchangeStream(Stream s) const noexcept override {
auto old_id = current_streams_[s.device_index()];
current_streams_[s.device_index()] = s.id();
return Stream(Stream::UNSAFE, s.device(), old_id);
}
DeviceIndex deviceCount() const noexcept override {
return kFakeGuardImplMaxDevices;
}
// Event-related functions
void record(
void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {}
void block(void* event, const Stream& stream) const override {}
bool queryEvent(void* event) const override {
return true;
}
void destroyEvent(void* event, const DeviceIndex device_index)
const noexcept override {}
// Convenience methods for testing
static DeviceIndex getDeviceIndex() {
return current_device_;
}
static void setDeviceIndex(DeviceIndex i) {
AT_ASSERT(i >= 0);
AT_ASSERT(i < kFakeGuardImplMaxDevices);
current_device_ = i;
}
static StreamId getCurrentStreamIdFor(DeviceIndex i) {
return current_streams_.at(i);
}
static void resetStreams() {
current_streams_.fill(0);
}
private:
thread_local static DeviceIndex current_device_;
thread_local static std::array<StreamId, kFakeGuardImplMaxDevices>
current_streams_;
};
template <DeviceType T>
thread_local DeviceIndex FakeGuardImpl<T>::current_device_ = 0;
template <DeviceType T>
thread_local std::array<StreamId, kFakeGuardImplMaxDevices>
FakeGuardImpl<T>::current_streams_ = {0, 0, 0, 0, 0, 0, 0, 0};
} // namespace c10::impl

View File

@ -0,0 +1,28 @@
#pragma once
#include <c10/core/impl/PyInterpreter.h>
namespace c10::impl {
struct C10_API GPUTrace {
// On the x86 architecture the atomic operations are lock-less.
static std::atomic<const PyInterpreter*> gpuTraceState;
// When PyTorch migrates to C++20, this should be changed to an atomic flag.
// Currently, the access to this variable is not synchronized, on the basis
// that it will only be flipped once and by the first interpreter that
// accesses it.
static bool haveState;
// This function will only register the first interpreter that tries to invoke
// it. For all of the next ones it will be a no-op.
static void set_trace(const PyInterpreter*);
static const PyInterpreter* get_trace() {
if (!haveState)
return nullptr;
return gpuTraceState.load(std::memory_order_acquire);
}
};
} // namespace c10::impl

View File

@ -0,0 +1,59 @@
#pragma once
#include <c10/macros/Export.h>
#include <atomic>
namespace c10::impl {
// This TLS controls whether or not we permanently associate PyObject
// with Tensor the first time it is allocated. When hermetic PyObject
// TLS is enabled (state is true), we DO NOT save PyObjects to Tensor,
// meaning you get a distinct PyObject whenever you execute the code in
// question.
struct C10_API HermeticPyObjectTLS {
static void set_state(bool state);
static bool get_state() {
// Hypothetical fastpath if torchdeploy/multipy isn't used. Per
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
// this qualifies relaxed access because it is a single-location data
// structure (only the boolean here).
//
// Forgetting about data races for a moment, is there a logical race?
//
// - Boolean only ever transitions from false to true. So the
// critical situation is when one interpreter is already running
// when a second interpreter switches haveState from false to true.
//
// - The first interpreter is indifferent whether or not it sees
// hasState true/false; obviously false works (this is what the
// interpreter was previously using; more directly, the interpreter
// calls into itself as the handler, so being hermetic is not
// required), and true simply means serviced python operator calls will
// be hermetic; in these cases it is expected to be functionally
// equivalent.
//
// - The second interpreter MUST see hasState true (as its requests will
// be forwarded to the first interpreter), but it is assumed that there
// is a synchronization between the interpreter initialization, and
// when we actually perform operations, so it is guaranteed to see
// hasState true.
//
// QED.
//
// This fastpath is currently disabled so that we can more easily test that
// hermetic mode works correctly even on stock build of PyTorch.
if (false && !haveState_.load(std::memory_order_relaxed))
return false;
return get_tls_state();
}
// Call this from the multipy/torchdeploy top level
static void init_state();
private:
// This only flipped once from false to true during torchdeploy/multipy
// initialization, and never again.
static std::atomic<bool> haveState_;
static bool get_tls_state();
};
} // namespace c10::impl

View File

@ -0,0 +1,429 @@
#pragma once
// This file provides implementations of InlineDeviceGuard and
// InlineOptionalDeviceGuard.
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/VirtualGuardImpl.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <type_traits>
#include <utility>
namespace c10::impl {
/**
* A DeviceGuard is an RAII class that sets a device to some value
* on construction, and resets the device to its original value on
* destruction.
*
* InlineDeviceGuard is a helper class for implementing DeviceGuards.
* It is templated over a DeviceGuardImpl (anything that implements
* DeviceGuardImplInterface). There are two primary ways to instantiate
* InlineDeviceGuard:
*
* - With a concrete implementation of DeviceGuardImpl, e.g., CUDAGuardImpl.
* This is the best way to use InlineDeviceGuard, as all calls are
* devirtualized, giving you code as efficient as straight line
* calls to cudaGetDevice/cudaSetDevice.
*
* - With VirtualGuardImpl, which does a virtual dispatch to a DeviceGuardImpl
* retrieved from a DeviceType registry. We have explicitly instantiated
* InlineDeviceGuard this way as c10::DeviceGuard.
*
* If you are in a hurry, you can use InlineDeviceGuard directly:
*
* using CUDAGuard = impl::InlineDeviceGuard<CUDAGuardImpl>;
*
* However, you can provide a better user experience if you explicitly write a
* wrapper class that itself contains the template instantiation:
*
* class CUDAGuard {
* public:
* // ... the API ...
* private:
* impl::InlineDeviceGuard<CUDAGuardImpl> guard_;
* }
*
* The wrapper class provides a good place to write documentation, and helps
* avoid weird template instantiation errors when a user incorrectly uses the
* class.
*
* If you need to test this class, consider instantiating it with FakeGuardImpl.
*/
template <typename T>
class InlineDeviceGuard {
public:
// Note [Omitted default constructor from RAII]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// In principle, we could add a default constructor to
// DeviceGuard which reads the current device and promises to
// restore to that device on exit. However, most cases where you
// would have written this, you probably meant to actually just
// use OptionalDeviceGuard (since you don't actually need the
// restore to happen if you don't ever actually set the device).
// We remove the constructor here to encourage you to think about
// what you actually want to happen.
explicit InlineDeviceGuard() = delete;
/// Set the current device to the passed Device.
explicit InlineDeviceGuard(Device device)
: impl_(device.type()),
original_device_(
device.index() == -1 ? impl_.getDevice()
: impl_.exchangeDevice(device)),
current_device_(device.index() == -1 ? original_device_ : device) {}
/// Set the current device index to the passed DeviceIndex. (The
/// device type is inferred from the template parameter T).
template <
typename U = T,
typename =
typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>>>
explicit InlineDeviceGuard(DeviceIndex device_index)
: InlineDeviceGuard(Device(U::static_type, device_index)) {}
/// Construct an InlineDeviceGuard using VirtualGuardImpl with an explicit
/// DeviceGuardImplInterface pointer.
template <
typename U = T,
typename = typename std::enable_if_t<std::is_same_v<U, VirtualGuardImpl>>>
explicit InlineDeviceGuard(
Device device,
const DeviceGuardImplInterface* impl)
: impl_(
VirtualGuardImpl(impl ? impl : getDeviceGuardImpl(device.type()))),
original_device_(
device.index() == -1 ? impl_.getDevice()
: impl_.exchangeDevice(device)),
current_device_(device.index() == -1 ? original_device_ : device) {}
/// Copy is disallowed
InlineDeviceGuard(const InlineDeviceGuard<T>&) = delete;
InlineDeviceGuard<T>& operator=(const InlineDeviceGuard<T>&) = delete;
/// Move is disallowed, as DeviceGuard does not have an uninitialized state,
/// which is required for moves on types with nontrivial destructors.
InlineDeviceGuard(InlineDeviceGuard<T>&& other) = delete;
InlineDeviceGuard& operator=(InlineDeviceGuard<T>&& other) = delete;
~InlineDeviceGuard() {
impl_.uncheckedSetDevice(original_device_);
}
/// Sets the device to the given one.
template <
typename U = T,
typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>, int> = 0>
void set_device(at::Device device) {
AT_ASSERT(
(U::static_type == DeviceType::HIP && device.is_cuda()) ||
device.type() == U::static_type);
auto index = device.index();
if (index == -1)
return;
impl_.setDevice(device);
current_device_ = device;
}
/// Resets the currently set device to its original device, and then sets the
/// current device to the passed device. This is effectively equivalent to
/// set_device when a guard supports only a single device type.
template <typename U = T>
typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>> reset_device(
at::Device device) {
set_device(device);
}
/// Resets the currently set device to its original device, and then sets the
/// current device to the passed device (for a possibly different device
/// type).
///
/// This method is named reset_device to highlight the fact that previous
/// device settings from this guard are NOT preserved, even if the device
/// has a different device type. For example:
///
/// // CUDA device is 0
/// DeviceGuard g(Device(kCUDA, 1));
/// g.reset_device(Device(kHIP, 2));
/// // CUDA device is 0 (!!)
///
/// NOTE: this implementation may skip some device setting if it can prove
/// that it is unnecessary.
///
/// Optional argument is for testing only.
template <typename U = T>
typename std::enable_if_t<std::is_same_v<U, VirtualGuardImpl>> reset_device(
at::Device device,
const impl::DeviceGuardImplInterface* impl = nullptr) {
auto index = device.index();
if (index == -1)
return;
if (device.type() == original_device_.type()) {
AT_ASSERT(impl == nullptr || impl->type() == device.type());
impl_.setDevice(device);
current_device_ = device;
} else {
// Destruct and reconstruct the DeviceGuard in place
impl_.setDevice(original_device_);
impl_ = !impl ? VirtualGuardImpl(device.type()) : VirtualGuardImpl(impl);
original_device_ = impl_.exchangeDevice(device);
current_device_ = device;
}
}
/// Sets the device index to the given one. The device type is inferred
/// from the original device type.
void set_index(DeviceIndex index) {
reset_device(Device(original_device_.type(), index));
}
/// Returns the device that was set at the time the most recent
/// reset_device(), or otherwise the device at construction time.
Device original_device() const {
return original_device_;
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device/reset_device/set_index.
Device current_device() const {
return current_device_;
}
protected:
T impl_;
private:
Device original_device_;
Device current_device_;
};
/**
* A OptionalDeviceGuard is an RAII class that sets a device to some value on
* initialization, and resets the device to its original value on destruction.
*
* InlineOptionalDeviceGuard is a helper class for implementing
* OptionalDeviceGuards. See guidance in InlineDeviceGuard on how to
* use this. See OptionalDeviceGuard for user-oriented usage notes.
*/
template <typename T>
class InlineOptionalDeviceGuard {
public:
// Note [Explicit initialization of optional fields]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Explicit initialization of optional fields
// required to workaround an nvcc bug; see
// https://github.com/pytorch/pytorch/issues/12117
/// Creates an uninitialized OptionalDeviceGuard.
explicit InlineOptionalDeviceGuard()
: guard_() // See Note [Explicit initialization of optional fields]
{}
/// Set the current device to the passed Device, if it is not nullopt.
explicit InlineOptionalDeviceGuard(std::optional<Device> device_opt)
: guard_() { // See Note [Explicit initialization of optional fields]
if (device_opt.has_value()) {
guard_.emplace(device_opt.value());
}
}
/// Set the current device to the passed DeviceIndex, if it is not nullopt.
template <
typename U = T,
typename =
typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>>>
explicit InlineOptionalDeviceGuard(
std::optional<DeviceIndex> device_index_opt)
: guard_() { // See Note [Explicit initialization of optional fields]
if (device_index_opt.has_value()) {
guard_.emplace(device_index_opt.value());
}
}
/// All constructors of DeviceGuard are valid for OptionalDeviceGuard
/// and result in initialized OptionalDeviceGuard.
template <typename... Args>
explicit InlineOptionalDeviceGuard(Args&&... args)
: guard_(std::in_place, std::forward<Args>(args)...) {}
// TODO: Consider reading Tensor and TensorList constructors here, when
// Tensor moves to c10. (These are only valid on OptionalDeviceGuard,
// because a Tensor may be undefined, in which case we need an uninitialized
// tensor guard.)
// Note [Move construction for RAII guards is tricky]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// In principle, move construction is useful for terminating
// the lifetime of a `OptionalDeviceGuard` early; for example:
//
// // current device is d0
// OptionalDeviceGuard g1(d1);
// // current device is d1
// {
// OptionalDeviceGuard g2(std::move(g1));
// }
// // current device is d0!!
//
// However, it's difficult to implement the move constructor
// in a way that works in all situations. For example, consider
// the following example:
//
// OptionalDeviceGuard g1(d1);
// {
// OptionalDeviceGuard g2(d2);
// {
// OptionalDeviceGuard g3(std::move(g1)); // !!!
// }
// }
//
// What should the current device be while g3 in scope... and what
// should it be after it goes out of scope? What about g2?
// There don't seem to be satisfactory answers for these questions.
//
// It's in principle possible to raise an error when this occurs
// by doing some extra thread-local bookkeeping. But why bother?
// Just don't provide the constructor.
InlineOptionalDeviceGuard(InlineOptionalDeviceGuard<T>&& other) = delete;
// Note [Move assignment for RAII guards is tricky]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Move assignment is deleted, because you need to know which guard was
// defined "first", as that guard's original_device_ wins--with the current
// representation, we have no way of telling which is the case. (Move
// construction does not have this problem, as one guard is always
// uninitialized.)
//
// We can make this clear by way of a pair of examples:
//
// Example 1:
//
// // initial device is n0
// {
// CUDAGuard g1(n1);
// {
// CUDAGuard g2(n2);
// // current device should be n2
// g1 = std::move(g2);
// // current device should still be n2
// }
// // current device should still be n2
// }
// // current device should be n0
//
// Example 2 (flip the order of the two guards):
//
// // initial device is n0
// {
// CUDAGuard g2(n2);
// {
// CUDAGuard g1(n1);
// // current device should be n1
// g1 = std::move(g2);
// // current device should be n2
// }
// // current device should be n0 (since g2 has been vacated)
// }
//
// In both examples, we need g1 to restore to n0 after move assignment.
// However, in example 1, this is determined by the restore value of g1
// (prior to the move). In example 2, however, it is determined by the the
// restore value of g2(!!). We don't know which one should win, without having
// a way of telling which guard was allocated first.
//
// We could solve this with an extra thread-local variable. But no one is
// actually using move-assignment. So just get rid of it.
InlineOptionalDeviceGuard& operator=(InlineOptionalDeviceGuard&& other) =
delete;
/// Sets the device to the given one. Initializes OptionalDeviceGuard if it
/// is not already initialized.
template <
typename U = T,
typename =
typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>>>
void set_device(at::Device device) {
if (!guard_.has_value()) {
guard_.emplace(device);
} else {
guard_->set_device(device);
}
}
/// Resets the currently set device to its original device, and then sets the
/// current device to the passed device (for a possibly different device
/// type). Initializes OptionalDeviceGuard if it is not already initialized.
///
/// See notes on why this is called reset_device on InlineDeviceGuard.
///
/// Optional argument is for testing only.
template <
typename U = T,
typename = typename std::enable_if_t<std::is_same_v<U, VirtualGuardImpl>>>
void reset_device(
at::Device device,
const DeviceGuardImplInterface* impl = nullptr) {
if (!guard_.has_value()) {
guard_.emplace(device, impl);
} else {
guard_->reset_device(device, impl);
}
}
/// Resets the currently set device to its original device, and then sets the
/// current device to the passed device. Initializes the guard if it is
/// not already initialized. This is effectively equivalent to set_device
/// when a guard supports only a single device type.
template <
typename U = T,
typename =
typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>>>
void reset_device(at::Device device) {
if (!guard_.has_value()) {
guard_.emplace(device);
} else {
guard_->reset_device(device);
}
}
/// Sets the device index to the given one. The device type is statically
/// known.
template <
typename U = T,
typename =
typename std::enable_if_t<!std::is_same_v<U, VirtualGuardImpl>>>
void set_index(DeviceIndex index) {
if (!guard_.has_value()) {
guard_.emplace(index);
} else {
guard_->set_index(index);
}
}
/// Returns the device that was set immediately prior to initialization of
/// the, guard, or nullopt if the guard is uninitialized.
std::optional<Device> original_device() const {
return guard_.has_value() ? std::make_optional(guard_->original_device())
: std::nullopt;
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device, if the guard is initialized,
/// or nullopt if the guard is uninitialized.
std::optional<Device> current_device() const {
return guard_.has_value() ? std::make_optional(guard_->current_device())
: std::nullopt;
}
/// Restore the original device, resetting this guard to uninitialized state.
void reset() {
guard_.reset();
}
private:
std::optional<InlineDeviceGuard<T>> guard_;
};
} // namespace c10::impl

View File

@ -0,0 +1,139 @@
#pragma once
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/util/Exception.h>
namespace c10::impl {
template <typename T>
struct InlineEvent final {
InlineEvent() = delete;
InlineEvent(
const DeviceType _device_type,
const EventFlag _flag = EventFlag::PYTORCH_DEFAULT)
: backend_{_device_type}, device_type_{_device_type}, flag_{_flag} {}
// Copy constructor and copy assignment operator (deleted)
InlineEvent(const InlineEvent&) = delete;
InlineEvent& operator=(const InlineEvent&) = delete;
// Move constructor and move assignment operator
InlineEvent(InlineEvent&& other) noexcept
: event_(other.event_),
backend_(std::move(other.backend_)),
device_type_(other.device_type_),
device_index_(other.device_index_),
flag_(other.flag_),
was_marked_for_recording_(other.was_marked_for_recording_) {
other.event_ = nullptr;
}
InlineEvent& operator=(InlineEvent&& other) noexcept {
swap(other);
return *this;
}
void swap(InlineEvent& other) noexcept {
std::swap(event_, other.event_);
std::swap(backend_, other.backend_);
std::swap(device_type_, other.device_type_);
std::swap(device_index_, other.device_index_);
std::swap(flag_, other.flag_);
std::swap(was_marked_for_recording_, other.was_marked_for_recording_);
}
~InlineEvent() noexcept {
if (event_)
backend_.destroyEvent(event_, device_index_);
}
DeviceType device_type() const noexcept {
return device_type_;
}
DeviceIndex device_index() const noexcept {
return device_index_;
}
EventFlag flag() const noexcept {
return flag_;
}
bool was_marked_for_recording() const noexcept {
return was_marked_for_recording_;
}
void recordOnce(const Stream& stream) {
if (!was_marked_for_recording_)
record(stream);
}
void record(const Stream& stream) {
TORCH_CHECK(
stream.device_type() == device_type_,
"Event device type ",
DeviceTypeName(device_type_),
" does not match recording stream's device type ",
DeviceTypeName(stream.device_type()),
".");
backend_.record(&event_, stream, device_index_, flag_);
was_marked_for_recording_ = true;
device_index_ = stream.device_index();
}
void block(const Stream& stream) const {
if (!was_marked_for_recording_)
return;
TORCH_CHECK(
stream.device_type() == device_type_,
"Event device type ",
DeviceTypeName(device_type_),
" does not match blocking stream's device type ",
DeviceTypeName(stream.device_type()),
".");
backend_.block(event_, stream);
}
bool query() const {
if (!was_marked_for_recording_)
return true;
return backend_.queryEvent(event_);
}
void* eventId() const {
return event_;
}
double elapsedTime(const InlineEvent& other) const {
TORCH_CHECK(
other.was_marked_for_recording(),
"other was not marked for recording.");
TORCH_CHECK(
was_marked_for_recording(), "self was not marked for recording.");
TORCH_CHECK(
other.device_type() == device_type_,
"Event device type ",
DeviceTypeName(device_type_),
" does not match other's device type ",
DeviceTypeName(other.device_type()),
".");
return backend_.elapsedTime(event_, other.event_, device_index_);
}
void synchronize() const {
if (!was_marked_for_recording_)
return;
backend_.synchronizeEvent(event_);
}
private:
void* event_ = nullptr;
T backend_;
DeviceType device_type_;
DeviceIndex device_index_ = -1;
EventFlag flag_ = EventFlag::PYTORCH_DEFAULT;
bool was_marked_for_recording_ = false;
};
} // namespace c10::impl

View File

@ -0,0 +1,256 @@
#pragma once
#include <c10/core/impl/InlineDeviceGuard.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/irange.h>
namespace c10::impl {
/**
* A StreamGuard is an RAII class that changes the current device
* to the device corresponding to some stream, and changes the
* default stream on that device to be this stream.
*
* InlineStreamGuard is a helper class for implementing StreamGuards.
* See InlineDeviceGuard for guidance on how to use this class.
*/
template <typename T>
class InlineStreamGuard : private InlineDeviceGuard<T> {
public:
/// No default constructor, see Note [Omitted default constructor from RAII]
explicit InlineStreamGuard() = delete;
/// Set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream.
explicit InlineStreamGuard(Stream stream)
: InlineDeviceGuard<T>(stream.device()),
original_stream_of_original_device_(
this->impl_.getStream(original_device())),
original_stream_of_current_device_(this->impl_.exchangeStream(stream)),
current_stream_(stream) {}
/// This constructor exists purely for testing
template <
typename U = T,
typename = typename std::enable_if_t<std::is_same_v<U, VirtualGuardImpl>>>
explicit InlineStreamGuard(
Stream stream,
const DeviceGuardImplInterface* impl)
: InlineDeviceGuard<T>(
stream.device(),
impl ? impl : getDeviceGuardImpl(stream.device_type())),
original_stream_of_original_device_(
this->impl_.getStream(original_device())),
original_stream_of_current_device_(this->impl_.exchangeStream(stream)),
current_stream_(stream) {}
/// Copy is disallowed
InlineStreamGuard(const InlineStreamGuard<T>&) = delete;
InlineStreamGuard<T>& operator=(const InlineStreamGuard<T>&) = delete;
/// Move is disallowed, as StreamGuard does not have an uninitialized state,
/// which is required for moves on types with nontrivial destructors.
InlineStreamGuard(InlineStreamGuard<T>&& other) = delete;
InlineStreamGuard& operator=(InlineStreamGuard<T>&& other) = delete;
~InlineStreamGuard() {
this->impl_.exchangeStream(original_stream_of_current_device_);
}
/// Resets the currently set stream to the original stream and
/// the currently set device to the original device. Then,
/// set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream.
///
/// NOTE: this implementation may skip some stream/device setting if
/// it can prove that it is unnecessary.
///
/// WARNING: reset_stream does NOT preserve previously set streams on
/// different devices. If you need to set streams on multiple devices
/// use MultiStreamGuard instead.
void reset_stream(Stream stream) {
// TODO: make a version that takes an impl argument. Unfortunately,
// that will require SFINAE because impl is only valid for the
// VirtualGuardImpl specialization.
if (stream.device() == this->current_device()) {
this->impl_.exchangeStream(stream);
current_stream_ = stream;
} else {
// Destruct and reconstruct the StreamGuard in-place
this->impl_.exchangeStream(original_stream_of_current_device_);
this->reset_device(stream.device());
original_stream_of_current_device_ = this->impl_.exchangeStream(stream);
current_stream_ = stream;
}
}
// It's not clear if set_device should also reset the current stream
// if the device is unchanged; therefore, we don't provide it.
// The situation is somewhat clearer with reset_device, but it's still
// a pretty weird thing to do, so haven't added this either.
/// Returns the stream of the original device prior to this guard. Subtly,
/// the stream returned here is the original stream of the *original*
/// device; i.e., it's the stream that your computation *would* have
/// been put on, if it hadn't been for this meddling stream guard.
/// This is usually what you want.
Stream original_stream() const {
return original_stream_of_original_device_;
}
/// Returns the most recent stream that was set using this device guard,
/// either from construction, or via set_stream.
Stream current_stream() const {
return current_stream_;
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device/reset_device/set_index.
Device current_device() const {
return InlineDeviceGuard<T>::current_device();
}
/// Returns the device that was set at the most recent reset_stream(),
/// or otherwise the device at construction time.
Device original_device() const {
return InlineDeviceGuard<T>::original_device();
}
private:
Stream
original_stream_of_original_device_; // what the user probably cares about
Stream original_stream_of_current_device_; // what we need to restore
Stream current_stream_;
};
/**
* An OptionalStreamGuard is an RAII class that sets a device to some value on
* initialization, and resets the device to its original value on destruction.
* See InlineOptionalDeviceGuard for more guidance on how to use this class.
*/
template <typename T>
class InlineOptionalStreamGuard {
public:
/// Creates an uninitialized stream guard.
explicit InlineOptionalStreamGuard()
: guard_() // See Note [Explicit initialization of optional fields]
{}
/// Set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream,
/// if the passed stream is not nullopt.
explicit InlineOptionalStreamGuard(std::optional<Stream> stream_opt)
: guard_() {
if (stream_opt.has_value()) {
guard_.emplace(stream_opt.value());
}
}
/// All constructors of StreamGuard are valid for OptionalStreamGuard
template <typename... Args>
explicit InlineOptionalStreamGuard(Args&&... args)
: guard_(std::in_place, std::forward<Args>(args)...) {}
// See Note [Move construction for RAII guards is tricky]
InlineOptionalStreamGuard(InlineOptionalStreamGuard<T>&& other) = delete;
// See Note [Move assignment for RAII guards is tricky]
InlineOptionalStreamGuard& operator=(InlineOptionalStreamGuard&& other) =
delete;
/// Resets the currently set stream to the original stream and
/// the currently set device to the original device. Then,
/// set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream.
/// Initializes the OptionalStreamGuard if it was not previously initialized.
void reset_stream(Stream stream) {
if (guard_.has_value()) {
guard_->reset_stream(stream);
} else {
guard_.emplace(stream);
}
}
/// Returns the stream that was set at the time the guard was most recently
/// initialized, or nullopt if the guard is uninitialized.
std::optional<Stream> original_stream() const {
return guard_.has_value() ? std::make_optional(guard_->original_stream())
: std::nullopt;
}
/// Returns the most recent stream that was set using this stream guard,
/// either from construction, or via reset_stream, if the guard is
/// initialized, or nullopt if the guard is uninitialized.
std::optional<Stream> current_stream() const {
return guard_.has_value() ? std::make_optional(guard_->current_stream())
: std::nullopt;
}
/// Restore the original device and stream, resetting this guard to
/// uninitialized state.
void reset() {
guard_.reset();
}
private:
std::optional<InlineStreamGuard<T>> guard_;
};
template <typename T>
class InlineMultiStreamGuard {
public:
/// Calls `set_stream` on each of the streams in the list.
/// This may be useful if you need to set different streams
/// for different devices.
explicit InlineMultiStreamGuard(ArrayRef<Stream> streams) {
if (!streams.empty()) {
impl_.emplace(getDeviceTypeOfStreams(streams));
original_streams_.reserve(streams.size());
for (const Stream& s : streams) {
original_streams_.emplace_back(this->impl_->exchangeStream(s));
}
}
}
/// Copy is disallowed
InlineMultiStreamGuard(const InlineMultiStreamGuard&) = delete;
InlineMultiStreamGuard<T>& operator=(const InlineMultiStreamGuard&) = delete;
/// Move is disallowed, as StreamGuard does not have an uninitialized state,
/// which is required for moves on types with nontrivial destructors.
InlineMultiStreamGuard(InlineMultiStreamGuard&& other) = delete;
InlineMultiStreamGuard& operator=(InlineMultiStreamGuard&& other) = delete;
~InlineMultiStreamGuard() noexcept {
if (this->impl_.has_value()) {
for (const Stream& s : original_streams_) {
this->impl_->exchangeStream(s);
}
}
}
protected:
std::optional<T> impl_;
private:
/// The original streams that were active on all devices.
std::vector<Stream> original_streams_;
static DeviceType getDeviceTypeOfStreams(ArrayRef<Stream> streams) {
TORCH_INTERNAL_ASSERT(!streams.empty());
DeviceType type = streams[0].device_type();
for (const auto idx : c10::irange(1, streams.size())) {
TORCH_CHECK_VALUE(
streams[idx].device_type() == type,
"Streams have a mix of device types: stream 0 is on ",
streams[0].device(),
" while stream ",
idx,
" is on device ",
streams[idx].device());
}
return type;
}
};
} // namespace c10::impl

View File

@ -0,0 +1,164 @@
#pragma once
#include <c10/core/DispatchKeySet.h>
#include <c10/macros/Export.h>
// TLS management for DispatchKeySet (the "local" DispatchKeySet(s))
//
// This manages two thread-local DispatchKeySets:
//
// - The included type set, which adds a tensor type for consideration
// in dispatch. (For example, you might add Profiling to
// the included type set to turn on profiling on all tensor operations.)
//
// - The excluded type set, which disqualifies a tensor type from dispatch.
// (For example, after redispatching on variable, we disqualify
// Autograd so we don't attempt to handle variable again.)
// (Exclusion wins over inclusion.)
//
// NB: Originally, I implemented the excluded type set as storing the inverted
// set, but TLS is defined to be zero-initialized, so this doesn't actually work
// (if it's inverted, you want the set to be -1 initialized).
namespace c10::impl {
// POD version of LocalDispatchKeySet. Declared here just so that
// we can put it in the guards.
// This struct encapsulates special handling for TLS initialization
// in set_included()/included() API so that they reflect the truth.
// If you want to create PODLocalDispatchKeySet with non-zero state,
// use set_included() instead of default constructor.
struct C10_API PODLocalDispatchKeySet {
uint64_t included_;
uint64_t excluded_;
// See Note [TLS Initialization]
DispatchKeySet included() const {
return DispatchKeySet(DispatchKeySet::RAW, included_) ^
c10::default_included_set;
}
DispatchKeySet excluded() const {
return DispatchKeySet(DispatchKeySet::RAW, excluded_) ^
c10::default_excluded_set;
}
void set_included(DispatchKeySet x) {
included_ = (x ^ c10::default_included_set).raw_repr();
}
void set_excluded(DispatchKeySet x) {
excluded_ = (x ^ c10::default_excluded_set).raw_repr();
}
};
static_assert(
std::is_trivial_v<PODLocalDispatchKeySet>,
"PODLocalDispatchKeySet must be a POD type.");
struct C10_API LocalDispatchKeySet {
/* implicit */ LocalDispatchKeySet(PODLocalDispatchKeySet x)
: included_(x.included()), excluded_(x.excluded()) {}
DispatchKeySet included_;
DispatchKeySet excluded_;
};
// thread_local variables cannot be C10_API on Windows.
// Inlining this seems to break AutoDispatchBelowAutograd on Android.
#if defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
C10_API LocalDispatchKeySet tls_local_dispatch_key_set();
#else // defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
extern C10_API thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
inline C10_API LocalDispatchKeySet tls_local_dispatch_key_set() {
// Don't let people fiddle with the thread_local directly just
// because they include this header.
return raw_local_dispatch_key_set;
}
#endif // defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
// Internal, use ThreadLocalStateGuard
C10_API void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set);
// RAII API for manipulating the thread-local dispatch state.
class C10_API IncludeDispatchKeyGuard {
public:
IncludeDispatchKeyGuard(DispatchKeySet);
IncludeDispatchKeyGuard(DispatchKey k)
: IncludeDispatchKeyGuard(DispatchKeySet(k)) {}
IncludeDispatchKeyGuard(const IncludeDispatchKeyGuard&) = delete;
IncludeDispatchKeyGuard operator=(const IncludeDispatchKeyGuard&) = delete;
IncludeDispatchKeyGuard(IncludeDispatchKeyGuard&&) = delete;
IncludeDispatchKeyGuard operator=(IncludeDispatchKeyGuard&&) = delete;
~IncludeDispatchKeyGuard();
private:
// A little micro-optimization to save us from tls_get_addr call
// on destruction
PODLocalDispatchKeySet* tls_;
DispatchKeySet include_;
};
class C10_API ExcludeDispatchKeyGuard {
public:
ExcludeDispatchKeyGuard(DispatchKeySet);
ExcludeDispatchKeyGuard(DispatchKey k)
: ExcludeDispatchKeyGuard(DispatchKeySet(k)) {}
ExcludeDispatchKeyGuard(const ExcludeDispatchKeyGuard&) = delete;
ExcludeDispatchKeyGuard operator=(const ExcludeDispatchKeyGuard&) = delete;
ExcludeDispatchKeyGuard(ExcludeDispatchKeyGuard&&) = delete;
ExcludeDispatchKeyGuard operator=(ExcludeDispatchKeyGuard&&) = delete;
~ExcludeDispatchKeyGuard();
private:
// A little micro-optimization to save us from tls_get_addr call
// on destruction
PODLocalDispatchKeySet* tls_;
DispatchKeySet exclude_;
};
struct C10_API ForceDispatchKeyGuard {
public:
ForceDispatchKeyGuard()
: saved_keyset_(c10::impl::tls_local_dispatch_key_set()) {}
ForceDispatchKeyGuard(c10::impl::LocalDispatchKeySet key_set)
: ForceDispatchKeyGuard() {
c10::impl::_force_tls_local_dispatch_key_set(key_set);
}
ForceDispatchKeyGuard(
c10::DispatchKeySet include,
c10::DispatchKeySet exclude)
: ForceDispatchKeyGuard() {
auto updated_set = saved_keyset_;
updated_set.included_ = include;
updated_set.excluded_ = exclude;
c10::impl::_force_tls_local_dispatch_key_set(updated_set);
}
~ForceDispatchKeyGuard() {
c10::impl::_force_tls_local_dispatch_key_set(saved_keyset_);
}
private:
c10::impl::LocalDispatchKeySet saved_keyset_;
};
// Non-RAII API for manipulating the thread-local dispatch state.
// Please prefer the RAII API. The non-RAII API may be useful when
// the included/excluded state of a given DispatchKey must span
// many calls from the Python to the C++, so you cannot conveniently
// use an RAII guard.
//
// Example use case: a Python context manager that includes a certain
// DispatchKey, to ensure ops running under the context manager dispatch
// through that DispatchKey's registered overrides.
//
// The non-RAII API is less efficient than the RAII guards because both the
// getter and setter will do a tls_getaddr lookup (the RAII struct only needs
// one!)
C10_API bool tls_is_dispatch_key_excluded(DispatchKey x);
C10_API void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state);
C10_API bool tls_is_dispatch_key_included(DispatchKey x);
C10_API void tls_set_dispatch_key_included(DispatchKey x, bool desired_state);
C10_API bool tls_is_dispatch_keyset_excluded(DispatchKeySet ks);
C10_API bool tls_is_dispatch_keyset_included(DispatchKeySet ks);
} // namespace c10::impl

View File

@ -0,0 +1,263 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/core/Layout.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/macros/Export.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/intrusive_ptr.h>
#include <c10/util/python_stub.h>
#include <string>
#include <vector>
// Forward declarations
namespace c10 {
struct IValue;
class OperatorHandle;
struct TensorImpl;
} // namespace c10
namespace torch::jit {
using Stack = std::vector<c10::IValue>;
}
// Actual implementation
namespace c10::impl {
struct C10_API PyInterpreter;
// Note [Python interpreter tag]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Traditionally, PyTorch is layered such that our Python library
// (libtorch_python) references our pure C++ library (libtorch) as the
// natural order of things. However, sometimes this natural order is
// subverted: C++ objects refer to Python objects (for example, we
// store a PyObject* pointer on TensorImpl so that converting from a
// C++ Tensor to a Python Tensor is just a memory dereference).
//
// These unusual orderings must be treated with care. To start, you need to
// virtualize the destructor so that the PyObject can be decref'ed on
// destruction (because the C++ object itself doesn't know anything about
// Python--remember, layering!). This process itself is fraught, since
// acquiring the GIL could lead to deadlocks if someone is blocking on you
// while holding the GIL. Furthermore, if the C++ objects outlive the
// interpreter (which can happen if you stash them in a static global
// variable defined in libtorch), you may attempt to decref the object when
// the Python interpreter has already been shutdown.
//
// BUT WAIT, IT GETS WORSE. With torchdeploy, there may be multiple Python
// interpreters in a single process. If a C++ object is accessible from
// multiple interpreters, we must take care not to accidentally pass a
// PyObject from one interpreter with another interpreter.
//
// To prevent these mixups, we introduce a PyInterpreter "tag" (object with
// a vtable), which specifies a specific Python interpreter.
//
// - Any given object can be associated with AT MOST one Python interpreter.
// We represent the interpreter tag as a memory address to an instance of
// a virtual class that is allocated once per interpreter (this is so that
// we can request the interpreter to perform operations for us, if
// necessary).
//
// - It can be recorded with a PyObject (PyInterpreterObject) so that
// we know what interpreter the object is associated with, and we can
// raise an error if you try to use the PyObject from the wrong
// interpreter context.
//
// - It contains a vtable that can be used to perform various Python
// operations from ordinary C++ code that ordinarily wouldn't be accessible
// from libtorch.
//
// A simple use case is when a C++ object must be associated with a PyObject.
// However, for TensorImpl, we lazily allocate a PyObject the first time the
// object passes into Python. The invariants for this situation are more
// subtle:
//
// - A given TensorImpl's interpreter tag can only go from uninitialized to
// tagged; once tagged, this is a quiescent state (once tagged to an
// interpreter, ALWAYS tagged to that interpreter)
//
// - A thread may mutate the PyObject field of a TensorImpl if and only if it
// holds the GIL for the interpreter tagged on the TensorImpl. (If the
// TensorImpl is not tagged, it must first atomically claim its tag before it
// can validly write)
//
// WARNING: This class has to be written very carefully, because it may be
// possible for a Tensor to have a reference an interpreter corresponding to
// a shared library that has ALREADY BEEN UNLOADED. This makes blindly calling
// virtual methods very dangerous, because the vtable may be garbage at that
// point (on a good day, you might get "pure virtual method called").
//
// The idea to solve this problem is we always leak PyInterpreters (so they
// always stay live even after dlclose), and make sure we can disarm their
// virtual methods by indirecting through a separate PyInterpreterVTable
// object. This can be replaced with a no-op vtable from libc10.so, which
// is guaranteed to stick around until the bitter end.
//
// NB: The downside with representing PyInterpreter tags as full objects is that
// it takes an extra word on TensorImpl. If tags were instead just integer
// indices, on 64-bit architectures we could pack the tag and PyObject together
// into a single atomic word. On 32-bit architectures we could simply say that
// only one Python interpreter is supported (erroring if a nontrivial
// interpreter tag is attempted to be set).
//
// The difficulty with this scheme is we need to maintain an out-of-line table
// to get at the PyInterpreters so that we can do virtual method calls on them,
// and registration/deregistration to this table must be done in a thread safe
// manner. This can be easily done if the number of possible PyInterpreters is
// small enough (e.g., 8-bit integer) by simply preallocating an array of
// sufficient size to hold all possible interpreters. Surely 128 threads is
// more than enough for anyone!
//
// I didn't decide to do this technique at the moment, because the extra word
// added by the PyInterpreter tag takes us to 24 words, which means that we
// still fit inside three eight word cache lines. If you need to penny pinch
// another word consider doing this!
struct C10_API PyInterpreterVTable {
virtual ~PyInterpreterVTable() = default;
// Report the name of this interpreter
virtual std::string name() const = 0;
// Run Py_INCREF on a PyObject.
virtual void incref(PyObject* pyobj) const = 0;
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
// See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
// Perform a detach by deferring to the __torch_dispatch__ implementation of
// detach, which will also arrange for the PyObject to get copied in this
// situation
virtual c10::intrusive_ptr<TensorImpl> detach(
const TensorImpl* self) const = 0;
// Invoke the Python boxed fallback dispatch to go back into Python
virtual void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
const = 0;
virtual void reportErrorCallback(PyObject* callback, DispatchKey key)
const = 0;
// This is only invoked in the multipy/torchdeploy situation from
// pythonOpRegistrationTrampoline; this lets us get to the Python
// interpreter to actually find the appropriate Python op registration
// entry to call.
virtual void python_op_registration_trampoline(
const c10::OperatorHandle& op,
c10::DispatchKey,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack,
bool with_keyset,
bool with_op) const = 0;
virtual void throw_abstract_impl_not_imported_error(
std::string opname,
const char* pymodule,
const char* context) const = 0;
// Invoke the Python dispatcher to handle this call
virtual void python_dispatcher(
const c10::OperatorHandle& op,
c10::DispatchKeySet,
torch::jit::Stack* stack) const = 0;
virtual bool is_contiguous(const TensorImpl* self, at::MemoryFormat)
const = 0;
virtual bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
const = 0;
virtual bool is_non_overlapping_and_dense(const TensorImpl* self) const = 0;
virtual c10::Device device(const TensorImpl* self) const = 0;
virtual int64_t dim(const TensorImpl* self) const = 0;
virtual c10::IntArrayRef strides(const TensorImpl* self) const = 0;
virtual c10::IntArrayRef sizes(const TensorImpl* self) const = 0;
virtual c10::SymIntArrayRef sym_sizes(const TensorImpl* self) const = 0;
virtual c10::Layout layout(const TensorImpl* self) const = 0;
virtual int64_t numel(const TensorImpl* self) const = 0;
virtual c10::SymInt sym_numel(const TensorImpl* self) const = 0;
virtual c10::SymIntArrayRef sym_strides(const TensorImpl* self) const = 0;
virtual c10::SymInt sym_storage_offset(const TensorImpl* self) const = 0;
virtual void trace_gpu_event_creation(
c10::DeviceType device_type,
uintptr_t event) const = 0;
virtual void trace_gpu_event_deletion(
c10::DeviceType device_type,
uintptr_t event) const = 0;
virtual void trace_gpu_event_record(
c10::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const = 0;
virtual void trace_gpu_event_wait(
c10::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const = 0;
virtual void trace_gpu_memory_allocation(
c10::DeviceType device_type,
uintptr_t ptr) const = 0;
virtual void trace_gpu_memory_deallocation(
c10::DeviceType device_type,
uintptr_t ptr) const = 0;
virtual void trace_gpu_stream_creation(
c10::DeviceType device_type,
uintptr_t stream) const = 0;
virtual void trace_gpu_device_synchronization(
c10::DeviceType device_type) const = 0;
virtual void trace_gpu_stream_synchronization(
c10::DeviceType device_type,
uintptr_t stream) const = 0;
virtual void trace_gpu_event_synchronization(
c10::DeviceType device_type,
uintptr_t event) const = 0;
virtual void reset_backward_hooks(const TensorImpl* self) const = 0;
};
struct C10_API PyInterpreter {
const PyInterpreterVTable* vtable_;
PyInterpreter(const PyInterpreterVTable* vtable) : vtable_(vtable){};
const PyInterpreterVTable& operator*() const noexcept {
return *vtable_;
}
const PyInterpreterVTable* operator->() const noexcept {
return vtable_;
}
// Disarm this PyInterpreter, making all of its methods noops.
// The vtable pointer is not an atomic at the moment, which means
// a disarm() invocation that is concurrent with active destructors
// is not thread safe and will trigger TSAN. My hope is that this
// situations doesn't ever actually happen; tensor destruction should
// quiesce when a dlclose happens, and any long lived tensors whose
// destructors would be disarmed here only begin the destruction process
// on process shutdown (long after the dlclose has occurred).
void disarm() noexcept;
};
// PyInterpreterStatus describes what the state of its interpreter tag
// is, relative to the thread currently holding the GIL.
enum class PyInterpreterStatus {
// We just allocated the Tensor, it hasn't escaped to other threads,
// we know that it definitely hasn't been tagged to be associated
// with an interpreter.
DEFINITELY_UNINITIALIZED,
// We queried the interpreter field and it looked uninitialized. But
// another thread may have raced with us to tag it with some other
// interpreter id. So we will have to do a CEX to make sure we can
// actually nab it.
MAYBE_UNINITIALIZED,
// We queried the interpreter field and it was tagged to belong to us.
// This means we have sole write access (as we hold the GIL for this
// interpreter)
TAGGED_BY_US,
// Someone else tagged this. We can't use this TensorImpl from Python.
TAGGED_BY_OTHER,
};
} // namespace c10::impl

View File

@ -0,0 +1,190 @@
#pragma once
#include <c10/core/impl/HermeticPyObjectTLS.h>
#include <c10/core/impl/PyInterpreter.h>
#include <c10/util/python_stub.h>
#include <optional>
#include <atomic>
namespace c10::impl {
struct C10_API PyObjectSlot {
public:
PyObjectSlot();
~PyObjectSlot();
void maybe_destroy_pyobj();
// Associate the TensorImpl with the specified PyObject, and, if necessary,
// also tag the interpreter.
//
// NB: This lives in a header so that we can inline away the switch on status
//
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
// PyObject if necessary!
void init_pyobj(
PyInterpreter* self_interpreter,
PyObject* pyobj,
PyInterpreterStatus status) {
impl::PyInterpreter* expected = nullptr;
switch (status) {
case impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED:
// caller guarantees there is no multithreaded access; if there is
// no data race OK to do a relaxed store
pyobj_interpreter_.store(self_interpreter, std::memory_order_relaxed);
break;
case impl::PyInterpreterStatus::TAGGED_BY_US:
// no tagging is necessary, the tag is already correct
break;
case impl::PyInterpreterStatus::MAYBE_UNINITIALIZED:
// attempt to claim this TensorImpl with the specified interpreter
// tag
if (pyobj_interpreter_.compare_exchange_strong(
expected, self_interpreter, std::memory_order_acq_rel)) {
break;
}
// test if, actually, it was already tagged by us! this situation can't
// be caused by a race, but it could be caused by a situation
// where someone conservatively tagged the tensor as MAYBE_UNINITIALIZED
// (because they didn't pre-check the tag) when actually it was
// owned by the interpreter
if (expected == self_interpreter) {
break;
}
// fallthrough, we lost the race. We are guaranteed not to lose the
// race with ourself, as calls to init_pyobj with the same interpreter
// ID must be sequentialized by the GIL
[[fallthrough]];
case impl::PyInterpreterStatus::TAGGED_BY_OTHER:
TORCH_CHECK(
false,
"cannot allocate PyObject for Tensor on interpreter ",
self_interpreter,
" that has already been used by another torch deploy interpreter ",
pyobj_interpreter_.load());
}
// we are the ONLY thread that can have gotten to this point. It is not
// possible to conflict with another zero interpreter as access is protected
// by GIL
// NB: owns_pyobj tag is initially false
pyobj_ = pyobj;
}
// Query the PyObject interpreter. This may return null if there is no
// interpreter. This is racy!
PyInterpreter* pyobj_interpreter();
PyObject* _unchecked_untagged_pyobj() const;
// Test the interpreter tag. If tagged for the current interpreter, return
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
// returns a nullopt. If it is definitely invalid, raises an error.
//
// If `ignore_hermetic_tls` is false and this function is called from a
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
// context is ignored, allowing you to check the interpreter tag of a
// nonhermetic PyObject from within a hermetic context. This is necessary
// because there are some cases where the deallocator function of a
// nonhermetic PyObject is called from within a hermetic context, so it must
// be properly treated as a nonhermetic PyObject.
//
// NB: this lives in header so that we can avoid actually creating the
// std::optional
std::optional<PyObject*> check_pyobj(
PyInterpreter* self_interpreter,
bool ignore_hermetic_tls = false) const {
// Note [Memory ordering on Python interpreter tag]
impl::PyInterpreter* interpreter =
pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter == nullptr) {
// NB: This never returns DEFINITELY_UNINITIALIZED because there is
// always the possibility that another thread races to initialize
// after we query here. The only time when we can conclude a tensor
// is definitely uninitialized is when we have just allocated it and
// it cannot have escaped to other threads yet
return std::nullopt;
} else if (interpreter == self_interpreter) {
// NB: pyobj_ could still be null!
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
return std::nullopt;
} else {
return std::make_optional(_unchecked_untagged_pyobj());
}
} else {
TORCH_CHECK(
false,
"cannot access PyObject for Tensor on interpreter ",
(*self_interpreter)->name(),
" that has already been used by another torch deploy interpreter ",
(*pyobj_interpreter_.load())->name());
}
}
// Clear the PyObject field for an interpreter, in situations where we
// statically know the tensor is tagged with our interpreter.
void unchecked_clear_pyobj(PyInterpreter* interpreter);
PyInterpreter& load_pyobj_interpreter() const;
// Check if the PyObjectSlot's interpreter is the same as the specified
// interpreter
bool check_interpreter(PyInterpreter* interpreter);
// Check if the PyObjectSlot is holding a PyObject, owned or non-owned
bool has_pyobj_nonhermetic();
bool owns_pyobj();
void set_owns_pyobj(bool b);
private:
// This field contains the interpreter tag for this object. See
// Note [Python interpreter tag] for general context
//
// Note [Memory ordering on Python interpreter tag]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// What memory_order do we need when accessing this atomic? We don't
// need a single total modification order (as provided by
// memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
// transition from -1 to some positive integer and never changes afterwards.
// Because there is only one modification, it trivially already has a total
// modification order (e.g., we don't need fences or locked instructions on
// x86)
//
// In fact, one could make a reasonable argument that relaxed reads are OK,
// due to the presence of external locking (GIL) to ensure that interactions
// with other data structures are still correctly synchronized, so that
// we fall in the "Single-Location Data Structures" case as described in
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
// However, on x86, it doesn't matter if I use acquire or relaxed on the load
// as I get the same assembly in both cases. So I just use the more
// conservative acquire (which will impede compiler optimizations but I don't
// care)
std::atomic<PyInterpreter*> pyobj_interpreter_;
// This field contains a reference to a PyObject representing this Tensor.
// If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
// PyObject for it and set this field. This field does not have to be
// protected by an atomic as it is only allowed to be accessed when you hold
// the GIL, or during destruction of the tensor.
//
// When a PyObject dies, you are obligated to clear this field
// (otherwise, you will try to use-after-free the pyobj); this currently
// occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp
//
// NB: Ordinarily, this should not be a strong reference, as if the
// PyObject owns the Tensor, this would create a reference cycle.
// However, sometimes this ownership flips. To track who owns
// who, this has a single pointer tag indicating whether or not the
// C++ object owns the PyObject (the common case, zero, means PyObject
// owns the C++ object); see _unchecked_untagged_pyobj for raw access
// or check_pyobj for checked access. See references to PyObject
// resurrection in torch/csrc/autograd/python_variable.cpp
PyObject* pyobj_;
};
} // namespace c10::impl

View File

@ -0,0 +1,24 @@
#pragma once
#include <c10/core/impl/PyInterpreter.h>
#include <c10/macros/Export.h>
namespace c10::impl {
struct C10_API PythonDispatcherTLS {
static void set_state(PyInterpreter* state);
static PyInterpreter* get_state();
static void reset_state();
};
struct C10_API DisablePythonDispatcher {
DisablePythonDispatcher() : old_(PythonDispatcherTLS::get_state()) {
PythonDispatcherTLS::set_state({});
}
~DisablePythonDispatcher() {
PythonDispatcherTLS::set_state(old_);
}
PyInterpreter* old_;
};
} // namespace c10::impl

View File

@ -0,0 +1,315 @@
#pragma once
#include <algorithm>
#include <cstdint>
#include <c10/macros/Macros.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/SmallVector.h>
#define C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE 5
namespace c10::impl {
// Packed container for TensorImpl sizes and strides.
// This design improves on the previous approach of using a pair of
// c10::SmallVector<int64_t, 5> by specializing for the operations we
// actually use and enforcing that the number of sizes is the same as
// the number of strides. The memory layout is as follows:
//
// 1 size_t for the size
// 5 eightbytes of inline sizes and 5 eightbytes of inline strides, OR pointer
// to out-of-line array
class C10_API SizesAndStrides {
public:
// TODO: different iterator types for sizes & strides to prevent
// mixing the two accidentally.
using sizes_iterator = int64_t*;
using sizes_const_iterator = const int64_t*;
using strides_iterator = int64_t*;
using strides_const_iterator = const int64_t*;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
SizesAndStrides() {
size_at_unchecked(0) = 0;
stride_at_unchecked(0) = 1;
}
~SizesAndStrides() {
if (C10_UNLIKELY(!isInline())) {
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
free(outOfLineStorage_);
}
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
SizesAndStrides(const SizesAndStrides& rhs) : size_(rhs.size_) {
if (C10_LIKELY(rhs.isInline())) {
copyDataInline(rhs);
} else {
allocateOutOfLineStorage(size_);
copyDataOutline(rhs);
}
}
SizesAndStrides& operator=(const SizesAndStrides& rhs) {
if (this == &rhs) {
return *this;
}
if (C10_LIKELY(rhs.isInline())) {
if (C10_UNLIKELY(!isInline())) {
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
free(outOfLineStorage_);
}
copyDataInline(rhs);
} else {
if (isInline()) {
allocateOutOfLineStorage(rhs.size_);
} else {
resizeOutOfLineStorage(rhs.size_);
}
copyDataOutline(rhs);
}
size_ = rhs.size_;
return *this;
}
// Move from rhs. rhs.size() == 0 afterwards.
SizesAndStrides(SizesAndStrides&& rhs) noexcept : size_(rhs.size_) {
if (C10_LIKELY(isInline())) {
memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_));
} else {
outOfLineStorage_ = rhs.outOfLineStorage_;
rhs.outOfLineStorage_ = nullptr;
}
rhs.size_ = 0;
}
// Move from rhs. rhs.size() == 0 afterwards.
SizesAndStrides& operator=(SizesAndStrides&& rhs) noexcept {
if (this == &rhs) {
return *this;
}
if (C10_LIKELY(rhs.isInline())) {
if (C10_UNLIKELY(!isInline())) {
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
free(outOfLineStorage_);
}
copyDataInline(rhs);
} else {
// They're outline. We're going to steal their vector.
if (!isInline()) {
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
free(outOfLineStorage_);
}
outOfLineStorage_ = rhs.outOfLineStorage_;
rhs.outOfLineStorage_ = nullptr;
}
size_ = rhs.size_;
rhs.size_ = 0;
return *this;
}
size_t size() const noexcept {
return size_;
}
const int64_t* sizes_data() const noexcept {
if (C10_LIKELY(isInline())) {
return &inlineStorage_[0];
} else {
return &outOfLineStorage_[0];
}
}
int64_t* sizes_data() noexcept {
if (C10_LIKELY(isInline())) {
return &inlineStorage_[0];
} else {
return &outOfLineStorage_[0];
}
}
sizes_const_iterator sizes_begin() const noexcept {
return sizes_data();
}
sizes_iterator sizes_begin() noexcept {
return sizes_data();
}
sizes_const_iterator sizes_end() const noexcept {
return sizes_begin() + size();
}
sizes_iterator sizes_end() noexcept {
return sizes_begin() + size();
}
IntArrayRef sizes_arrayref() const noexcept {
return IntArrayRef{sizes_data(), size()};
}
void set_sizes(IntArrayRef newSizes) {
resize(newSizes.size());
std::copy(newSizes.begin(), newSizes.end(), sizes_begin());
}
void set_strides(IntArrayRef strides) {
TORCH_INTERNAL_ASSERT(strides.size() == size());
std::copy(strides.begin(), strides.end(), strides_begin());
}
const int64_t* strides_data() const noexcept {
if (C10_LIKELY(isInline())) {
return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE];
} else {
return &outOfLineStorage_[size()];
}
}
int64_t* strides_data() noexcept {
if (C10_LIKELY(isInline())) {
return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE];
} else {
return &outOfLineStorage_[size()];
}
}
strides_const_iterator strides_begin() const noexcept {
if (C10_LIKELY(isInline())) {
return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE];
} else {
return &outOfLineStorage_[size()];
}
}
strides_iterator strides_begin() noexcept {
if (C10_LIKELY(isInline())) {
return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE];
} else {
return &outOfLineStorage_[size()];
}
}
strides_const_iterator strides_end() const noexcept {
return strides_begin() + size();
}
strides_iterator strides_end() noexcept {
return strides_begin() + size();
}
IntArrayRef strides_arrayref() const noexcept {
return IntArrayRef{strides_data(), size()};
}
// Size accessors.
int64_t size_at(size_t idx) const noexcept {
assert(idx < size());
return sizes_data()[idx];
}
int64_t& size_at(size_t idx) noexcept {
assert(idx < size());
return sizes_data()[idx];
}
int64_t size_at_unchecked(size_t idx) const noexcept {
return sizes_data()[idx];
}
int64_t& size_at_unchecked(size_t idx) noexcept {
return sizes_data()[idx];
}
// Size accessors.
int64_t stride_at(size_t idx) const noexcept {
assert(idx < size());
return strides_data()[idx];
}
int64_t& stride_at(size_t idx) noexcept {
assert(idx < size());
return strides_data()[idx];
}
int64_t stride_at_unchecked(size_t idx) const noexcept {
return strides_data()[idx];
}
int64_t& stride_at_unchecked(size_t idx) noexcept {
return strides_data()[idx];
}
void resize(size_t newSize) {
const auto oldSize = size();
if (newSize == oldSize) {
return;
}
if (C10_LIKELY(
newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE && isInline())) {
if (oldSize < newSize) {
const auto bytesToZero =
(newSize - oldSize) * sizeof(inlineStorage_[0]);
memset(&inlineStorage_[oldSize], 0, bytesToZero);
memset(
&inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize],
0,
bytesToZero);
}
size_ = newSize;
} else {
resizeSlowPath(newSize, oldSize);
}
}
void resizeSlowPath(size_t newSize, size_t oldSize);
private:
bool isInline() const noexcept {
return size_ <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE;
}
void copyDataInline(const SizesAndStrides& rhs) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.isInline());
memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_));
}
static size_t storageBytes(size_t size) noexcept {
return size * 2 * sizeof(int64_t);
}
void allocateOutOfLineStorage(size_t size) {
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
outOfLineStorage_ = static_cast<int64_t*>(malloc(storageBytes(size)));
TORCH_CHECK(
outOfLineStorage_,
"Could not allocate memory for Tensor SizesAndStrides!");
}
void resizeOutOfLineStorage(size_t newSize) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline());
outOfLineStorage_ = static_cast<int64_t*>(
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
realloc(outOfLineStorage_, storageBytes(newSize)));
TORCH_CHECK(
outOfLineStorage_,
"Could not allocate memory for Tensor SizesAndStrides!");
}
void copyDataOutline(const SizesAndStrides& rhs) noexcept {
memcpy(outOfLineStorage_, rhs.outOfLineStorage_, storageBytes(rhs.size_));
}
size_t size_{1};
union {
int64_t* outOfLineStorage_;
// NOLINTNEXTLINE(*c-array*)
int64_t inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * 2]{};
};
};
} // namespace c10::impl

View File

@ -0,0 +1,67 @@
#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/macros/Export.h>
namespace c10::impl {
enum class TorchDispatchModeKey : int8_t {
FAKE,
PROXY,
FUNCTIONAL,
NUM_MODE_KEYS
};
using PyObject_TorchDispatchMode = SafePyObjectT<TorchDispatchModeKey>;
struct C10_API TorchDispatchModeTLS {
// This API is NOT invariant safe.
// It must not take in an infra mode that uses TorchDispatchModeKey
// If you're pushing an infra mode onto the stack, we expect
// you to use set_mode
static void push_non_infra_mode_onto_stack(
std::shared_ptr<PyObject_TorchDispatchMode> mode);
// Pops the top mode of the stack,
// giving precedence to user modes before attempting to pop
// any infra modes
static const std::shared_ptr<PyObject_TorchDispatchMode> pop_stack();
// Returns the highest-priority infra mode on the stack,
// along with its mode key.
static const std::
tuple<std::shared_ptr<PyObject_TorchDispatchMode>, TorchDispatchModeKey>
pop_highest_infra_mode();
static const std::shared_ptr<PyObject_TorchDispatchMode>& get_stack_at(
int64_t idx);
static int64_t stack_len();
static const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
get_mode(TorchDispatchModeKey mode_key);
static const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
unset_mode(TorchDispatchModeKey mode_key);
static void set_mode(
const std::shared_ptr<PyObject_TorchDispatchMode>& mode,
TorchDispatchModeKey mode_key);
static const TorchDispatchModeTLS& get_state();
static void set_state(TorchDispatchModeTLS state);
static bool any_modes_set(bool skip_infra_modes = false);
private:
std::vector<std::shared_ptr<PyObject_TorchDispatchMode>> stack_;
// Users are allowed to push multiple ProxyTorchDispatchMode objects onto the
// stack
// However, we only allow a single FakeTensorMode onto the stack at a time
// (Pushing additional FakeTensorModes onto the stack is a no-op)
std::array<
std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>,
static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS)>
infra_modes_;
};
C10_API bool dispatch_mode_enabled();
C10_API std::string to_string(TorchDispatchModeKey mode_key);
} // namespace c10::impl

View File

@ -0,0 +1,103 @@
#pragma once
#include <c10/core/impl/DeviceGuardImplInterface.h>
namespace c10::impl {
/**
* An implementation of DeviceGuardImplInterface which delegates
* to virtual dispatch on the DeviceGuardImpl registry.
*/
class VirtualGuardImpl final : public DeviceGuardImplInterface {
public:
VirtualGuardImpl(DeviceType device_type)
: impl_(getDeviceGuardImpl(device_type)) {}
// This constructor exists purely for testing
VirtualGuardImpl(const DeviceGuardImplInterface* impl) : impl_(impl) {}
// Copying and moving is OK!
VirtualGuardImpl(const VirtualGuardImpl&) = default;
VirtualGuardImpl& operator=(const VirtualGuardImpl&) = default;
VirtualGuardImpl(VirtualGuardImpl&&) noexcept = default;
VirtualGuardImpl& operator=(VirtualGuardImpl&&) noexcept = default;
DeviceType type() const override {
return impl_->type();
}
Device exchangeDevice(Device d) const override {
return impl_->exchangeDevice(d);
}
Device getDevice() const override {
return impl_->getDevice();
}
void setDevice(Device d) const override {
impl_->setDevice(d);
}
void uncheckedSetDevice(Device d) const noexcept override {
impl_->uncheckedSetDevice(d);
}
Stream getStream(Device d) const noexcept override {
return impl_->getStream(d);
}
Stream getNewStream(Device d, int priority = 0) const override {
return impl_->getNewStream(d, priority);
}
Stream getDefaultStream(Device d) const override {
return impl_->getDefaultStream(d);
}
Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
const override {
return impl_->getStreamFromGlobalPool(d, isHighPriority);
}
Stream exchangeStream(Stream s) const noexcept override {
return impl_->exchangeStream(s);
}
DeviceIndex deviceCount() const noexcept override {
return impl_->deviceCount();
}
// Event functions
void record(
void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
impl_->record(event, stream, device_index, flag);
}
void block(void* event, const Stream& stream) const override {
impl_->block(event, stream);
}
bool queryEvent(void* event) const override {
return impl_->queryEvent(event);
}
void destroyEvent(void* event, const DeviceIndex device_index)
const noexcept override {
impl_->destroyEvent(event, device_index);
}
bool queryStream(const Stream& stream) const override {
return impl_->queryStream(stream);
}
void synchronizeStream(const Stream& stream) const override {
impl_->synchronizeStream(stream);
}
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
const override {
impl_->recordDataPtrOnStream(data_ptr, stream);
}
double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
const override {
return impl_->elapsedTime(event1, event2, device_index);
}
void synchronizeEvent(void* event) const override {
return impl_->synchronizeEvent(event);
}
private:
const DeviceGuardImplInterface* impl_ = nullptr;
};
} // namespace c10::impl

View File

@ -0,0 +1,12 @@
#pragma once
#include <c10/macros/Export.h>
#include <cstddef>
namespace c10 {
C10_API void* alloc_cpu(size_t nbytes);
C10_API void free_cpu(void* data);
} // namespace c10

View File

@ -0,0 +1,120 @@
#pragma once
#include <atomic>
#include <condition_variable>
#include <cstddef>
#include <functional>
#include <mutex>
#include <queue>
#include <thread>
#include <utility>
#include <vector>
#include <c10/macros/Export.h>
#include <c10/util/Registry.h>
#include <c10/util/numa.h>
#include <c10/util/thread_name.h>
namespace c10 {
class C10_API TaskThreadPoolBase {
public:
virtual void run(std::function<void()> func) = 0;
virtual size_t size() const = 0;
/**
* The number of available (i.e. idle) threads in this thread pool.
*/
virtual size_t numAvailable() const = 0;
/**
* Check if the current thread is from the thread pool.
*/
virtual bool inThreadPool() const = 0;
virtual ~TaskThreadPoolBase() noexcept = default;
static size_t defaultNumThreads();
};
class C10_API ThreadPool : public c10::TaskThreadPoolBase {
protected:
struct task_element_t {
bool run_with_id;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::function<void()> no_id;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::function<void(std::size_t)> with_id;
explicit task_element_t(std::function<void()> f)
: run_with_id(false), no_id(std::move(f)), with_id(nullptr) {}
explicit task_element_t(std::function<void(std::size_t)> f)
: run_with_id(true), no_id(nullptr), with_id(std::move(f)) {}
};
std::queue<task_element_t> tasks_;
std::vector<std::thread> threads_;
mutable std::mutex mutex_;
std::condition_variable condition_;
std::condition_variable completed_;
std::atomic_bool running_;
bool complete_;
std::size_t available_;
std::size_t total_;
int numa_node_id_;
public:
ThreadPool() = delete;
explicit ThreadPool(
int pool_size,
int numa_node_id = -1,
const std::function<void()>& init_thread = nullptr);
~ThreadPool() override;
size_t size() const override;
size_t numAvailable() const override;
bool inThreadPool() const override;
void run(std::function<void()> func) override;
template <typename Task>
void runTaskWithID(Task task) {
std::unique_lock<std::mutex> lock(mutex_);
// Set task and signal condition variable so that a worker thread will
// wake up and use the task.
tasks_.emplace(static_cast<std::function<void(std::size_t)>>(task));
complete_ = false;
condition_.notify_one();
}
/// @brief Wait for queue to be empty
void waitWorkComplete();
private:
// @brief Entry point for pool threads.
void main_loop(std::size_t index);
};
class C10_API TaskThreadPool : public c10::ThreadPool {
public:
explicit TaskThreadPool(int pool_size, int numa_node_id = -1)
: ThreadPool(pool_size, numa_node_id, [numa_node_id]() {
setThreadName("CaffeTaskThread");
NUMABind(numa_node_id);
}) {}
};
C10_DECLARE_SHARED_REGISTRY(
ThreadPoolRegistry,
TaskThreadPoolBase,
int,
int,
bool);
} // namespace c10

View File

@ -0,0 +1,31 @@
#ifdef THRUST_DEVICE_LOWER_BOUND_WORKS
#include <thrust/binary_search.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/functional.h>
#endif
namespace c10::cuda {
#ifdef THRUST_DEVICE_LOWER_BOUND_WORKS
template <typename Iter, typename Scalar>
__forceinline__ __device__ Iter
lower_bound(Iter start, Iter end, Scalar value) {
return thrust::lower_bound(thrust::device, start, end, value);
}
#else
// thrust::lower_bound is broken on device, see
// https://github.com/NVIDIA/thrust/issues/1734 Implementation inspired by
// https://github.com/pytorch/pytorch/blob/805120ab572efef66425c9f595d9c6c464383336/aten/src/ATen/native/cuda/Bucketization.cu#L28
template <typename Iter, typename Scalar>
__device__ Iter lower_bound(Iter start, Iter end, Scalar value) {
while (start < end) {
auto mid = start + ((end - start) >> 1);
if (*mid < value) {
start = mid + 1;
} else {
end = mid;
}
}
return end;
}
#endif // THRUST_DEVICE_LOWER_BOUND_WORKS
} // namespace c10::cuda

View File

@ -0,0 +1,124 @@
#pragma once
#include <c10/cuda/CUDAMacros.h>
#include <c10/util/Exception.h>
#include <atomic>
#include <cstddef>
#include <cstdlib>
#include <mutex>
#include <string>
#include <vector>
namespace c10::cuda::CUDACachingAllocator {
// Environment config parser
class C10_CUDA_API CUDAAllocatorConfig {
public:
static size_t max_split_size() {
return instance().m_max_split_size;
}
static double garbage_collection_threshold() {
return instance().m_garbage_collection_threshold;
}
static bool expandable_segments() {
#ifndef PYTORCH_C10_DRIVER_API_SUPPORTED
if (instance().m_expandable_segments) {
TORCH_WARN_ONCE("expandable_segments not supported on this platform")
}
return false;
#else
return instance().m_expandable_segments;
#endif
}
static bool release_lock_on_cudamalloc() {
return instance().m_release_lock_on_cudamalloc;
}
/** Pinned memory allocator settings */
static bool pinned_use_cuda_host_register() {
return instance().m_pinned_use_cuda_host_register;
}
static size_t pinned_num_register_threads() {
return instance().m_pinned_num_register_threads;
}
static size_t pinned_max_register_threads() {
// Based on the benchmark results, we see better allocation performance
// with 8 threads. However on future systems, we may need more threads
// and limiting this to 128 threads.
return 128;
}
// This is used to round-up allocation size to nearest power of 2 divisions.
// More description below in function roundup_power2_next_division
// As ane example, if we want 4 divisions between 2's power, this can be done
// using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4
static size_t roundup_power2_divisions(size_t size);
static std::vector<size_t> roundup_power2_divisions() {
return instance().m_roundup_power2_divisions;
}
static std::string last_allocator_settings() {
std::lock_guard<std::mutex> lock(
instance().m_last_allocator_settings_mutex);
return instance().m_last_allocator_settings;
}
static CUDAAllocatorConfig& instance() {
static CUDAAllocatorConfig* s_instance = ([]() {
auto inst = new CUDAAllocatorConfig();
const char* env = getenv("PYTORCH_CUDA_ALLOC_CONF");
inst->parseArgs(env);
return inst;
})();
return *s_instance;
}
void parseArgs(const char* env);
private:
CUDAAllocatorConfig();
static void lexArgs(const char* env, std::vector<std::string>& config);
static void consumeToken(
const std::vector<std::string>& config,
size_t i,
const char c);
size_t parseMaxSplitSize(const std::vector<std::string>& config, size_t i);
size_t parseGarbageCollectionThreshold(
const std::vector<std::string>& config,
size_t i);
size_t parseRoundUpPower2Divisions(
const std::vector<std::string>& config,
size_t i);
size_t parseAllocatorConfig(
const std::vector<std::string>& config,
size_t i,
bool& used_cudaMallocAsync);
size_t parsePinnedUseCudaHostRegister(
const std::vector<std::string>& config,
size_t i);
size_t parsePinnedNumRegisterThreads(
const std::vector<std::string>& config,
size_t i);
std::atomic<size_t> m_max_split_size;
std::vector<size_t> m_roundup_power2_divisions;
std::atomic<double> m_garbage_collection_threshold;
std::atomic<size_t> m_pinned_num_register_threads;
std::atomic<bool> m_expandable_segments;
std::atomic<bool> m_release_lock_on_cudamalloc;
std::atomic<bool> m_pinned_use_cuda_host_register;
std::string m_last_allocator_settings;
std::mutex m_last_allocator_settings_mutex;
};
// General caching allocator utilities
C10_CUDA_API void setAllocatorSettings(const std::string& env);
} // namespace c10::cuda::CUDACachingAllocator

View File

@ -0,0 +1,499 @@
#pragma once
#include <c10/core/CachingDeviceAllocator.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/ApproximateClock.h>
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>
#include <array>
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
namespace c10 {
// Caching allocator will execute every registered callback if it unable to find
// block inside of already allocated area.
class C10_CUDA_API FreeMemoryCallback {
public:
virtual ~FreeMemoryCallback() = default;
virtual bool Execute() = 0;
};
C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
#define REGISTER_FREE_MEMORY_CALLBACK(name, ...) \
C10_REGISTER_CLASS(FreeCudaMemoryCallbacksRegistry, name, __VA_ARGS__);
} // namespace c10
//
// TODO: Turn this into an honest to goodness class. I briefly attempted to do
// this, but it was a bit irritating to figure out how to also correctly
// apply pimpl pattern so I didn't have to leak any internal implementation
// details in the header (CUDACachingAllocator could be made a pimpl, but
// you also need to appropriately define a class which is a subclass
// of Allocator. Not impossible, but required a bit more surgery than
// I wanted to do at the time.)
//
// Why is this using a namespace rather than old-style THCCachingAllocator_
// prefix? Mostly because it made the HIPify rules easier to write; _ is
// not counted as a word boundary, so you would otherwise have to list each
// of these functions.
namespace c10::cuda::CUDACachingAllocator {
// Preserved only for BC reasons
// NOLINTNEXTLINE(misc-unused-using-decls)
using c10::CachingDeviceAllocator::DeviceStats;
extern const size_t kLargeBuffer;
typedef std::shared_ptr<GatheredContext> (*CreateContextFn)();
// Struct containing info of an allocation block (i.e. a fractional part of a
// cudaMalloc)..
struct BlockInfo {
size_t size = 0;
size_t requested_size = 0;
int32_t gc_counter = 0;
bool allocated = false;
bool active = false;
std::shared_ptr<GatheredContext>
context_when_allocated; // per-watcher context
};
// Struct containing info of a memory segment (i.e. one contiguous cudaMalloc).
struct SegmentInfo {
c10::DeviceIndex device = 0;
size_t address = 0;
size_t total_size = 0;
size_t requested_size = 0; // unrounded, actually requested size
size_t allocated_size = 0;
size_t active_size = 0;
cudaStream_t stream = nullptr;
bool is_large = false;
bool is_expandable = false;
MempoolId_t owner_private_pool_id = {0, 0};
std::vector<BlockInfo> blocks;
std::shared_ptr<GatheredContext> context_when_allocated;
};
struct AllocatorState {
virtual ~AllocatorState() = default;
};
union trace_time_ {
time_t t_;
approx_time_t approx_t_;
};
struct TraceEntry {
enum Action {
ALLOC, // API made to the caching allocator for new memory
FREE_REQUESTED, // API call made to the caching allocator to free memory
FREE_COMPLETED, // The allocator might have to delay a free because
// it is still in use on another stream via record_stream
// This event is generated when a free actually completes.
SEGMENT_ALLOC, // a call to cudaMalloc to get more memory from the OS
SEGMENT_FREE, // a call to cudaFree to return memory to the OS (e.g. to
// defragment or empty_caches)
SEGMENT_MAP, // a call to cuMemMap (used with expandable_segments)
SEGMENT_UNMAP, // unmap part of a segment (used with expandable segments)
SNAPSHOT, // a call to snapshot, used to correlate memory snapshots to trace
// events
OOM // the allocator threw an OutOfMemoryError (addr_ is the amount of free
// bytes reported by cuda)
};
TraceEntry(
Action action,
c10::DeviceIndex device,
size_t addr,
size_t size,
cudaStream_t stream,
approx_time_t time,
std::shared_ptr<GatheredContext> context = nullptr)
: action_(action),
device_(device),
addr_(addr),
context_(std::move(context)),
stream_(stream),
size_(size) {
time_.approx_t_ = time;
}
Action action_;
c10::DeviceIndex device_;
size_t addr_; // for OOM, this is the amount of free bytes reported by cuda
std::shared_ptr<GatheredContext> context_;
cudaStream_t stream_{};
size_t size_;
trace_time_ time_{};
};
// Calls made by record_function will save annotations
struct AnnotationEntry {
AnnotationEntry(c10::DeviceIndex device, approx_time_t time)
: device_(device) {
time_.approx_t_ = time;
}
void recordUserMetadata(const std::string& name, std::string value) {
metadata_[name] = std::move(value);
}
c10::DeviceIndex device_;
trace_time_ time_{};
std::unordered_map<std::string, std::string> metadata_;
};
struct AllocatorConfigInfo {
double garbage_collection_threshold;
size_t max_split_size;
size_t pinned_num_register_threads;
bool expandable_segments;
bool release_lock_on_malloc;
bool pinned_use_host_register;
std::string last_allocator_settings;
std::vector<size_t> roundup_power2_divisions;
};
struct SnapshotInfo {
std::vector<SegmentInfo> segments;
std::vector<std::vector<TraceEntry>> device_traces;
std::vector<AnnotationEntry> external_annotations;
AllocatorConfigInfo config_metadata;
};
// returns the pointers freed in the pool
// and the pointers allocated. Note: a pointer
// may appear in both freed and allocated
struct CheckpointDelta {
std::vector<void*> ptrs_freed;
std::vector<at::DataPtr> dataptrs_allocd;
};
enum struct RecordContext {
NEVER = 0,
STATE = 1, // only keep stacks for active allocations
ALLOC = 2, // additionally keep stacks for allocations in the trace history
ALL = 3, // additionally record stacks for when something is freed
};
using OutOfMemoryObserver = std::function<void(
int64_t device,
size_t allocated,
size_t device_total,
size_t device_free)>;
using AllocatorTraceTracker = std::function<void(const TraceEntry&)>;
struct ShareableHandle {
ptrdiff_t offset;
std::string handle;
};
class CUDAAllocator : public Allocator {
public:
virtual void* raw_alloc(size_t nbytes) = 0;
virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) = 0;
virtual void raw_delete(void* ptr) = 0;
virtual void init(int device_count) = 0;
virtual bool initialized() = 0;
virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0;
virtual void emptyCache() = 0;
virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0;
virtual void* getBaseAllocation(void* ptr, size_t* size) = 0;
virtual void recordStream(const DataPtr&, CUDAStream stream) = 0;
virtual c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
c10::DeviceIndex device) = 0;
virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0;
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
virtual SnapshotInfo snapshot() = 0;
virtual void beginAllocateToPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
std::function<bool(cudaStream_t)> filter) = 0;
virtual void endAllocateToPool(
c10::DeviceIndex device,
MempoolId_t mempool_id) = 0;
virtual void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) = 0;
// returns true if the allocated blocks are equal to expected live allocations
virtual bool checkPoolLiveAllocations(
c10::DeviceIndex device,
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) {
TORCH_CHECK(
false,
name(),
" does not yet support checkPoolLiveAllocations. "
"If you need it, please file an issue describing your use case.");
}
virtual ShareableHandle shareIpcHandle(void* ptr) = 0;
virtual std::shared_ptr<void> getIpcDevPtr(std::string handle) = 0;
virtual bool isHistoryEnabled() {
TORCH_CHECK(
false,
name(),
" does not yet support recordHistory. "
"If you need it, please file an issue describing your use case.");
}
virtual void recordHistory(
bool enabled,
CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
RecordContext when) = 0;
virtual void recordAnnotation(
const std::vector<std::pair<std::string, std::string>>& md){};
virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0;
// Attached AllocatorTraceTracker callbacks will be called while the
// per-device allocator lock is held. Any additional locks taken from within
// the callback must be proven to always have the lock order that never
// triggers a deadlock. In particular, Python's GIL may be held when
// calling the allocator so it is unsafe to try to acquire the GIL in this
// callback.
virtual void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) = 0;
virtual void enablePeerAccess(
c10::DeviceIndex dev,
c10::DeviceIndex dev_to_access) = 0;
// memory not allocated from cudaMalloc cannot be copied
// across devices using cudaMemcpyAsync if peer to peer access is disabled.
// instead it requires cudaMemcpyAsyncPeer
// with P2P Enabled, all combinations work
// with P2P Disabled:
// cudaMalloc cudaMallocAsync/cuMemMap
// cudaMemcpyAsyncPeer works works
// cudaMemcpyAsync works error
// This function performs chooses to use the Peer version of
// memcpy if required based on where the allocated put dst/src.
virtual cudaError_t memcpyAsync(
void* dst,
int dstDevice,
const void* src,
int srcDevice,
size_t count,
cudaStream_t stream,
bool p2p_enabled) = 0;
virtual std::shared_ptr<AllocatorState> getCheckpointState(
c10::DeviceIndex device,
MempoolId_t id) = 0;
virtual CheckpointDelta setCheckpointPoolState(
c10::DeviceIndex device,
std::shared_ptr<AllocatorState> pps) = 0;
virtual std::string name() = 0;
};
// Allocator object, statically initialized
// See BackendInitializer in CUDACachingAllocator.cpp.
// Atomic loads on x86 are just normal loads,
// (atomic stores are different), so reading this value
// is no different than loading a pointer.
C10_CUDA_API extern std::atomic<CUDAAllocator*> allocator;
inline CUDAAllocator* get() {
return allocator.load();
}
// Called directly by clients.
inline void* raw_alloc(size_t nbytes) {
return get()->raw_alloc(nbytes);
}
inline void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) {
return get()->raw_alloc_with_stream(nbytes, stream);
}
inline void raw_delete(void* ptr) {
return get()->raw_delete(ptr);
}
inline void init(int device_count) {
return get()->init(device_count);
}
inline void setMemoryFraction(double fraction, c10::DeviceIndex device) {
return get()->setMemoryFraction(fraction, device);
}
inline void emptyCache() {
return get()->emptyCache();
}
inline void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) {
return get()->cacheInfo(device, largestBlock);
}
inline void* getBaseAllocation(void* ptr, size_t* size) {
return get()->getBaseAllocation(ptr, size);
}
inline void recordStream(const DataPtr& dataPtr, CUDAStream stream) {
return get()->recordStream(dataPtr, stream);
}
inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
c10::DeviceIndex device) {
return get()->getDeviceStats(device);
}
inline void resetAccumulatedStats(c10::DeviceIndex device) {
return get()->resetAccumulatedStats(device);
}
inline void resetPeakStats(c10::DeviceIndex device) {
return get()->resetPeakStats(device);
}
inline SnapshotInfo snapshot() {
return get()->snapshot();
}
inline std::shared_ptr<AllocatorState> getCheckpointState(
c10::DeviceIndex device,
MempoolId_t id) {
return get()->getCheckpointState(device, id);
}
inline CheckpointDelta setCheckpointPoolState(
c10::DeviceIndex device,
std::shared_ptr<AllocatorState> pps) {
return get()->setCheckpointPoolState(device, std::move(pps));
}
// CUDAGraph interactions
inline void beginAllocateToPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
std::function<bool(cudaStream_t)> filter) {
get()->beginAllocateToPool(device, mempool_id, std::move(filter));
}
inline void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) {
get()->endAllocateToPool(device, mempool_id);
}
inline void recordHistory(
bool enabled,
CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
RecordContext when) {
return get()->recordHistory(
enabled, context_recorder, alloc_trace_max_entries, when);
}
inline void recordAnnotation(
const std::vector<std::pair<std::string, std::string>>& md) {
return get()->recordAnnotation(md);
}
inline bool isHistoryEnabled() {
return get()->isHistoryEnabled();
}
inline bool checkPoolLiveAllocations(
c10::DeviceIndex device,
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) {
return get()->checkPoolLiveAllocations(
device, mempool_id, expected_live_allocations);
}
inline void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
return get()->attachOutOfMemoryObserver(std::move(observer));
}
inline void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) {
return get()->attachAllocatorTraceTracker(std::move(tracker));
}
inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) {
return get()->releasePool(device, mempool_id);
}
// Not part of CUDA_ALLOCATOR_BACKEND_INTERFACE
inline std::shared_ptr<void> getIpcDevPtr(std::string handle) {
return get()->getIpcDevPtr(std::move(handle));
}
inline ShareableHandle shareIpcHandle(void* ptr) {
return get()->shareIpcHandle(ptr);
}
inline std::string name() {
return get()->name();
}
inline cudaError_t memcpyAsync(
void* dst,
int dstDevice,
const void* src,
int srcDevice,
size_t count,
cudaStream_t stream,
bool p2p_enabled) {
return get()->memcpyAsync(
dst, dstDevice, src, srcDevice, count, stream, p2p_enabled);
}
inline void enablePeerAccess(
c10::DeviceIndex dev,
c10::DeviceIndex dev_to_access) {
return get()->enablePeerAccess(dev, dev_to_access);
}
} // namespace c10::cuda::CUDACachingAllocator
namespace c10::cuda {
// MemPool represents a pool of memory in a caching allocator. Currently,
// it's just the ID of the pool object maintained in the CUDACachingAllocator.
//
// An allocator pointer can be passed to the MemPool to define how the
// allocations should be done in the pool. For example: using a different
// system allocator such as ncclMemAlloc.
struct C10_CUDA_API MemPool {
MemPool(
CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
bool is_user_created = true);
MempoolId_t id();
CUDACachingAllocator::CUDAAllocator* allocator();
private:
static std::atomic<CaptureId_t> uid_;
static std::atomic<CaptureId_t> uuid_;
CUDACachingAllocator::CUDAAllocator* allocator_;
bool is_user_created_;
MempoolId_t id_;
};
// MemPoolContext holds the currently active pool and stashes the previous
// pool. On deletion it makes the previous pool active.
struct C10_CUDA_API MemPoolContext {
MemPoolContext(MemPool* mempool);
~MemPoolContext();
// getActiveMemPool() can be used to get the currently active pool.
// For instance: in CUDACachingAllocator, we can route allocations
// to a user provided allocator, by doing:
//
// auto active_pool = MemPoolContext::getActiveMemPool();
// if (active_pool && active_pool->allocator()) {
// ptr = active_pool->allocator()->raw_alloc(size);
// }
//
static MemPool* getActiveMemPool();
private:
MemPool* prev_mempool_;
};
} // namespace c10::cuda

View File

@ -0,0 +1,96 @@
#pragma once
#include <c10/cuda/CUDAException.h>
#include <c10/macros/Macros.h>
namespace c10::cuda {
#ifdef TORCH_USE_CUDA_DSA
// Copy string from `src` to `dst`
static __device__ void dstrcpy(char* dst, const char* src) {
int i = 0;
// Copy string from source to destination, ensuring that it
// isn't longer than `C10_CUDA_DSA_MAX_STR_LEN-1`
while (*src != '\0' && i++ < C10_CUDA_DSA_MAX_STR_LEN - 1) {
*dst++ = *src++;
}
*dst = '\0';
}
static __device__ void dsa_add_new_assertion_failure(
DeviceAssertionsData* assertions_data,
const char* assertion_msg,
const char* filename,
const char* function_name,
const int line_number,
const uint32_t caller,
const dim3 block_id,
const dim3 thread_id) {
// `assertions_data` may be nullptr if device-side assertion checking
// is disabled at run-time. If it is disabled at compile time this
// function will never be called
if (!assertions_data) {
return;
}
// Atomically increment so other threads can fail at the same time
// Note that incrementing this means that the CPU can observe that
// a failure has happened and can begin to respond before we've
// written information about that failure out to the buffer.
const auto nid = atomicAdd(&(assertions_data->assertion_count), 1);
if (nid >= C10_CUDA_DSA_ASSERTION_COUNT) {
// At this point we're ran out of assertion buffer space.
// We could print a message about this, but that'd get
// spammy if a lot of threads did it, so we just silently
// ignore any other assertion failures. In most cases the
// failures will all probably be analogous anyway.
return;
}
// Write information about the assertion failure to memory.
// Note that this occurs only after the `assertion_count`
// increment broadcasts that there's been a problem.
auto& self = assertions_data->assertions[nid];
dstrcpy(self.assertion_msg, assertion_msg);
dstrcpy(self.filename, filename);
dstrcpy(self.function_name, function_name);
self.line_number = line_number;
self.caller = caller;
self.block_id[0] = block_id.x;
self.block_id[1] = block_id.y;
self.block_id[2] = block_id.z;
self.thread_id[0] = thread_id.x;
self.thread_id[1] = thread_id.y;
self.thread_id[2] = thread_id.z;
}
// Emulates a kernel assertion. The assertion won't stop the kernel's progress,
// so you should assume everything the kernel produces is garbage if there's an
// assertion failure.
// NOTE: This assumes that `assertions_data` and `assertion_caller_id` are
// arguments of the kernel and therefore accessible.
#define CUDA_KERNEL_ASSERT2(condition) \
do { \
if (C10_UNLIKELY(!(condition))) { \
/* Has an atomic element so threads can fail at the same time */ \
c10::cuda::dsa_add_new_assertion_failure( \
assertions_data, \
C10_STRINGIZE(condition), \
__FILE__, \
__FUNCTION__, \
__LINE__, \
assertion_caller_id, \
blockIdx, \
threadIdx); \
/* Now that the kernel has failed we early exit the kernel, but */ \
/* otherwise keep going and rely on the host to check UVM and */ \
/* determine we've had a problem */ \
return; \
} \
} while (false)
#else
#define CUDA_KERNEL_ASSERT2(condition) assert(condition)
#endif
} // namespace c10::cuda

View File

@ -0,0 +1,164 @@
#pragma once
#include <c10/cuda/CUDAMacros.h>
#include <cstdint>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include <vector>
#ifdef USE_CUDA
#define TORCH_USE_CUDA_DSA
#endif
/// Number of assertion failure messages we can store. If this is too small
/// threads will fail silently.
constexpr int C10_CUDA_DSA_ASSERTION_COUNT = 10;
constexpr int C10_CUDA_DSA_MAX_STR_LEN = 512;
namespace c10::cuda {
/// Holds information about any device-side assertions that fail.
/// Held in managed memory and access by both the CPU and the GPU.
struct DeviceAssertionData {
/// Stringification of the assertion
// NOLINTNEXTLINE(*-c-arrays)
char assertion_msg[C10_CUDA_DSA_MAX_STR_LEN]{};
/// File the assertion was in
// NOLINTNEXTLINE(*-c-arrays)
char filename[C10_CUDA_DSA_MAX_STR_LEN]{};
/// Name of the function the assertion was in
// NOLINTNEXTLINE(*-c-arrays)
char function_name[C10_CUDA_DSA_MAX_STR_LEN]{};
/// Line number the assertion was at
int line_number{};
/// Number uniquely identifying the kernel launch that triggered the assertion
uint32_t caller{};
/// block_id of the thread that failed the assertion
// NOLINTNEXTLINE(*-c-arrays)
int32_t block_id[3]{};
/// third_id of the thread that failed the assertion
// NOLINTNEXTLINE(*-c-arrays)
int32_t thread_id[3]{};
};
/// Used to hold assertions generated by the device
/// Held in managed memory and access by both the CPU and the GPU.
struct DeviceAssertionsData {
/// Total number of assertions found; a subset of thse will be recorded
/// in `assertions`
int32_t assertion_count{};
/// An array of assertions that will be written to in a race-free manner
// NOLINTNEXTLINE(*-c-arrays)
DeviceAssertionData assertions[C10_CUDA_DSA_ASSERTION_COUNT]{};
};
/// Use to hold info about kernel launches so that we can run kernels
/// asynchronously and still associate launches with device-side
/// assertion failures
struct CUDAKernelLaunchInfo {
/// Filename of the code where the kernel was launched from
const char* launch_filename;
/// Function from which the kernel was launched
const char* launch_function;
/// Line number of where the code was launched from
uint32_t launch_linenum;
/// Backtrace of where the kernel was launched from, only populated if
/// CUDAKernelLaunchRegistry::gather_launch_stacktrace is True
std::string launch_stacktrace;
/// Kernel that was launched
const char* kernel_name;
/// Device the kernel was launched on
int device;
/// Stream the kernel was launched on
int32_t stream;
/// A number that uniquely identifies the kernel launch
uint64_t generation_number;
};
/// Circular buffer used to hold information about kernel launches
/// this is later used to reconstruct how a device-side kernel assertion failure
/// occurred CUDAKernelLaunchRegistry is used as a singleton
class C10_CUDA_API CUDAKernelLaunchRegistry {
private:
/// Assume that this is the max number of kernel launches that might ever be
/// enqueued across all streams on a single device
static constexpr int max_kernel_launches = 1024;
/// How many kernel launch infos we've inserted. Used to ensure that circular
/// queue doesn't provide false information by always increasing, but also to
/// mark where we are inserting into the queue
#ifdef TORCH_USE_CUDA_DSA
uint64_t generation_number = 0;
#endif
/// Shared mutex between writer and accessor to ensure multi-threaded safety.
mutable std::mutex read_write_mutex;
/// Used to ensure prevent race conditions in GPU memory allocation
mutable std::mutex gpu_alloc_mutex;
/// Pointer to managed memory keeping track of device-side assertions. There
/// is one entry for each possible device the process might work with. Unused
/// entries are nullptrs. We could also use an unordered_set here, but this
/// vector design will be faster and the wasted memory is small since we
/// expect the number of GPUs per node will always be small
std::vector<
std::unique_ptr<DeviceAssertionsData, void (*)(DeviceAssertionsData*)>>
uvm_assertions;
/// A single circular buffer holds information about every kernel launch the
/// process makes across all devices.
std::vector<CUDAKernelLaunchInfo> kernel_launches;
bool check_env_for_enable_launch_stacktracing() const;
bool check_env_for_dsa_enabled() const;
public:
CUDAKernelLaunchRegistry();
/// Register a new kernel launch and obtain a generation number back to be
/// passed to the kernel
uint32_t insert(
const char* launch_filename,
const char* launch_function,
const uint32_t launch_linenum,
const char* kernel_name,
const int32_t stream_id);
/// Get copies of the kernel launch registry and each device's assertion
/// failure buffer so they can be inspected without raising race conditions
std::
pair<std::vector<DeviceAssertionsData>, std::vector<CUDAKernelLaunchInfo>>
snapshot() const;
/// Get a pointer to the current device's assertion failure buffer. If no such
/// buffer exists then one is created. This means that the first kernel launch
/// made on each device will be slightly slower because memory allocations are
/// required
DeviceAssertionsData* get_uvm_assertions_ptr_for_current_device();
/// Gets the global singleton of the registry
static CUDAKernelLaunchRegistry& get_singleton_ref();
/// If not all devices support DSA, we disable it
const bool do_all_devices_support_managed_memory = false;
/// Whether or not to gather stack traces when launching kernels
bool gather_launch_stacktrace = false;
/// Whether or not host-side DSA is enabled or disabled at run-time
/// Note: Device-side code cannot be enabled/disabled at run-time
bool enabled_at_runtime = false;
/// Whether or not a device has indicated a failure
bool has_failed() const;
#ifdef TORCH_USE_CUDA_DSA
const bool enabled_at_compile_time = true;
#else
const bool enabled_at_compile_time = false;
#endif
};
std::string c10_retrieve_device_side_assertion_info();
} // namespace c10::cuda
// Each kernel launched with TORCH_DSA_KERNEL_LAUNCH
// requires the same input arguments. We introduce the following macro to
// standardize these.
#define TORCH_DSA_KERNEL_ARGS \
[[maybe_unused]] c10::cuda::DeviceAssertionsData *const assertions_data, \
[[maybe_unused]] uint32_t assertion_caller_id
// This macro can be used to pass the DSA arguments onward to another
// function
#define TORCH_DSA_KERNEL_ARGS_PASS assertions_data, assertion_caller_id

View File

@ -0,0 +1,100 @@
#pragma once
#include <c10/cuda/CUDADeviceAssertionHost.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/cuda/CUDAMiscFunctions.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <cuda.h>
// Note [CHECK macro]
// ~~~~~~~~~~~~~~~~~~
// This is a macro so that AT_ERROR can get accurate __LINE__
// and __FILE__ information. We could split this into a short
// macro and a function implementation if we pass along __LINE__
// and __FILE__, but no one has found this worth doing.
// Used to denote errors from CUDA framework.
// This needs to be declared here instead util/Exception.h for proper conversion
// during hipify.
namespace c10 {
class C10_CUDA_API CUDAError : public c10::Error {
using Error::Error;
};
} // namespace c10
#define C10_CUDA_CHECK(EXPR) \
do { \
const cudaError_t __err = EXPR; \
c10::cuda::c10_cuda_check_implementation( \
static_cast<int32_t>(__err), \
__FILE__, \
__func__, /* Line number data type not well-defined between \
compilers, so we perform an explicit cast */ \
static_cast<uint32_t>(__LINE__), \
true); \
} while (0)
#define C10_CUDA_CHECK_WARN(EXPR) \
do { \
const cudaError_t __err = EXPR; \
if (C10_UNLIKELY(__err != cudaSuccess)) { \
auto error_unused C10_UNUSED = cudaGetLastError(); \
(void)error_unused; \
TORCH_WARN("CUDA warning: ", cudaGetErrorString(__err)); \
} \
} while (0)
// Indicates that a CUDA error is handled in a non-standard way
#define C10_CUDA_ERROR_HANDLED(EXPR) EXPR
// Intentionally ignore a CUDA error
#define C10_CUDA_IGNORE_ERROR(EXPR) \
do { \
const cudaError_t __err = EXPR; \
if (C10_UNLIKELY(__err != cudaSuccess)) { \
cudaError_t error_unused C10_UNUSED = cudaGetLastError(); \
(void)error_unused; \
} \
} while (0)
// Clear the last CUDA error
#define C10_CUDA_CLEAR_ERROR() \
do { \
cudaError_t error_unused C10_UNUSED = cudaGetLastError(); \
(void)error_unused; \
} while (0)
// This should be used directly after every kernel launch to ensure
// the launch happened correctly and provide an early, close-to-source
// diagnostic if it didn't.
#define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError())
/// Launches a CUDA kernel appending to it all the information need to handle
/// device-side assertion failures. Checks that the launch was successful.
#define TORCH_DSA_KERNEL_LAUNCH( \
kernel, blocks, threads, shared_mem, stream, ...) \
do { \
auto& launch_registry = \
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref(); \
kernel<<<blocks, threads, shared_mem, stream>>>( \
__VA_ARGS__, \
launch_registry.get_uvm_assertions_ptr_for_current_device(), \
launch_registry.insert( \
__FILE__, __FUNCTION__, __LINE__, #kernel, stream.id())); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
} while (0)
namespace c10::cuda {
/// In the event of a CUDA failure, formats a nice error message about that
/// failure and also checks for device-side assertion failures
C10_CUDA_API void c10_cuda_check_implementation(
const int32_t err,
const char* filename,
const char* function_name,
const int line_number,
const bool include_device_assertions);
} // namespace c10::cuda

View File

@ -0,0 +1,116 @@
#pragma once
// This header provides C++ wrappers around commonly used CUDA API functions.
// The benefit of using C++ here is that we can raise an exception in the
// event of an error, rather than explicitly pass around error codes. This
// leads to more natural APIs.
//
// The naming convention used here matches the naming convention of torch.cuda
#include <c10/core/Device.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAMacros.h>
#include <cuda_runtime_api.h>
namespace c10::cuda {
// NB: In the past, we were inconsistent about whether or not this reported
// an error if there were driver problems are not. Based on experience
// interacting with users, it seems that people basically ~never want this
// function to fail; it should just return zero if things are not working.
// Oblige them.
// It still might log a warning for user first time it's invoked
C10_CUDA_API DeviceIndex device_count() noexcept;
// Version of device_count that throws is no devices are detected
C10_CUDA_API DeviceIndex device_count_ensure_non_zero();
C10_CUDA_API DeviceIndex current_device();
C10_CUDA_API void set_device(DeviceIndex device);
C10_CUDA_API void device_synchronize();
C10_CUDA_API void warn_or_error_on_sync();
// Raw CUDA device management functions
C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count);
C10_CUDA_API cudaError_t GetDevice(DeviceIndex* device);
C10_CUDA_API cudaError_t SetDevice(DeviceIndex device);
C10_CUDA_API cudaError_t MaybeSetDevice(DeviceIndex device);
C10_CUDA_API DeviceIndex ExchangeDevice(DeviceIndex device);
C10_CUDA_API DeviceIndex MaybeExchangeDevice(DeviceIndex device);
C10_CUDA_API void SetTargetDevice();
enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR };
// this is a holder for c10 global state (similar to at GlobalContext)
// currently it's used to store cuda synchronization warning state,
// but can be expanded to hold other related global state, e.g. to
// record stream usage
class WarningState {
public:
void set_sync_debug_mode(SyncDebugMode l) {
sync_debug_mode = l;
}
SyncDebugMode get_sync_debug_mode() {
return sync_debug_mode;
}
private:
SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED;
};
C10_CUDA_API __inline__ WarningState& warning_state() {
static WarningState warning_state_;
return warning_state_;
}
// the subsequent functions are defined in the header because for performance
// reasons we want them to be inline
C10_CUDA_API void __inline__ memcpy_and_sync(
void* dst,
const void* src,
int64_t nbytes,
cudaMemcpyKind kind,
cudaStream_t stream) {
if (C10_UNLIKELY(
warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
warn_or_error_on_sync();
}
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_stream_synchronization(
c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
}
#if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301)
C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
#else
C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
#endif
}
C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
if (C10_UNLIKELY(
warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
warn_or_error_on_sync();
}
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_stream_synchronization(
c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
}
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
}
C10_CUDA_API bool hasPrimaryContext(DeviceIndex device_index);
C10_CUDA_API std::optional<DeviceIndex> getDeviceIndexWithPrimaryContext();
} // namespace c10::cuda

View File

@ -0,0 +1,77 @@
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <iostream>
#include <utility>
// CUDA Graphs utils used by c10 and aten.
// aten/cuda/CUDAGraphsUtils.cuh adds utils used by aten only.
namespace c10::cuda {
using CaptureId_t = unsigned long long;
// first is set if the instance is created by CUDAGraph::capture_begin.
// second is set if the instance is created by at::cuda::graph_pool_handle.
using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;
// RAII guard for "cudaStreamCaptureMode", a thread-local value
// that controls the error-checking strictness of a capture.
struct C10_CUDA_API CUDAStreamCaptureModeGuard {
CUDAStreamCaptureModeGuard(cudaStreamCaptureMode desired)
: strictness_(desired) {
C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&strictness_));
}
~CUDAStreamCaptureModeGuard() {
C10_CUDA_CHECK_WARN(cudaThreadExchangeStreamCaptureMode(&strictness_));
}
private:
cudaStreamCaptureMode strictness_;
};
// Protects against enum cudaStreamCaptureStatus implementation changes.
// Some compilers seem not to like static_assert without the messages.
static_assert(
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0,
"unexpected int(cudaStreamCaptureStatusNone) value");
static_assert(
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1,
"unexpected int(cudaStreamCaptureStatusActive) value");
static_assert(
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2,
"unexpected int(cudaStreamCaptureStatusInvalidated) value");
enum class CaptureStatus : int {
None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone),
Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive),
Invalidated = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated)
};
inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
switch (status) {
case CaptureStatus::None:
os << "cudaStreamCaptureStatusNone";
break;
case CaptureStatus::Active:
os << "cudaStreamCaptureStatusActive";
break;
case CaptureStatus::Invalidated:
os << "cudaStreamCaptureStatusInvalidated";
break;
default:
TORCH_INTERNAL_ASSERT(
false, "Unknown CUDA graph CaptureStatus", int(status));
}
return os;
}
// Use this version where you're sure a CUDA context exists already.
inline CaptureStatus currentStreamCaptureStatusMayInitCtx() {
cudaStreamCaptureStatus is_capturing{cudaStreamCaptureStatusNone};
C10_CUDA_CHECK(
cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing));
return CaptureStatus(is_capturing);
}
} // namespace c10::cuda

View File

@ -0,0 +1,301 @@
#pragma once
#include <c10/core/DeviceType.h>
#include <c10/core/impl/InlineDeviceGuard.h>
#include <c10/core/impl/InlineStreamGuard.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/cuda/impl/CUDAGuardImpl.h>
namespace c10::cuda {
// This code is kind of boilerplatey. See Note [Whither the DeviceGuard
// boilerplate]
/// A variant of DeviceGuard that is specialized for CUDA. It accepts
/// integer indices (interpreting them as CUDA devices) and is a little
/// more efficient than DeviceGuard (it compiles to straight line
/// cudaSetDevice/cudaGetDevice calls); however, it can only be used
/// from code that links against CUDA directly.
struct CUDAGuard {
/// No default constructor; see Note [Omitted default constructor from RAII]
explicit CUDAGuard() = delete;
/// Set the current CUDA device to the passed device index.
explicit CUDAGuard(DeviceIndex device_index) : guard_(device_index) {}
/// Sets the current CUDA device to the passed device. Errors if the passed
/// device is not a CUDA device.
explicit CUDAGuard(Device device) : guard_(device) {}
// Copy is not allowed
CUDAGuard(const CUDAGuard&) = delete;
CUDAGuard& operator=(const CUDAGuard&) = delete;
// Move is not allowed (there is no uninitialized state)
CUDAGuard(CUDAGuard&& other) = delete;
CUDAGuard& operator=(CUDAGuard&& other) = delete;
/// Sets the CUDA device to the given device. Errors if the given device
/// is not a CUDA device.
void set_device(Device device) {
guard_.set_device(device);
}
/// Sets the CUDA device to the given device. Errors if the given device
/// is not a CUDA device. (This method is provided for uniformity with
/// DeviceGuard).
void reset_device(Device device) {
guard_.reset_device(device);
}
/// Sets the CUDA device to the given device index.
void set_index(DeviceIndex device_index) {
guard_.set_index(device_index);
}
/// Returns the device that was set upon construction of the guard
Device original_device() const {
return guard_.original_device();
}
/// Returns the last device that was set via `set_device`, if any, otherwise
/// the device passed during construction.
Device current_device() const {
return guard_.current_device();
}
private:
/// The guard for the current device.
c10::impl::InlineDeviceGuard<impl::CUDAGuardImpl> guard_;
};
/// A variant of OptionalDeviceGuard that is specialized for CUDA. See
/// CUDAGuard for when you can use this.
struct OptionalCUDAGuard {
/// Create an uninitialized OptionalCUDAGuard.
explicit OptionalCUDAGuard() : guard_() {}
/// Set the current CUDA device to the passed Device, if it is not nullopt.
explicit OptionalCUDAGuard(std::optional<Device> device_opt)
: guard_(device_opt) {}
/// Set the current CUDA device to the passed device index, if it is not
/// nullopt
explicit OptionalCUDAGuard(std::optional<DeviceIndex> device_index_opt)
: guard_(device_index_opt) {}
// Copy is not allowed
OptionalCUDAGuard(const OptionalCUDAGuard&) = delete;
OptionalCUDAGuard& operator=(const OptionalCUDAGuard&) = delete;
// See Note [Move construction for RAII guards is tricky]
OptionalCUDAGuard(OptionalCUDAGuard&& other) = delete;
// See Note [Move assignment for RAII guards is tricky]
OptionalCUDAGuard& operator=(OptionalCUDAGuard&& other) = delete;
/// Sets the CUDA device to the given device, initializing the guard if it
/// is not already initialized. Errors if the given device is not a CUDA
/// device.
void set_device(Device device) {
guard_.set_device(device);
}
/// Sets the CUDA device to the given device, initializing the guard if it is
/// not already initialized. Errors if the given device is not a CUDA device.
/// (This method is provided for uniformity with OptionalDeviceGuard).
void reset_device(Device device) {
guard_.reset_device(device);
}
/// Sets the CUDA device to the given device index, initializing the guard if
/// it is not already initialized.
void set_index(DeviceIndex device_index) {
guard_.set_index(device_index);
}
/// Returns the device that was set immediately prior to initialization of the
/// guard, or nullopt if the guard is uninitialized.
std::optional<Device> original_device() const {
return guard_.original_device();
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device, if the guard is initialized,
/// or nullopt if the guard is uninitialized.
std::optional<Device> current_device() const {
return guard_.current_device();
}
/// Restore the original CUDA device, resetting this guard to uninitialized
/// state.
void reset() {
guard_.reset();
}
private:
c10::impl::InlineOptionalDeviceGuard<impl::CUDAGuardImpl> guard_;
};
/// A variant of StreamGuard that is specialized for CUDA. See CUDAGuard
/// for when you can use this.
struct CUDAStreamGuard {
/// No default constructor, see Note [Omitted default constructor from RAII]
explicit CUDAStreamGuard() = delete;
/// Set the current CUDA device to the device associated with the passed
/// stream, and set the current CUDA stream on that device to the passed
/// stream. Errors if the Stream is not a CUDA stream.
explicit CUDAStreamGuard(Stream stream) : guard_(stream) {}
/// Copy is disallowed
CUDAStreamGuard(const CUDAStreamGuard&) = delete;
CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete;
/// Move is disallowed, as CUDAStreamGuard does not have an uninitialized
/// state, which is required for moves on types with nontrivial destructors.
CUDAStreamGuard(CUDAStreamGuard&& other) = delete;
CUDAStreamGuard& operator=(CUDAStreamGuard&& other) = delete;
/// Resets the currently set stream to the original stream and
/// the currently set device to the original device. Then,
/// set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream.
/// Errors if the stream passed is not a CUDA stream.
///
/// NOTE: this implementation may skip some stream/device setting if
/// it can prove that it is unnecessary.
///
/// WARNING: reset_stream does NOT preserve previously set streams on
/// different devices. If you need to set streams on multiple devices
/// on CUDA, use CUDAMultiStreamGuard instead.
void reset_stream(Stream stream) {
guard_.reset_stream(stream);
}
/// Returns the CUDA stream that was set at the time the guard was
/// constructed.
CUDAStream original_stream() const {
return CUDAStream(CUDAStream::UNCHECKED, guard_.original_stream());
}
/// Returns the most recent CUDA stream that was set using this device guard,
/// either from construction, or via set_stream.
CUDAStream current_stream() const {
return CUDAStream(CUDAStream::UNCHECKED, guard_.current_stream());
}
/// Returns the most recent CUDA device that was set using this device guard,
/// either from construction, or via set_device/reset_device/set_index.
Device current_device() const {
return guard_.current_device();
}
/// Returns the CUDA device that was set at the most recent reset_stream(),
/// or otherwise the device at construction time.
Device original_device() const {
return guard_.original_device();
}
private:
c10::impl::InlineStreamGuard<impl::CUDAGuardImpl> guard_;
};
/// A variant of OptionalStreamGuard that is specialized for CUDA. See
/// CUDAGuard for when you can use this.
struct OptionalCUDAStreamGuard {
/// Create an uninitialized guard.
explicit OptionalCUDAStreamGuard() : guard_() {}
/// Set the current CUDA device to the device associated with the passed
/// stream, and set the current CUDA stream on that device to the passed
/// stream. Errors if the Stream is not a CUDA stream.
explicit OptionalCUDAStreamGuard(Stream stream) : guard_(stream) {}
/// Set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream,
/// if the passed stream is not nullopt.
explicit OptionalCUDAStreamGuard(std::optional<Stream> stream_opt)
: guard_(stream_opt) {}
/// Copy is disallowed
OptionalCUDAStreamGuard(const OptionalCUDAStreamGuard&) = delete;
OptionalCUDAStreamGuard& operator=(const OptionalCUDAStreamGuard&) = delete;
// See Note [Move construction for RAII guards is tricky]
OptionalCUDAStreamGuard(OptionalCUDAStreamGuard&& other) = delete;
// See Note [Move assignment for RAII guards is tricky]
OptionalCUDAStreamGuard& operator=(OptionalCUDAStreamGuard&& other) = delete;
/// Resets the currently set CUDA stream to the original stream and
/// the currently set device to the original device. Then,
/// set the current device to the device associated with the passed stream,
/// and set the current stream on that device to the passed stream.
/// Initializes the guard if it was not previously initialized.
void reset_stream(Stream stream) {
guard_.reset_stream(stream);
}
/// Returns the CUDA stream that was set at the time the guard was most
/// recently initialized, or nullopt if the guard is uninitialized.
std::optional<CUDAStream> original_stream() const {
auto r = guard_.original_stream();
if (r.has_value()) {
return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value()));
} else {
return std::nullopt;
}
}
/// Returns the most recent CUDA stream that was set using this stream guard,
/// either from construction, or via reset_stream, if the guard is
/// initialized, or nullopt if the guard is uninitialized.
std::optional<CUDAStream> current_stream() const {
auto r = guard_.current_stream();
if (r.has_value()) {
return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value()));
} else {
return std::nullopt;
}
}
/// Restore the original CUDA device and stream, resetting this guard to
/// uninitialized state.
void reset() {
guard_.reset();
}
private:
c10::impl::InlineOptionalStreamGuard<impl::CUDAGuardImpl> guard_;
};
/// A variant of MultiStreamGuard that is specialized for CUDA.
struct CUDAMultiStreamGuard {
explicit CUDAMultiStreamGuard(ArrayRef<CUDAStream> streams)
: guard_(unwrapStreams(streams)) {}
/// Copy is disallowed
CUDAMultiStreamGuard(const CUDAMultiStreamGuard&) = delete;
CUDAMultiStreamGuard& operator=(const CUDAMultiStreamGuard&) = delete;
// See Note [Move construction for RAII guards is tricky]
CUDAMultiStreamGuard(CUDAMultiStreamGuard&& other) = delete;
// See Note [Move assignment for RAII guards is tricky]
CUDAMultiStreamGuard& operator=(CUDAMultiStreamGuard&& other) = delete;
private:
c10::impl::InlineMultiStreamGuard<impl::CUDAGuardImpl> guard_;
static std::vector<Stream> unwrapStreams(ArrayRef<CUDAStream> cudaStreams) {
std::vector<Stream> streams;
streams.reserve(cudaStreams.size());
for (const CUDAStream& cudaStream : cudaStreams) {
streams.push_back(cudaStream);
}
return streams;
}
};
} // namespace c10::cuda

View File

@ -0,0 +1,51 @@
#pragma once
#ifndef C10_USING_CUSTOM_GENERATED_MACROS
// We have not yet modified the AMD HIP build to generate this file so
// we add an extra option to specifically ignore it.
#ifndef C10_CUDA_NO_CMAKE_CONFIGURE_FILE
#include <c10/cuda/impl/cuda_cmake_macros.h>
#endif // C10_CUDA_NO_CMAKE_CONFIGURE_FILE
#endif
// See c10/macros/Export.h for a detailed explanation of what the function
// of these macros are. We need one set of macros for every separate library
// we build.
#ifdef _WIN32
#if defined(C10_CUDA_BUILD_SHARED_LIBS)
#define C10_CUDA_EXPORT __declspec(dllexport)
#define C10_CUDA_IMPORT __declspec(dllimport)
#else
#define C10_CUDA_EXPORT
#define C10_CUDA_IMPORT
#endif
#else // _WIN32
#if defined(__GNUC__)
#define C10_CUDA_EXPORT __attribute__((__visibility__("default")))
#else // defined(__GNUC__)
#define C10_CUDA_EXPORT
#endif // defined(__GNUC__)
#define C10_CUDA_IMPORT C10_CUDA_EXPORT
#endif // _WIN32
// This one is being used by libc10_cuda.so
#ifdef C10_CUDA_BUILD_MAIN_LIB
#define C10_CUDA_API C10_CUDA_EXPORT
#else
#define C10_CUDA_API C10_CUDA_IMPORT
#endif
/**
* The maximum number of GPUs that we recognizes. Increasing this beyond the
* initial limit of 16 broke Caffe2 testing, hence the ifdef guards.
* This value cannot be more than 128 because our DeviceIndex is a uint8_t.
o */
#ifdef FBCODE_CAFFE2
// fbcode depends on this value being 16
#define C10_COMPILE_TIME_MAX_GPUS 16
#else
#define C10_COMPILE_TIME_MAX_GPUS 120
#endif

View File

@ -0,0 +1,152 @@
#pragma once
/* This file defines math functions compatible across different gpu
* platforms (currently CUDA and HIP).
*/
#if defined(__CUDACC__) || defined(__HIPCC__)
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#ifdef __HIPCC__
#define __MATH_FUNCTIONS_DECL__ inline C10_DEVICE
#else /* __HIPCC__ */
#ifdef __CUDACC_RTC__
#define __MATH_FUNCTIONS_DECL__ C10_HOST_DEVICE
#else /* __CUDACC_RTC__ */
#define __MATH_FUNCTIONS_DECL__ inline C10_HOST_DEVICE
#endif /* __CUDACC_RTC__ */
#endif /* __HIPCC__ */
namespace c10::cuda::compat {
__MATH_FUNCTIONS_DECL__ float abs(float x) {
return ::fabsf(x);
}
__MATH_FUNCTIONS_DECL__ double abs(double x) {
return ::fabs(x);
}
__MATH_FUNCTIONS_DECL__ float exp(float x) {
return ::expf(x);
}
__MATH_FUNCTIONS_DECL__ double exp(double x) {
return ::exp(x);
}
__MATH_FUNCTIONS_DECL__ float ceil(float x) {
return ::ceilf(x);
}
__MATH_FUNCTIONS_DECL__ double ceil(double x) {
return ::ceil(x);
}
__MATH_FUNCTIONS_DECL__ float copysign(float x, float y) {
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
return ::copysignf(x, y);
#else
// std::copysign gets ICE/Segfaults with gcc 7.5/8 on arm64
// (e.g. Jetson), see PyTorch PR #51834
// This host function needs to be here for the compiler but is never used
TORCH_INTERNAL_ASSERT(
false, "CUDAMathCompat copysign should not run on the CPU");
#endif
}
__MATH_FUNCTIONS_DECL__ double copysign(double x, double y) {
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
return ::copysign(x, y);
#else
// see above
TORCH_INTERNAL_ASSERT(
false, "CUDAMathCompat copysign should not run on the CPU");
#endif
}
__MATH_FUNCTIONS_DECL__ float floor(float x) {
return ::floorf(x);
}
__MATH_FUNCTIONS_DECL__ double floor(double x) {
return ::floor(x);
}
__MATH_FUNCTIONS_DECL__ float log(float x) {
return ::logf(x);
}
__MATH_FUNCTIONS_DECL__ double log(double x) {
return ::log(x);
}
__MATH_FUNCTIONS_DECL__ float log1p(float x) {
return ::log1pf(x);
}
__MATH_FUNCTIONS_DECL__ double log1p(double x) {
return ::log1p(x);
}
__MATH_FUNCTIONS_DECL__ float max(float x, float y) {
return ::fmaxf(x, y);
}
__MATH_FUNCTIONS_DECL__ double max(double x, double y) {
return ::fmax(x, y);
}
__MATH_FUNCTIONS_DECL__ float min(float x, float y) {
return ::fminf(x, y);
}
__MATH_FUNCTIONS_DECL__ double min(double x, double y) {
return ::fmin(x, y);
}
__MATH_FUNCTIONS_DECL__ float pow(float x, float y) {
return ::powf(x, y);
}
__MATH_FUNCTIONS_DECL__ double pow(double x, double y) {
return ::pow(x, y);
}
__MATH_FUNCTIONS_DECL__ void sincos(float x, float* sptr, float* cptr) {
return ::sincosf(x, sptr, cptr);
}
__MATH_FUNCTIONS_DECL__ void sincos(double x, double* sptr, double* cptr) {
return ::sincos(x, sptr, cptr);
}
__MATH_FUNCTIONS_DECL__ float sqrt(float x) {
return ::sqrtf(x);
}
__MATH_FUNCTIONS_DECL__ double sqrt(double x) {
return ::sqrt(x);
}
__MATH_FUNCTIONS_DECL__ float rsqrt(float x) {
return ::rsqrtf(x);
}
__MATH_FUNCTIONS_DECL__ double rsqrt(double x) {
return ::rsqrt(x);
}
__MATH_FUNCTIONS_DECL__ float tan(float x) {
return ::tanf(x);
}
__MATH_FUNCTIONS_DECL__ double tan(double x) {
return ::tan(x);
}
__MATH_FUNCTIONS_DECL__ float tanh(float x) {
return ::tanhf(x);
}
__MATH_FUNCTIONS_DECL__ double tanh(double x) {
return ::tanh(x);
}
__MATH_FUNCTIONS_DECL__ float normcdf(float x) {
return ::normcdff(x);
}
__MATH_FUNCTIONS_DECL__ double normcdf(double x) {
return ::normcdf(x);
}
} // namespace c10::cuda::compat
#endif

View File

@ -0,0 +1,12 @@
#pragma once
// this file is to avoid circular dependency between CUDAFunctions.h and
// CUDAExceptions.h
#include <c10/cuda/CUDAMacros.h>
#include <mutex>
namespace c10::cuda {
C10_CUDA_API const char* get_cuda_check_suffix() noexcept;
C10_CUDA_API std::mutex* getFreeMutex();
} // namespace c10::cuda

View File

@ -0,0 +1,268 @@
#pragma once
#include <cuda_runtime_api.h>
#include <c10/core/DeviceGuard.h>
#include <c10/core/Stream.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/util/Exception.h>
/*
* Stream pool note.
*
* A CUDAStream is an abstraction of an actual cuStream on the GPU. CUDAStreams
* are backed by cuStreams, but they use several pools to minimize the costs
* associated with creating, retaining, and destroying cuStreams.
*
* There are three pools per device, and a device's pools are lazily created.
*
* The first pool contains only the default stream. When the default stream
* is requested it's returned.
*
* The second pool is the "low priority" or "default priority" streams. In
* HIP builds there is no distinction between streams in this pool and streams
* in the third pool (below). There are 32 of these streams per device, and
* when a stream is requested one of these streams is returned round-robin.
* That is, the first stream requested is at index 0, the second at index 1...
* to index 31, then index 0 again.
*
* This means that if 33 low priority streams are requested, the first and
* last streams requested are actually the same stream (under the covers)
* and kernels enqueued on them cannot run concurrently.
*
* The third pool is the "high priority" streams. The third pool acts like
* the second pool except the streams are created with a higher priority.
*
* These pools suggest that stream users should prefer many short-lived streams,
* as the cost of acquiring and releasing streams is effectively zero. If
* many longer-lived streams are required in performance critical scenarios
* then the functionality here may need to be extended to allow, for example,
* "reserving" a subset of the pool so that other streams do not accidentally
* overlap the performance critical streams.
*
* Note: although the notion of "current stream for device" is thread local
* (every OS thread has a separate current stream, as one might expect),
* the stream pool is global across all threads; stream 0 is always stream 0
* no matter which thread you use it on. Multiple threads can synchronize
* on the same stream. Although the CUDA documentation is not very clear
* on the matter, streams are thread safe; e.g., it is safe to enqueue
* a kernel on the same stream from two different threads.
*/
namespace c10::cuda {
static constexpr int max_compile_time_stream_priorities = 4;
// Value object representing a CUDA stream. This is just a wrapper
// around c10::Stream, but it comes with a little extra CUDA-specific
// functionality (conversion to cudaStream_t), and a guarantee that
// the wrapped c10::Stream really is a CUDA stream.
class C10_CUDA_API CUDAStream {
public:
enum Unchecked { UNCHECKED };
/// Construct a CUDAStream from a Stream. This construction is checked,
/// and will raise an error if the Stream is not, in fact, a CUDA stream.
explicit CUDAStream(Stream stream) : stream_(stream) {
TORCH_CHECK(stream_.device_type() == DeviceType::CUDA);
}
/// Construct a CUDAStream from a Stream with no error checking.
/// This constructor uses the "named" constructor idiom, and can
/// be invoked as: CUDAStream(CUDAStream::UNCHECKED, stream)
explicit CUDAStream(Unchecked, Stream stream) : stream_(stream) {}
bool operator==(const CUDAStream& other) const noexcept {
return unwrap() == other.unwrap();
}
bool operator!=(const CUDAStream& other) const noexcept {
return unwrap() != other.unwrap();
}
/// Implicit conversion to cudaStream_t.
operator cudaStream_t() const {
return stream();
}
/// Implicit conversion to Stream (a.k.a., forget that the stream is a
/// CUDA stream).
operator Stream() const {
return unwrap();
}
/// Used to avoid baking in device type explicitly to Python-side API.
DeviceType device_type() const {
return DeviceType::CUDA;
}
/// Get the CUDA device index that this stream is associated with.
DeviceIndex device_index() const {
return stream_.device_index();
}
/// Get the full Device that this stream is associated with. The Device
/// is guaranteed to be a CUDA device.
Device device() const {
return Device(DeviceType::CUDA, device_index());
}
/// Return the stream ID corresponding to this particular stream.
StreamId id() const {
return stream_.id();
}
bool query() const {
DeviceGuard guard{stream_.device()};
cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaStreamQuery(stream()));
if (err == cudaSuccess) {
return true;
} else if (err != cudaErrorNotReady) {
C10_CUDA_CHECK(err);
} else {
// ignore and clear the error if not ready
(void)cudaGetLastError();
}
return false;
}
void synchronize() const {
DeviceGuard guard{stream_.device()};
c10::cuda::stream_synchronize(stream());
}
int priority() const {
DeviceGuard guard{stream_.device()};
int priority = 0;
C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority));
return priority;
}
/// Explicit conversion to cudaStream_t.
cudaStream_t stream() const;
/// Explicit conversion to Stream.
Stream unwrap() const {
return stream_;
}
/// Reversibly pack a CUDAStream into a struct representation.
/// Previously the stream's data was packed into a single int64_t,
/// as it was assumed the fields would not require more than
/// 64 bits of storage in total.
/// See https://github.com/pytorch/pytorch/issues/75854
/// for more information regarding newer platforms that may violate
/// this assumption.
///
/// The CUDAStream can be unpacked using unpack().
struct c10::StreamData3 pack3() const {
return stream_.pack3();
}
// Unpack a CUDAStream from the 3 fields generated by pack().
static CUDAStream unpack3(
StreamId stream_id,
DeviceIndex device_index,
DeviceType device_type) {
return CUDAStream(Stream::unpack3(stream_id, device_index, device_type));
}
static std::tuple<int, int> priority_range() {
// Note: this returns the range of priority **supported by PyTorch**, not
// the range of priority **supported by CUDA**. The former is a subset of
// the latter.
int least_priority = 0, greatest_priority = 0;
C10_CUDA_CHECK(
cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority));
#ifdef USE_ROCM
// See Note [HIP stream priorities]
TORCH_INTERNAL_ASSERT(
least_priority == 1, "Unexpected HIP stream priority range");
least_priority = 0;
#else
TORCH_INTERNAL_ASSERT(
least_priority == 0, "Unexpected CUDA stream priority range");
#endif
TORCH_INTERNAL_ASSERT(
greatest_priority <= -1, "Unexpected CUDA stream priority range");
greatest_priority = std::max(
-c10::cuda::max_compile_time_stream_priorities + 1, greatest_priority);
return std::make_tuple(least_priority, greatest_priority);
}
// Deleted for now; use CUDAEvent::block instead
// void synchronize_with(const CUDAEvent& event) const;
private:
Stream stream_;
};
/**
* Get a new stream from the CUDA stream pool. You can think of this
* as "creating" a new stream, but no such creation actually happens;
* instead, streams are preallocated from the pool and returned in a
* round-robin fashion.
*
* You can request a stream from the high priority pool by setting
* isHighPriority to true, or a stream for a specific device by setting device
* (defaulting to the current CUDA stream.)
*/
C10_API CUDAStream
getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1);
// no default priority to disambiguate overloads
C10_API CUDAStream
getStreamFromPool(const int priority, DeviceIndex device = -1);
/**
* Get a CUDAStream from a externally allocated one.
*
* This is mainly for interoperability with different libraries where we
* want to operate on a non-torch allocated stream for data exchange or similar
* purposes
*/
C10_API CUDAStream
getStreamFromExternal(cudaStream_t ext_stream, DeviceIndex device_index);
/**
* Get the default CUDA stream, for the passed CUDA device, or for the
* current device if no device index is passed. The default stream is
* where most computation occurs when you aren't explicitly using
* streams.
*/
C10_API CUDAStream getDefaultCUDAStream(DeviceIndex device_index = -1);
/**
* Get the current CUDA stream, for the passed CUDA device, or for the
* current device if no device index is passed. The current CUDA stream
* will usually be the default CUDA stream for the device, but it may
* be different if someone called 'setCurrentCUDAStream' or used 'StreamGuard'
* or 'CUDAStreamGuard'.
*/
C10_API CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1);
/**
* Set the current stream on the device of the passed in stream to be
* the passed in stream. Yes, you read that right: this function
* has *nothing* to do with the current device: it toggles the current
* stream of the device of the passed stream.
*
* Confused? Avoid using this function; prefer using 'CUDAStreamGuard' instead
* (which will switch both your current device and current stream in the way you
* expect, and reset it back to its original state afterwards).
*/
C10_API void setCurrentCUDAStream(CUDAStream stream);
C10_API std::ostream& operator<<(std::ostream& stream, const CUDAStream& s);
} // namespace c10::cuda
namespace std {
template <>
struct hash<c10::cuda::CUDAStream> {
size_t operator()(c10::cuda::CUDAStream s) const noexcept {
return std::hash<c10::Stream>{}(s.unwrap());
}
};
} // namespace std

View File

@ -0,0 +1,63 @@
#pragma once
#include <cuda.h>
#define NVML_NO_UNVERSIONED_FUNC_DEFS
#include <nvml.h>
#define C10_CUDA_DRIVER_CHECK(EXPR) \
do { \
CUresult __err = EXPR; \
if (__err != CUDA_SUCCESS) { \
const char* err_str; \
CUresult get_error_str_err C10_UNUSED = \
c10::cuda::DriverAPI::get()->cuGetErrorString_(__err, &err_str); \
if (get_error_str_err != CUDA_SUCCESS) { \
AT_ERROR("CUDA driver error: unknown error"); \
} else { \
AT_ERROR("CUDA driver error: ", err_str); \
} \
} \
} while (0)
#define C10_LIBCUDA_DRIVER_API(_) \
_(cuDeviceGetAttribute) \
_(cuMemAddressReserve) \
_(cuMemRelease) \
_(cuMemMap) \
_(cuMemAddressFree) \
_(cuMemSetAccess) \
_(cuMemUnmap) \
_(cuMemCreate) \
_(cuMemGetAllocationGranularity) \
_(cuMemExportToShareableHandle) \
_(cuMemImportFromShareableHandle) \
_(cuGetErrorString)
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
#define C10_LIBCUDA_DRIVER_API_12030(_) \
_(cuMulticastAddDevice) \
_(cuMulticastBindMem) \
_(cuMulticastCreate)
#else
#define C10_LIBCUDA_DRIVER_API_12030(_)
#endif
#define C10_NVML_DRIVER_API(_) \
_(nvmlInit_v2) \
_(nvmlDeviceGetHandleByPciBusId_v2) \
_(nvmlDeviceGetNvLinkRemoteDeviceType) \
_(nvmlDeviceGetNvLinkRemotePciInfo_v2) \
_(nvmlDeviceGetComputeRunningProcesses)
namespace c10::cuda {
struct DriverAPI {
#define CREATE_MEMBER(name) decltype(&name) name##_;
C10_LIBCUDA_DRIVER_API(CREATE_MEMBER)
C10_LIBCUDA_DRIVER_API_12030(CREATE_MEMBER)
C10_NVML_DRIVER_API(CREATE_MEMBER)
#undef CREATE_MEMBER
static DriverAPI* get();
static void* get_nvml_handle();
};
} // namespace c10::cuda

View File

@ -0,0 +1,249 @@
#pragma once
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <c10/core/impl/PyInterpreter.h>
#include <cuda_runtime_api.h>
#include <cstdint>
#include <optional>
namespace c10::cuda::impl {
struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr DeviceType static_type = DeviceType::CUDA;
CUDAGuardImpl() = default;
explicit CUDAGuardImpl(DeviceType t) {
TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA);
}
DeviceType type() const override {
return DeviceType::CUDA;
}
Device exchangeDevice(Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_cuda());
auto old_device_index = c10::cuda::ExchangeDevice(d.index());
return Device(DeviceType::CUDA, old_device_index);
}
Device getDevice() const override {
DeviceIndex device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
return Device(DeviceType::CUDA, device);
}
std::optional<Device> uncheckedGetDevice() const noexcept {
DeviceIndex device{-1};
const auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDevice(&device));
C10_CUDA_CHECK_WARN(err);
if (err != cudaSuccess) {
return std::nullopt;
}
return Device(DeviceType::CUDA, device);
}
void setDevice(Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_cuda());
C10_CUDA_CHECK(c10::cuda::SetDevice(d.index()));
}
void uncheckedSetDevice(Device d) const noexcept override {
C10_CUDA_CHECK_WARN(c10::cuda::MaybeSetDevice(d.index()));
}
Stream getStream(Device d) const noexcept override {
return getCurrentCUDAStream(d.index()).unwrap();
}
Stream getDefaultStream(Device d) const override {
return getDefaultCUDAStream(d.index());
}
Stream getNewStream(Device d, int priority = 0) const override {
return getStreamFromPool(priority, d.index());
}
Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
const override {
return getStreamFromPool(isHighPriority, d.index());
}
// NB: These do NOT set the current device
Stream exchangeStream(Stream s) const noexcept override {
CUDAStream cs(s);
auto old_stream = getCurrentCUDAStream(s.device().index());
setCurrentCUDAStream(cs);
return old_stream.unwrap();
}
DeviceIndex deviceCount() const noexcept override {
return device_count();
}
// Event-related functions
void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const {
// Maps PyTorch's Event::Flag to CUDA flag
auto cuda_flag = cudaEventDefault;
switch (flag) {
case EventFlag::PYTORCH_DEFAULT:
cuda_flag = cudaEventDisableTiming;
break;
case EventFlag::BACKEND_DEFAULT:
cuda_flag = cudaEventDefault;
break;
default:
TORCH_CHECK(false, "CUDA event received unknown flag");
}
C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_creation(
c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
}
}
void destroyEvent(void* event, const DeviceIndex device_index)
const noexcept override {
if (!event)
return;
auto cuda_event = static_cast<cudaEvent_t>(event);
DeviceIndex orig_device{-1};
C10_CUDA_CHECK_WARN(c10::cuda::GetDevice(&orig_device));
C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(device_index));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(
c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
}
C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event));
C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(orig_device));
}
void record(
void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
TORCH_CHECK(
device_index == -1 || device_index == stream.device_index(),
"Event device index ",
device_index,
" does not match recording stream's device index ",
stream.device_index(),
".");
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(*event);
CUDAStream cuda_stream{stream};
// Moves to stream's device to record
const auto orig_device = getDevice();
setDevice(stream.device());
// Creates the event (lazily)
if (!cuda_event)
createEvent(&cuda_event, flag);
C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream));
// Makes the void* point to the (possibly just allocated) CUDA event
*event = cuda_event;
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_record(
c10::kCUDA,
reinterpret_cast<uintptr_t>(cuda_event),
reinterpret_cast<uintptr_t>(cuda_stream.stream()));
}
// Resets device
setDevice(orig_device);
}
void block(void* event, const Stream& stream) const override {
if (!event)
return;
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
CUDAStream cuda_stream{stream};
const auto orig_device = getDevice();
setDevice(stream.device());
C10_CUDA_CHECK(cudaStreamWaitEvent(
cuda_stream,
cuda_event,
/*flags (must be zero)=*/0));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_wait(
c10::kCUDA,
reinterpret_cast<uintptr_t>(cuda_event),
reinterpret_cast<uintptr_t>(cuda_stream.stream()));
}
setDevice(orig_device);
}
// May be called from any device
bool queryEvent(void* event) const override {
if (!event)
return true;
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
// Note: cudaEventQuery can be safely called from any device
const cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(cuda_event));
if (err != cudaErrorNotReady) {
C10_CUDA_CHECK(err);
} else {
// ignore and clear the error if not ready
(void)cudaGetLastError();
}
return (err == cudaSuccess);
}
// Stream-related functions
bool queryStream(const Stream& stream) const override {
CUDAStream cuda_stream{stream};
return cuda_stream.query();
}
void synchronizeStream(const Stream& stream) const override {
CUDAStream cuda_stream{stream};
cuda_stream.synchronize();
}
void synchronizeEvent(void* event) const override {
if (!event)
return;
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_synchronization(
c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
}
// Note: cudaEventSynchronize can be safely called from any device
C10_CUDA_CHECK(cudaEventSynchronize(cuda_event));
}
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
const override {
CUDAStream cuda_stream{stream};
CUDACachingAllocator::recordStream(data_ptr, cuda_stream);
}
double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
const override {
TORCH_CHECK(
event1 && event2,
"Both events must be recorded before calculating elapsed time.");
// Even though cudaEventElapsedTime can be safely called from any device, if
// the current device is not initialized, it will create a new cuda context,
// which will consume a lot of memory.
DeviceIndex orig_device{-1};
C10_CUDA_CHECK(c10::cuda::GetDevice(&orig_device));
C10_CUDA_CHECK(c10::cuda::SetDevice(device_index));
cudaEvent_t cuda_event1 = static_cast<cudaEvent_t>(event1);
cudaEvent_t cuda_event2 = static_cast<cudaEvent_t>(event2);
float time_ms = 0;
// raise cudaErrorNotReady if either event is recorded but not yet completed
C10_CUDA_CHECK(cudaEventElapsedTime(&time_ms, cuda_event1, cuda_event2));
C10_CUDA_CHECK(c10::cuda::SetDevice(orig_device));
return static_cast<double>(time_ms);
}
};
} // namespace c10::cuda::impl

View File

@ -0,0 +1,9 @@
#pragma once
#include <c10/cuda/CUDAMacros.h>
namespace c10::cuda::impl {
C10_CUDA_API int c10_cuda_test();
}

View File

@ -0,0 +1,160 @@
#ifndef C10_MACROS_EXPORT_H_
#define C10_MACROS_EXPORT_H_
/* Header file to define the common scaffolding for exported symbols.
*
* Export is by itself a quite tricky situation to deal with, and if you are
* hitting this file, make sure you start with the background here:
* - Linux: https://gcc.gnu.org/wiki/Visibility
* - Windows:
* https://docs.microsoft.com/en-us/cpp/cpp/dllexport-dllimport?view=vs-2017
*
* Do NOT include this file directly. Instead, use c10/macros/Macros.h
*/
// You do not need to edit this part of file unless you are changing the core
// pytorch export abstractions.
//
// This part defines the C10 core export and import macros. This is controlled
// by whether we are building shared libraries or not, which is determined
// during build time and codified in c10/core/cmake_macros.h.
// When the library is built as a shared lib, EXPORT and IMPORT will contain
// visibility attributes. If it is being built as a static lib, then EXPORT
// and IMPORT basically have no effect.
// As a rule of thumb, you should almost NEVER mix static and shared builds for
// libraries that depend on c10. AKA, if c10 is built as a static library, we
// recommend everything dependent on c10 to be built statically. If c10 is built
// as a shared library, everything dependent on it should be built as shared. In
// the PyTorch project, all native libraries shall use the macro
// C10_BUILD_SHARED_LIB to check whether pytorch is building shared or static
// libraries.
// For build systems that do not directly depend on CMake and directly build
// from the source directory (such as Buck), one may not have a cmake_macros.h
// file at all. In this case, the build system is responsible for providing
// correct macro definitions corresponding to the cmake_macros.h.in file.
//
// In such scenarios, one should define the macro
// C10_USING_CUSTOM_GENERATED_MACROS
// to inform this header that it does not need to include the cmake_macros.h
// file.
#ifndef C10_USING_CUSTOM_GENERATED_MACROS
#include <c10/macros/cmake_macros.h>
#endif // C10_USING_CUSTOM_GENERATED_MACROS
#ifdef _WIN32
#define C10_HIDDEN
#if defined(C10_BUILD_SHARED_LIBS)
#define C10_EXPORT __declspec(dllexport)
#define C10_IMPORT __declspec(dllimport)
#else
#define C10_EXPORT
#define C10_IMPORT
#endif
#else // _WIN32
#if defined(__GNUC__)
#define C10_EXPORT __attribute__((__visibility__("default")))
#define C10_HIDDEN __attribute__((__visibility__("hidden")))
#else // defined(__GNUC__)
#define C10_EXPORT
#define C10_HIDDEN
#endif // defined(__GNUC__)
#define C10_IMPORT C10_EXPORT
#endif // _WIN32
#ifdef NO_EXPORT
#undef C10_EXPORT
#define C10_EXPORT
#endif
// Definition of an adaptive XX_API macro, that depends on whether you are
// building the library itself or not, routes to XX_EXPORT and XX_IMPORT.
// Basically, you will need to do this for each shared library that you are
// building, and the instruction is as follows: assuming that you are building
// a library called libawesome.so. You should:
// (1) for your cmake target (usually done by "add_library(awesome, ...)"),
// define a macro called AWESOME_BUILD_MAIN_LIB using
// target_compile_options.
// (2) define the AWESOME_API macro similar to the one below.
// And in the source file of your awesome library, use AWESOME_API to
// annotate public symbols.
// Here, for the C10 library, we will define the macro C10_API for both import
// and export.
// This one is being used by libc10.so
#ifdef C10_BUILD_MAIN_LIB
#define C10_API C10_EXPORT
#else
#define C10_API C10_IMPORT
#endif
// This one is being used by libtorch.so
#ifdef CAFFE2_BUILD_MAIN_LIB
#define TORCH_API C10_EXPORT
#else
#define TORCH_API C10_IMPORT
#endif
// You may be wondering: Whose brilliant idea was it to split torch_cuda into
// two pieces with confusing names?
// Once upon a time, there _was_ only TORCH_CUDA_API. All was happy until we
// tried to compile PyTorch for CUDA 11.1, which ran into relocation marker
// issues when linking big binaries.
// (https://github.com/pytorch/pytorch/issues/39968) We had two choices:
// (1) Stop supporting so many GPU architectures
// (2) Do something else
// We chose #2 and decided to split the behemoth that was torch_cuda into two
// smaller libraries, one with most of the core kernel functions (torch_cuda_cu)
// and the other that had..well..everything else (torch_cuda_cpp). The idea was
// this: instead of linking our static libraries (like the hefty
// libcudnn_static.a) with another huge library, torch_cuda, and run into pesky
// relocation marker issues, we could link our static libraries to a smaller
// part of torch_cuda (torch_cuda_cpp) and avoid the issues.
// libtorch_cuda_cu.so
#ifdef TORCH_CUDA_CU_BUILD_MAIN_LIB
#define TORCH_CUDA_CU_API C10_EXPORT
#elif defined(BUILD_SPLIT_CUDA)
#define TORCH_CUDA_CU_API C10_IMPORT
#endif
// libtorch_cuda_cpp.so
#ifdef TORCH_CUDA_CPP_BUILD_MAIN_LIB
#define TORCH_CUDA_CPP_API C10_EXPORT
#elif defined(BUILD_SPLIT_CUDA)
#define TORCH_CUDA_CPP_API C10_IMPORT
#endif
// libtorch_cuda.so (where torch_cuda_cu and torch_cuda_cpp are a part of the
// same api)
#ifdef TORCH_CUDA_BUILD_MAIN_LIB
#define TORCH_CUDA_CPP_API C10_EXPORT
#define TORCH_CUDA_CU_API C10_EXPORT
#elif !defined(BUILD_SPLIT_CUDA)
#define TORCH_CUDA_CPP_API C10_IMPORT
#define TORCH_CUDA_CU_API C10_IMPORT
#endif
#if defined(TORCH_HIP_BUILD_MAIN_LIB)
#define TORCH_HIP_API C10_EXPORT
#else
#define TORCH_HIP_API C10_IMPORT
#endif
#if defined(TORCH_XPU_BUILD_MAIN_LIB)
#define TORCH_XPU_API C10_EXPORT
#else
#define TORCH_XPU_API C10_IMPORT
#endif
// Enums only need to be exported on windows for non-CUDA files
#if defined(_WIN32) && defined(__CUDACC__)
#define C10_API_ENUM C10_API
#else
#define C10_API_ENUM
#endif
#endif // C10_MACROS_MACROS_H_

View File

@ -0,0 +1,581 @@
#ifndef C10_MACROS_MACROS_H_
#define C10_MACROS_MACROS_H_
#include <cassert>
/* Main entry for c10/macros.
*
* In your code, include c10/macros/Macros.h directly, instead of individual
* files in this folder.
*/
// For build systems that do not directly depend on CMake and directly build
// from the source directory (such as Buck), one may not have a cmake_macros.h
// file at all. In this case, the build system is responsible for providing
// correct macro definitions corresponding to the cmake_macros.h.in file.
//
// In such scenarios, one should define the macro
// C10_USING_CUSTOM_GENERATED_MACROS
// to inform this header that it does not need to include the cmake_macros.h
// file.
#ifndef C10_USING_CUSTOM_GENERATED_MACROS
#include <c10/macros/cmake_macros.h>
#endif // C10_USING_CUSTOM_GENERATED_MACROS
#include <c10/macros/Export.h>
#if defined(__clang__)
#define __ubsan_ignore_float_divide_by_zero__ \
__attribute__((no_sanitize("float-divide-by-zero")))
#define __ubsan_ignore_undefined__ __attribute__((no_sanitize("undefined")))
#define __ubsan_ignore_signed_int_overflow__ \
__attribute__((no_sanitize("signed-integer-overflow")))
#define __ubsan_ignore_pointer_overflow__ \
__attribute__((no_sanitize("pointer-overflow")))
#define __ubsan_ignore_function__ __attribute__((no_sanitize("function")))
#else
#define __ubsan_ignore_float_divide_by_zero__
#define __ubsan_ignore_undefined__
#define __ubsan_ignore_signed_int_overflow__
#define __ubsan_ignore_pointer_overflow__
#define __ubsan_ignore_function__
#endif
// Detect address sanitizer as some stuff doesn't work with it
#undef C10_ASAN_ENABLED
// for clang
#if defined(__has_feature)
#if ((__has_feature(address_sanitizer)))
#define C10_ASAN_ENABLED 1
#endif
#endif
// for gcc
#if defined(__SANITIZE_ADDRESS__)
#if __SANITIZE_ADDRESS__
#if !defined(C10_ASAN_ENABLED)
#define C10_ASAN_ENABLED 1
#endif
#endif
#endif
#if !defined(C10_ASAN_ENABLED)
#define C10_ASAN_ENABLED 0
#endif
// Detect undefined-behavior sanitizer (UBSAN)
#undef C10_UBSAN_ENABLED
// for clang or gcc >= 14
// NB: gcc 14 adds support for Clang's __has_feature
// https://gcc.gnu.org/gcc-14/changes.html
// gcc < 14 doesn't have a macro for UBSAN
// (e.g. __SANITIZE_UNDEFINED__ does not exist in gcc)
// https://github.com/google/sanitizers/issues/765
#if defined(__has_feature)
#if ((__has_feature(undefined_behavior_sanitizer)))
#define C10_UBSAN_ENABLED 1
#endif
#endif
#if !defined(C10_UBSAN_ENABLED)
#define C10_UBSAN_ENABLED 0
#endif
// Disable the copy and assignment operator for a class. Note that this will
// disable the usage of the class in std containers.
#define C10_DISABLE_COPY_AND_ASSIGN(classname) \
classname(const classname&) = delete; \
classname& operator=(const classname&) = delete
#define C10_CONCATENATE_IMPL(s1, s2) s1##s2
#define C10_CONCATENATE(s1, s2) C10_CONCATENATE_IMPL(s1, s2)
#define C10_MACRO_EXPAND(args) args
#define C10_STRINGIZE_IMPL(x) #x
#define C10_STRINGIZE(x) C10_STRINGIZE_IMPL(x)
/**
* C10_ANONYMOUS_VARIABLE(str) introduces a new identifier which starts with
* str and ends with a unique number.
*/
#ifdef __COUNTER__
#define C10_UID __COUNTER__
#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __COUNTER__)
#else
#define C10_UID __LINE__
#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __LINE__)
#endif
#ifdef __has_cpp_attribute
#define C10_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x)
#else
#define C10_HAS_CPP_ATTRIBUTE(x) (0)
#endif
/// C10_NODISCARD - Warn if a type or return value is discarded.
// Technically, we should check if __cplusplus > 201402L here, because
// [[nodiscard]] is only defined in C++17. However, some compilers
// we care about don't advertise being C++17 (e.g., clang), but
// support the attribute anyway. In fact, this is not just a good idea,
// it's the law: clang::warn_unused_result doesn't work on nvcc + clang
// and the best workaround for this case is to use [[nodiscard]]
// instead; see https://github.com/pytorch/pytorch/issues/13118
//
// Note to future editors: if you have noticed that a compiler is
// misbehaving (e.g., it advertises support, but the support doesn't
// actually work, or it is emitting warnings). Some compilers which
// are strict about the matter include MSVC, which will complain:
//
// error C2429: attribute 'nodiscard' requires compiler flag '/std:c++latest'
//
// Exhibits:
// - MSVC 19.14: https://godbolt.org/z/Dzd7gn (requires /std:c++latest)
// - Clang 8.0.0: https://godbolt.org/z/3PYL4Z (always advertises support)
// - gcc 8.3: https://godbolt.org/z/4tLMQS (always advertises support)
#if C10_HAS_CPP_ATTRIBUTE(nodiscard)
#define C10_NODISCARD [[nodiscard]]
// Workaround for llvm.org/PR23435, since clang 3.6 and below emit a spurious
// error when __has_cpp_attribute is given a scoped attribute in C mode.
#elif __cplusplus && C10_HAS_CPP_ATTRIBUTE(clang::warn_unused_result)
// TODO: It's possible this is still triggering
// https://github.com/pytorch/pytorch/issues/13118 on Windows; if it is, better
// fix it.
#define C10_NODISCARD [[clang::warn_unused_result]]
#else
#define C10_NODISCARD
#endif
// suppress an unused variable.
#if defined(_MSC_VER) && !defined(__clang__)
#define C10_UNUSED __pragma(warning(suppress : 4100 4101))
#else
#define C10_UNUSED __attribute__((__unused__))
#endif //_MSC_VER
#if !defined(__has_attribute)
#define __has_attribute(x) 0
#endif
// Direct port of LLVM_ATTRIBUTE_USED.
#if __has_attribute(used)
#define C10_USED __attribute__((__used__))
#else
#define C10_USED
#endif
#define C10_RESTRICT __restrict
// Simply define the namespace, in case a dependent library want to refer to
// the c10 namespace but not any nontrivial files.
namespace c10 {}
namespace c10::cuda {}
namespace c10::hip {}
namespace c10::xpu {}
// Since C10 is the core library for caffe2 (and aten), we will simply reroute
// all abstractions defined in c10 to be available in caffe2 as well.
// This is only for backwards compatibility. Please use the symbols from the
// c10 namespace where possible.
namespace caffe2 {
using namespace c10;
}
namespace at {
using namespace c10;
}
namespace at::cuda {
using namespace c10::cuda;
} // namespace at::cuda
// WARNING!!! THIS IS A GIANT HACK!!!
// This line means you cannot simultaneously include c10/hip
// and c10/cuda and then use them from the at::cuda namespace.
// This is true in practice, because HIPIFY works inplace on
// files in ATen/cuda, so it assumes that c10::hip is available
// from at::cuda. This namespace makes that happen. When
// HIPIFY is no longer out-of-place, we can switch the cuda
// here to hip and everyone is happy.
namespace at::cuda {
using namespace c10::hip;
} // namespace at::cuda
namespace at::xpu {
using namespace c10::xpu;
} // namespace at::xpu
// C10_LIKELY/C10_UNLIKELY
//
// These macros provide parentheses, so you can use these macros as:
//
// if C10_LIKELY(some_expr) {
// ...
// }
//
// NB: static_cast to boolean is mandatory in C++, because __builtin_expect
// takes a long argument, which means you may trigger the wrong conversion
// without it.
//
#if defined(__GNUC__) || defined(__ICL) || defined(__clang__)
#define C10_LIKELY(expr) (__builtin_expect(static_cast<bool>(expr), 1))
#define C10_UNLIKELY(expr) (__builtin_expect(static_cast<bool>(expr), 0))
#else
#define C10_LIKELY(expr) (expr)
#define C10_UNLIKELY(expr) (expr)
#endif
/// C10_NOINLINE - Functions whose declaration is annotated with this will not
/// be inlined.
#ifdef __GNUC__
#define C10_NOINLINE __attribute__((noinline))
#elif _MSC_VER
#define C10_NOINLINE __declspec(noinline)
#else
#define C10_NOINLINE
#endif
#if defined(_MSC_VER)
#define C10_ALWAYS_INLINE __forceinline
#elif __has_attribute(always_inline) || defined(__GNUC__)
#define C10_ALWAYS_INLINE __attribute__((__always_inline__)) inline
#else
#define C10_ALWAYS_INLINE inline
#endif
#if defined(_MSC_VER)
#define C10_ATTR_VISIBILITY_HIDDEN
#elif defined(__GNUC__)
#define C10_ATTR_VISIBILITY_HIDDEN __attribute__((__visibility__("hidden")))
#else
#define C10_ATTR_VISIBILITY_HIDDEN
#endif
#define C10_ERASE C10_ALWAYS_INLINE C10_ATTR_VISIBILITY_HIDDEN
#include <cstdint>
#ifdef __HIPCC__
// Unlike CUDA, HIP requires a HIP header to be included for __host__ to work.
// We do this #include here so that C10_HOST_DEVICE and friends will Just Work.
// See https://github.com/ROCm-Developer-Tools/HIP/issues/441
#include <hip/hip_runtime.h>
#endif
#if defined(__CUDACC__) || defined(__HIPCC__)
// Designates functions callable from the host (CPU) and the device (GPU)
#define C10_HOST_DEVICE __host__ __device__
#define C10_DEVICE __device__
#define C10_HOST __host__
// constants from
// (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications)
// The maximum number of threads per multiprocessor is 1024 for Turing
// architecture (7.5), 1536 for Geforce Ampere (8.6)/Jetson Orin (8.7), and
// 2048 for all other architectures. You'll get warnings if you exceed these
// constants. Hence, the following macros adjust the input values from the user
// to resolve potential warnings.
#if __CUDA_ARCH__ == 750
constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024;
#elif __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 890
constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1536;
#else
constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048;
#endif
// CUDA_MAX_THREADS_PER_BLOCK is same for all architectures currently
constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1024;
// CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block
// size. 256 is a good number for this fallback and should give good occupancy
// and versatility across all architectures.
constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
// NOTE: if you are thinking of constexpr-ify the inputs to launch bounds, it
// turns out that although __launch_bounds__ can take constexpr, it
// can't take a constexpr that has anything to do with templates.
// Currently we use launch_bounds that depend on template arguments in
// Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, C10_MAX_THREADS_PER_BLOCK
// and C10_MIN_BLOCKS_PER_SM are kept as macros.
// Suppose you were planning to write __launch_bounds__(a, b), based on your
// performance tuning on a modern GPU. Instead, you should write
// __launch_bounds__(C10_MAX_THREADS_PER_BLOCK(a), C10_MIN_BLOCKS_PER_SM(a, b)),
// which will also properly respect limits on old architectures.
#define C10_MAX_THREADS_PER_BLOCK(val) \
(((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) \
: CUDA_THREADS_PER_BLOCK_FALLBACK)
#define C10_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) \
((((threads_per_block) * (blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) \
? (blocks_per_sm) \
: ((CUDA_MAX_THREADS_PER_SM + (threads_per_block)-1) / \
(threads_per_block))))
// C10_LAUNCH_BOUNDS is analogous to __launch_bounds__
#define C10_LAUNCH_BOUNDS_0 \
__launch_bounds__( \
256, 4) // default launch bounds that should give good occupancy and
// versatility across all architectures.
#define C10_LAUNCH_BOUNDS_1(max_threads_per_block) \
__launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))))
#define C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) \
__launch_bounds__( \
(C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))), \
(C10_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm))))
#else
#define C10_HOST_DEVICE
#define C10_HOST
#define C10_DEVICE
#endif
#if defined(USE_ROCM)
#define C10_HIP_HOST_DEVICE __host__ __device__
#else
#define C10_HIP_HOST_DEVICE
#endif
#if defined(USE_ROCM)
#define C10_WARP_SIZE warpSize // = 64 or 32 (Defined in hip_runtime.h)
#else
#define C10_WARP_SIZE 32
#endif
#if defined(_MSC_VER) && _MSC_VER <= 1900
#define __func__ __FUNCTION__
#endif
// CUDA_KERNEL_ASSERT checks the assertion
// even when NDEBUG is defined. This is useful for important assertions in CUDA
// code that would otherwise be suppressed when building Release.
#if defined(__ANDROID__) || defined(__APPLE__) || defined(__FreeBSD__)
// Those platforms do not support assert()
#define CUDA_KERNEL_ASSERT(cond)
#define CUDA_KERNEL_ASSERT_MSG(cond, msg)
#define SYCL_KERNEL_ASSERT(cond)
#elif defined(_MSC_VER)
#if defined(NDEBUG)
extern "C" {
C10_IMPORT
#if defined(__SYCL_DEVICE_ONLY__)
extern SYCL_EXTERNAL void _wassert(
const wchar_t* wexpr,
const wchar_t* wfile,
unsigned line);
#else
#if defined(__CUDA_ARCH__)
__host__ __device__
#endif // __CUDA_ARCH__
void
_wassert(wchar_t const* _Message, wchar_t const* _File, unsigned _Line);
#endif // __SYCL_DEVICE_ONLY__
}
#endif // NDEBUG
#define CUDA_KERNEL_ASSERT(cond) \
if (C10_UNLIKELY(!(cond))) { \
(void)(_wassert( \
_CRT_WIDE(#cond), \
_CRT_WIDE(__FILE__), \
static_cast<unsigned>(__LINE__)), \
0); \
}
// TODO: This doesn't assert the message because I (chilli) couldn't figure out
// a nice way to convert a char* to a wchar_t*
#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \
if (C10_UNLIKELY(!(cond))) { \
(void)(_wassert( \
_CRT_WIDE(#cond), \
_CRT_WIDE(__FILE__), \
static_cast<unsigned>(__LINE__)), \
0); \
}
#define SYCL_KERNEL_ASSERT(cond) \
if (C10_UNLIKELY(!(cond))) { \
(void)(_wassert( \
_CRT_WIDE(#cond), \
_CRT_WIDE(__FILE__), \
static_cast<unsigned>(__LINE__)), \
0); \
}
#else // __APPLE__, _MSC_VER
#if defined(NDEBUG)
extern "C" {
#if defined(__SYCL_DEVICE_ONLY__)
extern SYCL_EXTERNAL void __assert_fail(
const char* expr,
const char* file,
unsigned int line,
const char* func);
#else // __SYCL_DEVICE_ONLY__
#if (defined(__CUDA_ARCH__) && !(defined(__clang__) && defined(__CUDA__)))
// CUDA supports __assert_fail function which are common for both device
// and host side code.
__host__ __device__
#endif
// This forward declaration matching the declaration of __assert_fail
// exactly how it is in glibc in case parts of the program are compiled with
// different NDEBUG settings. Otherwise we might get 'ambiguous declaration'
// error. Note: On ROCm - this declaration serves for host side compilation.
void
__assert_fail(
const char* assertion,
const char* file,
unsigned int line,
const char* function) noexcept __attribute__((__noreturn__));
#endif // __SYCL_DEVICE_ONLY__
}
#endif // NDEBUG
// ROCm disable kernel assert by default
#if !defined(C10_USE_ROCM_KERNEL_ASSERT) and defined(USE_ROCM)
#define CUDA_KERNEL_ASSERT(cond)
#define CUDA_KERNEL_ASSERT_MSG(cond, msg)
#define SYCL_KERNEL_ASSERT(cond)
#else
#define CUDA_KERNEL_ASSERT(cond) \
if (C10_UNLIKELY(!(cond))) { \
__assert_fail( \
#cond, __FILE__, static_cast<unsigned int>(__LINE__), __func__); \
}
#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \
if (C10_UNLIKELY(!(cond))) { \
__assert_fail( \
msg, __FILE__, static_cast<unsigned int>(__LINE__), __func__); \
}
#define SYCL_KERNEL_ASSERT(cond) \
if (C10_UNLIKELY(!(cond))) { \
__assert_fail( \
#cond, __FILE__, static_cast<unsigned int>(__LINE__), __func__); \
}
#endif // C10_USE_ROCM_KERNEL_ASSERT and USE_ROCM
#endif // __APPLE__
#ifdef __APPLE__
#include <TargetConditionals.h>
#endif
#if defined(__ANDROID__)
#define C10_ANDROID 1
#define C10_MOBILE 1
#elif ( \
defined(__APPLE__) && \
(TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE))
#define C10_IOS 1
#define C10_MOBILE 1
#endif // ANDROID / IOS
#if defined(C10_MOBILE) && C10_MOBILE
#define C10_ALWAYS_INLINE_UNLESS_MOBILE inline
#else
#define C10_ALWAYS_INLINE_UNLESS_MOBILE C10_ALWAYS_INLINE
#endif
#if defined(__CUDA_ARCH__)
#if defined(_MSC_VER) && defined(__CUDACC__)
#define CONSTEXPR_EXCEPT_WIN_CUDA const
#define C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA __host__
// Note [static constexpr char* members for windows NVCC]
// The Windows NVCC compiler doesn't handle static constexpr class members,
// although it's fixed in a later version.
// (see
// https://developercommunity.visualstudio.com/t/intellisense-error-c11-static-constexpr-member-ini/245425)
//
// If we want to ensure that our field is static under all builds, then we need
// to work around it specifically for windows NVCC by making it (a) const, (b)
// defined outside of the class definition We need to define it outside of the
// class definition because of the C++ standard; char* is not an integral type
// (see
// https://stackoverflow.com/questions/24278473/intellisense-a-member-of-type-const-char-const-cannot-have-an-in-class-in)
//
// So instead of this:
// struct Foo {
// static constexpr const char* name = "foo";
// }
// In Windows NVCC, we end up with this:
// struct Foo {
// static const char* name;
// }
// const char* Foo::name = "foo";
//
// This gives us a small perf hit for any code that wants to access these field
// members, but right now it isn't used in any perf-critical code paths.
#define STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(field, val) \
static const char* field;
#define STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(cls, field, val) \
const char* cls::field = val;
#else
#define CONSTEXPR_EXCEPT_WIN_CUDA constexpr
#define C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA __host__
#define STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(field, val) \
static constexpr const char* field = val;
#define STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(cls, field, val)
#endif
#else
#if defined(_MSC_VER) && defined(__CUDACC__)
#define CONSTEXPR_EXCEPT_WIN_CUDA const
#define C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA
#define STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(field, val) \
static const char* field;
#define STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(cls, field, val) \
const char* cls::field = val;
#else
#define CONSTEXPR_EXCEPT_WIN_CUDA constexpr
#define C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA constexpr
#define STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(field, val) \
static constexpr const char* field = val;
#define STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(cls, field, val)
#endif
#endif
#ifndef HAS_DEMANGLE
#if defined(__ANDROID__) || defined(_WIN32) || defined(__EMSCRIPTEN__)
#define HAS_DEMANGLE 0
#elif defined(__APPLE__) && \
(TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE)
#define HAS_DEMANGLE 0
#else
#define HAS_DEMANGLE 1
#endif
#endif // HAS_DEMANGLE
#define _C10_PRAGMA__(string) _Pragma(#string)
#define _C10_PRAGMA_(string) _C10_PRAGMA__(string)
#ifdef __clang__
#define C10_CLANG_DIAGNOSTIC_PUSH() _Pragma("clang diagnostic push")
#define C10_CLANG_DIAGNOSTIC_POP() _Pragma("clang diagnostic pop")
#define C10_CLANG_DIAGNOSTIC_IGNORE(flag) \
_C10_PRAGMA_(clang diagnostic ignored flag)
#define C10_CLANG_HAS_WARNING(flag) __has_warning(flag)
#else
#define C10_CLANG_DIAGNOSTIC_PUSH()
#define C10_CLANG_DIAGNOSTIC_POP()
#define C10_CLANG_DIAGNOSTIC_IGNORE(flag)
#define C10_CLANG_HAS_WARNING(flag) 0
#endif
#ifdef __clang__
#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) \
_C10_PRAGMA_(clang diagnostic push) \
_C10_PRAGMA_(clang diagnostic ignored "-Wunknown-warning-option") \
_C10_PRAGMA_(clang diagnostic ignored warning)
#define C10_DIAGNOSTIC_POP() _C10_PRAGMA_(clang diagnostic pop)
#elif __GNUC__
#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) \
_C10_PRAGMA_(GCC diagnostic push) \
_C10_PRAGMA_(GCC diagnostic ignored "-Wpragmas") \
_C10_PRAGMA_(GCC diagnostic ignored warning)
#define C10_DIAGNOSTIC_POP() _C10_PRAGMA_(GCC diagnostic pop)
#else
#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning)
#define C10_DIAGNOSTIC_POP()
#endif
#endif // C10_MACROS_MACROS_H_

View File

@ -0,0 +1,14 @@
#ifndef C10_MACROS_CMAKE_MACROS_H_
#define C10_MACROS_CMAKE_MACROS_H_
// Automatically generated header file for the C10 library.
// Do not include this file directly. Instead, include c10/macros/Macros.h.
#define C10_BUILD_SHARED_LIBS
/* #undef C10_USE_GLOG */
/* #undef C10_USE_GFLAGS */
/* #undef C10_USE_NUMA */
/* #undef C10_USE_MSVC_STATIC_RUNTIME */
/* #undef C10_USE_ROCM_KERNEL_ASSERT */
#endif // C10_MACROS_CMAKE_MACROS_H_

View File

@ -0,0 +1,81 @@
#include <c10/macros/Macros.h>
#include <c10/util/Backtrace.h>
#include <c10/util/env.h>
#include <cstdlib>
#include <exception>
#include <iostream>
#include <mutex>
#include <optional>
namespace c10 {
class AbortHandlerHelper {
public:
static AbortHandlerHelper& getInstance() {
#ifdef _WIN32
thread_local
#endif // _WIN32
static AbortHandlerHelper instance;
return instance;
}
void set(std::terminate_handler handler) {
std::lock_guard<std::mutex> lk(mutex);
if (!inited) {
prev = std::set_terminate(handler);
curr = std::get_terminate();
inited = true;
}
}
std::terminate_handler getPrev() const {
return prev;
}
private:
std::terminate_handler prev = nullptr;
std::terminate_handler curr = nullptr;
bool inited = false;
std::mutex mutex;
AbortHandlerHelper() = default;
~AbortHandlerHelper() {
// Only restore the handler if we are the current one
if (inited && curr == std::get_terminate()) {
std::set_terminate(prev);
}
}
public:
AbortHandlerHelper(AbortHandlerHelper const&) = delete;
void operator=(AbortHandlerHelper const&) = delete;
};
namespace detail {
C10_ALWAYS_INLINE void terminate_handler() {
std::cout << "Unhandled exception caught in c10/util/AbortHandler.h" << '\n';
auto backtrace = get_backtrace();
std::cout << backtrace << '\n' << std::flush;
auto prev_handler = AbortHandlerHelper::getInstance().getPrev();
if (prev_handler) {
prev_handler();
} else {
std::abort();
}
}
} // namespace detail
C10_ALWAYS_INLINE void set_terminate_handler() {
bool use_custom_terminate = false;
// On Windows it is enabled by default based on
// https://github.com/pytorch/pytorch/pull/50320#issuecomment-763147062
#ifdef _WIN32
use_custom_terminate = true;
#endif // _WIN32
auto result = c10::utils::check_env("TORCH_CUSTOM_TERMINATE");
if (result != std::nullopt) {
use_custom_terminate = result.value();
}
if (use_custom_terminate) {
AbortHandlerHelper::getInstance().set(detail::terminate_handler);
}
}
} // namespace c10

View File

@ -0,0 +1,176 @@
//===--- AlignOf.h - Portable calculation of type alignment -----*- C++ -*-===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// This file defines the AlignedCharArray and AlignedCharArrayUnion classes.
//
//===----------------------------------------------------------------------===//
// ATen: modified from llvm::AlignOf
// replaced LLVM_ALIGNAS with alignas
#pragma once
#include <cstddef>
namespace c10 {
/// \struct AlignedCharArray
/// \brief Helper for building an aligned character array type.
///
/// This template is used to explicitly build up a collection of aligned
/// character array types. We have to build these up using a macro and explicit
/// specialization to cope with MSVC (at least till 2015) where only an
/// integer literal can be used to specify an alignment constraint. Once built
/// up here, we can then begin to indirect between these using normal C++
/// template parameters.
// MSVC requires special handling here.
#ifndef _MSC_VER
template <size_t Alignment, size_t Size>
struct AlignedCharArray {
// NOLINTNEXTLINE(*c-arrays)
alignas(Alignment) char buffer[Size];
};
#else // _MSC_VER
/// \brief Create a type with an aligned char buffer.
template <size_t Alignment, size_t Size>
struct AlignedCharArray;
// We provide special variations of this template for the most common
// alignments because __declspec(align(...)) doesn't actually work when it is
// a member of a by-value function argument in MSVC, even if the alignment
// request is something reasonably like 8-byte or 16-byte. Note that we can't
// even include the declspec with the union that forces the alignment because
// MSVC warns on the existence of the declspec despite the union member forcing
// proper alignment.
template <size_t Size>
struct AlignedCharArray<1, Size> {
union {
char aligned;
char buffer[Size];
};
};
template <size_t Size>
struct AlignedCharArray<2, Size> {
union {
short aligned;
char buffer[Size];
};
};
template <size_t Size>
struct AlignedCharArray<4, Size> {
union {
int aligned;
char buffer[Size];
};
};
template <size_t Size>
struct AlignedCharArray<8, Size> {
union {
double aligned;
char buffer[Size];
};
};
// The rest of these are provided with a __declspec(align(...)) and we simply
// can't pass them by-value as function arguments on MSVC.
#define AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(x) \
template <size_t Size> \
struct AlignedCharArray<x, Size> { \
__declspec(align(x)) char buffer[Size]; \
};
AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(16)
AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(32)
AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(64)
AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(128)
#undef AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT
#endif // _MSC_VER
namespace detail {
template <
typename T1,
typename T2 = char,
typename T3 = char,
typename T4 = char,
typename T5 = char,
typename T6 = char,
typename T7 = char,
typename T8 = char,
typename T9 = char,
typename T10 = char>
class AlignerImpl {
T1 t1;
T2 t2;
T3 t3;
T4 t4;
T5 t5;
T6 t6;
T7 t7;
T8 t8;
T9 t9;
T10 t10;
public:
AlignerImpl() = delete;
};
template <
typename T1,
typename T2 = char,
typename T3 = char,
typename T4 = char,
typename T5 = char,
typename T6 = char,
typename T7 = char,
typename T8 = char,
typename T9 = char,
typename T10 = char>
union SizerImpl {
// NOLINTNEXTLINE(*c-arrays)
char arr1[sizeof(T1)], arr2[sizeof(T2)], arr3[sizeof(T3)], arr4[sizeof(T4)],
arr5[sizeof(T5)], arr6[sizeof(T6)], arr7[sizeof(T7)], arr8[sizeof(T8)],
arr9[sizeof(T9)], arr10[sizeof(T10)];
};
} // end namespace detail
/// \brief This union template exposes a suitably aligned and sized character
/// array member which can hold elements of any of up to ten types.
///
/// These types may be arrays, structs, or any other types. The goal is to
/// expose a char array buffer member which can be used as suitable storage for
/// a placement new of any of these types. Support for more than ten types can
/// be added at the cost of more boilerplate.
template <
typename T1,
typename T2 = char,
typename T3 = char,
typename T4 = char,
typename T5 = char,
typename T6 = char,
typename T7 = char,
typename T8 = char,
typename T9 = char,
typename T10 = char>
struct AlignedCharArrayUnion
: AlignedCharArray<
alignof(detail::AlignerImpl<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10>),
sizeof(::c10::detail::
SizerImpl<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10>)> {};
} // end namespace c10

View File

@ -0,0 +1,115 @@
// Copyright 2023-present Facebook. All Rights Reserved.
#pragma once
#include <c10/macros/Export.h>
#include <array>
#include <chrono>
#include <cstddef>
#include <cstdint>
#include <ctime>
#include <functional>
#include <type_traits>
#if defined(C10_IOS) && defined(C10_MOBILE)
#include <sys/time.h> // for gettimeofday()
#endif
#if defined(__i386__) || defined(__x86_64__) || defined(__amd64__)
#define C10_RDTSC
#if defined(_MSC_VER)
#include <intrin.h>
#elif defined(__CUDACC__) || defined(__HIPCC__)
#undef C10_RDTSC
#elif defined(__clang__)
// `__rdtsc` is available by default.
// NB: This has to be first, because Clang will also define `__GNUC__`
#elif defined(__GNUC__)
#include <x86intrin.h>
#else
#undef C10_RDTSC
#endif
#endif
namespace c10 {
using time_t = int64_t;
using steady_clock_t = std::conditional_t<
std::chrono::high_resolution_clock::is_steady,
std::chrono::high_resolution_clock,
std::chrono::steady_clock>;
inline time_t getTimeSinceEpoch() {
auto now = std::chrono::system_clock::now().time_since_epoch();
return std::chrono::duration_cast<std::chrono::nanoseconds>(now).count();
}
inline time_t getTime(bool allow_monotonic = false) {
#if defined(C10_IOS) && defined(C10_MOBILE)
// clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS
// can't rely on CLOCK_REALTIME, as it is defined no matter if clock_gettime
// is implemented or not
struct timeval now;
gettimeofday(&now, NULL);
return static_cast<time_t>(now.tv_sec) * 1000000000 +
static_cast<time_t>(now.tv_usec) * 1000;
#elif defined(_WIN32) || defined(__MACH__)
return std::chrono::duration_cast<std::chrono::nanoseconds>(
steady_clock_t::now().time_since_epoch())
.count();
#else
// clock_gettime is *much* faster than std::chrono implementation on Linux
struct timespec t {};
auto mode = CLOCK_REALTIME;
if (allow_monotonic) {
mode = CLOCK_MONOTONIC;
}
clock_gettime(mode, &t);
return static_cast<time_t>(t.tv_sec) * 1000000000 +
static_cast<time_t>(t.tv_nsec);
#endif
}
// We often do not need to capture true wall times. If a fast mechanism such
// as TSC is available we can use that instead and convert back to epoch time
// during post processing. This greatly reduce the clock's contribution to
// profiling.
// http://btorpey.github.io/blog/2014/02/18/clock-sources-in-linux/
// https://quick-bench.com/q/r8opkkGZSJMu9wM_XTbDouq-0Io
// TODO: We should use
// `https://github.com/google/benchmark/blob/main/src/cycleclock.h`
inline auto getApproximateTime() {
#if defined(C10_RDTSC)
return static_cast<uint64_t>(__rdtsc());
#else
return getTime();
#endif
}
using approx_time_t = decltype(getApproximateTime());
static_assert(
std::is_same_v<approx_time_t, int64_t> ||
std::is_same_v<approx_time_t, uint64_t>,
"Expected either int64_t (`getTime`) or uint64_t (some TSC reads).");
// Convert `getCount` results to Nanoseconds since unix epoch.
class C10_API ApproximateClockToUnixTimeConverter final {
public:
ApproximateClockToUnixTimeConverter();
std::function<time_t(approx_time_t)> makeConverter();
struct UnixAndApproximateTimePair {
time_t t_;
approx_time_t approx_t_;
};
static UnixAndApproximateTimePair measurePair();
private:
static constexpr size_t replicates = 1001;
using time_pairs = std::array<UnixAndApproximateTimePair, replicates>;
time_pairs measurePairs();
time_pairs start_times_;
};
} // namespace c10

View File

@ -0,0 +1,18 @@
#pragma once
#include <array>
#include <utility>
namespace c10 {
// This helper function creates a constexpr std::array
// From a compile time list of values, without requiring you to explicitly
// write out the length.
//
// See also https://stackoverflow.com/a/26351760/23845
template <typename V, typename... T>
inline constexpr auto array_of(T&&... t) -> std::array<V, sizeof...(T)> {
return {{std::forward<T>(t)...}};
}
} // namespace c10

View File

@ -0,0 +1,380 @@
//===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
// ATen: modified from llvm::ArrayRef.
// removed llvm-specific functionality
// removed some implicit const -> non-const conversions that rely on
// complicated std::enable_if meta-programming
// removed a bunch of slice variants for simplicity...
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/Deprecated.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <array>
#include <cstddef>
#include <cstdint>
#include <initializer_list>
#include <iterator>
#include <ostream>
#include <type_traits>
#include <vector>
namespace c10 {
/// ArrayRef - Represent a constant reference to an array (0 or more elements
/// consecutively in memory), i.e. a start pointer and a length. It allows
/// various APIs to take consecutive elements easily and conveniently.
///
/// This class does not own the underlying data, it is expected to be used in
/// situations where the data resides in some other buffer, whose lifetime
/// extends past that of the ArrayRef. For this reason, it is not in general
/// safe to store an ArrayRef.
///
/// This is intended to be trivially copyable, so it should be passed by
/// value.
template <typename T>
class ArrayRef final {
public:
using iterator = const T*;
using const_iterator = const T*;
using size_type = size_t;
using value_type = T;
using reverse_iterator = std::reverse_iterator<iterator>;
private:
/// The start of the array, in an external buffer.
const T* Data;
/// The number of elements.
size_type Length;
void debugCheckNullptrInvariant() {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
Data != nullptr || Length == 0,
"created ArrayRef with nullptr and non-zero length! std::optional relies on this being illegal");
}
public:
/// @name Constructors
/// @{
/// Construct an empty ArrayRef.
/* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {}
/// Construct an ArrayRef from a single element.
// TODO Make this explicit
constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
/// Construct an ArrayRef from a pointer and length.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef(const T* data, size_t length)
: Data(data), Length(length) {
debugCheckNullptrInvariant();
}
/// Construct an ArrayRef from a range.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef(const T* begin, const T* end)
: Data(begin), Length(end - begin) {
debugCheckNullptrInvariant();
}
/// Construct an ArrayRef from a SmallVector. This is templated in order to
/// avoid instantiating SmallVectorTemplateCommon<T> whenever we
/// copy-construct an ArrayRef.
template <typename U>
/* implicit */ ArrayRef(const SmallVectorTemplateCommon<T, U>& Vec)
: Data(Vec.data()), Length(Vec.size()) {
debugCheckNullptrInvariant();
}
template <
typename Container,
typename = std::enable_if_t<std::is_same_v<
std::remove_const_t<decltype(std::declval<Container>().data())>,
T*>>>
/* implicit */ ArrayRef(const Container& container)
: Data(container.data()), Length(container.size()) {
debugCheckNullptrInvariant();
}
/// Construct an ArrayRef from a std::vector.
// The enable_if stuff here makes sure that this isn't used for
// std::vector<bool>, because ArrayRef can't work on a std::vector<bool>
// bitfield.
template <typename A>
/* implicit */ ArrayRef(const std::vector<T, A>& Vec)
: Data(Vec.data()), Length(Vec.size()) {
static_assert(
!std::is_same<T, bool>::value,
"ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
}
/// Construct an ArrayRef from a std::array
template <size_t N>
/* implicit */ constexpr ArrayRef(const std::array<T, N>& Arr)
: Data(Arr.data()), Length(N) {}
/// Construct an ArrayRef from a C array.
template <size_t N>
// NOLINTNEXTLINE(*c-arrays*)
/* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}
/// Construct an ArrayRef from a std::initializer_list.
/* implicit */ constexpr ArrayRef(const std::initializer_list<T>& Vec)
: Data(
std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
: std::begin(Vec)),
Length(Vec.size()) {}
/// @}
/// @name Simple Operations
/// @{
constexpr iterator begin() const {
return Data;
}
constexpr iterator end() const {
return Data + Length;
}
// These are actually the same as iterator, since ArrayRef only
// gives you const iterators.
constexpr const_iterator cbegin() const {
return Data;
}
constexpr const_iterator cend() const {
return Data + Length;
}
constexpr reverse_iterator rbegin() const {
return reverse_iterator(end());
}
constexpr reverse_iterator rend() const {
return reverse_iterator(begin());
}
/// empty - Check if the array is empty.
constexpr bool empty() const {
return Length == 0;
}
constexpr const T* data() const {
return Data;
}
/// size - Get the array size.
constexpr size_t size() const {
return Length;
}
/// front - Get the first element.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& front() const {
TORCH_CHECK(
!empty(), "ArrayRef: attempted to access front() of empty list");
return Data[0];
}
/// back - Get the last element.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& back() const {
TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list");
return Data[Length - 1];
}
/// equals - Check for element-wise equality.
constexpr bool equals(ArrayRef RHS) const {
return Length == RHS.Length && std::equal(begin(), end(), RHS.begin());
}
/// slice(n, m) - Take M elements of the array starting at element N
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef<T> slice(size_t N, size_t M)
const {
TORCH_CHECK(
N + M <= size(),
"ArrayRef: invalid slice, N = ",
N,
"; M = ",
M,
"; size = ",
size());
return ArrayRef<T>(data() + N, M);
}
/// slice(n) - Chop off the first N elements of the array.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef<T> slice(size_t N) const {
TORCH_CHECK(
N <= size(), "ArrayRef: invalid slice, N = ", N, "; size = ", size());
return slice(N, size() - N);
}
/// @}
/// @name Operator Overloads
/// @{
constexpr const T& operator[](size_t Index) const {
return Data[Index];
}
/// Vector compatibility
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& at(size_t Index) const {
TORCH_CHECK(
Index < Length,
"ArrayRef: invalid index Index = ",
Index,
"; Length = ",
Length);
return Data[Index];
}
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
std::enable_if_t<std::is_same_v<U, T>, ArrayRef<T>>& operator=(
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
U&& Temporary) = delete;
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
std::enable_if_t<std::is_same_v<U, T>, ArrayRef<T>>& operator=(
std::initializer_list<U>) = delete;
/// @}
/// @name Expensive Operations
/// @{
std::vector<T> vec() const {
return std::vector<T>(Data, Data + Length);
}
/// @}
};
template <typename T>
std::ostream& operator<<(std::ostream& out, ArrayRef<T> list) {
int i = 0;
out << "[";
for (const auto& e : list) {
if (i++ > 0)
out << ", ";
out << e;
}
out << "]";
return out;
}
/// @name ArrayRef Convenience constructors
/// @{
/// Construct an ArrayRef from a single element.
template <typename T>
ArrayRef<T> makeArrayRef(const T& OneElt) {
return OneElt;
}
/// Construct an ArrayRef from a pointer and length.
template <typename T>
ArrayRef<T> makeArrayRef(const T* data, size_t length) {
return ArrayRef<T>(data, length);
}
/// Construct an ArrayRef from a range.
template <typename T>
ArrayRef<T> makeArrayRef(const T* begin, const T* end) {
return ArrayRef<T>(begin, end);
}
/// Construct an ArrayRef from a SmallVector.
template <typename T>
ArrayRef<T> makeArrayRef(const SmallVectorImpl<T>& Vec) {
return Vec;
}
/// Construct an ArrayRef from a SmallVector.
template <typename T, unsigned N>
ArrayRef<T> makeArrayRef(const SmallVector<T, N>& Vec) {
return Vec;
}
/// Construct an ArrayRef from a std::vector.
template <typename T>
ArrayRef<T> makeArrayRef(const std::vector<T>& Vec) {
return Vec;
}
/// Construct an ArrayRef from a std::array.
template <typename T, std::size_t N>
ArrayRef<T> makeArrayRef(const std::array<T, N>& Arr) {
return Arr;
}
/// Construct an ArrayRef from an ArrayRef (no-op) (const)
template <typename T>
ArrayRef<T> makeArrayRef(const ArrayRef<T>& Vec) {
return Vec;
}
/// Construct an ArrayRef from an ArrayRef (no-op)
template <typename T>
ArrayRef<T>& makeArrayRef(ArrayRef<T>& Vec) {
return Vec;
}
/// Construct an ArrayRef from a C array.
template <typename T, size_t N>
// NOLINTNEXTLINE(*c-arrays*)
ArrayRef<T> makeArrayRef(const T (&Arr)[N]) {
return ArrayRef<T>(Arr);
}
// WARNING: Template instantiation will NOT be willing to do an implicit
// conversions to get you to an c10::ArrayRef, which is why we need so
// many overloads.
template <typename T>
bool operator==(c10::ArrayRef<T> a1, c10::ArrayRef<T> a2) {
return a1.equals(a2);
}
template <typename T>
bool operator!=(c10::ArrayRef<T> a1, c10::ArrayRef<T> a2) {
return !a1.equals(a2);
}
template <typename T>
bool operator==(const std::vector<T>& a1, c10::ArrayRef<T> a2) {
return c10::ArrayRef<T>(a1).equals(a2);
}
template <typename T>
bool operator!=(const std::vector<T>& a1, c10::ArrayRef<T> a2) {
return !c10::ArrayRef<T>(a1).equals(a2);
}
template <typename T>
bool operator==(c10::ArrayRef<T> a1, const std::vector<T>& a2) {
return a1.equals(c10::ArrayRef<T>(a2));
}
template <typename T>
bool operator!=(c10::ArrayRef<T> a1, const std::vector<T>& a2) {
return !a1.equals(c10::ArrayRef<T>(a2));
}
using IntArrayRef = ArrayRef<int64_t>;
// This alias is deprecated because it doesn't make ownership
// semantics obvious. Use IntArrayRef instead!
C10_DEFINE_DEPRECATED_USING(IntList, ArrayRef<int64_t>)
} // namespace c10

View File

@ -0,0 +1,361 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/bit_cast.h>
#include <limits>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
#if defined(CL_SYCL_LANGUAGE_VERSION)
#include <CL/sycl.hpp> // for SYCL 1.2.1
#else
#include <sycl/sycl.hpp> // for SYCL 2020
#endif
#include <ext/oneapi/bfloat16.hpp>
#endif
namespace c10 {
/// Constructors
inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
:
#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \
__CUDA_ARCH__ >= 800
x(__bfloat16_as_ushort(__float2bfloat16(value)))
#elif defined(__SYCL_DEVICE_ONLY__) && \
defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
x(c10::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value)))
#else
// RNE by default
x(detail::round_to_nearest_even(value))
#endif
{
}
/// Implicit conversions
inline C10_HOST_DEVICE BFloat16::operator float() const {
#if defined(__CUDACC__) && !defined(USE_ROCM)
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
#elif defined(__SYCL_DEVICE_ONLY__) && \
defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
return float(*reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x));
#else
return detail::f32_from_bits(x);
#endif
}
#if defined(__CUDACC__) && !defined(USE_ROCM)
inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) {
x = *reinterpret_cast<const unsigned short*>(&value);
}
inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const {
return *reinterpret_cast<const __nv_bfloat16*>(&x);
}
#endif
#if defined(__HIPCC__) && defined(USE_ROCM)
// 6.2.0 introduced __hip_bfloat16_raw
#if defined(__BF16_HOST_DEVICE__)
inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) {
x = __hip_bfloat16_raw(value).x;
}
inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const {
return __hip_bfloat16(__hip_bfloat16_raw{x});
}
#else // !defined(__BF16_HOST_DEVICE__)
inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) {
x = value.data;
}
inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const {
return __hip_bfloat16{x};
}
#endif // !defined(__BF16_HOST_DEVICE__)
#endif // defined(__HIPCC__) && defined(USE_ROCM)
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
inline C10_HOST_DEVICE BFloat16::BFloat16(
const sycl::ext::oneapi::bfloat16& value) {
x = *reinterpret_cast<const unsigned short*>(&value);
}
inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const {
return *reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x);
}
#endif
// CUDA intrinsics
#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) {
#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __ldg(reinterpret_cast<const __nv_bfloat16*>(ptr));
#else
return *ptr;
#endif
}
#endif
/// Arithmetic
inline C10_HOST_DEVICE BFloat16
operator+(const BFloat16& a, const BFloat16& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline C10_HOST_DEVICE BFloat16
operator-(const BFloat16& a, const BFloat16& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline C10_HOST_DEVICE BFloat16
operator*(const BFloat16& a, const BFloat16& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}
inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a) {
return -static_cast<float>(a);
}
inline C10_HOST_DEVICE BFloat16& operator+=(BFloat16& a, const BFloat16& b) {
a = a + b;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator-=(BFloat16& a, const BFloat16& b) {
a = a - b;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator*=(BFloat16& a, const BFloat16& b) {
a = a * b;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) {
a = a / b;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) {
a.x = a.x | b.x;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) {
a.x = a.x ^ b.x;
return a;
}
inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) {
a.x = a.x & b.x;
return a;
}
/// Arithmetic with floats
inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) {
return static_cast<float>(a) + b;
}
inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) {
return static_cast<float>(a) - b;
}
inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) {
return static_cast<float>(a) * b;
}
inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) {
return static_cast<float>(a) / b;
}
inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) {
return a + static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) {
return a - static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) {
return a * static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) {
return a / static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) {
return a += static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) {
return a -= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) {
return a *= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) {
return a /= static_cast<float>(b);
}
/// Arithmetic with doubles
inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) {
return static_cast<double>(a) + b;
}
inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) {
return static_cast<double>(a) - b;
}
inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) {
return static_cast<double>(a) * b;
}
inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) {
return static_cast<double>(a) / b;
}
inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) {
return a + static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) {
return a - static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) {
return a * static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) {
return a / static_cast<double>(b);
}
/// Arithmetic with ints
inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) {
return a + static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) {
return a - static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) {
return a * static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) {
return a / static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) {
return static_cast<BFloat16>(a) + b;
}
inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) {
return static_cast<BFloat16>(a) - b;
}
inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) {
return static_cast<BFloat16>(a) * b;
}
inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) {
return static_cast<BFloat16>(a) / b;
}
//// Arithmetic with int64_t
inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) {
return a + static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) {
return a - static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) {
return a * static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) {
return a / static_cast<BFloat16>(b);
}
inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) {
return static_cast<BFloat16>(a) + b;
}
inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) {
return static_cast<BFloat16>(a) - b;
}
inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) {
return static_cast<BFloat16>(a) * b;
}
inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) {
return static_cast<BFloat16>(a) / b;
}
// Overloading < and > operators, because std::max and std::min use them.
inline C10_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) {
return float(lhs) > float(rhs);
}
inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) {
return float(lhs) < float(rhs);
}
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::BFloat16> {
public:
static constexpr bool is_signed = true;
static constexpr bool is_specialized = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
static constexpr auto has_denorm_loss =
numeric_limits<float>::has_denorm_loss;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 8;
static constexpr int digits10 = 2;
static constexpr int max_digits10 = 4;
static constexpr int radix = 2;
static constexpr int min_exponent = -125;
static constexpr int min_exponent10 = -37;
static constexpr int max_exponent = 128;
static constexpr int max_exponent10 = 38;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before =
numeric_limits<float>::tinyness_before;
static constexpr c10::BFloat16 min() {
return c10::BFloat16(0x0080, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 lowest() {
return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 max() {
return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 epsilon() {
return c10::BFloat16(0x3C00, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 round_error() {
return c10::BFloat16(0x3F00, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 infinity() {
return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 quiet_NaN() {
return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 signaling_NaN() {
return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
}
static constexpr c10::BFloat16 denorm_min() {
return c10::BFloat16(0x0001, c10::BFloat16::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -0,0 +1,292 @@
#pragma once
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
#endif
namespace std {
template <typename T>
struct is_reduced_floating_point
: std::integral_constant<
bool,
std::is_same_v<T, c10::Half> || std::is_same_v<T, c10::BFloat16>> {};
template <typename T>
constexpr bool is_reduced_floating_point_v =
is_reduced_floating_point<T>::value;
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T acos(T a) {
return std::acos(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T asin(T a) {
return std::asin(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T atan(T a) {
return std::atan(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T atanh(T a) {
return std::atanh(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T erf(T a) {
return std::erf(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T erfc(T a) {
return std::erfc(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T exp(T a) {
return std::exp(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T expm1(T a) {
return std::expm1(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline bool isfinite(T a) {
return std::isfinite(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T log(T a) {
return std::log(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T log10(T a) {
return std::log10(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T log1p(T a) {
return std::log1p(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T log2(T a) {
return std::log2(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T ceil(T a) {
return std::ceil(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T cos(T a) {
return std::cos(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T floor(T a) {
return std::floor(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T nearbyint(T a) {
return std::nearbyint(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T sin(T a) {
return std::sin(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T tan(T a) {
return std::tan(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T sinh(T a) {
return std::sinh(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T cosh(T a) {
return std::cosh(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T tanh(T a) {
return std::tanh(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T trunc(T a) {
return std::trunc(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T lgamma(T a) {
return std::lgamma(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T sqrt(T a) {
return std::sqrt(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T rsqrt(T a) {
return 1.0 / std::sqrt(float(a));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T abs(T a) {
return std::abs(float(a));
}
#if defined(_MSC_VER) && defined(__CUDACC__)
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T pow(T a, double b) {
return std::pow(float(a), float(b));
}
#else
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T pow(T a, double b) {
return std::pow(float(a), b);
}
#endif
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T pow(T a, T b) {
return std::pow(float(a), float(b));
}
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T fmod(T a, T b) {
return std::fmod(float(a), float(b));
}
/*
The following function is inspired from the implementation in `musl`
Link to License: https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT
----------------------------------------------------------------------
Copyright © 2005-2020 Rich Felker, et al.
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
----------------------------------------------------------------------
*/
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
C10_HOST_DEVICE inline T nextafter(T from, T to) {
// Reference:
// https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c
using int_repr_t = uint16_t;
constexpr uint8_t bits = 16;
union {
T f;
int_repr_t i;
} ufrom = {from}, uto = {to};
// get a mask to get the sign bit i.e. MSB
int_repr_t sign_mask = int_repr_t{1} << (bits - 1);
// short-circuit: if either is NaN, return NaN
if (from != from || to != to) {
return from + to;
}
// short-circuit: if they are exactly the same.
if (ufrom.i == uto.i) {
return from;
}
// mask the sign-bit to zero i.e. positive
// equivalent to abs(x)
int_repr_t abs_from = ufrom.i & ~sign_mask;
int_repr_t abs_to = uto.i & ~sign_mask;
if (abs_from == 0) {
// if both are zero but with different sign,
// preserve the sign of `to`.
if (abs_to == 0) {
return to;
}
// smallest subnormal with sign of `to`.
ufrom.i = (uto.i & sign_mask) | int_repr_t{1};
return ufrom.f;
}
// if abs(from) > abs(to) or sign(from) != sign(to)
if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) {
ufrom.i--;
} else {
ufrom.i++;
}
return ufrom.f;
}
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -0,0 +1,133 @@
#pragma once
// Defines the bloat16 type (brain floating-point). This representation uses
// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.
#include <c10/macros/Macros.h>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <iosfwd>
#include <ostream>
#if defined(__CUDACC__) && !defined(USE_ROCM)
#include <cuda_bf16.h>
#endif
#if defined(__HIPCC__) && defined(USE_ROCM)
#include <hip/hip_bf16.h>
#endif
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
#if defined(CL_SYCL_LANGUAGE_VERSION)
#include <CL/sycl.hpp> // for SYCL 1.2.1
#else
#include <sycl/sycl.hpp> // for SYCL 2020
#endif
#include <ext/oneapi/bfloat16.hpp>
#endif
namespace c10 {
namespace detail {
inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
float res = 0;
uint32_t tmp = src;
tmp <<= 16;
#if defined(USE_ROCM)
float* tempRes;
// We should be using memcpy in order to respect the strict aliasing rule
// but it fails in the HIP environment.
tempRes = reinterpret_cast<float*>(&tmp);
res = *tempRes;
#else
std::memcpy(&res, &tmp, sizeof(tmp));
#endif
return res;
}
inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
uint32_t res = 0;
#if defined(USE_ROCM)
// We should be using memcpy in order to respect the strict aliasing rule
// but it fails in the HIP environment.
uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src);
res = *tempRes;
#else
std::memcpy(&res, &src, sizeof(res));
#endif
return res >> 16;
}
inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
#if defined(USE_ROCM)
if (src != src) {
#elif defined(_MSC_VER)
if (isnan(src)) {
#else
if (std::isnan(src)) {
#endif
return UINT16_C(0x7FC0);
} else {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
union {
uint32_t U32; // NOLINT(facebook-hte-BadMemberName)
float F32; // NOLINT(facebook-hte-BadMemberName)
};
F32 = src;
uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
}
}
} // namespace detail
struct alignas(2) BFloat16 {
uint16_t x;
// HIP wants __host__ __device__ tag, CUDA does not
#if defined(USE_ROCM)
C10_HOST_DEVICE BFloat16() = default;
#else
BFloat16() = default;
#endif
struct from_bits_t {};
static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
return from_bits_t();
}
constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t)
: x(bits) {}
/* implicit */ inline C10_HOST_DEVICE BFloat16(float value);
inline C10_HOST_DEVICE operator float() const;
#if defined(__CUDACC__) && !defined(USE_ROCM)
inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
#endif
#if defined(__HIPCC__) && defined(USE_ROCM)
inline C10_HOST_DEVICE BFloat16(const __hip_bfloat16& value);
explicit inline C10_HOST_DEVICE operator __hip_bfloat16() const;
#endif
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value);
explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const;
#endif
};
C10_API inline std::ostream& operator<<(
std::ostream& out,
const BFloat16& value) {
out << (float)value;
return out;
}
} // namespace c10
#include <c10/util/BFloat16-inl.h> // IWYU pragma: keep

View File

@ -0,0 +1,31 @@
#ifndef C10_UTIL_BACKTRACE_H_
#define C10_UTIL_BACKTRACE_H_
#include <cstddef>
#include <memory>
#include <string>
#include <typeinfo>
#include <c10/macros/Macros.h>
#include <c10/util/Lazy.h>
namespace c10 {
// Symbolizing the backtrace can be expensive; pass it around as a lazy string
// so it is symbolized only if actually needed.
using Backtrace = std::shared_ptr<const LazyValue<std::string>>;
// DEPRECATED: Prefer get_lazy_backtrace().
C10_API std::string get_backtrace(
size_t frames_to_skip = 0,
size_t maximum_number_of_frames = 64,
bool skip_python_frames = true);
C10_API Backtrace get_lazy_backtrace(
size_t frames_to_skip = 0,
size_t maximum_number_of_frames = 64,
bool skip_python_frames = true);
} // namespace c10
#endif // C10_UTIL_BACKTRACE_H_

View File

@ -0,0 +1,116 @@
#pragma once
#include <cstddef>
#if defined(_MSC_VER)
#include <intrin.h>
#endif
namespace c10::utils {
/**
* This is a simple bitset class with sizeof(long long int) bits.
* You can set bits, unset bits, query bits by index,
* and query for the first set bit.
* Before using this class, please also take a look at std::bitset,
* which has more functionality and is more generic. It is probably
* a better fit for your use case. The sole reason for c10::utils::bitset
* to exist is that std::bitset misses a find_first_set() method.
*/
struct bitset final {
private:
#if defined(_MSC_VER)
// MSVCs _BitScanForward64 expects int64_t
using bitset_type = int64_t;
#else
// POSIX ffsll expects long long int
using bitset_type = long long int;
#endif
public:
static constexpr size_t NUM_BITS() {
return 8 * sizeof(bitset_type);
}
constexpr bitset() noexcept = default;
constexpr bitset(const bitset&) noexcept = default;
constexpr bitset(bitset&&) noexcept = default;
// there is an issure for gcc 5.3.0 when define default function as constexpr
// see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=68754.
bitset& operator=(const bitset&) noexcept = default;
bitset& operator=(bitset&&) noexcept = default;
constexpr void set(size_t index) noexcept {
bitset_ |= (static_cast<long long int>(1) << index);
}
constexpr void unset(size_t index) noexcept {
bitset_ &= ~(static_cast<long long int>(1) << index);
}
constexpr bool get(size_t index) const noexcept {
return bitset_ & (static_cast<long long int>(1) << index);
}
constexpr bool is_entirely_unset() const noexcept {
return 0 == bitset_;
}
// Call the given functor with the index of each bit that is set
template <class Func>
void for_each_set_bit(Func&& func) const {
bitset cur = *this;
size_t index = cur.find_first_set();
while (0 != index) {
// -1 because find_first_set() is not one-indexed.
index -= 1;
func(index);
cur.unset(index);
index = cur.find_first_set();
}
}
private:
// Return the index of the first set bit. The returned index is one-indexed
// (i.e. if the very first bit is set, this function returns '1'), and a
// return of '0' means that there was no bit set.
size_t find_first_set() const {
#if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_ARM64))
unsigned long result;
bool has_bits_set = (0 != _BitScanForward64(&result, bitset_));
if (!has_bits_set) {
return 0;
}
return result + 1;
#elif defined(_MSC_VER) && defined(_M_IX86)
unsigned long result;
if (static_cast<uint32_t>(bitset_) != 0) {
bool has_bits_set =
(0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_)));
if (!has_bits_set) {
return 0;
}
return result + 1;
} else {
bool has_bits_set =
(0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_ >> 32)));
if (!has_bits_set) {
return 32;
}
return result + 33;
}
#else
return __builtin_ffsll(bitset_);
#endif
}
friend bool operator==(bitset lhs, bitset rhs) noexcept {
return lhs.bitset_ == rhs.bitset_;
}
bitset_type bitset_{0};
};
inline bool operator!=(bitset lhs, bitset rhs) noexcept {
return !(lhs == rhs);
}
} // namespace c10::utils

View File

@ -0,0 +1,142 @@
#pragma once
#ifndef C10_UTIL_CPP17_H_
#define C10_UTIL_CPP17_H_
#include <c10/macros/Macros.h>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#if !defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && \
__GNUC__ < 9
#error \
"You're trying to build PyTorch with a too old version of GCC. We need GCC 9 or later."
#endif
#if defined(__clang__) && __clang_major__ < 9
#error \
"You're trying to build PyTorch with a too old version of Clang. We need Clang 9 or later."
#endif
#if (defined(_MSC_VER) && (!defined(_MSVC_LANG) || _MSVC_LANG < 201703L)) || \
(!defined(_MSC_VER) && __cplusplus < 201703L)
#error You need C++17 to compile PyTorch
#endif
#if defined(_WIN32) && (defined(min) || defined(max))
#error Macro clash with min and max -- define NOMINMAX when compiling your program on Windows
#endif
/*
* This header adds some polyfills with C++17 functionality
*/
namespace c10 {
// std::is_pod is deprecated in C++20, std::is_standard_layout and
// std::is_trivial are introduced in C++11, std::conjunction has been introduced
// in C++17.
template <typename T>
using is_pod = std::conjunction<std::is_standard_layout<T>, std::is_trivial<T>>;
template <typename T>
constexpr bool is_pod_v = is_pod<T>::value;
namespace guts {
template <typename Base, typename Child, typename... Args>
std::enable_if_t<
!std::is_array_v<Base> && !std::is_array_v<Child> &&
std::is_base_of_v<Base, Child>,
std::unique_ptr<Base>>
make_unique_base(Args&&... args) {
return std::unique_ptr<Base>(new Child(std::forward<Args>(args)...));
}
#if defined(__cpp_lib_apply) && !defined(__CUDA_ARCH__) && !defined(__HIP__)
template <class F, class Tuple>
C10_HOST_DEVICE inline constexpr decltype(auto) apply(F&& f, Tuple&& t) {
return std::apply(std::forward<F>(f), std::forward<Tuple>(t));
}
#else
// Implementation from http://en.cppreference.com/w/cpp/utility/apply (but
// modified)
// TODO This is an incomplete implementation of std::apply, not working for
// member functions.
namespace detail {
template <class F, class Tuple, std::size_t... INDEX>
#if defined(_MSC_VER)
// MSVC has a problem with the decltype() return type, but it also doesn't need
// it
C10_HOST_DEVICE constexpr auto apply_impl(
F&& f,
Tuple&& t,
std::index_sequence<INDEX...>)
#else
// GCC/Clang need the decltype() return type
C10_HOST_DEVICE constexpr decltype(auto) apply_impl(
F&& f,
Tuple&& t,
std::index_sequence<INDEX...>)
#endif
{
return std::forward<F>(f)(std::get<INDEX>(std::forward<Tuple>(t))...);
}
} // namespace detail
template <class F, class Tuple>
C10_HOST_DEVICE constexpr decltype(auto) apply(F&& f, Tuple&& t) {
return detail::apply_impl(
std::forward<F>(f),
std::forward<Tuple>(t),
std::make_index_sequence<
std::tuple_size<std::remove_reference_t<Tuple>>::value>{});
}
#endif
template <typename Functor, typename... Args>
std::enable_if_t<
std::is_member_pointer_v<std::decay_t<Functor>>,
typename std::invoke_result_t<Functor, Args...>>
invoke(Functor&& f, Args&&... args) {
return std::mem_fn(std::forward<Functor>(f))(std::forward<Args>(args)...);
}
template <typename Functor, typename... Args>
std::enable_if_t<
!std::is_member_pointer_v<std::decay_t<Functor>>,
typename std::invoke_result_t<Functor, Args...>>
invoke(Functor&& f, Args&&... args) {
return std::forward<Functor>(f)(std::forward<Args>(args)...);
}
namespace detail {
struct _identity final {
template <class T>
using type_identity = T;
template <class T>
decltype(auto) operator()(T&& arg) {
return std::forward<T>(arg);
}
};
template <class Func, class Enable = void>
struct function_takes_identity_argument : std::false_type {};
template <class Func>
struct function_takes_identity_argument<
Func,
std::void_t<decltype(std::declval<Func>()(_identity()))>> : std::true_type {
};
} // namespace detail
} // namespace guts
} // namespace c10
#endif // C10_UTIL_CPP17_H_

View File

@ -0,0 +1,67 @@
#pragma once
#include <atomic>
#include <mutex>
#include <utility>
#include <c10/macros/Macros.h>
#include <c10/util/C++17.h>
namespace c10 {
// custom c10 call_once implementation to avoid the deadlock in std::call_once.
// The implementation here is a simplified version from folly and likely much
// much higher memory footprint.
template <typename Flag, typename F, typename... Args>
inline void call_once(Flag& flag, F&& f, Args&&... args) {
if (C10_LIKELY(flag.test_once())) {
return;
}
flag.call_once_slow(std::forward<F>(f), std::forward<Args>(args)...);
}
class once_flag {
public:
#ifndef _WIN32
// running into build error on MSVC. Can't seem to get a repro locally so I'm
// just avoiding constexpr
//
// C:/actions-runner/_work/pytorch/pytorch\c10/util/CallOnce.h(26): error:
// defaulted default constructor cannot be constexpr because the
// corresponding implicitly declared default constructor would not be
// constexpr 1 error detected in the compilation of
// "C:/actions-runner/_work/pytorch/pytorch/aten/src/ATen/cuda/cub.cu".
constexpr
#endif
once_flag() noexcept = default;
once_flag(const once_flag&) = delete;
once_flag& operator=(const once_flag&) = delete;
private:
template <typename Flag, typename F, typename... Args>
friend void call_once(Flag& flag, F&& f, Args&&... args);
template <typename F, typename... Args>
void call_once_slow(F&& f, Args&&... args) {
std::lock_guard<std::mutex> guard(mutex_);
if (init_.load(std::memory_order_relaxed)) {
return;
}
c10::guts::invoke(std::forward<F>(f), std::forward<Args>(args)...);
init_.store(true, std::memory_order_release);
}
bool test_once() {
return init_.load(std::memory_order_acquire);
}
void reset_once() {
init_.store(false, std::memory_order_release);
}
private:
std::mutex mutex_;
std::atomic<bool> init_{false};
};
} // namespace c10

View File

@ -0,0 +1,130 @@
#pragma once
#include <c10/util/IdWrapper.h>
#include <c10/util/string_view.h>
#include <cstddef>
#include <cstdint>
namespace c10::util {
namespace detail {
// NOLINTNEXTLINE(*c-arrays*)
constexpr uint64_t crc64_table[] = {
0x0000000000000000, 0x7ad870c830358979, 0xf5b0e190606b12f2,
0x8f689158505e9b8b, 0xc038e5739841b68f, 0xbae095bba8743ff6,
0x358804e3f82aa47d, 0x4f50742bc81f2d04, 0xab28ecb46814fe75,
0xd1f09c7c5821770c, 0x5e980d24087fec87, 0x24407dec384a65fe,
0x6b1009c7f05548fa, 0x11c8790fc060c183, 0x9ea0e857903e5a08,
0xe478989fa00bd371, 0x7d08ff3b88be6f81, 0x07d08ff3b88be6f8,
0x88b81eabe8d57d73, 0xf2606e63d8e0f40a, 0xbd301a4810ffd90e,
0xc7e86a8020ca5077, 0x4880fbd87094cbfc, 0x32588b1040a14285,
0xd620138fe0aa91f4, 0xacf86347d09f188d, 0x2390f21f80c18306,
0x594882d7b0f40a7f, 0x1618f6fc78eb277b, 0x6cc0863448deae02,
0xe3a8176c18803589, 0x997067a428b5bcf0, 0xfa11fe77117cdf02,
0x80c98ebf2149567b, 0x0fa11fe77117cdf0, 0x75796f2f41224489,
0x3a291b04893d698d, 0x40f16bccb908e0f4, 0xcf99fa94e9567b7f,
0xb5418a5cd963f206, 0x513912c379682177, 0x2be1620b495da80e,
0xa489f35319033385, 0xde51839b2936bafc, 0x9101f7b0e12997f8,
0xebd98778d11c1e81, 0x64b116208142850a, 0x1e6966e8b1770c73,
0x8719014c99c2b083, 0xfdc17184a9f739fa, 0x72a9e0dcf9a9a271,
0x08719014c99c2b08, 0x4721e43f0183060c, 0x3df994f731b68f75,
0xb29105af61e814fe, 0xc849756751dd9d87, 0x2c31edf8f1d64ef6,
0x56e99d30c1e3c78f, 0xd9810c6891bd5c04, 0xa3597ca0a188d57d,
0xec09088b6997f879, 0x96d1784359a27100, 0x19b9e91b09fcea8b,
0x636199d339c963f2, 0xdf7adabd7a6e2d6f, 0xa5a2aa754a5ba416,
0x2aca3b2d1a053f9d, 0x50124be52a30b6e4, 0x1f423fcee22f9be0,
0x659a4f06d21a1299, 0xeaf2de5e82448912, 0x902aae96b271006b,
0x74523609127ad31a, 0x0e8a46c1224f5a63, 0x81e2d7997211c1e8,
0xfb3aa75142244891, 0xb46ad37a8a3b6595, 0xceb2a3b2ba0eecec,
0x41da32eaea507767, 0x3b024222da65fe1e, 0xa2722586f2d042ee,
0xd8aa554ec2e5cb97, 0x57c2c41692bb501c, 0x2d1ab4dea28ed965,
0x624ac0f56a91f461, 0x1892b03d5aa47d18, 0x97fa21650afae693,
0xed2251ad3acf6fea, 0x095ac9329ac4bc9b, 0x7382b9faaaf135e2,
0xfcea28a2faafae69, 0x8632586aca9a2710, 0xc9622c4102850a14,
0xb3ba5c8932b0836d, 0x3cd2cdd162ee18e6, 0x460abd1952db919f,
0x256b24ca6b12f26d, 0x5fb354025b277b14, 0xd0dbc55a0b79e09f,
0xaa03b5923b4c69e6, 0xe553c1b9f35344e2, 0x9f8bb171c366cd9b,
0x10e3202993385610, 0x6a3b50e1a30ddf69, 0x8e43c87e03060c18,
0xf49bb8b633338561, 0x7bf329ee636d1eea, 0x012b592653589793,
0x4e7b2d0d9b47ba97, 0x34a35dc5ab7233ee, 0xbbcbcc9dfb2ca865,
0xc113bc55cb19211c, 0x5863dbf1e3ac9dec, 0x22bbab39d3991495,
0xadd33a6183c78f1e, 0xd70b4aa9b3f20667, 0x985b3e827bed2b63,
0xe2834e4a4bd8a21a, 0x6debdf121b863991, 0x1733afda2bb3b0e8,
0xf34b37458bb86399, 0x8993478dbb8deae0, 0x06fbd6d5ebd3716b,
0x7c23a61ddbe6f812, 0x3373d23613f9d516, 0x49aba2fe23cc5c6f,
0xc6c333a67392c7e4, 0xbc1b436e43a74e9d, 0x95ac9329ac4bc9b5,
0xef74e3e19c7e40cc, 0x601c72b9cc20db47, 0x1ac40271fc15523e,
0x5594765a340a7f3a, 0x2f4c0692043ff643, 0xa02497ca54616dc8,
0xdafce7026454e4b1, 0x3e847f9dc45f37c0, 0x445c0f55f46abeb9,
0xcb349e0da4342532, 0xb1eceec59401ac4b, 0xfebc9aee5c1e814f,
0x8464ea266c2b0836, 0x0b0c7b7e3c7593bd, 0x71d40bb60c401ac4,
0xe8a46c1224f5a634, 0x927c1cda14c02f4d, 0x1d148d82449eb4c6,
0x67ccfd4a74ab3dbf, 0x289c8961bcb410bb, 0x5244f9a98c8199c2,
0xdd2c68f1dcdf0249, 0xa7f41839ecea8b30, 0x438c80a64ce15841,
0x3954f06e7cd4d138, 0xb63c61362c8a4ab3, 0xcce411fe1cbfc3ca,
0x83b465d5d4a0eece, 0xf96c151de49567b7, 0x76048445b4cbfc3c,
0x0cdcf48d84fe7545, 0x6fbd6d5ebd3716b7, 0x15651d968d029fce,
0x9a0d8ccedd5c0445, 0xe0d5fc06ed698d3c, 0xaf85882d2576a038,
0xd55df8e515432941, 0x5a3569bd451db2ca, 0x20ed197575283bb3,
0xc49581ead523e8c2, 0xbe4df122e51661bb, 0x3125607ab548fa30,
0x4bfd10b2857d7349, 0x04ad64994d625e4d, 0x7e7514517d57d734,
0xf11d85092d094cbf, 0x8bc5f5c11d3cc5c6, 0x12b5926535897936,
0x686de2ad05bcf04f, 0xe70573f555e26bc4, 0x9ddd033d65d7e2bd,
0xd28d7716adc8cfb9, 0xa85507de9dfd46c0, 0x273d9686cda3dd4b,
0x5de5e64efd965432, 0xb99d7ed15d9d8743, 0xc3450e196da80e3a,
0x4c2d9f413df695b1, 0x36f5ef890dc31cc8, 0x79a59ba2c5dc31cc,
0x037deb6af5e9b8b5, 0x8c157a32a5b7233e, 0xf6cd0afa9582aa47,
0x4ad64994d625e4da, 0x300e395ce6106da3, 0xbf66a804b64ef628,
0xc5bed8cc867b7f51, 0x8aeeace74e645255, 0xf036dc2f7e51db2c,
0x7f5e4d772e0f40a7, 0x05863dbf1e3ac9de, 0xe1fea520be311aaf,
0x9b26d5e88e0493d6, 0x144e44b0de5a085d, 0x6e963478ee6f8124,
0x21c640532670ac20, 0x5b1e309b16452559, 0xd476a1c3461bbed2,
0xaeaed10b762e37ab, 0x37deb6af5e9b8b5b, 0x4d06c6676eae0222,
0xc26e573f3ef099a9, 0xb8b627f70ec510d0, 0xf7e653dcc6da3dd4,
0x8d3e2314f6efb4ad, 0x0256b24ca6b12f26, 0x788ec2849684a65f,
0x9cf65a1b368f752e, 0xe62e2ad306bafc57, 0x6946bb8b56e467dc,
0x139ecb4366d1eea5, 0x5ccebf68aecec3a1, 0x2616cfa09efb4ad8,
0xa97e5ef8cea5d153, 0xd3a62e30fe90582a, 0xb0c7b7e3c7593bd8,
0xca1fc72bf76cb2a1, 0x45775673a732292a, 0x3faf26bb9707a053,
0x70ff52905f188d57, 0x0a2722586f2d042e, 0x854fb3003f739fa5,
0xff97c3c80f4616dc, 0x1bef5b57af4dc5ad, 0x61372b9f9f784cd4,
0xee5fbac7cf26d75f, 0x9487ca0fff135e26, 0xdbd7be24370c7322,
0xa10fceec0739fa5b, 0x2e675fb4576761d0, 0x54bf2f7c6752e8a9,
0xcdcf48d84fe75459, 0xb71738107fd2dd20, 0x387fa9482f8c46ab,
0x42a7d9801fb9cfd2, 0x0df7adabd7a6e2d6, 0x772fdd63e7936baf,
0xf8474c3bb7cdf024, 0x829f3cf387f8795d, 0x66e7a46c27f3aa2c,
0x1c3fd4a417c62355, 0x935745fc4798b8de, 0xe98f353477ad31a7,
0xa6df411fbfb21ca3, 0xdc0731d78f8795da, 0x536fa08fdfd90e51,
0x29b7d047efec8728,
};
inline C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA uint64_t
crc64impl(uint64_t accumulator, const char* data, size_t size) {
for (size_t i = 0; i < size; ++i) {
accumulator =
crc64_table[(accumulator ^ data[i]) & 0xFF] ^ (accumulator >> 8);
}
return accumulator;
}
} // namespace detail
struct crc64_t final : IdWrapper<crc64_t, uint64_t> {
constexpr crc64_t(uint64_t checksum) : IdWrapper(checksum) {}
constexpr uint64_t checksum() const {
return this->underlyingId();
}
};
// CRC64 with Jones coefficients and an init value of 0.
inline C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA crc64_t
crc64(const char* str, size_t size) {
return crc64_t{detail::crc64impl(0, str, size)};
}
inline C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA crc64_t crc64(c10::string_view str) {
return crc64(str.data(), str.size());
}
} // namespace c10::util
// Allow usage of crc64_t in std::unordered_set
C10_DEFINE_HASH_FOR_IDWRAPPER(c10::util::crc64_t);

View File

@ -0,0 +1,48 @@
#pragma once
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
/// This file provides some simple utilities for detecting common deadlocks in
/// PyTorch. For now, we focus exclusively on detecting Python GIL deadlocks,
/// as the GIL is a wide ranging lock that is taken out in many situations.
/// The basic strategy is before performing an operation that may block, you
/// can use TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP() to assert that the GIL is
/// not held. This macro is to be used in contexts where no static dependency
/// on Python is available (we will handle indirecting a virtual call for you).
///
/// If the GIL is held by a torchdeploy interpreter, we always report false.
/// If you are in a context where Python bindings are available, it's better
/// to directly assert on PyGILState_Check (as it avoids a vcall and also
/// works correctly with torchdeploy.)
#define TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP() \
TORCH_INTERNAL_ASSERT( \
!c10::impl::check_python_gil(), \
"Holding GIL before a blocking operation! Please release the GIL before blocking, or see https://github.com/pytorch/pytorch/issues/56297 for how to release the GIL for destructors of objects")
namespace c10::impl {
C10_API bool check_python_gil();
struct C10_API PythonGILHooks {
virtual ~PythonGILHooks() = default;
// Returns true if we hold the GIL. If not linked against Python we
// always return false.
virtual bool check_python_gil() const = 0;
};
C10_API void SetPythonGILHooks(PythonGILHooks* factory);
// DO NOT call this registerer from a torch deploy instance! You will clobber
// other registrations
struct C10_API PythonGILHooksRegisterer {
explicit PythonGILHooksRegisterer(PythonGILHooks* factory) {
SetPythonGILHooks(factory);
}
~PythonGILHooksRegisterer() {
SetPythonGILHooks(nullptr);
}
};
} // namespace c10::impl

View File

@ -0,0 +1,102 @@
#pragma once
/**
* This file provides portable macros for marking declarations
* as deprecated. You should generally use C10_DEPRECATED,
* except when marking 'using' declarations as deprecated,
* in which case you should use C10_DEFINE_DEPRECATED_USING
* (due to portability concerns).
*/
// Sample usage:
//
// C10_DEPRECATED void bad_func();
// struct C10_DEPRECATED BadStruct {
// ...
// };
// NB: __cplusplus doesn't work for MSVC, so for now MSVC always uses
// the "__declspec(deprecated)" implementation and not the C++14
// "[[deprecated]]" attribute. We tried enabling "[[deprecated]]" for C++14 on
// MSVC, but ran into issues with some older MSVC versions.
#if (defined(__cplusplus) && __cplusplus >= 201402L)
#define C10_DEPRECATED [[deprecated]]
#define C10_DEPRECATED_MESSAGE(message) [[deprecated(message)]]
#elif defined(__GNUC__)
#define C10_DEPRECATED __attribute__((deprecated))
// TODO Is there some way to implement this?
#define C10_DEPRECATED_MESSAGE(message) __attribute__((deprecated))
#elif defined(_MSC_VER)
#define C10_DEPRECATED __declspec(deprecated)
#define C10_DEPRECATED_MESSAGE(message) __declspec(deprecated(message))
#else
#warning "You need to implement C10_DEPRECATED for this compiler"
#define C10_DEPRECATED
#endif
// Sample usage:
//
// C10_DEFINE_DEPRECATED_USING(BadType, int)
//
// which is the portable version of
//
// using BadType [[deprecated]] = int;
// technically [[deprecated]] syntax is from c++14 standard, but it works in
// many compilers.
#if defined(__has_cpp_attribute)
#if __has_cpp_attribute(deprecated) && !defined(__CUDACC__)
#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \
using TypeName [[deprecated]] = TypeThingy;
#endif
#endif
#if defined(_MSC_VER)
#if defined(__CUDACC__)
// neither [[deprecated]] nor __declspec(deprecated) work on nvcc on Windows;
// you get the error:
//
// error: attribute does not apply to any entity
//
// So we just turn the macro off in this case.
#if defined(C10_DEFINE_DEPRECATED_USING)
#undef C10_DEFINE_DEPRECATED_USING
#endif
#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \
using TypeName = TypeThingy;
#else
// [[deprecated]] does work in windows without nvcc, though msc doesn't support
// `__has_cpp_attribute` when c++14 is supported, otherwise
// __declspec(deprecated) is used as the alternative.
#ifndef C10_DEFINE_DEPRECATED_USING
#if defined(_MSVC_LANG) && _MSVC_LANG >= 201402L
#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \
using TypeName [[deprecated]] = TypeThingy;
#else
#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \
using TypeName = __declspec(deprecated) TypeThingy;
#endif
#endif
#endif
#endif
#if !defined(C10_DEFINE_DEPRECATED_USING) && defined(__GNUC__)
// nvcc has a bug where it doesn't understand __attribute__((deprecated))
// declarations even when the host compiler supports it. We'll only use this gcc
// attribute when not cuda, and when using a GCC compiler that doesn't support
// the c++14 syntax we checked for above (available in __GNUC__ >= 5)
#if !defined(__CUDACC__)
#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \
using TypeName __attribute__((deprecated)) = TypeThingy;
#else
// using cuda + gcc < 5, neither deprecated syntax is available so turning off.
#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \
using TypeName = TypeThingy;
#endif
#endif
#if !defined(C10_DEFINE_DEPRECATED_USING)
#warning "You need to implement C10_DEFINE_DEPRECATED_USING for this compiler"
#define C10_DEFINE_DEPRECATED_USING
#endif

Some files were not shown because too many files have changed in this diff Show More