Skip to content

Commit bd555a0

Browse files
authored
[RLlib] New ConnectorV2 API #3: Introduce actual ConnectorV2 API. (#41074) (#41212)
1 parent e27ffa0 commit bd555a0

36 files changed

+1911
-71
lines changed

rllib/BUILD

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,7 @@ py_test(
747747

748748

749749
# --------------------------------------------------------------------
750-
# Connector tests
750+
# Connector(V1) tests
751751
# rllib/connector/
752752
#
753753
# Tag: connector
@@ -774,6 +774,21 @@ py_test(
774774
srcs = ["connectors/tests/test_agent.py"]
775775
)
776776

777+
# --------------------------------------------------------------------
778+
# ConnectorV2 tests
779+
# rllib/connector/
780+
#
781+
# Tag: connector_v2
782+
# --------------------------------------------------------------------
783+
784+
# TODO (sven): Add these tests in a separate PR.
785+
# py_test(
786+
# name = "connectors/tests/test_connector_v2",
787+
# tags = ["team:rllib", "connector_v2"],
788+
# size = "small",
789+
# srcs = ["connectors/tests/test_connector_v2.py"]
790+
# )
791+
777792
# --------------------------------------------------------------------
778793
# Env tests
779794
# rllib/env/

rllib/algorithms/algorithm.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,11 @@ def setup(self, config: AlgorithmConfig) -> None:
564564
config_obj.env = self._env_id
565565
self.config = config_obj
566566

567+
self._uses_new_env_runners = (
568+
self.config.env_runner_cls is not None
569+
and not issubclass(self.config.env_runner_cls, RolloutWorker)
570+
)
571+
567572
# Set Algorithm's seed after we have - if necessary - enabled
568573
# tf eager-execution.
569574
update_global_seed_if_necessary(self.config.framework_str, self.config.seed)
@@ -751,13 +756,12 @@ def setup(self, config: AlgorithmConfig) -> None:
751756
)
752757

