|
99 | 99 |
|
100 | 100 | if TYPE_CHECKING: |
101 | 101 | from ray.rllib.algorithms.algorithm import Algorithm |
| 102 | + from ray.rllib.connectors.connector_v2 import ConnectorV2 |
102 | 103 | from ray.rllib.core.learner import Learner |
103 | 104 | from ray.rllib.core.learner.learner_group import LearnerGroup |
| 105 | + from ray.rllib.core.rl_module.rl_module import RLModule |
104 | 106 | from ray.rllib.evaluation.episode import Episode as OldEpisode |
105 | 107 |
|
106 | 108 | logger = logging.getLogger(__name__) |
@@ -327,6 +329,8 @@ def __init__(self, algo_class=None): |
327 | 329 | self.num_envs_per_worker = 1 |
328 | 330 | self.create_env_on_local_worker = False |
329 | 331 | self.enable_connectors = True |
| 332 | + self._env_to_module_connector = None |
| 333 | + self._module_to_env_connector = None |
330 | 334 | # TODO (sven): Rename into `sample_timesteps` (or `sample_duration` |
331 | 335 | # and `sample_duration_unit` (replacing batch_mode), like we do it |
332 | 336 | # in the evaluation config). |
@@ -374,6 +378,7 @@ def __init__(self, algo_class=None): |
374 | 378 | except AttributeError: |
375 | 379 | pass |
376 | 380 |
|
| 381 | + self._learner_connector = None |
377 | 382 | self.optimizer = {} |
378 | 383 | self.max_requests_in_flight_per_sampler_worker = 2 |
379 | 384 | self._learner_class = None |
@@ -1152,6 +1157,121 @@ class directly. Note that this arg can also be specified via |
1152 | 1157 | logger_creator=self.logger_creator, |
1153 | 1158 | ) |
1154 | 1159 |
|
| 1160 | + def build_env_to_module_connector(self, env): |
| 1161 | + from ray.rllib.connectors.env_to_module import ( |
| 1162 | + EnvToModulePipeline, |
| 1163 | + DefaultEnvToModule, |
| 1164 | + ) |
| 1165 | + |
| 1166 | + custom_connectors = [] |
| 1167 | + # Create an env-to-module connector pipeline (including RLlib's default |
| 1168 | + # env->module connector piece) and return it. |
| 1169 | + if self._env_to_module_connector is not None: |
| 1170 | + val_ = self._env_to_module_connector(env) |
| 1171 | + |
| 1172 | + from ray.rllib.connectors.connector_v2 import ConnectorV2 |
| 1173 | + |
| 1174 | + if isinstance(val_, ConnectorV2) and not isinstance( |
| 1175 | + val_, EnvToModulePipeline |
| 1176 | + ): |
| 1177 | + custom_connectors = [val_] |
| 1178 | + elif isinstance(val_, (list, tuple)): |
| 1179 | + custom_connectors = list(val_) |
| 1180 | + else: |
| 1181 | + return val_ |
| 1182 | + |
| 1183 | + pipeline = EnvToModulePipeline( |
| 1184 | + connectors=custom_connectors, |
| 1185 | + input_observation_space=env.single_observation_space, |
| 1186 | + input_action_space=env.single_action_space, |
| 1187 | + env=env, |
| 1188 | + ) |
| 1189 | + pipeline.append( |
| 1190 | + DefaultEnvToModule( |
| 1191 | + input_observation_space=pipeline.observation_space, |
| 1192 | + input_action_space=pipeline.action_space, |
| 1193 | + env=env, |
| 1194 | + ) |
| 1195 | + ) |
| 1196 | + return pipeline |
| 1197 | + |
| 1198 | + def build_module_to_env_connector(self, env): |
| 1199 | + |
| 1200 | + from ray.rllib.connectors.module_to_env import ( |
| 1201 | + DefaultModuleToEnv, |
| 1202 | + ModuleToEnvPipeline, |
| 1203 | + ) |
| 1204 | + |
| 1205 | + custom_connectors = [] |
| 1206 | + # Create a module-to-env connector pipeline (including RLlib's default |
| 1207 | + # module->env connector piece) and return it. |
| 1208 | + if self._module_to_env_connector is not None: |
| 1209 | + val_ = self._module_to_env_connector(env) |
| 1210 | + |
| 1211 | + from ray.rllib.connectors.connector_v2 import ConnectorV2 |
| 1212 | + |
| 1213 | + if isinstance(val_, ConnectorV2) and not isinstance( |
| 1214 | + val_, ModuleToEnvPipeline |
| 1215 | + ): |
| 1216 | + custom_connectors = [val_] |
| 1217 | + elif isinstance(val_, (list, tuple)): |
| 1218 | + custom_connectors = list(val_) |
| 1219 | + else: |
| 1220 | + return val_ |
| 1221 | + |
| 1222 | + pipeline = ModuleToEnvPipeline( |
| 1223 | + connectors=custom_connectors, |
| 1224 | + input_observation_space=env.single_observation_space, |
| 1225 | + input_action_space=env.single_action_space, |
| 1226 | + env=env, |
| 1227 | + ) |
| 1228 | + pipeline.append( |
| 1229 | + DefaultModuleToEnv( |
| 1230 | + input_observation_space=pipeline.observation_space, |
| 1231 | + input_action_space=pipeline.action_space, |
| 1232 | + env=env, |
| 1233 | + normalize_actions=self.normalize_actions, |
| 1234 | + clip_actions=self.clip_actions, |
| 1235 | + ) |
| 1236 | + ) |
| 1237 | + return pipeline |
| 1238 | + |
| 1239 | + def build_learner_connector(self, input_observation_space, input_action_space): |
| 1240 | + from ray.rllib.connectors.learner import ( |
| 1241 | + DefaultLearnerConnector, |
| 1242 | + LearnerConnectorPipeline, |
| 1243 | + ) |
| 1244 | + |
| 1245 | + custom_connectors = [] |
| 1246 | + # Create a learner connector pipeline (including RLlib's default |
| 1247 | + # learner connector piece) and return it. |
| 1248 | + if self._learner_connector is not None: |
| 1249 | + val_ = self._learner_connector(input_observation_space, input_action_space) |
| 1250 | + |
| 1251 | + from ray.rllib.connectors.connector_v2 import ConnectorV2 |
| 1252 | + |
| 1253 | + if isinstance(val_, ConnectorV2) and not isinstance( |
| 1254 | + val_, LearnerConnectorPipeline |
| 1255 | + ): |
| 1256 | + custom_connectors = [val_] |
| 1257 | + elif isinstance(val_, (list, tuple)): |
| 1258 | + custom_connectors = list(val_) |
| 1259 | + else: |
| 1260 | + return val_ |
| 1261 | + |
| 1262 | + pipeline = LearnerConnectorPipeline( |
| 1263 | + connectors=custom_connectors, |
| 1264 | + input_observation_space=input_observation_space, |
| 1265 | + input_action_space=input_action_space, |
| 1266 | + ) |
| 1267 | + pipeline.append( |
| 1268 | + DefaultLearnerConnector( |
| 1269 | + input_observation_space=pipeline.observation_space, |
| 1270 | + input_action_space=pipeline.action_space, |
| 1271 | + ) |
| 1272 | + ) |
| 1273 | + return pipeline |
| 1274 | + |
1155 | 1275 | def build_learner_group( |
1156 | 1276 | self, |
1157 | 1277 | *, |
@@ -1605,6 +1725,12 @@ def rollouts( |
1605 | 1725 | create_env_on_local_worker: Optional[bool] = NotProvided, |
1606 | 1726 | sample_collector: Optional[Type[SampleCollector]] = NotProvided, |
1607 | 1727 | enable_connectors: Optional[bool] = NotProvided, |
| 1728 | + env_to_module_connector: Optional[ |
| 1729 | + Callable[[EnvType], "ConnectorV2"] |
| 1730 | + ] = NotProvided, |
| 1731 | + module_to_env_connector: Optional[ |
| 1732 | + Callable[[EnvType, "RLModule"], "ConnectorV2"] |
| 1733 | + ] = NotProvided, |
1608 | 1734 | use_worker_filter_stats: Optional[bool] = NotProvided, |
1609 | 1735 | update_worker_filter_stats: Optional[bool] = NotProvided, |
1610 | 1736 | rollout_fragment_length: Optional[Union[int, str]] = NotProvided, |
@@ -1650,6 +1776,11 @@ def rollouts( |
1650 | 1776 | enable_connectors: Use connector based environment runner, so that all |
1651 | 1777 | preprocessing of obs and postprocessing of actions are done in agent |
1652 | 1778 | and action connectors. |
| 1779 | + env_to_module_connector: A callable taking an Env as input arg and returning |
| 1780 | + an env-to-module ConnectorV2 (might be a pipeline) object. |
| 1781 | + module_to_env_connector: A callable taking an Env and an RLModule as input |
| 1782 | + args and returning a module-to-env ConnectorV2 (might be a pipeline) |
| 1783 | + object. |
1653 | 1784 | use_worker_filter_stats: Whether to use the workers in the WorkerSet to |
1654 | 1785 | update the central filters (held by the local worker). If False, stats |
1655 | 1786 | from the workers will not be used and discarded. |
@@ -1737,6 +1868,10 @@ def rollouts( |
1737 | 1868 | self.create_env_on_local_worker = create_env_on_local_worker |
1738 | 1869 | if enable_connectors is not NotProvided: |
1739 | 1870 | self.enable_connectors = enable_connectors |
| 1871 | + if env_to_module_connector is not NotProvided: |
| 1872 | + self._env_to_module_connector = env_to_module_connector |
| 1873 | + if module_to_env_connector is not NotProvided: |
| 1874 | + self._module_to_env_connector = module_to_env_connector |
1740 | 1875 | if use_worker_filter_stats is not NotProvided: |
1741 | 1876 | self.use_worker_filter_stats = use_worker_filter_stats |
1742 | 1877 | if update_worker_filter_stats is not NotProvided: |
@@ -1855,6 +1990,9 @@ def training( |
1855 | 1990 | optimizer: Optional[dict] = NotProvided, |
1856 | 1991 | max_requests_in_flight_per_sampler_worker: Optional[int] = NotProvided, |
1857 | 1992 | learner_class: Optional[Type["Learner"]] = NotProvided, |
| 1993 | + learner_connector: Optional[ |
| 1994 | + Callable[["RLModule"], "ConnectorV2"] |
| 1995 | + ] = NotProvided, |
1858 | 1996 | # Deprecated arg. |
1859 | 1997 | _enable_learner_api: Optional[bool] = NotProvided, |
1860 | 1998 | ) -> "AlgorithmConfig": |
@@ -1916,6 +2054,9 @@ def training( |
1916 | 2054 | in your experiment of timesteps. |
1917 | 2055 | learner_class: The `Learner` class to use for (distributed) updating of the |
1918 | 2056 | RLModule. Only used when `_enable_new_api_stack=True`. |
| 2057 | + learner_connector: A callable taking an env observation space and an env |
| 2058 | + action space as inputs and returning a learner ConnectorV2 (might be |
| 2059 | + a pipeline) object. |
1919 | 2060 |
|
1920 | 2061 | Returns: |
1921 | 2062 | This updated AlgorithmConfig object. |
@@ -1960,6 +2101,8 @@ def training( |
1960 | 2101 | ) |
1961 | 2102 | if learner_class is not NotProvided: |
1962 | 2103 | self._learner_class = learner_class |
| 2104 | + if learner_connector is not NotProvided: |
| 2105 | + self._learner_connector = learner_connector |
1963 | 2106 |
|
1964 | 2107 | return self |
1965 | 2108 |
|
|
0 commit comments