I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,97 @@
from typing import Any, Optional
import gymnasium as gym
import numpy as np
from numpy import ndarray
from godot_rl.core.godot_env import GodotEnv
from godot_rl.core.utils import lod_to_dol
class CleanRLGodotEnv:
def __init__(self, env_path: Optional[str] = None, n_parallel: int = 1, seed: int = 0, **kwargs: object) -> None:
# If we are doing editor training, n_parallel must be 1
if env_path is None and n_parallel > 1:
raise ValueError("You must provide the path to a exported game executable if n_parallel > 1")
# Define the default port
port = kwargs.pop("port", GodotEnv.DEFAULT_PORT)
# Create a list of GodotEnv instances
self.envs = [
GodotEnv(env_path=env_path, convert_action_space=True, port=port + p, seed=seed + p, **kwargs)
for p in range(n_parallel)
]
# Store the number of parallel environments
self.n_parallel = n_parallel
def _check_valid_action_space(self) -> None:
# Check if the action space is a tuple space with multiple spaces
action_space = self.envs[0].action_space
if isinstance(action_space, gym.spaces.Tuple):
assert (
len(action_space.spaces) == 1
), f"sb3 supports a single action space, this env contains multiple spaces {action_space}"
def step(self, action: np.ndarray) -> tuple[ndarray, list[Any], list[Any], list[Any], list[Any]]:
# Initialize lists for collecting results
all_obs = []
all_rewards = []
all_term = []
all_trunc = []
all_info = []
# Get the number of environments
num_envs = self.envs[0].num_envs
# Send actions to each environment
for i in range(self.n_parallel):
self.envs[i].step_send(action[i * num_envs : (i + 1) * num_envs])
# Receive results from each environment
for i in range(self.n_parallel):
obs, reward, term, trunc, info = self.envs[i].step_recv()
all_obs.extend(obs)
all_rewards.extend(reward)
all_term.extend(term)
all_trunc.extend(trunc)
all_info.extend(info)
# Convert list of dictionaries to dictionary of lists
obs = lod_to_dol(all_obs)
# Return results
return np.stack(obs["obs"]), all_rewards, all_term, all_trunc, all_info
def reset(self, seed) -> tuple[ndarray, list[Any]]:
# Initialize lists for collecting results
all_obs = []
all_info = []
# Reset each environment
for i in range(self.n_parallel):
obs, info = self.envs[i].reset()
all_obs.extend(obs)
all_info.extend(info)
# Convert list of dictionaries to dictionary of lists
obs = lod_to_dol(all_obs)
return np.stack(obs["obs"]), all_info
@property
def single_observation_space(self):
return self.envs[0].observation_space["obs"]
@property
def single_action_space(self):
return self.envs[0].action_space
@property
def num_envs(self) -> int:
return self.envs[0].num_envs * self.n_parallel
def close(self) -> None:
# Close each environment
for env in self.envs:
env.close()

View File