753758
# Only when using RolloutWorkers: Update also the worker set's
754-
# `should_module_be_updated_fn` (analogous to is_policy_to_train).
759+
# `is_policy_to_train` (analogous to LearnerGroup's
760+
# `should_module_be_updated_fn`).
755761
# Note that with the new EnvRunner API in combination with the new stack,
756762
# this information only needs to be kept in the LearnerGroup and not on the
757763
# EnvRunners anymore.
758-
if self.config.env_runner_cls is None or issubclass(
759-
self.config.env_runner_cls, RolloutWorker
760-
):
764+
if not self._uses_new_env_runners:
761765
update_fn = self.learner_group.should_module_be_updated_fn
762766
self.workers.foreach_worker(
763767
lambda w: w.set_is_policy_to_train(update_fn),
@@ -3030,11 +3034,7 @@ def _run_one_evaluation(
30303034
"""
30313035
eval_func_to_use = (
30323036
self._evaluate_async_with_env_runner
3033-
if (
3034-
self.config.enable_async_evaluation
3035-
and self.config.env_runner_cls is not None
3036-
and not issubclass(self.config.env_runner_cls, RolloutWorker)
3037-
)
3037+
if (self.config.enable_async_evaluation and self._uses_new_env_runners)
30383038
else self._evaluate_async
30393039
if self.config.enable_async_evaluation
30403040
else self.evaluate

rllib/algorithms/algorithm_config.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,10 @@
9999

100100
if TYPE_CHECKING:
101101
from ray.rllib.algorithms.algorithm import Algorithm
102+
from ray.rllib.connectors.connector_v2 import ConnectorV2
102103
from ray.rllib.core.learner import Learner
103104
from ray.rllib.core.learner.learner_group import LearnerGroup
105+
from ray.rllib.core.rl_module.rl_module import RLModule
104106
from ray.rllib.evaluation.episode import Episode as OldEpisode
105107

106108
logger = logging.getLogger(__name__)
@@ -327,6 +329,8 @@ def __init__(self, algo_class=None):
327329
self.num_envs_per_worker = 1
328330
self.create_env_on_local_worker = False
329331
self.enable_connectors = True
332+
self._env_to_module_connector = None
333+
self._module_to_env_connector = None
330334
# TODO (sven): Rename into `sample_timesteps` (or `sample_duration`
331335
# and `sample_duration_unit` (replacing batch_mode), like we do it
332336
# in the evaluation config).
@@ -374,6 +378,7 @@ def __init__(self, algo_class=None):
374378
except AttributeError:
375379
pass
376380

381+
self._learner_connector = None
377382
self.optimizer = {}
378383
self.max_requests_in_flight_per_sampler_worker = 2
379384
self._learner_class = None
@@ -1152,6 +1157,121 @@ class directly. Note that this arg can also be specified via
11521157
logger_creator=self.logger_creator,
11531158
)
11541159

1160+
def build_env_to_module_connector(self, env):
1161+
from ray.rllib.connectors.env_to_module import (
1162+
EnvToModulePipeline,
1163+
DefaultEnvToModule,
1164+
)
1165+
1166+
custom_connectors = []
1167+
# Create an env-to-module connector pipeline (including RLlib's default
1168+
# env->module connector piece) and return it.
1169+
if self._env_to_module_connector is not None:
1170+
val_ = self._env_to_module_connector(env)
1171+
1172+
from ray.rllib.connectors.connector_v2 import ConnectorV2
1173+
1174+
if isinstance(val_, ConnectorV2) and not isinstance(
1175+
val_, EnvToModulePipeline
1176+
):
1177+
custom_connectors = [val_]
1178+
elif isinstance(val_, (list, tuple)):
1179+
custom_connectors = list(val_)
1180+
else:
1181+
return val_
1182+
1183+
pipeline = EnvToModulePipeline(
1184+
connectors=custom_connectors,
1185+
input_observation_space=env.single_observation_space,
1186+
input_action_space=env.single_action_space,
1187+
env=env,
1188+
)
1189+
pipeline.append(
1190+
DefaultEnvToModule(
1191+
input_observation_space=pipeline.observation_space,
1192+
input_action_space=pipeline.action_space,
1193+
env=env,
1194+
)
1195+
)
1196+
return pipeline
1197+
1198+
def build_module_to_env_connector(self, env):
1199+
1200+
from ray.rllib.connectors.module_to_env import (
1201+
DefaultModuleToEnv,
1202+
ModuleToEnvPipeline,
1203+
)
1204+
1205+
custom_connectors = []
1206+
# Create a module-to-env connector pipeline (including RLlib's default
1207+
# module->env connector piece) and return it.
1208+
if self._module_to_env_connector is not None:
1209+
val_ = self._module_to_env_connector(env)
1210+
1211+
from ray.rllib.connectors.connector_v2 import ConnectorV2
1212+
1213+
if isinstance(val_, ConnectorV2) and not isinstance(
1214+
val_, ModuleToEnvPipeline
1215+
):
1216+
custom_connectors = [val_]
1217+
elif isinstance(val_, (list, tuple)):
1218+
custom_connectors = list(val_)
1219+
else:
1220+
return val_
1221+
1222+
pipeline = ModuleToEnvPipeline(
1223+
connectors=custom_connectors,
1224+
input_observation_space=env.single_observation_space,
1225+
input_action_space=env.single_action_space,
1226+
env=env,
1227+
)
1228+
pipeline.append(
1229+
DefaultModuleToEnv(
1230+
input_observation_space=pipeline.observation_space,
1231+
input_action_space=pipeline.action_space,
1232+
env=env,
1233+
normalize_actions=self.normalize_actions,
1234+
clip_actions=self.clip_actions,
1235+
)
1236+
)
1237+
return pipeline
1238+
1239+
def build_learner_connector(self, input_observation_space, input_action_space):
1240+
from ray.rllib.connectors.learner import (
1241+
DefaultLearnerConnector,
1242+
LearnerConnectorPipeline,
1243+
)
1244+
1245+
custom_connectors = []
1246+
# Create a learner connector pipeline (including RLlib's default
1247+
# learner connector piece) and return it.
1248+
if self._learner_connector is not None:
1249+
val_ = self._learner_connector(input_observation_space, input_action_space)
1250+
1251+
from ray.rllib.connectors.connector_v2 import ConnectorV2
1252+
1253+
if isinstance(val_, ConnectorV2) and not isinstance(
1254+
val_, LearnerConnectorPipeline
1255+
):
1256+
custom_connectors = [val_]
1257+
elif isinstance(val_, (list, tuple)):
1258+
custom_connectors = list(val_)
1259+
else:
1260+
return val_
1261+
1262+
pipeline = LearnerConnectorPipeline(
1263+
connectors=custom_connectors,
1264+
input_observation_space=input_observation_space,
1265+
input_action_space=input_action_space,
1266+
)
1267+
pipeline.append(
1268+
DefaultLearnerConnector(
1269+
input_observation_space=pipeline.observation_space,
1270+
input_action_space=pipeline.action_space,
1271+
)
1272+
)
1273+
return pipeline
1274+
11551275
def build_learner_group(
11561276
self,
11571277
*,
@@ -1605,6 +1725,12 @@ def rollouts(
16051725
create_env_on_local_worker: Optional[bool] = NotProvided,
16061726
sample_collector: Optional[Type[SampleCollector]] = NotProvided,
16071727
enable_connectors: Optional[bool] = NotProvided,
1728+
env_to_module_connector: Optional[
1729+
Callable[[EnvType], "ConnectorV2"]
1730+
] = NotProvided,
1731+
module_to_env_connector: Optional[
1732+
Callable[[EnvType, "RLModule"], "ConnectorV2"]
1733+
] = NotProvided,
16081734
use_worker_filter_stats: Optional[bool] = NotProvided,
16091735
update_worker_filter_stats: Optional[bool] = NotProvided,
16101736
rollout_fragment_length: Optional[Union[int, str]] = NotProvided,
@@ -1650,6 +1776,11 @@ def rollouts(
16501776
enable_connectors: Use connector based environment runner, so that all
16511777
preprocessing of obs and postprocessing of actions are done in agent
16521778
and action connectors.
1779+
env_to_module_connector: A callable taking an Env as input arg and returning
1780+
an env-to-module ConnectorV2 (might be a pipeline) object.
1781+
module_to_env_connector: A callable taking an Env and an RLModule as input
1782+
args and returning a module-to-env ConnectorV2 (might be a pipeline)
1783+
object.
16531784
use_worker_filter_stats: Whether to use the workers in the WorkerSet to
16541785
update the central filters (held by the local worker). If False, stats
16551786
from the workers will not be used and discarded.
@@ -1737,6 +1868,10 @@ def rollouts(
17371868
self.create_env_on_local_worker = create_env_on_local_worker
17381869
if enable_connectors is not NotProvided:
17391870
self.enable_connectors = enable_connectors
1871+
if env_to_module_connector is not NotProvided:
1872+
self._env_to_module_connector = env_to_module_connector
1873+
if module_to_env_connector is not NotProvided:
1874+
self._module_to_env_connector = module_to_env_connector
17401875
if use_worker_filter_stats is not NotProvided:
17411876
self.use_worker_filter_stats = use_worker_filter_stats
17421877
if update_worker_filter_stats is not NotProvided:
@@ -1855,6 +1990,9 @@ def training(
18551990
optimizer: Optional[dict] = NotProvided,
18561991
max_requests_in_flight_per_sampler_worker: Optional[int] = NotProvided,
18571992
learner_class: Optional[Type["Learner"]] = NotProvided,
1993+
learner_connector: Optional[
1994+
Callable[["RLModule"], "ConnectorV2"]
1995+
] = NotProvided,
18581996
# Deprecated arg.
18591997
_enable_learner_api: Optional[bool] = NotProvided,
18601998
) -> "AlgorithmConfig":
@@ -1916,6 +2054,9 @@ def training(
19162054
in your experiment of timesteps.
19172055
learner_class: The `Learner` class to use for (distributed) updating of the
19182056
RLModule. Only used when `_enable_new_api_stack=True`.
2057+
learner_connector: A callable taking an env observation space and an env
2058+
action space as inputs and returning a learner ConnectorV2 (might be
2059+
a pipeline) object.
19192060
19202061
Returns:
19212062
This updated AlgorithmConfig object.
@@ -1960,6 +2101,8 @@ def training(
19602101
)
19612102
if learner_class is not NotProvided:
19622103
self._learner_class = learner_class
2104+
if learner_connector is not NotProvided:
2105+
self._learner_connector = learner_connector
19632106

19642107
return self
19652108

rllib/algorithms/impala/impala.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,17 @@ class ImpalaConfig(AlgorithmConfig):
8686
8787
# Update the config object.
8888
config = config.training(
89-
lr=tune.grid_search([0.0001, ]), grad_clip=20.0
89+
lr=tune.grid_search([0.0001, 0.0002]), grad_clip=20.0
9090
)
9191
config = config.resources(num_gpus=0)
9292
config = config.rollouts(num_rollout_workers=1)
9393
# Set the config object's env.
9494
config = config.environment(env="CartPole-v1")
95-
# Use to_dict() to get the old-style python config dict
96-
# when running with tune.
95+
# Run with tune.
9796
tune.Tuner(
9897
"IMPALA",
98+
param_space=config,
9999
run_config=air.RunConfig(stop={"training_iteration": 1}),
100-
param_space=config.to_dict(),
101100
).fit()
102101
103102
.. testoutput::

rllib/algorithms/pg/pg.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,11 @@ class PGConfig(AlgorithmConfig):
3030
>>> config = config.training(lr=tune.grid_search([0.001, 0.0001]))
3131
>>> # Set the config object's env.
3232
>>> config = config.environment(env="CartPole-v1")
33-
>>> # Use to_dict() to get the old-style python config dict
34-
>>> # when running with tune.
33+
>>> # Run with tune.
3534
>>> tune.Tuner( # doctest: +SKIP
3635
... "PG",
3736
... run_config=air.RunConfig(stop={"episode_reward_mean": 200}),
38-
... param_space=config.to_dict(),
37+
... param_space=config,
3938
... ).fit()
4039
"""
4140

rllib/algorithms/ppo/ppo.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,10 @@ def training(
253253
# Pass kwargs onto super's `training()` method.
254254
super().training(**kwargs)
255255

256-
# TODO (sven): Move to generic AlgorithmConfig.
257-
if lr_schedule is not NotProvided:
258-
self.lr_schedule = lr_schedule
259256
if use_critic is not NotProvided:
260257
self.use_critic = use_critic
261-
# TODO (Kourosh) This is experimental. Set learner_hps parameters as
262-
# well. Don't forget to remove .use_critic from algorithm config.
258+
# TODO (Kourosh) This is experimental.
259+
# Don't forget to remove .use_critic from algorithm config.
263260
if use_gae is not NotProvided:
264261
self.use_gae = use_gae
265262
if lambda_ is not NotProvided:
@@ -280,15 +277,19 @@ def training(
280277
self.vf_loss_coeff = vf_loss_coeff
281278
if entropy_coeff is not NotProvided:
282279
self.entropy_coeff = entropy_coeff
283-
if entropy_coeff_schedule is not NotProvided:
284-
self.entropy_coeff_schedule = entropy_coeff_schedule
285280
if clip_param is not NotProvided:
286281
self.clip_param = clip_param
287282
if vf_clip_param is not NotProvided:
288283
self.vf_clip_param = vf_clip_param
289284
if grad_clip is not NotProvided:
290285
self.grad_clip = grad_clip
291286

287+
# TODO (sven): Remove these once new API stack is only option for PPO.
288+
if lr_schedule is not NotProvided:
289+
self.lr_schedule = lr_schedule
290+
if entropy_coeff_schedule is not NotProvided:
291+
self.entropy_coeff_schedule = entropy_coeff_schedule
292+
292293
return self
293294

294295
@override(AlgorithmConfig)
@@ -312,8 +313,8 @@ def validate(self) -> None:
312313
raise ValueError(
313314
f"`sgd_minibatch_size` ({self.sgd_minibatch_size}) must be <= "
314315
f"`train_batch_size` ({self.train_batch_size}). In PPO, the train batch"
315-
f" is be split into {self.sgd_minibatch_size} chunks, each of which is "
316-
f"iterated over (used for updating the policy) {self.num_sgd_iter} "
316+
f" will be split into {self.sgd_minibatch_size} chunks, each of which "
317+
f"is iterated over (used for updating the policy) {self.num_sgd_iter} "
317318
"times."
318319
)
319320

@@ -476,7 +477,6 @@ def training_step(self) -> ResultDict:
476477
self.workers.local_worker().set_weights(weights)
477478

478479
if self.config._enable_new_api_stack:
479-
480480
kl_dict = {}
481481
if self.config.use_kl_loss:
482482
for pid in policies_to_update:

0 commit comments

Comments
 (0)