I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,485 @@
import atexit
import json
import os
import pathlib
import socket
import subprocess
import time
from collections import OrderedDict
from sys import platform
from typing import Optional
import numpy as np
from gymnasium import spaces
from godot_rl.core.utils import ActionSpaceProcessor, convert_macos_path
class GodotEnv:
MAJOR_VERSION = "0" # Versioning for the environment
MINOR_VERSION = "7"
DEFAULT_PORT = 11008 # Default port for communication with Godot Game
DEFAULT_TIMEOUT = 60 # Default socket timeout TODO
def __init__(
self,
env_path: str = None,
port: int = DEFAULT_PORT,
show_window: bool = False,
seed: int = 0,
framerate: Optional[int] = None,
action_repeat: Optional[int] = None,
speedup: Optional[int] = None,
convert_action_space: bool = False,
):
"""
Initialize a new instance of GodotEnv
Args:
env_path (str): path to the godot binary environment.
port (int): Port number for communication.
show_window (bool): flag to display Godot game window.
seed (int): seed to initialize the environment.
framerate (int): the framerate to run the Godot game at.
action_repeat (int): the number of frames to repeat an action for.
speedup (int): the factor to speedup game time by.
convert_action_space (bool): flag to convert action space.
"""
self.proc = None
if env_path is not None and env_path != "debug":
env_path = self._set_platform_suffix(env_path)
self.check_platform(env_path)
self._launch_env(env_path, port, show_window, framerate, seed, action_repeat, speedup)
else:
print("No game binary has been provided, please press PLAY in the Godot editor")
self.port = port
self.connection = self._start_server()
self.num_envs = None
self._handshake()
# Action and observation spaces for each in-game agent/env/AIController (used only for multi-agent case with Rllib for now)
self.action_spaces = []
self.observation_spaces = []
self._get_env_info()
# Single-agent observation space
self.observation_space = self.observation_spaces[0]
# sf2 requires a tuple action space
# Multiple agents' action space(s)
self.tuple_action_spaces = [
spaces.Tuple([v for _, v in action_space.items()]) for action_space in self.action_spaces
]
# Single agent action space processor using the action space(s) of the first agent
self.action_space_processor = ActionSpaceProcessor(self.tuple_action_spaces[0], convert_action_space)
# For multi-policy envs: The name of each agent's policy set in the env itself (any training_mode
# AIController instance is treated as an agent)
self.agent_policy_names
atexit.register(self._close)
def _set_platform_suffix(self, env_path: str) -> str:
"""
Set the platform suffix for the given environment path based on the platform.
Args:
env_path (str): The environment path.
Returns:
str: The environment path with the platform suffix.
"""
suffixes = {
"linux": ".x86_64",
"linux2": ".x86_64",
"darwin": ".app",
"win32": ".exe",
}
suffix = suffixes[platform]
return str(pathlib.Path(env_path).with_suffix(suffix))
def check_platform(self, filename: str):
"""
Check the platform and assert the file type
Args:
filename (str): Path of the file to check.
Raises:
AssertionError: If the file type does not match with the platform or file does not exist.
"""
if platform == "linux" or platform == "linux2":
# Linux
assert (
pathlib.Path(filename).suffix == ".x86_64"
), f"Incorrect file suffix for {filename=} {pathlib.Path(filename).suffix=}. Please provide a .x86_64 file"
elif platform == "darwin":
# OSX
assert (
pathlib.Path(filename).suffix == ".app"
), f"Incorrect file suffix for {filename=} {pathlib.Path(filename).suffix=}. Please provide a .app file"
elif platform == "win32":
# Windows...
assert (
pathlib.Path(filename).suffix == ".exe"
), f"Incorrect file suffix for {filename=} {pathlib.Path(filename).suffix=}. Please provide a .exe file"
else:
assert 0, f"unknown filetype {pathlib.Path(filename).suffix}"
assert os.path.exists(filename)
def from_numpy(self, action, order_ij=False):
"""
Handles dict to tuple actions
Args:
action: The action to be converted.
order_ij (bool): Order flag.
Returns:
list: The converted action.
"""
result = []
for agent_idx in range(self.num_envs):
env_action = {}
for j, k in enumerate(self.action_spaces[agent_idx].keys()):
if order_ij is True:
v = action[agent_idx][j]
else:
v = action[j][agent_idx]
if isinstance(v, np.ndarray):
env_action[k] = v.tolist()
else:
env_action[k] = int(v) # cannot serialize int32
result.append(env_action)
return result
def step(self, action, order_ij=False):
"""
Perform one step in the environment.
Args:
action: Action to be taken.
order_ij (bool): Order flag.
Returns:
tuple: Tuple containing observation, reward, done flag, termination flag, and info.
"""
self.step_send(action, order_ij=order_ij)
return self.step_recv()
def step_send(self, action, order_ij=False):
"""
Send the action to the Godot environment.
Args:
action: Action to be sent.
order_ij (bool): Order flag.
"""
action = self.action_space_processor.to_original_dist(action)
message = {
"type": "action",
"action": self.from_numpy(action, order_ij=order_ij),
}
self._send_as_json(message)
def step_recv(self):
"""
Receive the step response from the Godot environment.
Returns:
tuple: Tuple containing observation, reward, done flag, termination flag, and info.
"""
response = self._get_json_dict()
response["obs"] = self._process_obs(response["obs"])
return (
response["obs"],
response["reward"],
np.array(response["done"]).tolist(),
np.array(response["done"]).tolist(), # TODO update API to term, trunc
[{}] * len(response["done"]),
)
def _process_obs(self, response_obs: dict):
"""
Process observation data.
Args:
response_obs (dict): The response observation to be processed.
Returns:
dict: The processed observation data.
"""
for k in response_obs[0].keys():
if "2d" in k:
for sub in response_obs:
sub[k] = self._decode_2d_obs_from_string(sub[k], self.observation_space[k].shape)
return response_obs
def reset(self, seed=None):
"""
Reset the Godot environment.
Returns:
dict: The initial observation data.
"""
message = {
"type": "reset",
}
self._send_as_json(message)
response = self._get_json_dict()
response["obs"] = self._process_obs(response["obs"])
assert response["type"] == "reset"
obs = response["obs"]
return obs, [{}] * self.num_envs
def call(self, method):
message = {
"type": "call",
"method": method,
}
self._send_as_json(message)
response = self._get_json_dict()
return response["returns"]
def close(self):
message = {
"type": "close",
}
self._send_as_json(message)
print("close message sent")
time.sleep(1.0)
self.connection.close()
try:
atexit.unregister(self._close)
except Exception as e:
print("exception unregistering close method", e)
@property
def action_space(self):
"""
Returns a single action space.
"""
return self.action_space_processor.action_space
def _close(self):
print("exit was not clean, using atexit to close env")
self.close()
def _launch_env(self, env_path, port, show_window, framerate, seed, action_repeat, speedup):
# --fixed-fps {framerate}
path = convert_macos_path(env_path) if platform == "darwin" else env_path
launch_cmd = f"{path} --port={port} --env_seed={seed}"
if show_window is False:
launch_cmd += " --disable-render-loop --headless"
if framerate is not None:
launch_cmd += f" --fixed-fps {framerate}"
if action_repeat is not None:
launch_cmd += f" --action_repeat={action_repeat}"
if speedup is not None:
launch_cmd += f" --speedup={speedup}"
launch_cmd = launch_cmd.split(" ")
self.proc = subprocess.Popen(
launch_cmd,
start_new_session=True,
# shell=True,
)
def _start_server(self):
# Either launch a an exported Godot project or connect to a playing godot game
# connect to playing godot game
print(f"waiting for remote GODOT connection on port {self.port}")
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# Bind the socket to the port, "localhost" was not working on windows VM, had to use the IP
server_address = ("127.0.0.1", self.port)
sock.bind(server_address)
# Listen for incoming connections
sock.listen(1)
sock.settimeout(GodotEnv.DEFAULT_TIMEOUT)
connection, client_address = sock.accept()
# connection.settimeout(GodotEnv.DEFAULT_TIMEOUT)
# connection.setblocking(False) TODO
print("connection established")
return connection
def _handshake(self):
message = {
"type": "handshake",
"major_version": GodotEnv.MAJOR_VERSION,
"minor_version": GodotEnv.MINOR_VERSION,
}
self._send_as_json(message)
def _get_env_info(self):
message = {"type": "env_info"}
self._send_as_json(message)
json_dict = self._get_json_dict()
assert json_dict["type"] == "env_info"
# Number of AIController instances in a single Godot env/process
self.num_envs = json_dict["n_agents"]
# actions can be "single" for a single action head
# or "multi" for several outputeads
print("action space", json_dict["action_space"])
# Compatibility with previous versions of Godot plugin:
# A single action space will be received as a dict in previous versions,
# A list of dicts will be received from the newer version, defining the action_space for each agent (AIController)
if isinstance(json_dict["action_space"], dict):
json_dict["action_space"] = [json_dict["action_space"]] * self.num_envs
for agent_action_space in json_dict["action_space"]:
tmp_action_spaces = OrderedDict()
for k, v in agent_action_space.items():
if v["action_type"] == "discrete":
tmp_action_spaces[k] = spaces.Discrete(v["size"])
elif v["action_type"] == "continuous":
tmp_action_spaces[k] = spaces.Box(low=-1.0, high=1.0, shape=(v["size"],))
else:
print(f"action space {v['action_type']} is not supported")
assert 0, f"action space {v['action_type']} is not supported"
self.action_spaces.append(spaces.Dict(tmp_action_spaces))
print("observation space", json_dict["observation_space"])
# Compatibility with older versions of Godot plugin:
# A single observation space will be received as a dict in previous versions,
# A list of dicts will be received from newer version, defining the observation_space for each agent (AIController)
if isinstance(json_dict["observation_space"], dict):
json_dict["observation_space"] = [json_dict["observation_space"]] * self.num_envs
for agent_obs_space in json_dict["observation_space"]:
observation_spaces = {}
for k, v in agent_obs_space.items():
if v["space"] == "box":
if "2d" in k:
observation_spaces[k] = spaces.Box(
low=0,
high=255,
shape=v["size"],
dtype=np.uint8,
)
else:
observation_spaces[k] = spaces.Box(
low=-1.0,
high=1.0,
shape=v["size"],
dtype=np.float32,
)
elif v["space"] == "discrete":
observation_spaces[k] = spaces.Discrete(v["size"])
else:
print(f"observation space {v['space']} is not supported")
assert 0, f"observation space {v['space']} is not supported"
self.observation_spaces.append(spaces.Dict(observation_spaces))
# Gets policy names defined in AIControllers in Godot. If an older version of the plugin is used and no policy
# names are sent, "shared_policy" will be set for compatibility.
self.agent_policy_names = json_dict.get("agent_policy_names", ["shared_policy"] * self.num_envs)
@staticmethod
def _decode_2d_obs_from_string(
hex_string,
shape,
):
return np.frombuffer(bytes.fromhex(hex_string), dtype=np.uint8).reshape(shape)
def _send_as_json(self, dictionary):
message_json = json.dumps(dictionary)
self._send_string(message_json)
def _get_json_dict(self):
data = self._get_data()
return json.loads(data)
def _get_obs(self):
return self._get_data()
def _clear_socket(self):
self.connection.setblocking(False)
try:
while True:
data = self.connection.recv(4)
if not data:
break
except BlockingIOError:
pass
self.connection.setblocking(True)
def _get_data(self):
try:
# Receive the size (in bytes) of the remaining data to receive
string_size_bytes: bytearray = bytearray()
received_length: int = 0
# The first 4 bytes contain the length of the remaining data
length: int = 4
while received_length < length:
data = self.connection.recv(length - received_length)
received_length += len(data)
string_size_bytes.extend(data)
length = int.from_bytes(string_size_bytes, "little")
# Receive the rest of the data
string_bytes: bytearray = bytearray()
received_length = 0
while received_length < length:
data = self.connection.recv(length - received_length)
received_length += len(data)
string_bytes.extend(data)
string: str = string_bytes.decode()
return string
except socket.timeout as e:
print("env timed out", e)
return None
def _send_string(self, string):
message = len(string).to_bytes(4, "little") + bytes(string.encode())
self.connection.sendall(message)
def _send_action(self, action):
self._send_string(action)
def interactive():
env = GodotEnv()
print("observation space", env.observation_space)
print("action space", env.action_space)
obs = env.reset()
for i in range(1000):
action = [env.action_space.sample() for _ in range(env.num_envs)]
action = list(zip(*action))
obs, reward, term, trunc, info = env.step(action)
env.close()
if __name__ == "__main__":
interactive()

