diff --git a/examples/06_pytorch_oxe_dataloader.py b/examples/06_pytorch_oxe_dataloader.py index 8804a442..43dde74a 100644 --- a/examples/06_pytorch_oxe_dataloader.py +++ b/examples/06_pytorch_oxe_dataloader.py @@ -40,7 +40,7 @@ def __len__(self): ] ) if hasattr(self._rlds_dataset, "sample_weights"): - lengths *= np.array(self._rlds_dataset.sample_weights) + lengths = np.array(self._rlds_dataset.sample_weights) * lengths total_len = lengths.sum() if self._is_train: return int(0.95 * total_len) diff --git a/octo/data/oxe/__init__.py b/octo/data/oxe/__init__.py index 2ec9555c..77a7d1a5 100755 --- a/octo/data/oxe/__init__.py +++ b/octo/data/oxe/__init__.py @@ -1,5 +1,6 @@ import copy import logging +import os from typing import Any, Dict, List, Sequence, Tuple, Union from octo.data.oxe.oxe_dataset_configs import ActionEncoding, OXE_DATASET_CONFIGS @@ -74,6 +75,12 @@ def make_oxe_dataset_kwargs( dataset_kwargs["standardize_fn"] = OXE_STANDARDIZATION_TRANSFORMS[name] + if "data_dir" in dataset_kwargs: + if dataset_kwargs["data_dir"][0] == "~": + dataset_kwargs["data_dir"] = os.path.expanduser("~") + dataset_kwargs["data_dir"][1:] + data_dir = dataset_kwargs["data_dir"] + del dataset_kwargs["data_dir"] + return {"name": name, "data_dir": data_dir, **dataset_kwargs} diff --git a/octo/data/oxe/oxe_standardization_transforms.py b/octo/data/oxe/oxe_standardization_transforms.py index f25eed28..43dbe421 100755 --- a/octo/data/oxe/oxe_standardization_transforms.py +++ b/octo/data/oxe/oxe_standardization_transforms.py @@ -32,7 +32,7 @@ def bridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: trajectory["action"][:, :6], binarize_gripper_actions(trajectory["action"][:, -1])[:, None], ], - axis=1, + axis=-1, ) trajectory = relabel_actions(trajectory) trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]