@ -0,0 +1,110 @@
import torch
from gymnasium.vector.utils import spaces
from stable_baselines3 import PPO
class OnnxableMultiInputPolicy(torch.nn.Module):
def __init__(
self,
obs_keys,
features_extractor,
mlp_extractor,
action_net,
value_net,
use_obs_array,
):
super().__init__()
self.obs_keys = obs_keys
self.features_extractor = features_extractor
self.mlp_extractor = mlp_extractor
self.action_net = action_net
self.value_net = value_net
self.use_obs_array = use_obs_array
def forward(self, obs, state_ins):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
features = None
if self.use_obs_array:
features = self.features_extractor(obs)
else:
obs_dict = {k: v for k, v in zip(self.obs_keys, obs)}
features = self.features_extractor(obs_dict)
action_hidden, value_hidden = self.mlp_extractor(features)
return self.action_net(action_hidden), state_ins
def export_ppo_model_as_onnx(ppo: PPO, onnx_model_path: str, use_obs_array: bool = False):
ppo_policy = ppo.policy.to("cpu")
onnxable_model = OnnxableMultiInputPolicy(
["obs"],
ppo_policy.features_extractor,
ppo_policy.mlp_extractor,
ppo_policy.action_net,
ppo_policy.value_net,
use_obs_array,
)
if use_obs_array:
dummy_input = torch.unsqueeze(torch.tensor(ppo.observation_space.sample()), 0)
else:
dummy_input = dict(ppo.observation_space.sample())
for k, v in dummy_input.items():
dummy_input[k] = torch.from_numpy(v).unsqueeze(0)
dummy_input = [v for v in dummy_input.values()]
torch.onnx.export(
onnxable_model,
args=(dummy_input, torch.zeros(1).float()),
f=onnx_model_path,
opset_version=9,
input_names=["obs", "state_ins"],
output_names=["output", "state_outs"],
dynamic_axes={
"obs": {0: "batch_size"},
"state_ins": {0: "batch_size"}, # variable length axes
"output": {0: "batch_size"},
"state_outs": {0: "batch_size"},
},
)
# If the space is MultiDiscrete, we skip verifying as action output will have an expected mismatch
# (the output from onnx will be the action logits for each discrete action,
# while the output from sb3 will be a single int)
if not isinstance(ppo.action_space, spaces.MultiDiscrete):
verify_onnx_export(ppo, onnx_model_path, use_obs_array=use_obs_array)
def verify_onnx_export(ppo: PPO, onnx_model_path: str, num_tests=10, use_obs_array: bool = False):
import numpy as np
import onnx
import onnxruntime as ort
onnx_model = onnx.load(onnx_model_path)
onnx.checker.check_model(onnx_model)
sb3_model = ppo.policy.to("cpu")
ort_sess = ort.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
for i in range(num_tests):
obs = None
obs2 = None
if use_obs_array:
obs = np.expand_dims(ppo.observation_space.sample(), axis=0)
obs2 = torch.tensor(obs)
else:
obs = dict(ppo.observation_space.sample())
obs2 = {}
for k, v in obs.items():
obs2[k] = torch.from_numpy(v).unsqueeze(0)
obs = [v for v in obs.values()]
with torch.no_grad():
action_sb3, _, _ = sb3_model(obs2, deterministic=True)
action_onnx, state_outs = ort_sess.run(None, {"obs": obs, "state_ins": np.array([0.0], dtype=np.float32)})
assert np.allclose(action_sb3, action_onnx, atol=1e-5), "Mismatch in action output"
assert np.allclose(state_outs, np.array([0.0]), atol=1e-5), "Mismatch in state_outs output"

View File