View File

@ -0,0 +1,136 @@
import importlib
import re
import gymnasium as gym
import numpy as np
def lod_to_dol(lod):
return {k: [dic[k] for dic in lod] for k in lod[0]}
def dol_to_lod(dol):
return [dict(zip(dol, t)) for t in zip(*dol.values())]
def convert_macos_path(env_path):
"""
On MacOs the user is supposed to provide a application.app file to env_path.
However the actual binary is in application.app/Contents/Macos/application.
This helper function converts the path to the path of the actual binary.
Example input: ./Demo.app
Example output: ./Demo.app/Contents/Macos/Demo
"""
filenames = re.findall(r"[^\/]+(?=\.)", env_path)
assert len(filenames) == 1, "An error occured while converting the env path for MacOS."
return env_path + "/Contents/MacOS/" + filenames[0]
class ActionSpaceProcessor:
# can convert tuple action dists to a single continuous action distribution
# eg (Box(a), Box(b)) -> Box(a+b)
# (Box(a), Discrete(2)) -> Box(a+2)
# etc
# does not yet work with discrete dists of n>2
def __init__(self, action_space: gym.spaces.Tuple, convert) -> None:
self._original_action_space = action_space
self._convert = convert
space_size = 0
if convert:
use_multi_discrete_spaces = False
multi_discrete_spaces = np.array([])
if isinstance(action_space, gym.spaces.Tuple):
if all(isinstance(space, gym.spaces.Discrete) for space in action_space.spaces):
use_multi_discrete_spaces = True
for space in action_space.spaces:
multi_discrete_spaces = np.append(multi_discrete_spaces, space.n)
else:
for space in action_space.spaces:
if isinstance(space, gym.spaces.Box):
assert len(space.shape) == 1
space_size += space.shape[0]
elif isinstance(space, gym.spaces.Discrete):
if space.n > 2:
# for now only binary actions are supported if you mix different spaces
raise NotImplementedError(
"Discrete actions with size larger than 2 "
"are currently not supported if used together with continuous actions."
)
space_size += 1
else:
raise NotImplementedError
elif isinstance(action_space, gym.spaces.Dict):
raise NotImplementedError
else:
assert isinstance(action_space, (gym.spaces.Box, gym.spaces.Discrete))
return
if use_multi_discrete_spaces:
self.converted_action_space = gym.spaces.MultiDiscrete(multi_discrete_spaces)
else:
self.converted_action_space = gym.spaces.Box(-1, 1, shape=[space_size])
@property
def action_space(self):
if not self._convert:
return self._original_action_space
return self.converted_action_space
def to_original_dist(self, action):
if not self._convert:
return action
original_action = []
counter = 0
# If only discrete actions are used in the environment:
# - SB3 will send int actions containing the discrete action,
# - CleanRL example script (continuous PPO) will only send float actions, which we convert to binary discrete,
# - If mixed actions are used, both will send float actions.
integer_actions: bool = action.dtype == np.int64
for space in self._original_action_space.spaces:
if isinstance(space, gym.spaces.Box):
assert len(space.shape) == 1
original_action.append(action[:, counter : counter + space.shape[0]])
counter += space.shape[0]
elif isinstance(space, gym.spaces.Discrete):
discrete_actions = None
if integer_actions:
discrete_actions = action[:, counter]
else:
if space.n > 2:
raise NotImplementedError(
"Discrete actions with size larger than "
"2 are currently not implemented for this algorithm."
)
# If the action is not an integer, convert it to a binary discrete action
discrete_actions = np.greater(action[:, counter], 0.0)
discrete_actions = discrete_actions.astype(np.float32)
original_action.append(discrete_actions)
counter += 1
else:
raise NotImplementedError
return original_action
def can_import(module_name):
return not cant_import(module_name)
def cant_import(module_name):
try:
importlib.import_module(module_name)
return False
except ImportError:
return True

