Files
2024-10-30 22:14:35 +01:00

140 lines
4.1 KiB
Python

import os
import pathlib
from typing import List, Optional, Tuple
import numpy as np
import ray
import yaml
from ray import tune
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.utils.typing import EnvActionType, EnvInfoDict, EnvObsType
from godot_rl.core.godot_env import GodotEnv
class RayVectorGodotEnv(VectorEnv):
def __init__(
self,
port=10008,
seed=0,
config=None,
) -> None:
self._env = GodotEnv(
env_path=config["env_path"],
port=port,
seed=seed,
show_window=config["show_window"],
action_repeat=config["action_repeat"],
speedup=config["speedup"],
)
super().__init__(
observation_space=self._env.observation_space,
action_space=self._env.action_space,
num_envs=self._env.num_envs,
)
def vector_reset(
self, *, seeds: Optional[List[int]] = None, options: Optional[List[dict]] = None
) -> List[EnvObsType]:
self.obs, info = self._env.reset()
return self.obs, info
def vector_step(
self, actions: List[EnvActionType]
) -> Tuple[List[EnvObsType], List[float], List[bool], List[EnvInfoDict]]:
actions = np.array(actions, dtype=np.dtype(object))
self.obs, reward, term, trunc, info = self._env.step(actions, order_ij=True)
return self.obs, reward, term, trunc, info
def get_unwrapped(self):
return [self._env]
def reset_at(
self,
index: Optional[int] = None,
*,
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> EnvObsType:
# the env is reset automatically, no need to reset it
return self.obs[index], {}
def register_env():
tune.register_env(
"godot",
lambda c: RayVectorGodotEnv(
config=c,
port=c.worker_index + GodotEnv.DEFAULT_PORT + 10,
seed=c.worker_index + c["seed"],
),
)
# Refactored section: Commented onnx section was removed as it was re-implemented in rllib_example.py
def rllib_training(args, extras):
with open(args.config_file) as f:
exp = yaml.safe_load(f)
register_env()
exp["config"]["env_config"]["env_path"] = args.env_path
exp["config"]["env_config"]["seed"] = args.seed
if args.env_path is not None:
run_name = exp["algorithm"] + "/" + pathlib.Path(args.env_path).stem
else:
run_name = exp["algorithm"] + "/editor"
print("run_name", run_name)
if args.num_gpus is not None:
exp["config"]["num_gpus"] = args.num_gpus
if args.env_path is None:
print("SETTING WORKERS TO 1")
exp["config"]["num_workers"] = 1
checkpoint_freq = 10
exp["config"]["env_config"]["show_window"] = args.viz
exp["config"]["env_config"]["speedup"] = args.speedup
if args.eval or args.export:
checkpoint_freq = 0
exp["config"]["env_config"]["show_window"] = True
exp["config"]["env_config"]["framerate"] = None
exp["config"]["lr"] = 0.0
exp["config"]["num_sgd_iter"] = 1
exp["config"]["num_workers"] = 1
exp["config"]["train_batch_size"] = 8192
exp["config"]["sgd_minibatch_size"] = 128
exp["config"]["explore"] = False
exp["stop"]["training_iteration"] = 999999
print(exp)
ray.init(num_gpus=exp["config"]["num_gpus"] or 1)
if not args.export:
tune.run(
exp["algorithm"],
name=run_name,
config=exp["config"],
stop=exp["stop"],
verbose=3,
checkpoint_freq=checkpoint_freq,
checkpoint_at_end=not args.eval,
restore=args.restore,
storage_path=os.path.abspath(args.experiment_dir) or os.path.abspath("logs/rllib"),
trial_name_creator=lambda trial: (
f"{args.experiment_name}" if args.experiment_name else f"{trial.trainable_name}_{trial.trial_id}"
),
)
if args.export:
raise NotImplementedError("Use examples/rllib_example.py to export to onnx.")
# rllib_export(args.restore)
ray.shutdown()