@ -0,0 +1,147 @@
# PettingZoo wrapper for GDRL
# Multi-agent, where 1 agent corresponds to one AIController instance in Godot
# Based on https://pettingzoo.farama.org/content/environment_creation/#example-custom-parallel-environment
# https://github.com/Farama-Foundation/PettingZoo/?tab=License-1-ov-file#readme
# and adjusted to work with GodotRL and Rllib (made for and tested only with rllib for now)
import functools
from typing import Dict
import numpy as np
from pettingzoo import ParallelEnv
from godot_rl.core.godot_env import GodotEnv
def env(render_mode=None):
"""
The env function often wraps the environment in wrappers by default.
You can find full documentation for these methods
elsewhere in the developer documentation.
"""
# Not implemented
return env
class GDRLPettingZooEnv(ParallelEnv):
metadata = {"render_modes": ["human"], "name": "GDRLPettingZooEnv"}
def __init__(self, port=GodotEnv.DEFAULT_PORT, show_window=True, seed=0, config: Dict = {}):
"""
The init method takes in environment arguments and should define the following attributes:
- possible_agents
- render_mode
Note: as of v1.18.1, the action_spaces and observation_spaces attributes are deprecated.
Spaces should be defined in the action_space() and observation_space() methods.
If these methods are not overridden, spaces will be inferred from self.observation_spaces/action_spaces, raising a warning.
These attributes should not be changed after initialization.
"""
# Initialize the Godot Env which we will wrap
self.godot_env = GodotEnv(
env_path=config.get("env_path"),
show_window=config.get("show_window"),
action_repeat=config.get("action_repeat"),
speedup=config.get("speedup"),
convert_action_space=False,
seed=seed,
port=port,
)
self.render_mode = None # Controlled by the env
self.possible_agents = [agent_idx for agent_idx in range(self.godot_env.num_envs)]
self.agents = self.possible_agents[:]
# The policy names here are set on each AIController in Godot editor,
# used to map agents to policies for multi-policy training.
self.agent_policy_names = self.godot_env.agent_policy_names
# optional: a mapping between agent name and ID
self.agent_name_mapping = dict(zip(self.possible_agents, list(range(len(self.possible_agents)))))
self.observation_spaces = {
agent: self.godot_env.observation_spaces[agent_idx] for agent_idx, agent in enumerate(self.agents)
}
self.action_spaces = {
agent: self.godot_env.tuple_action_spaces[agent_idx] for agent_idx, agent in enumerate(self.agents)
}
# Observation space should be defined here.
# lru_cache allows observation and action spaces to be memoized, reducing clock cycles required to get each agent's space.
# If your spaces change over time, remove this line (disable caching).
@functools.lru_cache(maxsize=None)
def observation_space(self, agent):
return self.observation_spaces[agent]
# Action space should be defined here.
# If your spaces change over time, remove this line (disable caching).
@functools.lru_cache(maxsize=None)
def action_space(self, agent):
return self.action_spaces[agent]
def render(self):
"""
Renders the environment. In human mode, it can print to terminal, open
up a graphical window, or open up some other display that a human can see and understand.
"""
# Not implemented
def close(self):
"""
Close should release any graphical displays, subprocesses, network connections
or any other environment data which should not be kept around after the
user is no longer using the environment.
"""
self.godot_env.close()
def reset(self, seed=None, options=None):
"""
Reset needs to initialize the `agents` attribute and must set up the
environment so that render(), and step() can be called without issues.
Returns the observations for each agent
"""
godot_obs, godot_infos = self.godot_env.reset()
observations = {agent: godot_obs[agent_idx] for agent_idx, agent in enumerate(self.agents)}
infos = {agent: godot_infos[agent_idx] for agent_idx, agent in enumerate(self.agents)}
return observations, infos
def step(self, actions):
"""
step(action) takes in an action for each agent and should return the
- observations
- rewards
- terminations
- truncations
- infos
dicts where each dict looks like {agent_1: item_1, agent_2: item_2}
"""
# Once an agent (AIController) has done = true, it will not receive any more actions until all agents in the
# Godot env have done = true. For agents that received no actions, we will set zeros instead for
# compatibility.
godot_actions = [
actions[agent] if agent in actions else np.zeros_like(self.action_spaces[agent_idx].sample())
for agent_idx, agent in enumerate(self.agents)
]
godot_obs, godot_rewards, godot_dones, godot_truncations, godot_infos = self.godot_env.step(
godot_actions, order_ij=True
)
observations = {agent: godot_obs[agent] for agent in actions}
rewards = {agent: godot_rewards[agent] for agent in actions}
terminations = {agent: godot_dones[agent] for agent in actions}
# Truncations are not yet implemented in GDRL API
truncations = {agent: False for agent in actions}
# typically there won't be any information in the infos, but there must
# still be an entry for each agent
infos = {agent: godot_infos[agent] for agent in actions}
return observations, rewards, terminations, truncations, infos

View File