View File

@ -0,0 +1,36 @@
# we download examples from github and we save them in the examples folder
import os
import shutil
from zipfile import ZipFile
import wget
BRANCHES = {"4": "main", "3": "godot3.5"}
BASE_URL = "https://github.com/edbeeching/godot_rl_agents_examples"
def download_examples():
# select branch
print("Select Godot version:")
for key in BRANCHES.keys():
print(f"{key} : {BRANCHES[key]}")
branch = input("Enter your choice: ")
BRANCH = BRANCHES[branch]
os.makedirs("examples", exist_ok=True)
URL = f"{BASE_URL}/archive/refs/heads/{BRANCH}.zip"
print(f"downloading examples from {URL}")
wget.download(URL, out="")
print()
print("unzipping")
with ZipFile(f"{BRANCH}.zip", "r") as zipObj:
# Extract all the contents of zip file in different directory
zipObj.extractall("examples/")
print("cleaning up")
os.remove(f"{BRANCH}.zip")
print("moving files")
for file in os.listdir(f"examples/godot_rl_agents_examples-{BRANCH}"):
shutil.move(f"examples/godot_rl_agents_examples-{BRANCH}/{file}", "examples")
os.rmdir(f"examples/godot_rl_agents_examples-{BRANCH}")

View File

@ -0,0 +1,57 @@
import os
from sys import platform
from zipfile import ZipFile
import wget
BASE_URL = "https://downloads.tuxfamily.org/godotengine/"
VERSIONS = {"3": "3.5.1", "4": "4.0"}
MOST_RECENT_VERSION = "rc5"
def get_version():
while True:
version = input("Which Godot version do you want to download (3 or 4)? ")
if version in VERSIONS:
return version
print("Invalid version. Please enter 3 or 4.")
def download_editor():
version = get_version()
VERSION = VERSIONS[version]
NEW_BASE_URL = f"{BASE_URL}{VERSION}/{version}/"
NAME = "stable"
if VERSION == "4.0":
NEW_BASE_URL = f"{BASE_URL}{VERSION}/{MOST_RECENT_VERSION}/"
NAME = MOST_RECENT_VERSION
LINUX_FILENAME = f"Godot_v{VERSION}-{NAME}_linux.x86_64.zip"
if VERSION == "4.0":
MAC_FILENAME = f"Godot_v{VERSION}-{NAME}_macos.universal.zip"
else:
MAC_FILENAME = f"Godot_v{VERSION}-{NAME}_osx.universal.64.zip"
WINDOWS_FILENAME = f"Godot_v{VERSION}-{NAME}_win64.exe.zip"
os.makedirs("editor", exist_ok=True)
FILENAME = ""
if platform == "linux" or platform == "linux2":
FILENAME = LINUX_FILENAME
elif platform == "darwin":
FILENAME = MAC_FILENAME
elif platform == "win32" or platform == "win64":
FILENAME = WINDOWS_FILENAME
else:
raise NotImplementedError
URL = f"{NEW_BASE_URL}{FILENAME}"
print(f"downloading editor {FILENAME} for platform: {platform}")
wget.download(URL, out="")
print()
print("unzipping")
with ZipFile(FILENAME, "r") as zipObj:
# Extract all the contents of zip file in different directory
zipObj.extractall("editor/")
print("cleaning up")
os.remove(FILENAME)

