I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,154 @@
#!/usr/bin/python3
# mypy: allow-untyped-defs
import importlib
import logging
import os
import sys
import tempfile
from typing import Optional
import torch
from torch.distributed.nn.jit.templates.remote_module_template import (
get_remote_module_template,
)
logger = logging.getLogger(__name__)
_FILE_PREFIX = "_remote_module_"
_TEMP_DIR = tempfile.TemporaryDirectory()
INSTANTIATED_TEMPLATE_DIR_PATH = _TEMP_DIR.name
logger.info("Created a temporary directory at %s", INSTANTIATED_TEMPLATE_DIR_PATH)
sys.path.append(INSTANTIATED_TEMPLATE_DIR_PATH)
def get_arg_return_types_from_interface(module_interface):
assert getattr(
module_interface, "__torch_script_interface__", False
), "Expect a TorchScript class interface decorated by @torch.jit.interface."
qualified_name = torch._jit_internal._qualified_name(module_interface)
cu = torch.jit._state._python_cu
module_interface_c = cu.get_interface(qualified_name)
assert (
"forward" in module_interface_c.getMethodNames()
), f"Expect forward in interface methods, while it has {module_interface_c.getMethodNames()}"
method_schema = module_interface_c.getMethod("forward")
arg_str_list = []
arg_type_str_list = []
assert method_schema is not None
for argument in method_schema.arguments:
arg_str_list.append(argument.name)
if argument.has_default_value():
default_value_str = f" = {argument.default_value}"
else:
default_value_str = ""
arg_type_str = f"{argument.name}: {argument.type}{default_value_str}"
arg_type_str_list.append(arg_type_str)
arg_str_list = arg_str_list[1:] # Remove "self".
args_str = ", ".join(arg_str_list)
arg_type_str_list = arg_type_str_list[1:] # Remove "self".
arg_types_str = ", ".join(arg_type_str_list)
assert len(method_schema.returns) == 1
argument = method_schema.returns[0]
return_type_str = str(argument.type)
return args_str, arg_types_str, return_type_str
def _write(out_path, text):
old_text: Optional[str]
try:
with open(out_path) as f:
old_text = f.read()
except OSError:
old_text = None
if old_text != text:
with open(out_path, "w") as f:
logger.info("Writing %s", out_path)
f.write(text)
else:
logger.info("Skipped writing %s", out_path)
def _do_instantiate_remote_module_template(
generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda
):
generated_code_text = get_remote_module_template(
enable_moving_cpu_tensors_to_cuda
).format(**str_dict)
out_path = os.path.join(
INSTANTIATED_TEMPLATE_DIR_PATH, f"{generated_module_name}.py"
)
_write(out_path, generated_code_text)
# From importlib doc,
# > If you are dynamically importing a module that was created since
# the interpreter began execution (e.g., created a Python source file),
# you may need to call invalidate_caches() in order for the new module
# to be noticed by the import system.
importlib.invalidate_caches()
generated_module = importlib.import_module(f"{generated_module_name}")
return generated_module
def instantiate_scriptable_remote_module_template(
module_interface_cls, enable_moving_cpu_tensors_to_cuda=True
):
if not getattr(module_interface_cls, "__torch_script_interface__", False):
raise ValueError(
f"module_interface_cls {module_interface_cls} must be a type object decorated by "
"@torch.jit.interface"
)
# Generate the template instance name.
module_interface_cls_name = torch._jit_internal._qualified_name(
module_interface_cls
).replace(".", "_")
generated_module_name = f"{_FILE_PREFIX}{module_interface_cls_name}"
# Generate type annotation strs.
assign_module_interface_cls_str = (
f"from {module_interface_cls.__module__} import "
f"{module_interface_cls.__name__} as module_interface_cls"
)
args_str, arg_types_str, return_type_str = get_arg_return_types_from_interface(
module_interface_cls
)
kwargs_str = ""
arrow_and_return_type_str = f" -> {return_type_str}"
arrow_and_future_return_type_str = f" -> Future[{return_type_str}]"
str_dict = dict(
assign_module_interface_cls=assign_module_interface_cls_str,
arg_types=arg_types_str,
arrow_and_return_type=arrow_and_return_type_str,
arrow_and_future_return_type=arrow_and_future_return_type_str,
args=args_str,
kwargs=kwargs_str,
jit_script_decorator="@torch.jit.script",
)
return _do_instantiate_remote_module_template(
generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda
)
def instantiate_non_scriptable_remote_module_template():
generated_module_name = f"{_FILE_PREFIX}non_scriptable"
str_dict = dict(
assign_module_interface_cls="module_interface_cls = None",
args="*args",
kwargs="**kwargs",
arg_types="*args, **kwargs",
arrow_and_return_type="",
arrow_and_future_return_type="",
jit_script_decorator="",
)
# For a non-scriptable template, always enable moving CPU tensors to a cuda device,
# because there is no syntax limitation on the extra handling caused by the script.
return _do_instantiate_remote_module_template(generated_module_name, str_dict, True)

