988 lines
42 KiB
Python
988 lines
42 KiB
Python
"""Policies: abstract base class and concrete implementations."""
|
|
|
|
import collections
|
|
import copy
|
|
import warnings
|
|
from abc import ABC, abstractmethod
|
|
from functools import partial
|
|
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
|
|
|
import numpy as np
|
|
import torch as th
|
|
from gymnasium import spaces
|
|
from torch import nn
|
|
|
|
from stable_baselines3.common.distributions import (
|
|
BernoulliDistribution,
|
|
CategoricalDistribution,
|
|
DiagGaussianDistribution,
|
|
Distribution,
|
|
MultiCategoricalDistribution,
|
|
StateDependentNoiseDistribution,
|
|
make_proba_distribution,
|
|
)
|
|
from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, maybe_transpose, preprocess_obs
|
|
from stable_baselines3.common.torch_layers import (
|
|
BaseFeaturesExtractor,
|
|
CombinedExtractor,
|
|
FlattenExtractor,
|
|
MlpExtractor,
|
|
NatureCNN,
|
|
create_mlp,
|
|
)
|
|
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
|
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
|
|
|
|
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
|
|
|
|
|
|
class BaseModel(nn.Module):
|
|
"""
|
|
The base model object: makes predictions in response to observations.
|
|
|
|
In the case of policies, the prediction is an action. In the case of critics, it is the
|
|
estimated value of the observation.
|
|
|
|
:param observation_space: The observation space of the environment
|
|
:param action_space: The action space of the environment
|
|
:param features_extractor_class: Features extractor to use.
|
|
:param features_extractor_kwargs: Keyword arguments
|
|
to pass to the features extractor.
|
|
:param features_extractor: Network to extract features
|
|
(a CNN when using images, a nn.Flatten() layer otherwise)
|
|
:param normalize_images: Whether to normalize images or not,
|
|
dividing by 255.0 (True by default)
|
|
:param optimizer_class: The optimizer to use,
|
|
``th.optim.Adam`` by default
|
|
:param optimizer_kwargs: Additional keyword arguments,
|
|
excluding the learning rate, to pass to the optimizer
|
|
"""
|
|
|
|
optimizer: th.optim.Optimizer
|
|
|
|
def __init__(
|
|
self,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
|
features_extractor: Optional[BaseFeaturesExtractor] = None,
|
|
normalize_images: bool = True,
|
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
if optimizer_kwargs is None:
|
|
optimizer_kwargs = {}
|
|
|
|
if features_extractor_kwargs is None:
|
|
features_extractor_kwargs = {}
|
|
|
|
self.observation_space = observation_space
|
|
self.action_space = action_space
|
|
self.features_extractor = features_extractor
|
|
self.normalize_images = normalize_images
|
|
|
|
self.optimizer_class = optimizer_class
|
|
self.optimizer_kwargs = optimizer_kwargs
|
|
|
|
self.features_extractor_class = features_extractor_class
|
|
self.features_extractor_kwargs = features_extractor_kwargs
|
|
# Automatically deactivate dtype and bounds checks
|
|
if not normalize_images and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)):
|
|
self.features_extractor_kwargs.update(dict(normalized_image=True))
|
|
|
|
def _update_features_extractor(
|
|
self,
|
|
net_kwargs: Dict[str, Any],
|
|
features_extractor: Optional[BaseFeaturesExtractor] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Update the network keyword arguments and create a new features extractor object if needed.
|
|
If a ``features_extractor`` object is passed, then it will be shared.
|
|
|
|
:param net_kwargs: the base network keyword arguments, without the ones
|
|
related to features extractor
|
|
:param features_extractor: a features extractor object.
|
|
If None, a new object will be created.
|
|
:return: The updated keyword arguments
|
|
"""
|
|
net_kwargs = net_kwargs.copy()
|
|
if features_extractor is None:
|
|
# The features extractor is not shared, create a new one
|
|
features_extractor = self.make_features_extractor()
|
|
net_kwargs.update(dict(features_extractor=features_extractor, features_dim=features_extractor.features_dim))
|
|
return net_kwargs
|
|
|
|
def make_features_extractor(self) -> BaseFeaturesExtractor:
|
|
"""Helper method to create a features extractor."""
|
|
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
|
|
|
def extract_features(self, obs: PyTorchObs, features_extractor: BaseFeaturesExtractor) -> th.Tensor:
|
|
"""
|
|
Preprocess the observation if needed and extract features.
|
|
|
|
:param obs: Observation
|
|
:param features_extractor: The features extractor to use.
|
|
:return: The extracted features
|
|
"""
|
|
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
|
|
return features_extractor(preprocessed_obs)
|
|
|
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
|
"""
|
|
Get data that need to be saved in order to re-create the model when loading it from disk.
|
|
|
|
:return: The dictionary to pass to the as kwargs constructor when reconstruction this model.
|
|
"""
|
|
return dict(
|
|
observation_space=self.observation_space,
|
|
action_space=self.action_space,
|
|
# Passed to the constructor by child class
|
|
# squash_output=self.squash_output,
|
|
# features_extractor=self.features_extractor
|
|
normalize_images=self.normalize_images,
|
|
)
|
|
|
|
@property
|
|
def device(self) -> th.device:
|
|
"""Infer which device this policy lives on by inspecting its parameters.
|
|
If it has no parameters, the 'cpu' device is used as a fallback.
|
|
|
|
:return:"""
|
|
for param in self.parameters():
|
|
return param.device
|
|
return get_device("cpu")
|
|
|
|
def save(self, path: str) -> None:
|
|
"""
|
|
Save model to a given location.
|
|
|
|
:param path:
|
|
"""
|
|
th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
|
|
|
|
@classmethod
|
|
def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel:
|
|
"""
|
|
Load model from path.
|
|
|
|
:param path:
|
|
:param device: Device on which the policy should be loaded.
|
|
:return:
|
|
"""
|
|
device = get_device(device)
|
|
# Note(antonin): we cannot use `weights_only=True` here because we need to allow
|
|
# gymnasium imports for the policy to be loaded successfully
|
|
saved_variables = th.load(path, map_location=device, weights_only=False)
|
|
|
|
# Create policy object
|
|
model = cls(**saved_variables["data"])
|
|
# Load weights
|
|
model.load_state_dict(saved_variables["state_dict"])
|
|
model.to(device)
|
|
return model
|
|
|
|
def load_from_vector(self, vector: np.ndarray) -> None:
|
|
"""
|
|
Load parameters from a 1D vector.
|
|
|
|
:param vector:
|
|
"""
|
|
th.nn.utils.vector_to_parameters(th.as_tensor(vector, dtype=th.float, device=self.device), self.parameters())
|
|
|
|
def parameters_to_vector(self) -> np.ndarray:
|
|
"""
|
|
Convert the parameters to a 1D vector.
|
|
|
|
:return:
|
|
"""
|
|
return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()
|
|
|
|
def set_training_mode(self, mode: bool) -> None:
|
|
"""
|
|
Put the policy in either training or evaluation mode.
|
|
|
|
This affects certain modules, such as batch normalisation and dropout.
|
|
|
|
:param mode: if true, set to training mode, else set to evaluation mode
|
|
"""
|
|
self.train(mode)
|
|
|
|
def is_vectorized_observation(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> bool:
|
|
"""
|
|
Check whether or not the observation is vectorized,
|
|
apply transposition to image (so that they are channel-first) if needed.
|
|
This is used in DQN when sampling random action (epsilon-greedy policy)
|
|
|
|
:param observation: the input observation to check
|
|
:return: whether the given observation is vectorized or not
|
|
"""
|
|
vectorized_env = False
|
|
if isinstance(observation, dict):
|
|
assert isinstance(
|
|
self.observation_space, spaces.Dict
|
|
), f"The observation provided is a dict but the obs space is {self.observation_space}"
|
|
for key, obs in observation.items():
|
|
obs_space = self.observation_space.spaces[key]
|
|
vectorized_env = vectorized_env or is_vectorized_observation(maybe_transpose(obs, obs_space), obs_space)
|
|
else:
|
|
vectorized_env = is_vectorized_observation(
|
|
maybe_transpose(observation, self.observation_space), self.observation_space
|
|
)
|
|
return vectorized_env
|
|
|
|
def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[PyTorchObs, bool]:
|
|
"""
|
|
Convert an input observation to a PyTorch tensor that can be fed to a model.
|
|
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
|
|
|
:param observation: the input observation
|
|
:return: The observation as PyTorch tensor
|
|
and whether the observation is vectorized or not
|
|
"""
|
|
vectorized_env = False
|
|
if isinstance(observation, dict):
|
|
assert isinstance(
|
|
self.observation_space, spaces.Dict
|
|
), f"The observation provided is a dict but the obs space is {self.observation_space}"
|
|
# need to copy the dict as the dict in VecFrameStack will become a torch tensor
|
|
observation = copy.deepcopy(observation)
|
|
for key, obs in observation.items():
|
|
obs_space = self.observation_space.spaces[key]
|
|
if is_image_space(obs_space):
|
|
obs_ = maybe_transpose(obs, obs_space)
|
|
else:
|
|
obs_ = np.array(obs)
|
|
vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space)
|
|
# Add batch dimension if needed
|
|
observation[key] = obs_.reshape((-1, *self.observation_space[key].shape)) # type: ignore[misc]
|
|
|
|
elif is_image_space(self.observation_space):
|
|
# Handle the different cases for images
|
|
# as PyTorch use channel first format
|
|
observation = maybe_transpose(observation, self.observation_space)
|
|
|
|
else:
|
|
observation = np.array(observation)
|
|
|
|
if not isinstance(observation, dict):
|
|
# Dict obs need to be handled separately
|
|
vectorized_env = is_vectorized_observation(observation, self.observation_space)
|
|
# Add batch dimension if needed
|
|
observation = observation.reshape((-1, *self.observation_space.shape)) # type: ignore[misc]
|
|
|
|
obs_tensor = obs_as_tensor(observation, self.device)
|
|
return obs_tensor, vectorized_env
|
|
|
|
|
|
class BasePolicy(BaseModel, ABC):
|
|
"""The base policy object.
|
|
|
|
Parameters are mostly the same as `BaseModel`; additions are documented below.
|
|
|
|
:param args: positional arguments passed through to `BaseModel`.
|
|
:param kwargs: keyword arguments passed through to `BaseModel`.
|
|
:param squash_output: For continuous actions, whether the output is squashed
|
|
or not using a ``tanh()`` function.
|
|
"""
|
|
|
|
features_extractor: BaseFeaturesExtractor
|
|
|
|
def __init__(self, *args, squash_output: bool = False, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._squash_output = squash_output
|
|
|
|
@staticmethod
|
|
def _dummy_schedule(progress_remaining: float) -> float:
|
|
"""(float) Useful for pickling policy."""
|
|
del progress_remaining
|
|
return 0.0
|
|
|
|
@property
|
|
def squash_output(self) -> bool:
|
|
"""(bool) Getter for squash_output."""
|
|
return self._squash_output
|
|
|
|
@staticmethod
|
|
def init_weights(module: nn.Module, gain: float = 1) -> None:
|
|
"""
|
|
Orthogonal initialization (used in PPO and A2C)
|
|
"""
|
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
nn.init.orthogonal_(module.weight, gain=gain)
|
|
if module.bias is not None:
|
|
module.bias.data.fill_(0.0)
|
|
|
|
@abstractmethod
|
|
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
|
"""
|
|
Get the action according to the policy for a given observation.
|
|
|
|
By default provides a dummy implementation -- not all BasePolicy classes
|
|
implement this, e.g. if they are a Critic in an Actor-Critic method.
|
|
|
|
:param observation:
|
|
:param deterministic: Whether to use stochastic or deterministic actions
|
|
:return: Taken action according to the policy
|
|
"""
|
|
|
|
def predict(
|
|
self,
|
|
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
|
state: Optional[Tuple[np.ndarray, ...]] = None,
|
|
episode_start: Optional[np.ndarray] = None,
|
|
deterministic: bool = False,
|
|
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
|
"""
|
|
Get the policy action from an observation (and optional hidden state).
|
|
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
|
|
|
:param observation: the input observation
|
|
:param state: The last hidden states (can be None, used in recurrent policies)
|
|
:param episode_start: The last masks (can be None, used in recurrent policies)
|
|
this correspond to beginning of episodes,
|
|
where the hidden states of the RNN must be reset.
|
|
:param deterministic: Whether or not to return deterministic actions.
|
|
:return: the model's action and the next hidden state
|
|
(used in recurrent policies)
|
|
"""
|
|
# Switch to eval mode (this affects batch norm / dropout)
|
|
self.set_training_mode(False)
|
|
|
|
# Check for common mistake that the user does not mix Gym/VecEnv API
|
|
# Tuple obs are not supported by SB3, so we can safely do that check
|
|
if isinstance(observation, tuple) and len(observation) == 2 and isinstance(observation[1], dict):
|
|
raise ValueError(
|
|
"You have passed a tuple to the predict() function instead of a Numpy array or a Dict. "
|
|
"You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) "
|
|
"vs `obs = vec_env.reset()` (SB3 VecEnv). "
|
|
"See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 "
|
|
"and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api"
|
|
)
|
|
|
|
obs_tensor, vectorized_env = self.obs_to_tensor(observation)
|
|
|
|
with th.no_grad():
|
|
actions = self._predict(obs_tensor, deterministic=deterministic)
|
|
# Convert to numpy, and reshape to the original action shape
|
|
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc]
|
|
|
|
if isinstance(self.action_space, spaces.Box):
|
|
if self.squash_output:
|
|
# Rescale to proper domain when using squashing
|
|
actions = self.unscale_action(actions) # type: ignore[assignment, arg-type]
|
|
else:
|
|
# Actions could be on arbitrary scale, so clip the actions to avoid
|
|
# out of bound error (e.g. if sampling from a Gaussian distribution)
|
|
actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type]
|
|
|
|
# Remove batch dimension if needed
|
|
if not vectorized_env:
|
|
assert isinstance(actions, np.ndarray)
|
|
actions = actions.squeeze(axis=0)
|
|
|
|
return actions, state # type: ignore[return-value]
|
|
|
|
def scale_action(self, action: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Rescale the action from [low, high] to [-1, 1]
|
|
(no need for symmetric action space)
|
|
|
|
:param action: Action to scale
|
|
:return: Scaled action
|
|
"""
|
|
assert isinstance(
|
|
self.action_space, spaces.Box
|
|
), f"Trying to scale an action using an action space that is not a Box(): {self.action_space}"
|
|
low, high = self.action_space.low, self.action_space.high
|
|
return 2.0 * ((action - low) / (high - low)) - 1.0
|
|
|
|
def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Rescale the action from [-1, 1] to [low, high]
|
|
(no need for symmetric action space)
|
|
|
|
:param scaled_action: Action to un-scale
|
|
"""
|
|
assert isinstance(
|
|
self.action_space, spaces.Box
|
|
), f"Trying to unscale an action using an action space that is not a Box(): {self.action_space}"
|
|
low, high = self.action_space.low, self.action_space.high
|
|
return low + (0.5 * (scaled_action + 1.0) * (high - low))
|
|
|
|
|
|
class ActorCriticPolicy(BasePolicy):
|
|
"""
|
|
Policy class for actor-critic algorithms (has both policy and value prediction).
|
|
Used by A2C, PPO and the likes.
|
|
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param lr_schedule: Learning rate schedule (could be constant)
|
|
:param net_arch: The specification of the policy and value networks.
|
|
:param activation_fn: Activation function
|
|
:param ortho_init: Whether to use or not orthogonal initialization
|
|
:param use_sde: Whether to use State Dependent Exploration or not
|
|
:param log_std_init: Initial value for the log standard deviation
|
|
:param full_std: Whether to use (n_features x n_actions) parameters
|
|
for the std instead of only (n_features,) when using gSDE
|
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
|
a positive standard deviation (cf paper). It allows to keep variance
|
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
|
:param squash_output: Whether to squash the output using a tanh function,
|
|
this allows to ensure boundaries when using gSDE.
|
|
:param features_extractor_class: Features extractor to use.
|
|
:param features_extractor_kwargs: Keyword arguments
|
|
to pass to the features extractor.
|
|
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
|
:param normalize_images: Whether to normalize images or not,
|
|
dividing by 255.0 (True by default)
|
|
:param optimizer_class: The optimizer to use,
|
|
``th.optim.Adam`` by default
|
|
:param optimizer_kwargs: Additional keyword arguments,
|
|
excluding the learning rate, to pass to the optimizer
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
lr_schedule: Schedule,
|
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
|
ortho_init: bool = True,
|
|
use_sde: bool = False,
|
|
log_std_init: float = 0.0,
|
|
full_std: bool = True,
|
|
use_expln: bool = False,
|
|
squash_output: bool = False,
|
|
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
|
share_features_extractor: bool = True,
|
|
normalize_images: bool = True,
|
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
|
):
|
|
if optimizer_kwargs is None:
|
|
optimizer_kwargs = {}
|
|
# Small values to avoid NaN in Adam optimizer
|
|
if optimizer_class == th.optim.Adam:
|
|
optimizer_kwargs["eps"] = 1e-5
|
|
|
|
super().__init__(
|
|
observation_space,
|
|
action_space,
|
|
features_extractor_class,
|
|
features_extractor_kwargs,
|
|
optimizer_class=optimizer_class,
|
|
optimizer_kwargs=optimizer_kwargs,
|
|
squash_output=squash_output,
|
|
normalize_images=normalize_images,
|
|
)
|
|
|
|
if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict):
|
|
warnings.warn(
|
|
(
|
|
"As shared layers in the mlp_extractor are removed since SB3 v1.8.0, "
|
|
"you should now pass directly a dictionary and not a list "
|
|
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
|
|
),
|
|
)
|
|
net_arch = net_arch[0]
|
|
|
|
# Default network architecture, from stable-baselines
|
|
if net_arch is None:
|
|
if features_extractor_class == NatureCNN:
|
|
net_arch = []
|
|
else:
|
|
net_arch = dict(pi=[64, 64], vf=[64, 64])
|
|
|
|
self.net_arch = net_arch
|
|
self.activation_fn = activation_fn
|
|
self.ortho_init = ortho_init
|
|
|
|
self.share_features_extractor = share_features_extractor
|
|
self.features_extractor = self.make_features_extractor()
|
|
self.features_dim = self.features_extractor.features_dim
|
|
if self.share_features_extractor:
|
|
self.pi_features_extractor = self.features_extractor
|
|
self.vf_features_extractor = self.features_extractor
|
|
else:
|
|
self.pi_features_extractor = self.features_extractor
|
|
self.vf_features_extractor = self.make_features_extractor()
|
|
|
|
self.log_std_init = log_std_init
|
|
dist_kwargs = None
|
|
|
|
assert not (squash_output and not use_sde), "squash_output=True is only available when using gSDE (use_sde=True)"
|
|
# Keyword arguments for gSDE distribution
|
|
if use_sde:
|
|
dist_kwargs = {
|
|
"full_std": full_std,
|
|
"squash_output": squash_output,
|
|
"use_expln": use_expln,
|
|
"learn_features": False,
|
|
}
|
|
|
|
self.use_sde = use_sde
|
|
self.dist_kwargs = dist_kwargs
|
|
|
|
# Action distribution
|
|
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
|
|
|
|
self._build(lr_schedule)
|
|
|
|
def _get_constructor_parameters(self) -> Dict[str, Any]:
|
|
data = super()._get_constructor_parameters()
|
|
|
|
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) # type: ignore[arg-type, return-value]
|
|
|
|
data.update(
|
|
dict(
|
|
net_arch=self.net_arch,
|
|
activation_fn=self.activation_fn,
|
|
use_sde=self.use_sde,
|
|
log_std_init=self.log_std_init,
|
|
squash_output=default_none_kwargs["squash_output"],
|
|
full_std=default_none_kwargs["full_std"],
|
|
use_expln=default_none_kwargs["use_expln"],
|
|
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
|
ortho_init=self.ortho_init,
|
|
optimizer_class=self.optimizer_class,
|
|
optimizer_kwargs=self.optimizer_kwargs,
|
|
features_extractor_class=self.features_extractor_class,
|
|
features_extractor_kwargs=self.features_extractor_kwargs,
|
|
)
|
|
)
|
|
return data
|
|
|
|
def reset_noise(self, n_envs: int = 1) -> None:
|
|
"""
|
|
Sample new weights for the exploration matrix.
|
|
|
|
:param n_envs:
|
|
"""
|
|
assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE"
|
|
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
|
|
|
|
def _build_mlp_extractor(self) -> None:
|
|
"""
|
|
Create the policy and value networks.
|
|
Part of the layers can be shared.
|
|
"""
|
|
# Note: If net_arch is None and some features extractor is used,
|
|
# net_arch here is an empty list and mlp_extractor does not
|
|
# really contain any layers (acts like an identity module).
|
|
self.mlp_extractor = MlpExtractor(
|
|
self.features_dim,
|
|
net_arch=self.net_arch,
|
|
activation_fn=self.activation_fn,
|
|
device=self.device,
|
|
)
|
|
|
|
def _build(self, lr_schedule: Schedule) -> None:
|
|
"""
|
|
Create the networks and the optimizer.
|
|
|
|
:param lr_schedule: Learning rate schedule
|
|
lr_schedule(1) is the initial learning rate
|
|
"""
|
|
self._build_mlp_extractor()
|
|
|
|
latent_dim_pi = self.mlp_extractor.latent_dim_pi
|
|
|
|
if isinstance(self.action_dist, DiagGaussianDistribution):
|
|
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
|
|
latent_dim=latent_dim_pi, log_std_init=self.log_std_init
|
|
)
|
|
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
|
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
|
|
latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init
|
|
)
|
|
elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)):
|
|
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.")
|
|
|
|
self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
|
|
# Init weights: use orthogonal initialization
|
|
# with small initial weight for the output
|
|
if self.ortho_init:
|
|
# TODO: check for features_extractor
|
|
# Values from stable-baselines.
|
|
# features_extractor/mlp values are
|
|
# originally from openai/baselines (default gains/init_scales).
|
|
module_gains = {
|
|
self.features_extractor: np.sqrt(2),
|
|
self.mlp_extractor: np.sqrt(2),
|
|
self.action_net: 0.01,
|
|
self.value_net: 1,
|
|
}
|
|
if not self.share_features_extractor:
|
|
# Note(antonin): this is to keep SB3 results
|
|
# consistent, see GH#1148
|
|
del module_gains[self.features_extractor]
|
|
module_gains[self.pi_features_extractor] = np.sqrt(2)
|
|
module_gains[self.vf_features_extractor] = np.sqrt(2)
|
|
|
|
for module, gain in module_gains.items():
|
|
module.apply(partial(self.init_weights, gain=gain))
|
|
|
|
# Setup optimizer with initial learning rate
|
|
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) # type: ignore[call-arg]
|
|
|
|
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
|
"""
|
|
Forward pass in all the networks (actor and critic)
|
|
|
|
:param obs: Observation
|
|
:param deterministic: Whether to sample or use deterministic actions
|
|
:return: action, value and log probability of the action
|
|
"""
|
|
# Preprocess the observation if needed
|
|
features = self.extract_features(obs)
|
|
if self.share_features_extractor:
|
|
latent_pi, latent_vf = self.mlp_extractor(features)
|
|
else:
|
|
pi_features, vf_features = features
|
|
latent_pi = self.mlp_extractor.forward_actor(pi_features)
|
|
latent_vf = self.mlp_extractor.forward_critic(vf_features)
|
|
# Evaluate the values for the given observations
|
|
values = self.value_net(latent_vf)
|
|
distribution = self._get_action_dist_from_latent(latent_pi)
|
|
actions = distribution.get_actions(deterministic=deterministic)
|
|
log_prob = distribution.log_prob(actions)
|
|
actions = actions.reshape((-1, *self.action_space.shape)) # type: ignore[misc]
|
|
return actions, values, log_prob
|
|
|
|
def extract_features( # type: ignore[override]
|
|
self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None
|
|
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
|
|
"""
|
|
Preprocess the observation if needed and extract features.
|
|
|
|
:param obs: Observation
|
|
:param features_extractor: The features extractor to use. If None, then ``self.features_extractor`` is used.
|
|
:return: The extracted features. If features extractor is not shared, returns a tuple with the
|
|
features for the actor and the features for the critic.
|
|
"""
|
|
if self.share_features_extractor:
|
|
return super().extract_features(obs, self.features_extractor if features_extractor is None else features_extractor)
|
|
else:
|
|
if features_extractor is not None:
|
|
warnings.warn(
|
|
"Provided features_extractor will be ignored because the features extractor is not shared.",
|
|
UserWarning,
|
|
)
|
|
|
|
pi_features = super().extract_features(obs, self.pi_features_extractor)
|
|
vf_features = super().extract_features(obs, self.vf_features_extractor)
|
|
return pi_features, vf_features
|
|
|
|
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
|
|
"""
|
|
Retrieve action distribution given the latent codes.
|
|
|
|
:param latent_pi: Latent code for the actor
|
|
:return: Action distribution
|
|
"""
|
|
mean_actions = self.action_net(latent_pi)
|
|
|
|
if isinstance(self.action_dist, DiagGaussianDistribution):
|
|
return self.action_dist.proba_distribution(mean_actions, self.log_std)
|
|
elif isinstance(self.action_dist, CategoricalDistribution):
|
|
# Here mean_actions are the logits before the softmax
|
|
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
|
elif isinstance(self.action_dist, MultiCategoricalDistribution):
|
|
# Here mean_actions are the flattened logits
|
|
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
|
elif isinstance(self.action_dist, BernoulliDistribution):
|
|
# Here mean_actions are the logits (before rounding to get the binary actions)
|
|
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
|
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
|
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
|
|
else:
|
|
raise ValueError("Invalid action distribution")
|
|
|
|
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
|
"""
|
|
Get the action according to the policy for a given observation.
|
|
|
|
:param observation:
|
|
:param deterministic: Whether to use stochastic or deterministic actions
|
|
:return: Taken action according to the policy
|
|
"""
|
|
return self.get_distribution(observation).get_actions(deterministic=deterministic)
|
|
|
|
def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
|
|
"""
|
|
Evaluate actions according to the current policy,
|
|
given the observations.
|
|
|
|
:param obs: Observation
|
|
:param actions: Actions
|
|
:return: estimated value, log likelihood of taking those actions
|
|
and entropy of the action distribution.
|
|
"""
|
|
# Preprocess the observation if needed
|
|
features = self.extract_features(obs)
|
|
if self.share_features_extractor:
|
|
latent_pi, latent_vf = self.mlp_extractor(features)
|
|
else:
|
|
pi_features, vf_features = features
|
|
latent_pi = self.mlp_extractor.forward_actor(pi_features)
|
|
latent_vf = self.mlp_extractor.forward_critic(vf_features)
|
|
distribution = self._get_action_dist_from_latent(latent_pi)
|
|
log_prob = distribution.log_prob(actions)
|
|
values = self.value_net(latent_vf)
|
|
entropy = distribution.entropy()
|
|
return values, log_prob, entropy
|
|
|
|
def get_distribution(self, obs: PyTorchObs) -> Distribution:
|
|
"""
|
|
Get the current policy distribution given the observations.
|
|
|
|
:param obs:
|
|
:return: the action distribution.
|
|
"""
|
|
features = super().extract_features(obs, self.pi_features_extractor)
|
|
latent_pi = self.mlp_extractor.forward_actor(features)
|
|
return self._get_action_dist_from_latent(latent_pi)
|
|
|
|
def predict_values(self, obs: PyTorchObs) -> th.Tensor:
|
|
"""
|
|
Get the estimated values according to the current policy given the observations.
|
|
|
|
:param obs: Observation
|
|
:return: the estimated values.
|
|
"""
|
|
features = super().extract_features(obs, self.vf_features_extractor)
|
|
latent_vf = self.mlp_extractor.forward_critic(features)
|
|
return self.value_net(latent_vf)
|
|
|
|
|
|
class ActorCriticCnnPolicy(ActorCriticPolicy):
|
|
"""
|
|
CNN policy class for actor-critic algorithms (has both policy and value prediction).
|
|
Used by A2C, PPO and the likes.
|
|
|
|
:param observation_space: Observation space
|
|
:param action_space: Action space
|
|
:param lr_schedule: Learning rate schedule (could be constant)
|
|
:param net_arch: The specification of the policy and value networks.
|
|
:param activation_fn: Activation function
|
|
:param ortho_init: Whether to use or not orthogonal initialization
|
|
:param use_sde: Whether to use State Dependent Exploration or not
|
|
:param log_std_init: Initial value for the log standard deviation
|
|
:param full_std: Whether to use (n_features x n_actions) parameters
|
|
for the std instead of only (n_features,) when using gSDE
|
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
|
a positive standard deviation (cf paper). It allows to keep variance
|
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
|
:param squash_output: Whether to squash the output using a tanh function,
|
|
this allows to ensure boundaries when using gSDE.
|
|
:param features_extractor_class: Features extractor to use.
|
|
:param features_extractor_kwargs: Keyword arguments
|
|
to pass to the features extractor.
|
|
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
|
:param normalize_images: Whether to normalize images or not,
|
|
dividing by 255.0 (True by default)
|
|
:param optimizer_class: The optimizer to use,
|
|
``th.optim.Adam`` by default
|
|
:param optimizer_kwargs: Additional keyword arguments,
|
|
excluding the learning rate, to pass to the optimizer
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
lr_schedule: Schedule,
|
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
|
ortho_init: bool = True,
|
|
use_sde: bool = False,
|
|
log_std_init: float = 0.0,
|
|
full_std: bool = True,
|
|
use_expln: bool = False,
|
|
squash_output: bool = False,
|
|
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
|
share_features_extractor: bool = True,
|
|
normalize_images: bool = True,
|
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
|
):
|
|
super().__init__(
|
|
observation_space,
|
|
action_space,
|
|
lr_schedule,
|
|
net_arch,
|
|
activation_fn,
|
|
ortho_init,
|
|
use_sde,
|
|
log_std_init,
|
|
full_std,
|
|
use_expln,
|
|
squash_output,
|
|
features_extractor_class,
|
|
features_extractor_kwargs,
|
|
share_features_extractor,
|
|
normalize_images,
|
|
optimizer_class,
|
|
optimizer_kwargs,
|
|
)
|
|
|
|
|
|
class MultiInputActorCriticPolicy(ActorCriticPolicy):
|
|
"""
|
|
MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction).
|
|
Used by A2C, PPO and the likes.
|
|
|
|
:param observation_space: Observation space (Tuple)
|
|
:param action_space: Action space
|
|
:param lr_schedule: Learning rate schedule (could be constant)
|
|
:param net_arch: The specification of the policy and value networks.
|
|
:param activation_fn: Activation function
|
|
:param ortho_init: Whether to use or not orthogonal initialization
|
|
:param use_sde: Whether to use State Dependent Exploration or not
|
|
:param log_std_init: Initial value for the log standard deviation
|
|
:param full_std: Whether to use (n_features x n_actions) parameters
|
|
for the std instead of only (n_features,) when using gSDE
|
|
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
|
a positive standard deviation (cf paper). It allows to keep variance
|
|
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
|
:param squash_output: Whether to squash the output using a tanh function,
|
|
this allows to ensure boundaries when using gSDE.
|
|
:param features_extractor_class: Uses the CombinedExtractor
|
|
:param features_extractor_kwargs: Keyword arguments
|
|
to pass to the features extractor.
|
|
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
|
|
:param normalize_images: Whether to normalize images or not,
|
|
dividing by 255.0 (True by default)
|
|
:param optimizer_class: The optimizer to use,
|
|
``th.optim.Adam`` by default
|
|
:param optimizer_kwargs: Additional keyword arguments,
|
|
excluding the learning rate, to pass to the optimizer
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
observation_space: spaces.Dict,
|
|
action_space: spaces.Space,
|
|
lr_schedule: Schedule,
|
|
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
|
activation_fn: Type[nn.Module] = nn.Tanh,
|
|
ortho_init: bool = True,
|
|
use_sde: bool = False,
|
|
log_std_init: float = 0.0,
|
|
full_std: bool = True,
|
|
use_expln: bool = False,
|
|
squash_output: bool = False,
|
|
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
|
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
|
share_features_extractor: bool = True,
|
|
normalize_images: bool = True,
|
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
|
):
|
|
super().__init__(
|
|
observation_space,
|
|
action_space,
|
|
lr_schedule,
|
|
net_arch,
|
|
activation_fn,
|
|
ortho_init,
|
|
use_sde,
|
|
log_std_init,
|
|
full_std,
|
|
use_expln,
|
|
squash_output,
|
|
features_extractor_class,
|
|
features_extractor_kwargs,
|
|
share_features_extractor,
|
|
normalize_images,
|
|
optimizer_class,
|
|
optimizer_kwargs,
|
|
)
|
|
|
|
|
|
class ContinuousCritic(BaseModel):
|
|
"""
|
|
Critic network(s) for DDPG/SAC/TD3.
|
|
It represents the action-state value function (Q-value function).
|
|
Compared to A2C/PPO critics, this one represents the Q-value
|
|
and takes the continuous action as input. It is concatenated with the state
|
|
and then fed to the network which outputs a single value: Q(s, a).
|
|
For more recent algorithms like SAC/TD3, multiple networks
|
|
are created to give different estimates.
|
|
|
|
By default, it creates two critic networks used to reduce overestimation
|
|
thanks to clipped Q-learning (cf TD3 paper).
|
|
|
|
:param observation_space: Obervation space
|
|
:param action_space: Action space
|
|
:param net_arch: Network architecture
|
|
:param features_extractor: Network to extract features
|
|
(a CNN when using images, a nn.Flatten() layer otherwise)
|
|
:param features_dim: Number of features
|
|
:param activation_fn: Activation function
|
|
:param normalize_images: Whether to normalize images or not,
|
|
dividing by 255.0 (True by default)
|
|
:param n_critics: Number of critic networks to create.
|
|
:param share_features_extractor: Whether the features extractor is shared or not
|
|
between the actor and the critic (this saves computation time)
|
|
"""
|
|
|
|
features_extractor: BaseFeaturesExtractor
|
|
|
|
def __init__(
|
|
self,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Box,
|
|
net_arch: List[int],
|
|
features_extractor: BaseFeaturesExtractor,
|
|
features_dim: int,
|
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
|
normalize_images: bool = True,
|
|
n_critics: int = 2,
|
|
share_features_extractor: bool = True,
|
|
):
|
|
super().__init__(
|
|
observation_space,
|
|
action_space,
|
|
features_extractor=features_extractor,
|
|
normalize_images=normalize_images,
|
|
)
|
|
|
|
action_dim = get_action_dim(self.action_space)
|
|
|
|
self.share_features_extractor = share_features_extractor
|
|
self.n_critics = n_critics
|
|
self.q_networks: List[nn.Module] = []
|
|
for idx in range(n_critics):
|
|
q_net_list = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
|
|
q_net = nn.Sequential(*q_net_list)
|
|
self.add_module(f"qf{idx}", q_net)
|
|
self.q_networks.append(q_net)
|
|
|
|
def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]:
|
|
# Learn the features extractor using the policy loss only
|
|
# when the features_extractor is shared with the actor
|
|
with th.set_grad_enabled(not self.share_features_extractor):
|
|
features = self.extract_features(obs, self.features_extractor)
|
|
qvalue_input = th.cat([features, actions], dim=1)
|
|
return tuple(q_net(qvalue_input) for q_net in self.q_networks)
|
|
|
|
def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
|
|
"""
|
|
Only predict the Q-value using the first network.
|
|
This allows to reduce computation when all the estimates are not needed
|
|
(e.g. when updating the policy in TD3).
|
|
"""
|
|
with th.no_grad():
|
|
features = self.extract_features(obs, self.features_extractor)
|
|
return self.q_networks[0](th.cat([features, actions], dim=1))
|