View File

@ -0,0 +1,37 @@
import argparse
import os
from huggingface_hub import Repository
def load_from_hf(dir_path: str, repo_id: str):
temp = repo_id.split("/")
repo_name = temp[1]
local_dir = os.path.join(dir_path, repo_name)
Repository(local_dir, repo_id, repo_type="dataset")
print(f"The repository {repo_id} has been cloned to {local_dir}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-r",
"--hf_repository",
help="Repo id of the dataset / environment repo from the Hugging Face Hub in the form user_name/repo_name",
type=str,
)
parser.add_argument(
"-d",
"--example_dir",
help="Local destination of the repository. Will save repo to examples/repo_name",
type=str,
default="./examples",
)
args = parser.parse_args()
load_from_hf(args.example_dir, args.hf_repository)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,114 @@
"""
This is the main entrypoint to the Godot RL Agents interface
Example usage is best found in the documentation:
https://github.com/edbeeching/godot_rl_agents/blob/main/docs/EXAMPLE_ENVIRONMENTS.md
Hyperparameters and training algorithm can be defined in a .yaml file, see ppo_test.yaml as an example.
Interactive Training:
With the Godot editor open, type gdrl in the terminal to launch training and
then press PLAY in the Godot editor. Training can be stopped with CTRL+C or
by pressing STOP in the editor.
Training with an exported executable:
gdrl --env_path path/to/exported/executable ---config_path path/to/yaml/file
"""
import argparse
try:
from godot_rl.wrappers.ray_wrapper import rllib_training
except ImportError as e:
error_message = str(e)
def rllib_training(args, extras):
print("Import error importing rllib. If you have not installed the package, try: pip install godot-rl[rllib]")
print("Otherwise try fixing the error.", error_message)
try:
from godot_rl.wrappers.stable_baselines_wrapper import stable_baselines_training
except ImportError as e:
error_message = str(e)
def stable_baselines_training(args, extras):
print("Import error importing sb3. If you have not installed the package, try: pip install godot-rl[sb3]")
print("Otherwise try fixing the error.", error_message)
try:
from godot_rl.wrappers.sample_factory_wrapper import sample_factory_enjoy, sample_factory_training
except ImportError as e:
error_message = str(e)
def sample_factory_training(args, extras):
print(
"Import error importing sample-factory If you have not installed the package, try: pip install godot-rl[sf]"
)
print("Otherwise try fixing the error.", error_message)
def get_args():
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument(
"--trainer", default="sb3", choices=["sb3", "sf", "rllib"], type=str, help="framework to use (rllib, sf, sb3)"
)
parser.add_argument("--env_path", default=None, type=str, help="Godot binary to use")
parser.add_argument(
"--config_file", default="ppo_test.yaml", type=str, help="The yaml config file [only for rllib]"
)
parser.add_argument("--restore", default=None, type=str, help="the location of a checkpoint to restore from")
parser.add_argument("--eval", default=False, action="store_true", help="whether to eval the model")
parser.add_argument("--speedup", default=1, type=int, help="whether to speed up the physics in the env")
parser.add_argument("--export", default=False, action="store_true", help="wheter to export the model")
parser.add_argument("--num_gpus", default=None, type=int, help="Number of GPUs to use [only for rllib]")
parser.add_argument(
"--experiment_dir",
default=None,
type=str,
help="The name of the the experiment directory, in which the tensorboard logs are getting stored",
)
parser.add_argument(
"--experiment_name",
default="experiment",
type=str,
help="The name of the the experiment, which will be displayed in tensborboard",
)
parser.add_argument("--viz", default=False, action="store_true", help="Whether to visualize one process")
parser.add_argument("--seed", default=0, type=int, help="seed of the experiment")
args, extras = parser.parse_known_args()
if args.experiment_dir is None:
args.experiment_dir = f"logs/{args.trainer}"
if args.trainer == "sf" and args.env_path is None:
print("WARNING: the sample-factory intergration is not designed to run in interactive mode, export you game")
return args, extras
def main():
args, extras = get_args()
if args.trainer == "rllib":
training_function = rllib_training
elif args.trainer == "sb3":
training_function = stable_baselines_training
elif args.trainer == "sf":
if args.eval:
training_function = sample_factory_enjoy
else:
training_function = sample_factory_training
else:
raise NotImplementedError
training_function(args, extras)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,97 @@
from typing import Any, Optional
import gymnasium as gym
import numpy as np
from numpy import ndarray
from godot_rl.core.godot_env import GodotEnv
from godot_rl.core.utils import lod_to_dol
class CleanRLGodotEnv:
def __init__(self, env_path: Optional[str] = None, n_parallel: int = 1, seed: int = 0, **kwargs: object) -> None:
# If we are doing editor training, n_parallel must be 1
if env_path is None and n_parallel > 1:
raise ValueError("You must provide the path to a exported game executable if n_parallel > 1")
# Define the default port
port = kwargs.pop("port", GodotEnv.DEFAULT_PORT)
# Create a list of GodotEnv instances
self.envs = [
GodotEnv(env_path=env_path, convert_action_space=True, port=port + p, seed=seed + p, **kwargs)
for p in range(n_parallel)
]
# Store the number of parallel environments
self.n_parallel = n_parallel
def _check_valid_action_space(self) -> None:
# Check if the action space is a tuple space with multiple spaces
action_space = self.envs[0].action_space
if isinstance(action_space, gym.spaces.Tuple):
assert (
len(action_space.spaces) == 1
), f"sb3 supports a single action space, this env contains multiple spaces {action_space}"
def step(self, action: np.ndarray) -> tuple[ndarray, list[Any], list[Any], list[Any], list[Any]]:
# Initialize lists for collecting results
all_obs = []
all_rewards = []
all_term = []
all_trunc = []
all_info = []
# Get the number of environments
num_envs = self.envs[0].num_envs
# Send actions to each environment
for i in range(self.n_parallel):
self.envs[i].step_send(action[i * num_envs : (i + 1) * num_envs])
# Receive results from each environment
for i in range(self.n_parallel):
obs, reward, term, trunc, info = self.envs[i].step_recv()
all_obs.extend(obs)
all_rewards.extend(reward)
all_term.extend(term)
all_trunc.extend(trunc)
all_info.extend(info)
# Convert list of dictionaries to dictionary of lists
obs = lod_to_dol(all_obs)
# Return results
return np.stack(obs["obs"]), all_rewards, all_term, all_trunc, all_info
def reset(self, seed) -> tuple[ndarray, list[Any]]:
# Initialize lists for collecting results
all_obs = []
all_info = []
# Reset each environment
for i in range(self.n_parallel):
obs, info = self.envs[i].reset()
all_obs.extend(obs)
all_info.extend(info)
# Convert list of dictionaries to dictionary of lists
obs = lod_to_dol(all_obs)
return np.stack(obs["obs"]), all_info
@property
def single_observation_space(self):
return self.envs[0].observation_space["obs"]
@property
def single_action_space(self):
return self.envs[0].action_space
@property
def num_envs(self) -> int:
return self.envs[0].num_envs * self.n_parallel
def close(self) -> None:
# Close each environment
for env in self.envs:
env.close()