View File

@ -0,0 +1,108 @@
#!/usr/bin/python3
# mypy: allow-untyped-defs
def get_remote_module_template(enable_moving_cpu_tensors_to_cuda: bool):
return _TEMPLATE_PREFIX + (
_REMOTE_FORWARD_TEMPLATE_ENABLE_MOVING_CPU_TENSORS_TO_CUDA
if enable_moving_cpu_tensors_to_cuda
else _REMOTE_FORWARD_TEMPLATE
)
_TEMPLATE_PREFIX = """from typing import *
import torch
import torch.distributed.rpc as rpc
from torch import Tensor
from torch._jit_internal import Future
from torch.distributed.rpc import RRef
from typing import Tuple # pyre-ignore: unused import
{assign_module_interface_cls}
def forward_async(self, {arg_types}){arrow_and_future_return_type}:
args = (self.module_rref, self.device, self.is_device_map_set, {args})
kwargs = {{{kwargs}}}
return rpc.rpc_async(
self.module_rref.owner(),
_remote_forward,
args,
kwargs,
)
def forward(self, {arg_types}){arrow_and_return_type}:
args = (self.module_rref, self.device, self.is_device_map_set, {args})
kwargs = {{{kwargs}}}
ret_fut = rpc.rpc_async(
self.module_rref.owner(),
_remote_forward,
args,
kwargs,
)
return ret_fut.wait()
_generated_methods = [
forward_async,
forward,
]
{jit_script_decorator}
"""
# This template may cause typing error (the mismatch between ``Tuple[()]`` and ``Tuple[Any]``)
# even if the code is only used for instantiation but not execution.
# Therefore, only include handling moving CPU tensors to a cuda device if necessary.
# TODO: Merge these two templates together in the future once TorchScript syntax is improved.
_REMOTE_FORWARD_TEMPLATE_ENABLE_MOVING_CPU_TENSORS_TO_CUDA = """
def _remote_forward(
module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, {arg_types}){arrow_and_return_type}:
module = module_rref.local_value()
device = torch.device(device)
if device.type != "cuda":
return module.forward({args}, {kwargs})
# If the module is on a cuda device,
# move any CPU tensor in args or kwargs to the same cuda device.
# Since torch script does not support generator expression,
# have to use concatenation instead of
# ``tuple(i.to(device) if isinstance(i, Tensor) else i for i in *args)``.
args = ({args},)
out_args: Tuple[()] = ()
for arg in args:
arg = (arg.to(device),) if isinstance(arg, Tensor) else (arg,)
out_args = out_args + arg
kwargs = {{{kwargs}}}
for k, v in kwargs.items():
if isinstance(v, Tensor):
kwargs[k] = kwargs[k].to(device)
if is_device_map_set:
return module.forward(*out_args, {kwargs})
# If the device map is empty, then only CPU tensors are allowed to send over wire,
# so have to move any GPU tensor to CPU in the output.
# Since torch script does not support generator expression,
# have to use concatenation instead of
# ``tuple(i.cpu() if isinstance(i, Tensor) else i for i in module.forward(*out_args, {kwargs}))``.
ret: Tuple[()] = ()
for i in module.forward(*out_args, {kwargs}):
i = (i.cpu(),) if isinstance(i, Tensor) else (i,)
ret = ret + i
return ret
"""
_REMOTE_FORWARD_TEMPLATE = """
def _remote_forward(
module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, {arg_types}){arrow_and_return_type}:
module = module_rref.local_value()
return module.forward({args}, {kwargs})
"""