@ -0,0 +1,139 @@
import os
import pathlib
from typing import List, Optional, Tuple
import numpy as np
import ray
import yaml
from ray import tune
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.utils.typing import EnvActionType, EnvInfoDict, EnvObsType
from godot_rl.core.godot_env import GodotEnv
class RayVectorGodotEnv(VectorEnv):
def __init__(
self,
port=10008,
seed=0,
config=None,
) -> None:
self._env = GodotEnv(
env_path=config["env_path"],
port=port,
seed=seed,
show_window=config["show_window"],
action_repeat=config["action_repeat"],
speedup=config["speedup"],
)
super().__init__(
observation_space=self._env.observation_space,
action_space=self._env.action_space,
num_envs=self._env.num_envs,
)
def vector_reset(
self, *, seeds: Optional[List[int]] = None, options: Optional[List[dict]] = None
) -> List[EnvObsType]:
self.obs, info = self._env.reset()
return self.obs, info
def vector_step(
self, actions: List[EnvActionType]
) -> Tuple[List[EnvObsType], List[float], List[bool], List[EnvInfoDict]]:
actions = np.array(actions, dtype=np.dtype(object))
self.obs, reward, term, trunc, info = self._env.step(actions, order_ij=True)
return self.obs, reward, term, trunc, info
def get_unwrapped(self):
return [self._env]
def reset_at(
self,
index: Optional[int] = None,
*,
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> EnvObsType:
# the env is reset automatically, no need to reset it
return self.obs[index], {}
def register_env():
tune.register_env(
"godot",
lambda c: RayVectorGodotEnv(
config=c,
port=c.worker_index + GodotEnv.DEFAULT_PORT + 10,
seed=c.worker_index + c["seed"],
),
)
# Refactored section: Commented onnx section was removed as it was re-implemented in rllib_example.py
def rllib_training(args, extras):
with open(args.config_file) as f:
exp = yaml.safe_load(f)
register_env()
exp["config"]["env_config"]["env_path"] = args.env_path
exp["config"]["env_config"]["seed"] = args.seed
if args.env_path is not None:
run_name = exp["algorithm"] + "/" + pathlib.Path(args.env_path).stem
else:
run_name = exp["algorithm"] + "/editor"
print("run_name", run_name)
if args.num_gpus is not None:
exp["config"]["num_gpus"] = args.num_gpus
if args.env_path is None:
print("SETTING WORKERS TO 1")
exp["config"]["num_workers"] = 1
checkpoint_freq = 10
exp["config"]["env_config"]["show_window"] = args.viz
exp["config"]["env_config"]["speedup"] = args.speedup
if args.eval or args.export:
checkpoint_freq = 0
exp["config"]["env_config"]["show_window"] = True
exp["config"]["env_config"]["framerate"] = None
exp["config"]["lr"] = 0.0
exp["config"]["num_sgd_iter"] = 1
exp["config"]["num_workers"] = 1
exp["config"]["train_batch_size"] = 8192
exp["config"]["sgd_minibatch_size"] = 128
exp["config"]["explore"] = False
exp["stop"]["training_iteration"] = 999999
print(exp)
ray.init(num_gpus=exp["config"]["num_gpus"] or 1)
if not args.export:
tune.run(
exp["algorithm"],
name=run_name,
config=exp["config"],
stop=exp["stop"],
verbose=3,
checkpoint_freq=checkpoint_freq,
checkpoint_at_end=not args.eval,
restore=args.restore,
storage_path=os.path.abspath(args.experiment_dir) or os.path.abspath("logs/rllib"),
trial_name_creator=lambda trial: (
f"{args.experiment_name}" if args.experiment_name else f"{trial.trainable_name}_{trial.trial_id}"
),
)
if args.export:
raise NotImplementedError("Use examples/rllib_example.py to export to onnx.")
# rllib_export(args.restore)
ray.shutdown()

View File

@ -0,0 +1,192 @@
import argparse
from functools import partial
import numpy as np
from gymnasium import Env
from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args
from sample_factory.enjoy import enjoy
from sample_factory.envs.env_utils import register_env
from sample_factory.train import run_rl
from godot_rl.core.godot_env import GodotEnv
from godot_rl.core.utils import lod_to_dol
class SampleFactoryEnvWrapperBatched(GodotEnv, Env):
@property
def unwrapped(self):
return self
@property
def num_agents(self):
return self.num_envs
def reset(self, seed=None, options=None):
obs, info = super().reset(seed=seed)
obs = lod_to_dol(obs)
return {k: np.array(v) for k, v in obs.items()}, info
def step(self, action):
obs, reward, term, trunc, info = super().step(action, order_ij=False)
obs = lod_to_dol(obs)
return {k: np.array(v) for k, v in obs.items()}, np.array(reward), np.array(term), np.array(trunc) * 0, info
@staticmethod
def to_numpy(lod):
for d in lod:
for k, v in d.items():
d[k] = np.array(v)
return lod
def render():
return
class SampleFactoryEnvWrapperNonBatched(GodotEnv, Env):
@property
def unwrapped(self):
return self
@property
def num_agents(self):
return self.num_envs
def reset(self, seed=None, options=None):
obs, info = super().reset(seed=seed)
return self.to_numpy(obs), info
def step(self, action):
obs, reward, term, trunc, info = super().step(action, order_ij=True)
return self.to_numpy(obs), np.array(reward), np.array(term), np.array(trunc) * 0, info
@staticmethod
def to_numpy(lod):
for d in lod:
for k, v in d.items():
d[k] = np.array(v)
return lod
def render():
return
def make_godot_env_func(
env_path, full_env_name, cfg=None, env_config=None, render_mode=None, seed=0, speedup=1, viz=False
):
port = cfg.base_port
print("BASE PORT ", cfg.base_port)
show_window = False
_seed = seed
if env_config:
port += 1 + env_config.env_id
_seed += 1 + env_config.env_id
print("env id", env_config.env_id)
if viz: #
print("creating viz env")
show_window = env_config.env_id == 0
if cfg.batched_sampling:
env = SampleFactoryEnvWrapperBatched(
env_path=env_path, port=port, seed=_seed, show_window=show_window, speedup=speedup
)
else:
env = SampleFactoryEnvWrapperNonBatched(
env_path=env_path, port=port, seed=_seed, show_window=show_window, speedup=speedup
)
return env
def register_gdrl_env(args):
make_env = partial(make_godot_env_func, args.env_path, speedup=args.speedup, seed=args.seed, viz=args.viz)
register_env("gdrl", make_env)
def gdrl_override_defaults(_env, parser):
"""RL params specific to Atari envs."""
parser.set_defaults(
# let's set this to True by default so it's consistent with how we report results for other envs
# (i.e. VizDoom or DMLab). When running evaluations for reports or to compare with other frameworks we can
# set this to false in command line
summaries_use_frameskip=True,
use_record_episode_statistics=True,
gamma=0.99,
env_frameskip=1,
env_framestack=4,
num_workers=1,
num_envs_per_worker=2,
worker_num_splits=2,
env_agents=16,
train_for_env_steps=1000000,
nonlinearity="relu",
kl_loss_coeff=0.0,
use_rnn=False,
adaptive_stddev=True,
reward_scale=1.0,
with_vtrace=False,
recurrence=1,
batch_size=2048,
rollout=32,
max_grad_norm=0.5,
num_epochs=2,
num_batches_per_epoch=4,
ppo_clip_ratio=0.2,
value_loss_coeff=0.5,
exploration_loss="entropy",
exploration_loss_coeff=0.000,
learning_rate=0.00025,
lr_schedule="linear_decay",
shuffle_minibatches=False,
gae_lambda=0.95,
batched_sampling=False,
normalize_input=True,
normalize_returns=True,
serial_mode=False,
async_rl=True,
experiment_summaries_interval=3,
adam_eps=1e-5,
)
def add_gdrl_env_args(_env, p: argparse.ArgumentParser, evaluation=False):
if evaluation:
# apparently env.render(mode="human") is not supported anymore and we need to specify the render mode in
# the env actor
p.add_argument("--render_mode", default="human", type=str, help="")
p.add_argument("--base_port", default=GodotEnv.DEFAULT_PORT, type=int, help="")
p.add_argument(
"--env_agents",
default=2,
type=int,
help="Num agents in each envpool (if used)",
)
def parse_gdrl_args(args, argv=None, evaluation=False):
parser, partial_cfg = parse_sf_args(argv=argv, evaluation=evaluation)
add_gdrl_env_args(partial_cfg.env, parser, evaluation=evaluation)
gdrl_override_defaults(partial_cfg.env, parser)
final_cfg = parse_full_cfg(parser, argv)
final_cfg.train_dir = args.experiment_dir or "logs/sf"
final_cfg.experiment = args.experiment_name or final_cfg.experiment
return final_cfg
def sample_factory_training(args, extras):
register_gdrl_env(args)
cfg = parse_gdrl_args(args=args, argv=extras, evaluation=args.eval)
# cfg.base_port = random.randint(20000, 22000)
status = run_rl(cfg)
return status
def sample_factory_enjoy(args, extras):
register_gdrl_env(args)
cfg = parse_gdrl_args(args=args, argv=extras, evaluation=args.eval)
status = enjoy(cfg)
return status

View File

@ -0,0 +1,33 @@
from typing import Any, Dict, List, Tuple
import gymnasium as gym
import numpy as np
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
# A variant of the Stable Baselines Godot Env that only supports a single obs space from the dictionary - obs["obs"] by default.
# This provides some basic support for using envs that have a single obs space with policies other than MultiInputPolicy.
class SBGSingleObsEnv(StableBaselinesGodotEnv):
def __init__(self, obs_key="obs", *args, **kwargs) -> None:
self.obs_key = obs_key
super().__init__(*args, **kwargs)
def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[Dict[str, Any]]]:
obs, rewards, term, info = super().step(action)
# Terminal obs info is needed for imitation learning
for idx, done in enumerate(term):
if done:
info[idx]["terminal_observation"] = obs[self.obs_key][idx]
return obs[self.obs_key], rewards, term, info
def reset(self) -> np.ndarray:
obs = super().reset()
return obs[self.obs_key]
@property
def observation_space(self) -> gym.Space:
return self.envs[0].observation_space[self.obs_key]