View File

@ -0,0 +1,110 @@
import torch
from gymnasium.vector.utils import spaces
from stable_baselines3 import PPO
class OnnxableMultiInputPolicy(torch.nn.Module):
def __init__(
self,
obs_keys,
features_extractor,
mlp_extractor,
action_net,
value_net,
use_obs_array,
):
super().__init__()
self.obs_keys = obs_keys
self.features_extractor = features_extractor
self.mlp_extractor = mlp_extractor
self.action_net = action_net
self.value_net = value_net
self.use_obs_array = use_obs_array
def forward(self, obs, state_ins):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
features = None
if self.use_obs_array:
features = self.features_extractor(obs)
else:
obs_dict = {k: v for k, v in zip(self.obs_keys, obs)}
features = self.features_extractor(obs_dict)
action_hidden, value_hidden = self.mlp_extractor(features)
return self.action_net(action_hidden), state_ins
def export_ppo_model_as_onnx(ppo: PPO, onnx_model_path: str, use_obs_array: bool = False):
ppo_policy = ppo.policy.to("cpu")
onnxable_model = OnnxableMultiInputPolicy(
["obs"],
ppo_policy.features_extractor,
ppo_policy.mlp_extractor,
ppo_policy.action_net,
ppo_policy.value_net,
use_obs_array,
)
if use_obs_array:
dummy_input = torch.unsqueeze(torch.tensor(ppo.observation_space.sample()), 0)
else:
dummy_input = dict(ppo.observation_space.sample())
for k, v in dummy_input.items():
dummy_input[k] = torch.from_numpy(v).unsqueeze(0)
dummy_input = [v for v in dummy_input.values()]
torch.onnx.export(
onnxable_model,
args=(dummy_input, torch.zeros(1).float()),
f=onnx_model_path,
opset_version=9,
input_names=["obs", "state_ins"],
output_names=["output", "state_outs"],
dynamic_axes={
"obs": {0: "batch_size"},
"state_ins": {0: "batch_size"}, # variable length axes
"output": {0: "batch_size"},
"state_outs": {0: "batch_size"},
},
)
# If the space is MultiDiscrete, we skip verifying as action output will have an expected mismatch
# (the output from onnx will be the action logits for each discrete action,
# while the output from sb3 will be a single int)
if not isinstance(ppo.action_space, spaces.MultiDiscrete):
verify_onnx_export(ppo, onnx_model_path, use_obs_array=use_obs_array)
def verify_onnx_export(ppo: PPO, onnx_model_path: str, num_tests=10, use_obs_array: bool = False):
import numpy as np
import onnx
import onnxruntime as ort
onnx_model = onnx.load(onnx_model_path)
onnx.checker.check_model(onnx_model)
sb3_model = ppo.policy.to("cpu")
ort_sess = ort.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
for i in range(num_tests):
obs = None
obs2 = None
if use_obs_array:
obs = np.expand_dims(ppo.observation_space.sample(), axis=0)
obs2 = torch.tensor(obs)
else:
obs = dict(ppo.observation_space.sample())
obs2 = {}
for k, v in obs.items():
obs2[k] = torch.from_numpy(v).unsqueeze(0)
obs = [v for v in obs.values()]
with torch.no_grad():
action_sb3, _, _ = sb3_model(obs2, deterministic=True)
action_onnx, state_outs = ort_sess.run(None, {"obs": obs, "state_ins": np.array([0.0], dtype=np.float32)})
assert np.allclose(action_sb3, action_onnx, atol=1e-5), "Mismatch in action output"
assert np.allclose(state_outs, np.array([0.0]), atol=1e-5), "Mismatch in state_outs output"

View File

@ -0,0 +1,147 @@
# PettingZoo wrapper for GDRL
# Multi-agent, where 1 agent corresponds to one AIController instance in Godot
# Based on https://pettingzoo.farama.org/content/environment_creation/#example-custom-parallel-environment
# https://github.com/Farama-Foundation/PettingZoo/?tab=License-1-ov-file#readme
# and adjusted to work with GodotRL and Rllib (made for and tested only with rllib for now)
import functools
from typing import Dict
import numpy as np
from pettingzoo import ParallelEnv
from godot_rl.core.godot_env import GodotEnv
def env(render_mode=None):
"""
The env function often wraps the environment in wrappers by default.
You can find full documentation for these methods
elsewhere in the developer documentation.
"""
# Not implemented
return env
class GDRLPettingZooEnv(ParallelEnv):
metadata = {"render_modes": ["human"], "name": "GDRLPettingZooEnv"}
def __init__(self, port=GodotEnv.DEFAULT_PORT, show_window=True, seed=0, config: Dict = {}):
"""
The init method takes in environment arguments and should define the following attributes:
- possible_agents
- render_mode
Note: as of v1.18.1, the action_spaces and observation_spaces attributes are deprecated.
Spaces should be defined in the action_space() and observation_space() methods.
If these methods are not overridden, spaces will be inferred from self.observation_spaces/action_spaces, raising a warning.
These attributes should not be changed after initialization.
"""
# Initialize the Godot Env which we will wrap
self.godot_env = GodotEnv(
env_path=config.get("env_path"),
show_window=config.get("show_window"),
action_repeat=config.get("action_repeat"),
speedup=config.get("speedup"),
convert_action_space=False,
seed=seed,
port=port,
)
self.render_mode = None # Controlled by the env
self.possible_agents = [agent_idx for agent_idx in range(self.godot_env.num_envs)]
self.agents = self.possible_agents[:]
# The policy names here are set on each AIController in Godot editor,
# used to map agents to policies for multi-policy training.
self.agent_policy_names = self.godot_env.agent_policy_names
# optional: a mapping between agent name and ID
self.agent_name_mapping = dict(zip(self.possible_agents, list(range(len(self.possible_agents)))))
self.observation_spaces = {
agent: self.godot_env.observation_spaces[agent_idx] for agent_idx, agent in enumerate(self.agents)
}
self.action_spaces = {
agent: self.godot_env.tuple_action_spaces[agent_idx] for agent_idx, agent in enumerate(self.agents)
}
# Observation space should be defined here.
# lru_cache allows observation and action spaces to be memoized, reducing clock cycles required to get each agent's space.
# If your spaces change over time, remove this line (disable caching).
@functools.lru_cache(maxsize=None)
def observation_space(self, agent):
return self.observation_spaces[agent]
# Action space should be defined here.
# If your spaces change over time, remove this line (disable caching).
@functools.lru_cache(maxsize=None)
def action_space(self, agent):
return self.action_spaces[agent]
def render(self):
"""
Renders the environment. In human mode, it can print to terminal, open
up a graphical window, or open up some other display that a human can see and understand.
"""
# Not implemented
def close(self):
"""
Close should release any graphical displays, subprocesses, network connections
or any other environment data which should not be kept around after the
user is no longer using the environment.
"""
self.godot_env.close()
def reset(self, seed=None, options=None):
"""
Reset needs to initialize the `agents` attribute and must set up the
environment so that render(), and step() can be called without issues.
Returns the observations for each agent
"""
godot_obs, godot_infos = self.godot_env.reset()
observations = {agent: godot_obs[agent_idx] for agent_idx, agent in enumerate(self.agents)}
infos = {agent: godot_infos[agent_idx] for agent_idx, agent in enumerate(self.agents)}
return observations, infos
def step(self, actions):
"""
step(action) takes in an action for each agent and should return the
- observations
- rewards
- terminations
- truncations
- infos
dicts where each dict looks like {agent_1: item_1, agent_2: item_2}
"""
# Once an agent (AIController) has done = true, it will not receive any more actions until all agents in the
# Godot env have done = true. For agents that received no actions, we will set zeros instead for
# compatibility.
godot_actions = [
actions[agent] if agent in actions else np.zeros_like(self.action_spaces[agent_idx].sample())
for agent_idx, agent in enumerate(self.agents)
]
godot_obs, godot_rewards, godot_dones, godot_truncations, godot_infos = self.godot_env.step(
godot_actions, order_ij=True
)
observations = {agent: godot_obs[agent] for agent in actions}
rewards = {agent: godot_rewards[agent] for agent in actions}
terminations = {agent: godot_dones[agent] for agent in actions}
# Truncations are not yet implemented in GDRL API
truncations = {agent: False for agent in actions}
# typically there won't be any information in the infos, but there must
# still be an entry for each agent
infos = {agent: godot_infos[agent] for agent in actions}
return observations, rewards, terminations, truncations, infos

View File