View File

@ -0,0 +1,176 @@
from typing import Any, Dict, List, Optional, Tuple
import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
from godot_rl.core.godot_env import GodotEnv
from godot_rl.core.utils import can_import, lod_to_dol
class StableBaselinesGodotEnv(VecEnv):
def __init__(
self,
env_path: Optional[str] = None,
n_parallel: int = 1,
seed: int = 0,
**kwargs,
) -> None:
# If we are doing editor training, n_parallel must be 1
if env_path is None and n_parallel > 1:
raise ValueError("You must provide the path to a exported game executable if n_parallel > 1")
# Define the default port
port = kwargs.pop("port", GodotEnv.DEFAULT_PORT)
# Create a list of GodotEnv instances
self.envs = [
GodotEnv(
env_path=env_path,
convert_action_space=True,
port=port + p,
seed=seed + p,
**kwargs,
)
for p in range(n_parallel)
]
# Store the number of parallel environments
self.n_parallel = n_parallel
# Check the action space for validity
self._check_valid_action_space()
# Initialize the results holder
self.results = None
def _check_valid_action_space(self) -> None:
# Check if the action space is a tuple space with multiple spaces
action_space = self.envs[0].action_space
if isinstance(action_space, gym.spaces.Tuple):
assert (
len(action_space.spaces) == 1
), f"sb3 supports a single action space, this env contains multiple spaces {action_space}"
def step(self, action: np.ndarray) -> Tuple[Dict[str, np.ndarray], np.ndarray, np.ndarray, List[Dict[str, Any]]]:
# Initialize lists for collecting results
all_obs = []
all_rewards = []
all_term = []
all_trunc = []
all_info = []
# Get the number of environments
num_envs = self.envs[0].num_envs
# Send actions to each environment
for i in range(self.n_parallel):
self.envs[i].step_send(action[i * num_envs : (i + 1) * num_envs])
# Receive results from each environment
for i in range(self.n_parallel):
obs, reward, term, trunc, info = self.envs[i].step_recv()
all_obs.extend(obs)
all_rewards.extend(reward)
all_term.extend(term)
all_trunc.extend(trunc)
all_info.extend(info)
# Convert list of dictionaries to dictionary of lists
obs = lod_to_dol(all_obs)
# Return results
return (
{k: np.array(v) for k, v in obs.items()},
np.array(all_rewards, dtype=np.float32),
np.array(all_term),
all_info,
)
def reset(self) -> Dict[str, np.ndarray]:
# Initialize lists for collecting results
all_obs = []
all_info = []
# Reset each environment
for i in range(self.n_parallel):
obs, info = self.envs[i].reset()
all_obs.extend(obs)
all_info.extend(info)
# Convert list of dictionaries to dictionary of lists
obs = lod_to_dol(all_obs)
return {k: np.array(v) for k, v in obs.items()}
def close(self) -> None:
# Close each environment
for env in self.envs:
env.close()
@property
def observation_space(self) -> gym.Space:
return self.envs[0].observation_space
@property
def action_space(self) -> gym.Space:
# sb3 is not compatible with tuple/dict action spaces
return self.envs[0].action_space
@property
def num_envs(self) -> int:
return self.envs[0].num_envs * self.n_parallel
def env_is_wrapped(self, wrapper_class: type, indices: Optional[List[int]] = None) -> List[bool]:
# Return a list indicating that no environments are wrapped
return [False] * (self.envs[0].num_envs * self.n_parallel)
# Placeholder methods that should be implemented for a full VecEnv implementation
def env_method(self):
raise NotImplementedError()
def get_attr(self, attr_name: str, indices=None) -> List[Any]:
if attr_name == "render_mode":
return [None for _ in range(self.num_envs)]
raise AttributeError("get attr not fully implemented in godot-rl StableBaselinesWrapper")
def seed(self, seed=None):
raise NotImplementedError()
def set_attr(self):
raise NotImplementedError()
def step_async(self, actions: np.ndarray) -> None:
# Execute the step function asynchronously, not actually implemented in this setting
self.results = self.step(actions)
def step_wait(
self,
) -> Tuple[Dict[str, np.ndarray], np.ndarray, np.ndarray, List[Dict[str, Any]]]:
# Wait for the results from the asynchronous step
return self.results
def stable_baselines_training(args, extras, n_steps: int = 200000, **kwargs) -> None:
if can_import("ray"):
print("WARNING, stable baselines and ray[rllib] are not compatible")
# Initialize the custom environment
env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup, **kwargs)
env = VecMonitor(env)
# Initialize the PPO model
model = PPO(
"MultiInputPolicy",
env,
ent_coef=0.0001,
verbose=2,
n_steps=32,
tensorboard_log=args.experiment_dir or "logs/sb3",
)
# Train the model
model.learn(n_steps, tb_log_name=args.experiment_name)
print("closing env")
env.close()