@ -0,0 +1,139 @@
import os
import pathlib
from typing import List, Optional, Tuple
import numpy as np
import ray
import yaml
from ray import tune
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.utils.typing import EnvActionType, EnvInfoDict, EnvObsType
from godot_rl.core.godot_env import GodotEnv
class RayVectorGodotEnv(VectorEnv):
def __init__(
self,
port=10008,
seed=0,
config=None,
) -> None:
self._env = GodotEnv(
env_path=config["env_path"],
port=port,
seed=seed,
show_window=config["show_window"],
action_repeat=config["action_repeat"],
speedup=config["speedup"],
)
super().__init__(
observation_space=self._env.observation_space,
action_space=self._env.action_space,
num_envs=self._env.num_envs,
)
def vector_reset(
self, *, seeds: Optional[List[int]] = None, options: Optional[List[dict]] = None
) -> List[EnvObsType]:
self.obs, info = self._env.reset()
return self.obs, info
def vector_step(
self, actions: List[EnvActionType]
) -> Tuple[List[EnvObsType], List[float], List[bool], List[EnvInfoDict]]:
actions = np.array(actions, dtype=np.dtype(object))
self.obs, reward, term, trunc, info = self._env.step(actions, order_ij=True)
return self.obs, reward, term, trunc, info
def get_unwrapped(self):
return [self._env]
def reset_at(
self,
index: Optional[int] = None,
*,
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> EnvObsType:
# the env is reset automatically, no need to reset it
return self.obs[index], {}
def register_env():
tune.register_env(
"godot",
lambda c: RayVectorGodotEnv(
config=c,
port=c.worker_index + GodotEnv.DEFAULT_PORT + 10,
seed=c.worker_index + c["seed"],
),
)
# Refactored section: Commented onnx section was removed as it was re-implemented in rllib_example.py
def rllib_training(args, extras):
with open(args.config_file) as f:
exp = yaml.safe_load(f)
register_env()
exp["config"]["env_config"]["env_path"] = args.env_path
exp["config"]["env_config"]["seed"] = args.seed
if args.env_path is not None:
run_name = exp["algorithm"] + "/" + pathlib.Path(args.env_path).stem
else:
run_name = exp["algorithm"] + "/editor"
print("run_name", run_name)
if args.num_gpus is not None:
exp["config"]["num_gpus"] = args.num_gpus
if args.env_path is None:
print("SETTING WORKERS TO 1")
exp["config"]["num_workers"] = 1
checkpoint_freq = 10
exp["config"]["env_config"]["show_window"] = args.viz
exp["config"]["env_config"]["speedup"] = args.speedup
if args.eval or args.export:
checkpoint_freq = 0
exp["config"]["env_config"]["show_window"] = True
exp["config"]["env_config"]["framerate"] = None
exp["config"]["lr"] = 0.0
exp["config"]["num_sgd_iter"] = 1
exp["config"]["num_workers"] = 1
exp["config"]["train_batch_size"] = 8192
exp["config"]["sgd_minibatch_size"] = 128
exp["config"]["explore"] = False
exp["stop"]["training_iteration"] = 999999
print(exp)
ray.init(num_gpus=exp["config"]["num_gpus"] or 1)
if not args.export:
tune.run(
exp["algorithm"],
name=run_name,
config=exp["config"],
stop=exp["stop"],
verbose=3,
checkpoint_freq=checkpoint_freq,
checkpoint_at_end=not args.eval,
restore=args.restore,
storage_path=os.path.abspath(args.experiment_dir) or os.path.abspath("logs/rllib"),
trial_name_creator=lambda trial: (
f"{args.experiment_name}" if args.experiment_name else f"{trial.trainable_name}_{trial.trial_id}"
),
)
if args.export:
raise NotImplementedError("Use examples/rllib_example.py to export to onnx.")
# rllib_export(args.restore)
ray.shutdown()

View File

@ -0,0 +1,192 @@
import argparse
from functools import partial
import numpy as np
from gymnasium import Env
from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args
from sample_factory.enjoy import enjoy
from sample_factory.envs.env_utils import register_env
from sample_factory.train import run_rl
from godot_rl.core.godot_env import GodotEnv
from godot_rl.core.utils import lod_to_dol
class SampleFactoryEnvWrapperBatched(GodotEnv, Env):
@property
def unwrapped(self):
return self
@property
def num_agents(self):
return self.num_envs
def reset(self, seed=None, options=None):
obs, info = super().reset(seed=seed)
obs = lod_to_dol(obs)
return {k: np.array(v) for k, v in obs.items()}, info
def step(self, action):
obs, reward, term, trunc, info = super().step(action, order_ij=False)
obs = lod_to_dol(obs)
return {k: np.array(v) for k, v in obs.items()}, np.array(reward), np.array(term), np.array(trunc) * 0, info
@staticmethod
def to_numpy(lod):
for d in lod:
for k, v in d.items():
d[k] = np.array(v)
return lod
def render():
return
class SampleFactoryEnvWrapperNonBatched(GodotEnv, Env):
@property
def unwrapped(self):
return self
@property
def num_agents(self):
return self.num_envs
def reset(self, seed=None, options=None):
obs, info = super().reset(seed=seed)
return self.to_numpy(obs), info
def step(self, action):
obs, reward, term, trunc, info = super().step(action, order_ij=True)
return self.to_numpy(obs), np.array(reward), np.array(term), np.array(trunc) * 0, info
@staticmethod
def to_numpy(lod):
for d in lod:
for k, v in d.items():
d[k] = np.array(v)
return lod
def render():
return
def make_godot_env_func(
env_path, full_env_name, cfg=None, env_config=None, render_mode=None, seed=0, speedup=1, viz=False
):
port = cfg.base_port
print("BASE PORT ", cfg.base_port)
show_window = False
_seed = seed
if env_config:
port += 1 + env_config.env_id
_seed += 1 + env_config.env_id
print("env id", env_config.env_id)
if viz: #
print("creating viz env")
show_window = env_config.env_id == 0
if cfg.batched_sampling:
env = SampleFactoryEnvWrapperBatched(
env_path=env_path, port=port, seed=_seed, show_window=show_window, speedup=speedup
)
else:
env = SampleFactoryEnvWrapperNonBatched(
env_path=env_path, port=port, seed=_seed, show_window=show_window, speedup=speedup
)
return env
def register_gdrl_env(args):
make_env = partial(make_godot_env_func, args.env_path, speedup=args.speedup, seed=args.seed, viz=args.viz)
register_env("gdrl", make_env)
def gdrl_override_defaults(_env, parser):
"""RL params specific to Atari envs."""
parser.set_defaults(
# let's set this to True by default so it's consistent with how we report results for other envs
# (i.e. VizDoom or DMLab). When running evaluations for reports or to compare with other frameworks we can
# set this to false in command line
summaries_use_frameskip=True,
use_record_episode_statistics=True,
gamma=0.99,
env_frameskip=1,
env_framestack=4,
num_workers=1,
num_envs_per_worker=2,
worker_num_splits=2,
env_agents=16,
train_for_env_steps=1000000,
nonlinearity="relu",
kl_loss_coeff=0.0,
use_rnn=False,
adaptive_stddev=True,
reward_scale=1.0,
with_vtrace=False,
recurrence=1,
batch_size=2048,
rollout=32,
max_grad_norm=0.5,
num_epochs=2,
num_batches_per_epoch=4,
ppo_clip_ratio=0.2,
value_loss_coeff=0.5,
exploration_loss="entropy",
exploration_loss_coeff=0.000,
learning_rate=0.00025,
lr_schedule="linear_decay",
shuffle_minibatches=False,
gae_lambda=0.95,
batched_sampling=False,
normalize_input=True,
normalize_returns=True,
serial_mode=False,
async_rl=True,
experiment_summaries_interval=3,
adam_eps=1e-5,
)
def add_gdrl_env_args(_env, p: argparse.ArgumentParser, evaluation=False):
if evaluation:
# apparently env.render(mode="human") is not supported anymore and we need to specify the render mode in
# the env actor
p.add_argument("--render_mode", default="human", type=str, help="")
p.add_argument("--base_port", default=GodotEnv.DEFAULT_PORT, type=int, help="")
p.add_argument(
"--env_agents",
default=2,
type=int,
help="Num agents in each envpool (if used)",
)
def parse_gdrl_args(args, argv=None, evaluation=False):
parser, partial_cfg = parse_sf_args(argv=argv, evaluation=evaluation)
add_gdrl_env_args(partial_cfg.env, parser, evaluation=evaluation)
gdrl_override_defaults(partial_cfg.env, parser)
final_cfg = parse_full_cfg(parser, argv)
final_cfg.train_dir = args.experiment_dir or "logs/sf"
final_cfg.experiment = args.experiment_name or final_cfg.experiment
return final_cfg
def sample_factory_training(args, extras):
register_gdrl_env(args)
cfg = parse_gdrl_args(args=args, argv=extras, evaluation=args.eval)
# cfg.base_port = random.randint(20000, 22000)
status = run_rl(cfg)
return status
def sample_factory_enjoy(args, extras):
register_gdrl_env(args)
cfg = parse_gdrl_args(args=args, argv=extras, evaluation=args.eval)
status = enjoy(cfg)
return status

View File

@ -0,0 +1,33 @@
from typing import Any, Dict, List, Tuple
import gymnasium as gym
import numpy as np
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
# A variant of the Stable Baselines Godot Env that only supports a single obs space from the dictionary - obs["obs"] by default.
# This provides some basic support for using envs that have a single obs space with policies other than MultiInputPolicy.
class SBGSingleObsEnv(StableBaselinesGodotEnv):
def __init__(self, obs_key="obs", *args, **kwargs) -> None:
self.obs_key = obs_key
super().__init__(*args, **kwargs)
def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[Dict[str, Any]]]:
obs, rewards, term, info = super().step(action)
# Terminal obs info is needed for imitation learning
for idx, done in enumerate(term):
if done:
info[idx]["terminal_observation"] = obs[self.obs_key][idx]
return obs[self.obs_key], rewards, term, info
def reset(self) -> np.ndarray:
obs = super().reset()
return obs[self.obs_key]
@property
def observation_space(self) -> gym.Space:
return self.envs[0].observation_space[self.obs_key]

View File

@ -0,0 +1,176 @@
from typing import Any, Dict, List, Optional, Tuple
import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
from godot_rl.core.godot_env import GodotEnv
from godot_rl.core.utils import can_import, lod_to_dol
class StableBaselinesGodotEnv(VecEnv):
def __init__(
self,
env_path: Optional[str] = None,
n_parallel: int = 1,
seed: int = 0,
**kwargs,
) -> None:
# If we are doing editor training, n_parallel must be 1
if env_path is None and n_parallel > 1:
raise ValueError("You must provide the path to a exported game executable if n_parallel > 1")
# Define the default port
port = kwargs.pop("port", GodotEnv.DEFAULT_PORT)
# Create a list of GodotEnv instances
self.envs = [
GodotEnv(
env_path=env_path,
convert_action_space=True,
port=port + p,
seed=seed + p,
**kwargs,
)
for p in range(n_parallel)
]
# Store the number of parallel environments
self.n_parallel = n_parallel
# Check the action space for validity
self._check_valid_action_space()
# Initialize the results holder
self.results = None
def _check_valid_action_space(self) -> None:
# Check if the action space is a tuple space with multiple spaces
action_space = self.envs[0].action_space
if isinstance(action_space, gym.spaces.Tuple):
assert (
len(action_space.spaces) == 1
), f"sb3 supports a single action space, this env contains multiple spaces {action_space}"
def step(self, action: np.ndarray) -> Tuple[Dict[str, np.ndarray], np.ndarray, np.ndarray, List[Dict[str, Any]]]:
# Initialize lists for collecting results
all_obs = []
all_rewards = []
all_term = []
all_trunc = []
all_info = []
# Get the number of environments
num_envs = self.envs[0].num_envs
# Send actions to each environment
for i in range(self.n_parallel):
self.envs[i].step_send(action[i * num_envs : (i + 1) * num_envs])
# Receive results from each environment
for i in range(self.n_parallel):
obs, reward, term, trunc, info = self.envs[i].step_recv()
all_obs.extend(obs)
all_rewards.extend(reward)
all_term.extend(term)
all_trunc.extend(trunc)
all_info.extend(info)
# Convert list of dictionaries to dictionary of lists
obs = lod_to_dol(all_obs)
# Return results
return (
{k: np.array(v) for k, v in obs.items()},
np.array(all_rewards, dtype=np.float32),
np.array(all_term),
all_info,
)
def reset(self) -> Dict[str, np.ndarray]:
# Initialize lists for collecting results
all_obs = []
all_info = []
# Reset each environment
for i in range(self.n_parallel):
obs, info = self.envs[i].reset()
all_obs.extend(obs)
all_info.extend(info)
# Convert list of dictionaries to dictionary of lists
obs = lod_to_dol(all_obs)
return {k: np.array(v) for k, v in obs.items()}
def close(self) -> None:
# Close each environment
for env in self.envs:
env.close()
@property
def observation_space(self) -> gym.Space:
return self.envs[0].observation_space
@property
def action_space(self) -> gym.Space:
# sb3 is not compatible with tuple/dict action spaces
return self.envs[0].action_space
@property
def num_envs(self) -> int:
return self.envs[0].num_envs * self.n_parallel
def env_is_wrapped(self, wrapper_class: type, indices: Optional[List[int]] = None) -> List[bool]:
# Return a list indicating that no environments are wrapped
return [False] * (self.envs[0].num_envs * self.n_parallel)
# Placeholder methods that should be implemented for a full VecEnv implementation
def env_method(self):
raise NotImplementedError()
def get_attr(self, attr_name: str, indices=None) -> List[Any]:
if attr_name == "render_mode":
return [None for _ in range(self.num_envs)]
raise AttributeError("get attr not fully implemented in godot-rl StableBaselinesWrapper")
def seed(self, seed=None):
raise NotImplementedError()
def set_attr(self):
raise NotImplementedError()
def step_async(self, actions: np.ndarray) -> None:
# Execute the step function asynchronously, not actually implemented in this setting
self.results = self.step(actions)
def step_wait(
self,
) -> Tuple[Dict[str, np.ndarray], np.ndarray, np.ndarray, List[Dict[str, Any]]]:
# Wait for the results from the asynchronous step
return self.results
def stable_baselines_training(args, extras, n_steps: int = 200000, **kwargs) -> None:
if can_import("ray"):
print("WARNING, stable baselines and ray[rllib] are not compatible")
# Initialize the custom environment
env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup, **kwargs)
env = VecMonitor(env)
# Initialize the PPO model
model = PPO(
"MultiInputPolicy",
env,
ent_coef=0.0001,
verbose=2,
n_steps=32,
tensorboard_log=args.experiment_dir or "logs/sb3",
)
# Train the model
model.learn(n_steps, tb_log_name=args.experiment_name)
print("closing env")
env.close()