Skip to content

datas

BehaviorIterableDataset

Bases: IterableDataset

BehaviorIterableDataset is an IterableDataset designed for loading and streaming demonstration data for behavior tasks. It supports multi-modal observations (including low-dimensional proprioception, visual data, and point clouds), action chunking, temporal context windows, and distributed data loading for scalable training. Key Features: - Loads demonstration data from disk, supporting multiple tasks and robots. - Preloads low-dimensional data (actions, proprioception, task info) into memory for efficient access. - Supports visual observation types: RGB, depth, segmentation, and point clouds, with multi-view camera support. - Handles action chunking for sequence prediction tasks, with optional prediction horizon and masking. - Supports temporal downsampling of data for variable frame rates. - Provides deterministic shuffling and partitioning for distributed and multi-worker training. - Normalizes observations and actions to standardized ranges for learning. - Optionally loads and normalizes privileged task information. - Implements efficient chunked streaming of data for training with context windows.

Source code in OmniGibson/omnigibson/learning/datas/iterable_dataset.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
class BehaviorIterableDataset(IterableDataset):
    """
    BehaviorIterableDataset is an IterableDataset designed for loading and streaming demonstration data for behavior tasks.
    It supports multi-modal observations (including low-dimensional proprioception, visual data, and point clouds), action chunking,
    temporal context windows, and distributed data loading for scalable training.
    Key Features:
        - Loads demonstration data from disk, supporting multiple tasks and robots.
        - Preloads low-dimensional data (actions, proprioception, task info) into memory for efficient access.
        - Supports visual observation types: RGB, depth, segmentation, and point clouds, with multi-view camera support.
        - Handles action chunking for sequence prediction tasks, with optional prediction horizon and masking.
        - Supports temporal downsampling of data for variable frame rates.
        - Provides deterministic shuffling and partitioning for distributed and multi-worker training.
        - Normalizes observations and actions to standardized ranges for learning.
        - Optionally loads and normalizes privileged task information.
        - Implements efficient chunked streaming of data for training with context windows.
    """

    @classmethod
    def get_all_demo_keys(cls, data_path: str, task_names: List[str]) -> List[Any]:
        assert os.path.exists(data_path), "Data path does not exist!"
        task_dir_names = [f"task-{TASK_NAMES_TO_INDICES[name]:04d}" for name in task_names]
        demo_keys = sorted(
            [
                file_name.split(".")[0].split("_")[-1]
                for task_dir_name in task_dir_names
                for file_name in os.listdir(f"{data_path}/2025-challenge-demos/data/{task_dir_name}")
                if file_name.endswith(".parquet")
            ]
        )
        return demo_keys

    def __init__(
        self,
        *args,
        data_path: str,
        demo_keys: List[Any],
        robot_type: str = "R1Pro",
        obs_window_size: int,
        ctx_len: int,
        use_action_chunks: bool = False,
        action_prediction_horizon: Optional[int] = None,
        downsample_factor: int = 1,
        visual_obs_types: List[str],
        multi_view_cameras: Optional[Dict[str, Any]] = None,
        use_task_info: bool = False,
        task_info_range: Optional[ListConfig] = None,
        seed: int = 42,
        shuffle: bool = True,
        **kwargs,
    ) -> None:
        """
        Initialize the BehaviorIterableDataset.
        Args:
            data_path (str): Path to the data directory.
            demo_keys (List[Any]): List of demo keys.
            robot_type (str): Type of the robot. Default is "R1Pro".
            obs_window_size (int): Size of the observation window.
            ctx_len (int): Context length.
            use_action_chunks (bool): Whether to use action chunks.
                Action will be from (T, A) to (T, L_pred_horizon, A)
            action_prediction_horizon (Optional[int]): Horizon of the action prediction.
                Must not be None if use_action_chunks is True.
            downsample_factor (int): Downsample factor for the data (with uniform temporal subsampling).
                Note that the original data is at 30Hz, so if factor=3 then data will be at 10Hz.
                Default is 1 (no downsampling), must be >= 1.
            visual_obs_types (List[str]): List of visual observation types to load.
                Valid options are: "rgb", "depth", "seg".
            multi_view_cameras (Optional[Dict[str, Any]]): Dict of id-camera pairs to load obs from.
            use_task_info (bool): Whether to load privileged task information.
            task_info_range (Optional[ListConfig]): Range of the task information (for normalization).
            seed (int): Random seed.
            shuffle (bool): Whether to shuffle the dataset.
        """
        super().__init__()
        self._data_path = data_path
        self._demo_keys = demo_keys
        self._robot_type = robot_type
        self._obs_window_size = obs_window_size
        self._ctx_len = ctx_len
        self._use_action_chunks = use_action_chunks
        self._action_prediction_horizon = action_prediction_horizon
        assert (
            self._action_prediction_horizon is not None if self._use_action_chunks else True
        ), "action_prediction_horizon must be provided if use_action_chunks is True!"
        self._downsample_factor = downsample_factor
        assert self._downsample_factor >= 1, "downsample_factor must be >= 1!"
        self._use_task_info = use_task_info
        self._task_info_range = (
            th.tensor(OmegaConf.to_container(task_info_range)) if task_info_range is not None else None
        )
        self._seed = seed
        self._shuffle = shuffle
        self._epoch = 0

        assert set(visual_obs_types).issubset(
            {"rgb", "depth_linear", "seg_instance_id", "pcd"}
        ), "visual_obs_types must be a subset of {'rgb', 'depth_linear', 'seg_instance_id', 'pcd'}!"
        self._visual_obs_types = set(visual_obs_types)

        self._multi_view_cameras = multi_view_cameras

        self._demo_indices = list(range(len(self._demo_keys)))
        # Preload low dim into memory
        self._all_demos = [self._preload_demo(demo_key) for demo_key in self._demo_keys]
        # get demo lengths (N_chunks)
        self._demo_lengths = []
        for demo in self._all_demos:
            L = get_batch_size(demo, strict=True)
            assert L >= self._obs_window_size >= 1
            self._demo_lengths.append(L - self._obs_window_size + 1)
        logger.info(f"Dataset chunk length: {sum(self._demo_lengths)}")

    @property
    def epoch(self):
        return self._epoch

    @epoch.setter
    def epoch(self, epoch: int):
        self._epoch = epoch
        if self._shuffle:
            # deterministically shuffle the demos
            g = th.Generator()
            g.manual_seed(epoch + self._seed)
            self._demo_indices = th.randperm(len(self._demo_keys), generator=g).tolist()

    def __iter__(self) -> Generator[Dict[str, Any], None, None]:
        global_worker_id, total_global_workers = self._get_global_worker_id()
        demo_lengths_shuffled = [self._demo_lengths[i] for i in self._demo_indices]
        start_demo_id, start_demo_idx, end_demo_id, end_demo_idx = sequential_sum_balanced_partitioning(
            demo_lengths_shuffled, total_global_workers, global_worker_id
        )
        for demo_idx, demo_ptr in enumerate(self._demo_indices[start_demo_id : end_demo_id + 1]):
            start_idx = start_demo_idx if demo_idx == 0 else 0
            end_idx = end_demo_idx if demo_idx == end_demo_id - start_demo_id else self._demo_lengths[demo_ptr]
            yield from self.get_streamed_data(demo_ptr, start_idx, end_idx)

    def get_streamed_data(self, demo_ptr: int, start_idx: int, end_idx: int) -> Generator[Dict[str, Any], None, None]:
        task_id = int(self._demo_keys[demo_ptr]) // 10000
        chunk_generator = self._chunk_demo(demo_ptr, start_idx, end_idx)
        # Initialize obs loaders
        obs_loaders = dict()
        for obs_type in self._visual_obs_types:
            if obs_type == "pcd":
                # pcd_generator
                f_pcd = h5py.File(
                    f"{self._data_path}/pcd_vid/task-{task_id:04d}/episode_{self._demo_keys[demo_ptr]}.hdf5",
                    "r",
                    swmr=True,
                    libver="latest",
                )
                # Create a generator that yields sliding windows of point clouds
                pcd_data = f_pcd["data/demo_0/robot_r1::fused_pcd"]
                pcd_generator = self._h5_window_generator(pcd_data, start_idx, end_idx)
            else:
                # calculate the start a
                for camera_id in self._multi_view_cameras.keys():
                    camera_name = self._multi_view_cameras[camera_id]["name"]
                    stride = 1
                    kwargs = {}
                    if obs_type == "seg_instance_id":
                        with open(
                            f"{self._data_path}/2025-challenge-demos/meta/episodes/task-{task_id:04d}/episode_{self._demo_keys[demo_ptr]}.json",
                            "r",
                        ) as f:
                            kwargs["id_list"] = th.tensor(
                                json.load(f)[f"{ROBOT_CAMERA_NAMES['R1Pro'][camera_id]}::unique_ins_ids"]
                            )
                    obs_loaders[f"{camera_name}::{obs_type}"] = iter(
                        OBS_LOADER_MAP[obs_type](
                            data_path=f"{self._data_path}/2025-challenge-demos",
                            task_id=task_id,
                            camera_id=camera_id,
                            demo_id=self._demo_keys[demo_ptr],
                            batch_size=self._obs_window_size,
                            stride=stride,
                            start_idx=start_idx * stride * self._downsample_factor,
                            end_idx=((end_idx - 1) * stride + self._obs_window_size) * self._downsample_factor,
                            output_size=tuple(self._multi_view_cameras[camera_id]["resolution"]),
                            **kwargs,
                        )
                    )
        for _ in range(start_idx, end_idx):
            data, mask = next(chunk_generator)
            # load visual obs
            for obs_type in self._visual_obs_types:
                if obs_type == "pcd":
                    # get file from
                    data["obs"]["pcd"] = next(pcd_generator)
                else:
                    for camera in self._multi_view_cameras.values():
                        data["obs"][f"{camera['name']}::{obs_type}"] = next(
                            obs_loaders[f"{camera['name']}::{obs_type}"]
                        )
            data["masks"] = mask
            yield data
        for obs_type in self._visual_obs_types:
            if obs_type == "pcd":
                f_pcd.close()
            else:
                for camera in self._multi_view_cameras.values():
                    obs_loaders[f"{camera['name']}::{obs_type}"].close()

    def _preload_demo(self, demo_key: Any) -> Dict[str, Any]:
        """
        Preload a single demo into memory. Currently it loads action, proprio, and optionally task info.
        Args:
            demo_key (Any): Key of the demo to preload.
        Returns:
            demo (dict): Preloaded demo.
        """
        demo = dict()
        demo["obs"] = {"qpos": dict(), "eef": dict()}
        # load low_dim data
        action_dict = dict()
        low_dim_data = self._extract_low_dim_data(demo_key)
        for key, data in low_dim_data.items():
            if key == "proprio":
                # normalize proprioception
                if "base_qvel" in PROPRIOCEPTION_INDICES[self._robot_type]:
                    demo["obs"]["odom"] = {
                        "base_velocity": 2
                        * (
                            data[..., PROPRIOCEPTION_INDICES[self._robot_type]["base_qvel"]]
                            - JOINT_RANGE[self._robot_type]["base"][0]
                        )
                        / (JOINT_RANGE[self._robot_type]["base"][1] - JOINT_RANGE[self._robot_type]["base"][0])
                        - 1.0
                    }
                for key in PROPRIO_QPOS_INDICES[self._robot_type]:
                    if "gripper" in key:
                        # rectify gripper actions to {-1, 1}
                        demo["obs"]["qpos"][key] = th.mean(
                            data[..., PROPRIO_QPOS_INDICES[self._robot_type][key]], dim=-1, keepdim=True
                        )
                        demo["obs"]["qpos"][key] = th.where(
                            demo["obs"]["qpos"][key]
                            > (JOINT_RANGE[self._robot_type][key][0] + JOINT_RANGE[self._robot_type][key][1]) / 2,
                            1.0,
                            -1.0,
                        )
                    else:
                        # normalize the qpos to [-1, 1]
                        demo["obs"]["qpos"][key] = (
                            2
                            * (
                                data[..., PROPRIO_QPOS_INDICES[self._robot_type][key]]
                                - JOINT_RANGE[self._robot_type][key][0]
                            )
                            / (JOINT_RANGE[self._robot_type][key][1] - JOINT_RANGE[self._robot_type][key][0])
                            - 1.0
                        )
                for key in EEF_POSITION_RANGE[self._robot_type]:
                    demo["obs"]["eef"][f"{key}_pos"] = (
                        2
                        * (
                            data[..., PROPRIOCEPTION_INDICES[self._robot_type][f"eef_{key}_pos"]]
                            - EEF_POSITION_RANGE[self._robot_type][key][0]
                        )
                        / (EEF_POSITION_RANGE[self._robot_type][key][1] - EEF_POSITION_RANGE[self._robot_type][key][0])
                        - 1.0
                    )
                    # don't normalize the eef orientation
                    demo["obs"]["eef"][f"{key}_quat"] = data[
                        ..., PROPRIOCEPTION_INDICES[self._robot_type][f"eef_{key}_quat"]
                    ]
            elif key == "action":
                # Note that we need to take the action at the timestamp before the next observation
                # First pad the action array so that it is divisible by the downsample factor
                if data.shape[0] % self._downsample_factor != 0:
                    pad_size = self._downsample_factor - (data.shape[0] % self._downsample_factor)
                    # pad with the last action
                    data = th.cat([data, data[-1:].repeat(pad_size, 1)], dim=0)
                # Now downsample the action array
                data = data[self._downsample_factor - 1 :: self._downsample_factor]
                for key, indices in ACTION_QPOS_INDICES[self._robot_type].items():
                    action_dict[key] = data[:, indices]
                    # action normalization
                    if "gripper" not in key:  # Gripper actions are already normalized to [-1, 1]
                        action_dict[key] = (
                            2
                            * (action_dict[key] - JOINT_RANGE[self._robot_type][key][0])
                            / (JOINT_RANGE[self._robot_type][key][1] - JOINT_RANGE[self._robot_type][key][0])
                            - 1.0
                        )
                if self._use_action_chunks:
                    # make actions from (T, A) to (T, L_pred_horizon, A)
                    # need to construct a mask
                    action_chunks = []
                    action_chunk_masks = []
                    action_structure = deepcopy(any_slice(action_dict, np.s_[0:1]))  # (1, A)
                    for t in range(get_batch_size(action_dict, strict=True)):
                        action_chunk = any_slice(action_dict, np.s_[t : t + self._action_prediction_horizon])
                        action_chunk_size = get_batch_size(action_chunk, strict=True)
                        pad_size = self._action_prediction_horizon - action_chunk_size
                        mask = any_concat(
                            [
                                th.ones((action_chunk_size,), dtype=th.bool),
                                th.zeros((pad_size,), dtype=th.bool),
                            ],
                            dim=0,
                        )  # (L_pred_horizon,)
                        action_chunk = any_concat(
                            [
                                action_chunk,
                            ]
                            + [any_ones_like(action_structure)] * pad_size,
                            dim=0,
                        )  # (L_pred_horizon, A)
                        action_chunks.append(action_chunk)
                        action_chunk_masks.append(mask)
                    action_chunks = any_stack(action_chunks, dim=0)  # (T, L_pred_horizon, A)
                    action_chunk_masks = th.stack(action_chunk_masks, dim=0)  # (T, L_pred_horizon)
                    demo["actions"] = action_chunks
                    demo["action_masks"] = action_chunk_masks
                else:
                    demo["actions"] = action_dict
            elif key == "task":
                if self._task_info_range is not None:
                    # Normalize task info to [-1, 1]
                    demo["obs"]["task"] = (
                        2 * (data - self._task_info_range[0]) / (self._task_info_range[1] - self._task_info_range[0])
                        - 1.0
                    )
                else:
                    # If no range is provided, just use the raw data
                    demo["obs"]["task"] = data
            else:
                # For other keys, just store the data as is
                demo["obs"][key] = data
        return demo

    def _extract_low_dim_data(self, demo_key: Any) -> Dict[str, th.Tensor]:
        task_id = int(demo_key) // 10000
        df = pd.read_parquet(
            os.path.join(
                self._data_path, "2025-challenge-demos", "data", f"task-{task_id:04d}", f"episode_{demo_key}.parquet"
            )
        )
        ret = {
            "proprio": th.from_numpy(
                np.array(df["observation.state"][:: self._downsample_factor].tolist(), dtype=np.float32)
            ),
            "action": th.from_numpy(np.array(df["action"].tolist(), dtype=np.float32)),
            "cam_rel_poses": th.from_numpy(
                np.array(df["observation.cam_rel_poses"][:: self._downsample_factor].tolist(), dtype=np.float32)
            ),
        }
        if self._use_task_info:
            ret["task"] = th.from_numpy(
                np.array(df["observation.task_info"][:: self._downsample_factor].tolist(), dtype=np.float32)
            )
        return ret

    def _chunk_demo(self, demo_ptr: int, start_idx: int, end_idx: int) -> Generator[Tuple[dict, th.Tensor], None, None]:
        demo = self._all_demos[demo_ptr]
        # split obs into chunks
        for chunk_idx in range(start_idx, end_idx):
            data, mask = [], []
            s = np.s_[chunk_idx : chunk_idx + self._obs_window_size]
            data = dict()
            for k in demo:
                if k == "actions":
                    data[k] = any_slice(demo[k], np.s_[chunk_idx : chunk_idx + self._ctx_len])
                    action_chunk_size = get_batch_size(data[k], strict=True)
                    pad_size = self._ctx_len - action_chunk_size
                    if self._use_action_chunks:
                        assert pad_size == 0, "pad_size should be 0 if use_action_chunks is True!"
                        mask = demo["action_masks"][chunk_idx : chunk_idx + self._ctx_len]
                    else:
                        # pad action chunks to equal length of ctx_len
                        data[k] = any_concat(
                            [
                                data[k],
                            ]
                            + [any_ones_like(any_slice(data[k], np.s_[0:1]))] * pad_size,
                            dim=0,
                        )
                        mask = th.cat(
                            [
                                th.ones((action_chunk_size,), dtype=th.bool),
                                th.zeros((pad_size,), dtype=th.bool),
                            ],
                            dim=0,
                        )
                elif k != "action_masks":
                    data[k] = any_slice(demo[k], s)
                else:
                    # action_masks has already been processed
                    pass
            yield data, mask

    def _get_global_worker_id(self):
        worker_info = get_worker_info()
        worker_id = worker_info.id if worker_info is not None else 0
        if dist.is_initialized():
            rank = dist.get_rank()
            world_size = dist.get_world_size()
            num_workers = worker_info.num_workers if worker_info else 1
            global_worker_id = rank * num_workers + worker_id
            total_global_workers = world_size * num_workers
        else:
            global_worker_id = worker_id
            total_global_workers = worker_info.num_workers if worker_info else 1
        return global_worker_id, total_global_workers

    def _h5_window_generator(self, df: h5py.Dataset, start_idx: int, end_idx: int) -> Generator[th.Tensor, None, None]:
        for i in range(start_idx, end_idx):
            yield th.from_numpy(
                df[
                    i * self._downsample_factor : (i + self._obs_window_size)
                    * self._downsample_factor : self._downsample_factor
                ]
            )

__init__(*args, data_path, demo_keys, robot_type='R1Pro', obs_window_size, ctx_len, use_action_chunks=False, action_prediction_horizon=None, downsample_factor=1, visual_obs_types, multi_view_cameras=None, use_task_info=False, task_info_range=None, seed=42, shuffle=True, **kwargs)

Initialize the BehaviorIterableDataset. Args: data_path (str): Path to the data directory. demo_keys (List[Any]): List of demo keys. robot_type (str): Type of the robot. Default is "R1Pro". obs_window_size (int): Size of the observation window. ctx_len (int): Context length. use_action_chunks (bool): Whether to use action chunks. Action will be from (T, A) to (T, L_pred_horizon, A) action_prediction_horizon (Optional[int]): Horizon of the action prediction. Must not be None if use_action_chunks is True. downsample_factor (int): Downsample factor for the data (with uniform temporal subsampling). Note that the original data is at 30Hz, so if factor=3 then data will be at 10Hz. Default is 1 (no downsampling), must be >= 1. visual_obs_types (List[str]): List of visual observation types to load. Valid options are: "rgb", "depth", "seg". multi_view_cameras (Optional[Dict[str, Any]]): Dict of id-camera pairs to load obs from. use_task_info (bool): Whether to load privileged task information. task_info_range (Optional[ListConfig]): Range of the task information (for normalization). seed (int): Random seed. shuffle (bool): Whether to shuffle the dataset.

Source code in OmniGibson/omnigibson/learning/datas/iterable_dataset.py
def __init__(
    self,
    *args,
    data_path: str,
    demo_keys: List[Any],
    robot_type: str = "R1Pro",
    obs_window_size: int,
    ctx_len: int,
    use_action_chunks: bool = False,
    action_prediction_horizon: Optional[int] = None,
    downsample_factor: int = 1,
    visual_obs_types: List[str],
    multi_view_cameras: Optional[Dict[str, Any]] = None,
    use_task_info: bool = False,
    task_info_range: Optional[ListConfig] = None,
    seed: int = 42,
    shuffle: bool = True,
    **kwargs,
) -> None:
    """
    Initialize the BehaviorIterableDataset.
    Args:
        data_path (str): Path to the data directory.
        demo_keys (List[Any]): List of demo keys.
        robot_type (str): Type of the robot. Default is "R1Pro".
        obs_window_size (int): Size of the observation window.
        ctx_len (int): Context length.
        use_action_chunks (bool): Whether to use action chunks.
            Action will be from (T, A) to (T, L_pred_horizon, A)
        action_prediction_horizon (Optional[int]): Horizon of the action prediction.
            Must not be None if use_action_chunks is True.
        downsample_factor (int): Downsample factor for the data (with uniform temporal subsampling).
            Note that the original data is at 30Hz, so if factor=3 then data will be at 10Hz.
            Default is 1 (no downsampling), must be >= 1.
        visual_obs_types (List[str]): List of visual observation types to load.
            Valid options are: "rgb", "depth", "seg".
        multi_view_cameras (Optional[Dict[str, Any]]): Dict of id-camera pairs to load obs from.
        use_task_info (bool): Whether to load privileged task information.
        task_info_range (Optional[ListConfig]): Range of the task information (for normalization).
        seed (int): Random seed.
        shuffle (bool): Whether to shuffle the dataset.
    """
    super().__init__()
    self._data_path = data_path
    self._demo_keys = demo_keys
    self._robot_type = robot_type
    self._obs_window_size = obs_window_size
    self._ctx_len = ctx_len
    self._use_action_chunks = use_action_chunks
    self._action_prediction_horizon = action_prediction_horizon
    assert (
        self._action_prediction_horizon is not None if self._use_action_chunks else True
    ), "action_prediction_horizon must be provided if use_action_chunks is True!"
    self._downsample_factor = downsample_factor
    assert self._downsample_factor >= 1, "downsample_factor must be >= 1!"
    self._use_task_info = use_task_info
    self._task_info_range = (
        th.tensor(OmegaConf.to_container(task_info_range)) if task_info_range is not None else None
    )
    self._seed = seed
    self._shuffle = shuffle
    self._epoch = 0

    assert set(visual_obs_types).issubset(
        {"rgb", "depth_linear", "seg_instance_id", "pcd"}
    ), "visual_obs_types must be a subset of {'rgb', 'depth_linear', 'seg_instance_id', 'pcd'}!"
    self._visual_obs_types = set(visual_obs_types)

    self._multi_view_cameras = multi_view_cameras

    self._demo_indices = list(range(len(self._demo_keys)))
    # Preload low dim into memory
    self._all_demos = [self._preload_demo(demo_key) for demo_key in self._demo_keys]
    # get demo lengths (N_chunks)
    self._demo_lengths = []
    for demo in self._all_demos:
        L = get_batch_size(demo, strict=True)
        assert L >= self._obs_window_size >= 1
        self._demo_lengths.append(L - self._obs_window_size + 1)
    logger.info(f"Dataset chunk length: {sum(self._demo_lengths)}")

BehaviorLeRobotDataset

Bases: LeRobotDataset

BehaviorLeRobotDataset is a customized dataset class for loading and managing LeRobot datasets, with additional filtering and loading options tailored for the BEHAVIOR-1K benchmark. This class extends LeRobotDataset and introduces the following customizations: - Task-based filtering: Load only episodes corresponding to specific tasks. - Modality and camera selection: Load only specified modalities (e.g., "rgb", "depth", "seg_instance_id") and cameras (e.g., "left_wrist", "right_wrist", "head"). - Ability to download and use additional annotation and metainfo files. - Local-only mode: Optionally restrict dataset usage to local files, disabling downloads. - Optional batch streaming using keyframe for faster access. These customizations allow for more efficient and targeted dataset usage in the context of B1K tasks

Source code in OmniGibson/omnigibson/learning/datas/lerobot_dataset.py
class BehaviorLeRobotDataset(LeRobotDataset):
    """
    BehaviorLeRobotDataset is a customized dataset class for loading and managing LeRobot datasets,
    with additional filtering and loading options tailored for the BEHAVIOR-1K benchmark.
    This class extends LeRobotDataset and introduces the following customizations:
        - Task-based filtering: Load only episodes corresponding to specific tasks.
        - Modality and camera selection: Load only specified modalities (e.g., "rgb", "depth", "seg_instance_id")
          and cameras (e.g., "left_wrist", "right_wrist", "head").
        - Ability to download and use additional annotation and metainfo files.
        - Local-only mode: Optionally restrict dataset usage to local files, disabling downloads.
        - Optional batch streaming using keyframe for faster access.
    These customizations allow for more efficient and targeted dataset usage in the context of B1K tasks
    """

    def __init__(
        self,
        repo_id: str,
        root: str | Path | None = None,
        episodes: list[int] | None = None,
        image_transforms: Callable | None = None,
        delta_timestamps: dict[list[float]] | None = None,
        tolerance_s: float = 1e-4,
        revision: str | None = None,
        force_cache_sync: bool = False,
        download_videos: bool = True,
        video_backend: str | None = "pyav",
        batch_encoding_size: int = 1,
        # === Customized arguments for BehaviorLeRobotDataset ===
        tasks: Iterable[str] = None,
        modalities: Iterable[str] = None,
        cameras: Iterable[str] = None,
        local_only: bool = False,
        check_timestamp_sync: bool = True,
        chunk_streaming_using_keyframe: bool = True,
        shuffle: bool = True,
        seed: int = 42,
    ):
        """
        Custom args:
            episodes (List[int]): list of episodes to use PER TASK.
                NOTE: This is different from the actual episode indices in the dataset.
                Rather, this is meant to be used for train/val split, or loading a specific amount of partial data.
                If set to None, all episodes will be loaded for a given task.
            tasks (List[str]): list of task names to load. If None, all tasks will be loaded.
            modalities (List[str]): list of modality names to load. If None, all modalities will be loaded.
                must be a subset of ["rgb", "depth", "seg_instance_id"]
            cameras (List[str]): list of camera names to load. If None, all cameras will be loaded.
                must be a subset of ["left_wrist", "right_wrist", "head"]
            local_only (bool): whether to only use local data (not download from HuggingFace).
                NOTE: set this to False and force_cache_sync to True if you want to force re-syncing the local cache with the remote dataset.
                For more details, please refer to the `force_cache_sync` argument in the base class.
            check_timestamp_sync (bool): whether to check timestamp synchronization between different modalities and the state/action data.
                While it is set to True in the original LeRobotDataset and is set to True here by default, it can be set to False to skip the check for faster loading.
                This will especially save time if you are loading the complete challenge demo dataset.
            chunk_streaming_using_keyframe (bool): whether to use chunk streaming mode for loading the dataset using keyframes.
                When this is enabled, the dataset will pseudo-randomly load data in chunks based on keyframes, allowing for faster access to the data.
                NOTE: As B1K challenge demos has GOP size of 250 frames for efficient storage, it is STRONGLY recommended to set this to True if you don't need true frame-level random access.
                When this is enabled, it is recommended to set shuffle to True for better randomness in chunk selection.
                We also enforce that segmentation instance ID videos can only be loaded in chunk_streaming_using_keyframe mode for faster access.
            shuffle (bool): whether to shuffle the chunks after loading. This ONLY applies in chunk streaming mode. Recommended to be set to True for better randomness in chunk selection.
            seed (int): random seed for shuffling chunks.
        """
        Dataset.__init__(self)
        self.repo_id = repo_id
        self.root = Path(os.path.expanduser(str(root))) if root else HF_LEROBOT_HOME / repo_id
        self.image_transforms = image_transforms
        self.delta_timestamps = delta_timestamps
        self.tolerance_s = tolerance_s
        self.revision = revision if revision else CODEBASE_VERSION
        self.video_backend = video_backend if video_backend else get_safe_default_codec()
        self.delta_indices = None
        self.batch_encoding_size = batch_encoding_size
        self.episodes_since_last_encoding = 0

        # Unused attributes
        self.image_writer = None
        self.episode_buffer = None

        self.root.mkdir(exist_ok=True, parents=True)

        # ========== Customizations ==========
        self.seed = seed
        if modalities is None:
            modalities = ["rgb", "depth", "seg_instance_id"]
        if "seg_instance_id" in modalities:
            assert chunk_streaming_using_keyframe, "For the sake of data loading speed, please use chunk_streaming_using_keyframe=True when loading segmentation instance ID videos."
        if "depth" in modalities:
            assert self.video_backend == "pyav", (
                "Depth videos can only be decoded with the 'pyav' backend. "
                "Please set video_backend='pyav' when initializing the dataset."
            )
        if cameras is None:
            cameras = ["head", "left_wrist", "right_wrist"]
        self.task_names = set(tasks) if tasks is not None else set(TASK_NAMES_TO_INDICES.keys())
        self.task_indices = [TASK_NAMES_TO_INDICES[task] for task in self.task_names]
        # Load metadata
        self.meta = BehaviorLerobotDatasetMetadata(
            repo_id=self.repo_id,
            root=self.root,
            revision=self.revision,
            force_cache_sync=force_cache_sync,
            tasks=self.task_names,
            modalities=modalities,
            cameras=cameras,
        )
        # overwrite episode based on task
        all_episodes = load_jsonlines(self.root / EPISODES_PATH)
        # get the episodes grouped by task
        epi_by_task = defaultdict(list)
        for item in all_episodes:
            if item["episode_index"] // 1e4 in self.meta.tasks:
                epi_by_task[item["episode_index"] // 1e4].append(item["episode_index"])
        # sort and cherrypick episodes within each task
        for task_id, ep_indices in epi_by_task.items():
            epi_by_task[task_id] = sorted(ep_indices)
            if episodes is not None:
                epi_by_task[task_id] = [epi_by_task[task_id][i] for i in episodes if i < len(epi_by_task[task_id])]
        # now put episodes back together
        self.episodes = sorted([ep for eps in epi_by_task.values() for ep in eps])
        # handle streaming mode and shuffling of episodes
        self._chunk_streaming_using_keyframe = chunk_streaming_using_keyframe
        if self._chunk_streaming_using_keyframe:
            if not shuffle:
                logger.warning(
                    "chunk_streaming_using_keyframe mode is enabled but shuffle is set to False. This may lead to less randomness in chunk selection."
                )
            self.chunks = self._get_keyframe_chunk_indices()
            # Now, we randomly permute the episodes if shuffle is True
            if shuffle:
                self.current_streaming_chunk_idx = None
                self.current_streaming_frame_idx = None
            else:
                self.current_streaming_chunk_idx = 0
                self.current_streaming_frame_idx = self.chunks[self.current_streaming_chunk_idx][0]
            self.obs_loaders = dict()
            self._should_obs_loaders_reload = True
        # record the positional index of each episode index within self.episodes
        self.episode_data_index_pos = {ep_idx: i for i, ep_idx in enumerate(self.episodes)}
        logger.info(f"Total episodes: {len(self.episodes)}")
        # ====================================

        if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
            episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
            self.stats = aggregate_stats(episodes_stats)

        # Load actual data
        try:
            if force_cache_sync:
                raise FileNotFoundError
            for fpath in self.get_episodes_file_paths():
                assert (self.root / fpath).is_file(), f"Missing file: {self.root / fpath}"
            self.hf_dataset = self.load_hf_dataset()
        except (AssertionError, FileNotFoundError, NotADirectoryError) as e:
            if local_only:
                raise e
            self.revision = get_safe_version(self.repo_id, self.revision)
            self.download_episodes(download_videos)
            self.hf_dataset = self.load_hf_dataset()

        self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)

        # Check timestamps
        if check_timestamp_sync:
            timestamps = th.stack(self.hf_dataset["timestamp"]).numpy()
            episode_indices = th.stack(self.hf_dataset["episode_index"]).numpy()
            ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
            check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)

        # Setup delta_indices
        if self.delta_timestamps is not None:
            check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
            self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)

    def get_episodes_file_paths(self) -> list[str]:
        """
        Overwrite the original method to use the episodes indices instead of range(self.meta.total_episodes)
        """
        episodes = self.episodes if self.episodes is not None else list(self.meta.episodes.keys())
        fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
        # append metainfo and language annotations
        fpaths += [str(self.meta.get_metainfo_path(ep_idx)) for ep_idx in episodes]
        # TODO: add this back once we have all the language annotations
        # fpaths += [str(self.meta.get_annotation_path(ep_idx)) for ep_idx in episodes]
        if len(self.meta.video_keys) > 0:
            video_files = [
                str(self.meta.get_video_file_path(ep_idx, vid_key))
                for vid_key in self.meta.video_keys
                for ep_idx in episodes
            ]
            fpaths += video_files

        return fpaths

    def download_episodes(self, download_videos: bool = True) -> None:
        """
        Overwrite base method to allow more flexible pattern matching.
        Here, we do coarse filtering based on tasks, cameras, and modalities.
        We do this instead of filename patterns to speed up pattern checking and download speed.
        """
        allow_patterns = []
        if set(self.task_indices) != set(TASK_NAMES_TO_INDICES.values()):
            for task in self.task_indices:
                allow_patterns.append(f"**/task-{task:04d}/**")
        if len(self.meta.modalities) != 3:
            for modality in self.meta.modalities:
                if len(self.meta.camera_names) != 3:
                    for camera in self.meta.camera_names:
                        allow_patterns.append(f"**/observation.images.{modality}.{camera}/**")
                else:
                    allow_patterns.append(f"**/observation.images.{modality}.*/**")
        elif len(self.meta.camera_names) != 3:
            for camera in self.meta.camera_names:
                allow_patterns.append(f"**/observation.images.*.{camera}/**")
        ignore_patterns = []
        if not download_videos:
            ignore_patterns.append("videos/")
        if set(self.task_indices) != set(TASK_NAMES_TO_INDICES.values()):
            for task in set(TASK_NAMES_TO_INDICES.values()).difference(self.task_indices):
                ignore_patterns.append(f"**/task-{task:04d}/**")

        allow_patterns = None if allow_patterns == [] else allow_patterns
        ignore_patterns = None if ignore_patterns == [] else ignore_patterns
        self.pull_from_repo(allow_patterns=allow_patterns, ignore_patterns=ignore_patterns)

    def pull_from_repo(
        self,
        allow_patterns: list[str] | str | None = None,
        ignore_patterns: list[str] | str | None = None,
    ) -> None:
        """
        Overwrite base class to increase max workers to num of CPUs - 2
        """
        logger.info(f"Pulling dataset {self.repo_id} from HuggingFace hub...")
        snapshot_download(
            self.repo_id,
            repo_type="dataset",
            revision=self.revision,
            local_dir=self.root,
            allow_patterns=allow_patterns,
            ignore_patterns=ignore_patterns,
            max_workers=os.cpu_count() - 2,
        )

    def load_hf_dataset(self) -> datasets.Dataset:
        """hf_dataset contains all the observations, states, actions, rewards, etc."""
        if self.episodes is None:
            path = str(self.root / "data")
            hf_dataset = load_dataset("parquet", data_dir=path, split="train")
        else:
            files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
            hf_dataset = load_dataset("parquet", data_files=files, split="train")

        hf_dataset.set_transform(hf_transform_to_torch)
        return hf_dataset

    def __getitem__(self, idx) -> dict:
        if not self._chunk_streaming_using_keyframe:
            return super().__getitem__(idx)
        # Streaming mode: we will load the episode at the current streaming index, and then increment the index for next call
        # Randomize chunk index on first call
        if self.current_streaming_chunk_idx is None:
            worker_info = get_worker_info()
            worker_id = 0 if worker_info is None else worker_info.id
            rng = np.random.default_rng(self.seed + worker_id)
            rng.shuffle(self.chunks)
            self.current_streaming_chunk_idx = rng.integers(0, len(self.chunks)).item()
            self.current_streaming_frame_idx = self.chunks[self.current_streaming_chunk_idx][0]
        # Current chunk iterated, move to next chunk
        if self.current_streaming_frame_idx >= self.chunks[self.current_streaming_chunk_idx][1]:
            self.current_streaming_chunk_idx += 1
            # All data iterated, restart from beginning
            if self.current_streaming_chunk_idx >= len(self.chunks):
                self.current_streaming_chunk_idx = 0
            self.current_streaming_frame_idx = self.chunks[self.current_streaming_chunk_idx][0]
            self._should_obs_loaders_reload = True
        item = self.hf_dataset[self.current_streaming_frame_idx]
        ep_idx = item["episode_index"].item()

        if self._should_obs_loaders_reload:
            for loader in self.obs_loaders.values():
                loader.close()
            self.obs_loaders = dict()
            # reload video loaders for new episode
            self.current_streaming_episode_idx = ep_idx
            for vid_key in self.meta.video_keys:
                kwargs = {}
                task_id = item["task_index"].item()
                if "seg_instance_id" in vid_key:
                    # load id list
                    with open(
                        self.root / "meta/episodes" / f"task-{task_id:04d}" / f"episode_{ep_idx:08d}.json",
                        "r",
                    ) as f:
                        kwargs["id_list"] = th.tensor(
                            json.load(f)[f"{ROBOT_CAMERA_NAMES['R1Pro'][vid_key.split('.')[-1]]}::unique_ins_ids"]
                        )
                self.obs_loaders[vid_key] = iter(
                    OBS_LOADER_MAP[vid_key.split(".")[2]](
                        data_path=self.root,
                        task_id=task_id,
                        camera_id=vid_key.split(".")[-1],
                        demo_id=f"{ep_idx:08d}",
                        start_idx=self.chunks[self.current_streaming_chunk_idx][2],
                        start_idx_is_keyframe=False,  # TODO (Wensi): Change this to True after figuring out the correct keyframe indices
                        batch_size=1,
                        stride=1,
                        **kwargs,
                    )
                )
            self._should_obs_loaders_reload = False

        query_indices = None
        if self.delta_indices is not None:
            query_indices, padding = self._get_query_indices(self.current_streaming_frame_idx, ep_idx)
            query_result = self._query_hf_dataset(query_indices)
            item = {**item, **padding}
            for key, val in query_result.items():
                item[key] = val

        # load visual observations
        for key in self.meta.video_keys:
            item[key] = next(self.obs_loaders[key])[0]

        if self.image_transforms is not None:
            image_keys = self.meta.camera_keys
            for cam in image_keys:
                item[cam] = self.image_transforms(item[cam])

        # Add task as a string
        task_idx = item["task_index"].item()
        item["task"] = self.meta.tasks[task_idx]
        self.current_streaming_frame_idx += 1

        return item

    def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
        ep_idx = self.episode_data_index_pos[ep_idx]
        ep_start = self.episode_data_index["from"][ep_idx]
        ep_end = self.episode_data_index["to"][ep_idx]
        query_indices = {
            key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
            for key, delta_idx in self.delta_indices.items()
        }
        padding = {  # Pad values outside of current episode range
            f"{key}_is_pad": th.BoolTensor(
                [(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
            )
            for key, delta_idx in self.delta_indices.items()
        }
        return query_indices, padding

    def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, th.Tensor]:
        """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
        in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
        Segmentation Fault. This probably happens because a memory reference to the video loader is created in
        the main process and a subprocess fails to access it.
        """
        item = {}
        for vid_key, query_ts in query_timestamps.items():
            video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
            frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
            item[vid_key] = frames.squeeze(0)

        return item

    def _get_keyframe_chunk_indices(self, chunk_size=250) -> List[Tuple[int, int, int]]:
        """
        Divide each episode into chunks of data based on GOP of the data (here for B1K, GOP size is 250 frames).
        Args:
            chunk_size (int): size of each chunk in number of frames. Default is 250 for B1K. Should be the GOP size of the video data.
        Returns:
            List of tuples, where each tuple contains (start_index, end_index, local_start_index) for each chunk.
        """
        episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in self.meta.episodes.items()}
        episode_lengths = [episode_lengths[ep_idx] for ep_idx in self.episodes]
        chunks = []
        offset = 0
        for L in episode_lengths:
            local_starts = list(range(0, L, chunk_size))
            local_ends = local_starts[1:] + [L]
            for ls, le in zip(local_starts, local_ends):
                chunks.append((offset + ls, offset + le, ls))
            offset += L
        return chunks

__init__(repo_id, root=None, episodes=None, image_transforms=None, delta_timestamps=None, tolerance_s=0.0001, revision=None, force_cache_sync=False, download_videos=True, video_backend='pyav', batch_encoding_size=1, tasks=None, modalities=None, cameras=None, local_only=False, check_timestamp_sync=True, chunk_streaming_using_keyframe=True, shuffle=True, seed=42)

Custom args

episodes (List[int]): list of episodes to use PER TASK. NOTE: This is different from the actual episode indices in the dataset. Rather, this is meant to be used for train/val split, or loading a specific amount of partial data. If set to None, all episodes will be loaded for a given task. tasks (List[str]): list of task names to load. If None, all tasks will be loaded. modalities (List[str]): list of modality names to load. If None, all modalities will be loaded. must be a subset of ["rgb", "depth", "seg_instance_id"] cameras (List[str]): list of camera names to load. If None, all cameras will be loaded. must be a subset of ["left_wrist", "right_wrist", "head"] local_only (bool): whether to only use local data (not download from HuggingFace). NOTE: set this to False and force_cache_sync to True if you want to force re-syncing the local cache with the remote dataset. For more details, please refer to the force_cache_sync argument in the base class. check_timestamp_sync (bool): whether to check timestamp synchronization between different modalities and the state/action data. While it is set to True in the original LeRobotDataset and is set to True here by default, it can be set to False to skip the check for faster loading. This will especially save time if you are loading the complete challenge demo dataset. chunk_streaming_using_keyframe (bool): whether to use chunk streaming mode for loading the dataset using keyframes. When this is enabled, the dataset will pseudo-randomly load data in chunks based on keyframes, allowing for faster access to the data. NOTE: As B1K challenge demos has GOP size of 250 frames for efficient storage, it is STRONGLY recommended to set this to True if you don't need true frame-level random access. When this is enabled, it is recommended to set shuffle to True for better randomness in chunk selection. We also enforce that segmentation instance ID videos can only be loaded in chunk_streaming_using_keyframe mode for faster access. shuffle (bool): whether to shuffle the chunks after loading. This ONLY applies in chunk streaming mode. Recommended to be set to True for better randomness in chunk selection. seed (int): random seed for shuffling chunks.

Source code in OmniGibson/omnigibson/learning/datas/lerobot_dataset.py
def __init__(
    self,
    repo_id: str,
    root: str | Path | None = None,
    episodes: list[int] | None = None,
    image_transforms: Callable | None = None,
    delta_timestamps: dict[list[float]] | None = None,
    tolerance_s: float = 1e-4,
    revision: str | None = None,
    force_cache_sync: bool = False,
    download_videos: bool = True,
    video_backend: str | None = "pyav",
    batch_encoding_size: int = 1,
    # === Customized arguments for BehaviorLeRobotDataset ===
    tasks: Iterable[str] = None,
    modalities: Iterable[str] = None,
    cameras: Iterable[str] = None,
    local_only: bool = False,
    check_timestamp_sync: bool = True,
    chunk_streaming_using_keyframe: bool = True,
    shuffle: bool = True,
    seed: int = 42,
):
    """
    Custom args:
        episodes (List[int]): list of episodes to use PER TASK.
            NOTE: This is different from the actual episode indices in the dataset.
            Rather, this is meant to be used for train/val split, or loading a specific amount of partial data.
            If set to None, all episodes will be loaded for a given task.
        tasks (List[str]): list of task names to load. If None, all tasks will be loaded.
        modalities (List[str]): list of modality names to load. If None, all modalities will be loaded.
            must be a subset of ["rgb", "depth", "seg_instance_id"]
        cameras (List[str]): list of camera names to load. If None, all cameras will be loaded.
            must be a subset of ["left_wrist", "right_wrist", "head"]
        local_only (bool): whether to only use local data (not download from HuggingFace).
            NOTE: set this to False and force_cache_sync to True if you want to force re-syncing the local cache with the remote dataset.
            For more details, please refer to the `force_cache_sync` argument in the base class.
        check_timestamp_sync (bool): whether to check timestamp synchronization between different modalities and the state/action data.
            While it is set to True in the original LeRobotDataset and is set to True here by default, it can be set to False to skip the check for faster loading.
            This will especially save time if you are loading the complete challenge demo dataset.
        chunk_streaming_using_keyframe (bool): whether to use chunk streaming mode for loading the dataset using keyframes.
            When this is enabled, the dataset will pseudo-randomly load data in chunks based on keyframes, allowing for faster access to the data.
            NOTE: As B1K challenge demos has GOP size of 250 frames for efficient storage, it is STRONGLY recommended to set this to True if you don't need true frame-level random access.
            When this is enabled, it is recommended to set shuffle to True for better randomness in chunk selection.
            We also enforce that segmentation instance ID videos can only be loaded in chunk_streaming_using_keyframe mode for faster access.
        shuffle (bool): whether to shuffle the chunks after loading. This ONLY applies in chunk streaming mode. Recommended to be set to True for better randomness in chunk selection.
        seed (int): random seed for shuffling chunks.
    """
    Dataset.__init__(self)
    self.repo_id = repo_id
    self.root = Path(os.path.expanduser(str(root))) if root else HF_LEROBOT_HOME / repo_id
    self.image_transforms = image_transforms
    self.delta_timestamps = delta_timestamps
    self.tolerance_s = tolerance_s
    self.revision = revision if revision else CODEBASE_VERSION
    self.video_backend = video_backend if video_backend else get_safe_default_codec()
    self.delta_indices = None
    self.batch_encoding_size = batch_encoding_size
    self.episodes_since_last_encoding = 0

    # Unused attributes
    self.image_writer = None
    self.episode_buffer = None

    self.root.mkdir(exist_ok=True, parents=True)

    # ========== Customizations ==========
    self.seed = seed
    if modalities is None:
        modalities = ["rgb", "depth", "seg_instance_id"]
    if "seg_instance_id" in modalities:
        assert chunk_streaming_using_keyframe, "For the sake of data loading speed, please use chunk_streaming_using_keyframe=True when loading segmentation instance ID videos."
    if "depth" in modalities:
        assert self.video_backend == "pyav", (
            "Depth videos can only be decoded with the 'pyav' backend. "
            "Please set video_backend='pyav' when initializing the dataset."
        )
    if cameras is None:
        cameras = ["head", "left_wrist", "right_wrist"]
    self.task_names = set(tasks) if tasks is not None else set(TASK_NAMES_TO_INDICES.keys())
    self.task_indices = [TASK_NAMES_TO_INDICES[task] for task in self.task_names]
    # Load metadata
    self.meta = BehaviorLerobotDatasetMetadata(
        repo_id=self.repo_id,
        root=self.root,
        revision=self.revision,
        force_cache_sync=force_cache_sync,
        tasks=self.task_names,
        modalities=modalities,
        cameras=cameras,
    )
    # overwrite episode based on task
    all_episodes = load_jsonlines(self.root / EPISODES_PATH)
    # get the episodes grouped by task
    epi_by_task = defaultdict(list)
    for item in all_episodes:
        if item["episode_index"] // 1e4 in self.meta.tasks:
            epi_by_task[item["episode_index"] // 1e4].append(item["episode_index"])
    # sort and cherrypick episodes within each task
    for task_id, ep_indices in epi_by_task.items():
        epi_by_task[task_id] = sorted(ep_indices)
        if episodes is not None:
            epi_by_task[task_id] = [epi_by_task[task_id][i] for i in episodes if i < len(epi_by_task[task_id])]
    # now put episodes back together
    self.episodes = sorted([ep for eps in epi_by_task.values() for ep in eps])
    # handle streaming mode and shuffling of episodes
    self._chunk_streaming_using_keyframe = chunk_streaming_using_keyframe
    if self._chunk_streaming_using_keyframe:
        if not shuffle:
            logger.warning(
                "chunk_streaming_using_keyframe mode is enabled but shuffle is set to False. This may lead to less randomness in chunk selection."
            )
        self.chunks = self._get_keyframe_chunk_indices()
        # Now, we randomly permute the episodes if shuffle is True
        if shuffle:
            self.current_streaming_chunk_idx = None
            self.current_streaming_frame_idx = None
        else:
            self.current_streaming_chunk_idx = 0
            self.current_streaming_frame_idx = self.chunks[self.current_streaming_chunk_idx][0]
        self.obs_loaders = dict()
        self._should_obs_loaders_reload = True
    # record the positional index of each episode index within self.episodes
    self.episode_data_index_pos = {ep_idx: i for i, ep_idx in enumerate(self.episodes)}
    logger.info(f"Total episodes: {len(self.episodes)}")
    # ====================================

    if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
        episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
        self.stats = aggregate_stats(episodes_stats)

    # Load actual data
    try:
        if force_cache_sync:
            raise FileNotFoundError
        for fpath in self.get_episodes_file_paths():
            assert (self.root / fpath).is_file(), f"Missing file: {self.root / fpath}"
        self.hf_dataset = self.load_hf_dataset()
    except (AssertionError, FileNotFoundError, NotADirectoryError) as e:
        if local_only:
            raise e
        self.revision = get_safe_version(self.repo_id, self.revision)
        self.download_episodes(download_videos)
        self.hf_dataset = self.load_hf_dataset()

    self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)

    # Check timestamps
    if check_timestamp_sync:
        timestamps = th.stack(self.hf_dataset["timestamp"]).numpy()
        episode_indices = th.stack(self.hf_dataset["episode_index"]).numpy()
        ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
        check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)

    # Setup delta_indices
    if self.delta_timestamps is not None:
        check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
        self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)

download_episodes(download_videos=True)

Overwrite base method to allow more flexible pattern matching. Here, we do coarse filtering based on tasks, cameras, and modalities. We do this instead of filename patterns to speed up pattern checking and download speed.

Source code in OmniGibson/omnigibson/learning/datas/lerobot_dataset.py
def download_episodes(self, download_videos: bool = True) -> None:
    """
    Overwrite base method to allow more flexible pattern matching.
    Here, we do coarse filtering based on tasks, cameras, and modalities.
    We do this instead of filename patterns to speed up pattern checking and download speed.
    """
    allow_patterns = []
    if set(self.task_indices) != set(TASK_NAMES_TO_INDICES.values()):
        for task in self.task_indices:
            allow_patterns.append(f"**/task-{task:04d}/**")
    if len(self.meta.modalities) != 3:
        for modality in self.meta.modalities:
            if len(self.meta.camera_names) != 3:
                for camera in self.meta.camera_names:
                    allow_patterns.append(f"**/observation.images.{modality}.{camera}/**")
            else:
                allow_patterns.append(f"**/observation.images.{modality}.*/**")
    elif len(self.meta.camera_names) != 3:
        for camera in self.meta.camera_names:
            allow_patterns.append(f"**/observation.images.*.{camera}/**")
    ignore_patterns = []
    if not download_videos:
        ignore_patterns.append("videos/")
    if set(self.task_indices) != set(TASK_NAMES_TO_INDICES.values()):
        for task in set(TASK_NAMES_TO_INDICES.values()).difference(self.task_indices):
            ignore_patterns.append(f"**/task-{task:04d}/**")

    allow_patterns = None if allow_patterns == [] else allow_patterns
    ignore_patterns = None if ignore_patterns == [] else ignore_patterns
    self.pull_from_repo(allow_patterns=allow_patterns, ignore_patterns=ignore_patterns)

get_episodes_file_paths()

Overwrite the original method to use the episodes indices instead of range(self.meta.total_episodes)

Source code in OmniGibson/omnigibson/learning/datas/lerobot_dataset.py
def get_episodes_file_paths(self) -> list[str]:
    """
    Overwrite the original method to use the episodes indices instead of range(self.meta.total_episodes)
    """
    episodes = self.episodes if self.episodes is not None else list(self.meta.episodes.keys())
    fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
    # append metainfo and language annotations
    fpaths += [str(self.meta.get_metainfo_path(ep_idx)) for ep_idx in episodes]
    # TODO: add this back once we have all the language annotations
    # fpaths += [str(self.meta.get_annotation_path(ep_idx)) for ep_idx in episodes]
    if len(self.meta.video_keys) > 0:
        video_files = [
            str(self.meta.get_video_file_path(ep_idx, vid_key))
            for vid_key in self.meta.video_keys
            for ep_idx in episodes
        ]
        fpaths += video_files

    return fpaths

load_hf_dataset()

hf_dataset contains all the observations, states, actions, rewards, etc.

Source code in OmniGibson/omnigibson/learning/datas/lerobot_dataset.py
def load_hf_dataset(self) -> datasets.Dataset:
    """hf_dataset contains all the observations, states, actions, rewards, etc."""
    if self.episodes is None:
        path = str(self.root / "data")
        hf_dataset = load_dataset("parquet", data_dir=path, split="train")
    else:
        files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
        hf_dataset = load_dataset("parquet", data_files=files, split="train")

    hf_dataset.set_transform(hf_transform_to_torch)
    return hf_dataset

pull_from_repo(allow_patterns=None, ignore_patterns=None)

Overwrite base class to increase max workers to num of CPUs - 2

Source code in OmniGibson/omnigibson/learning/datas/lerobot_dataset.py
def pull_from_repo(
    self,
    allow_patterns: list[str] | str | None = None,
    ignore_patterns: list[str] | str | None = None,
) -> None:
    """
    Overwrite base class to increase max workers to num of CPUs - 2
    """
    logger.info(f"Pulling dataset {self.repo_id} from HuggingFace hub...")
    snapshot_download(
        self.repo_id,
        repo_type="dataset",
        revision=self.revision,
        local_dir=self.root,
        allow_patterns=allow_patterns,
        ignore_patterns=ignore_patterns,
        max_workers=os.cpu_count() - 2,
    )

BehaviorLerobotDatasetMetadata

Bases: LeRobotDatasetMetadata

Source code in OmniGibson/omnigibson/learning/datas/lerobot_dataset.py
class BehaviorLerobotDatasetMetadata(LeRobotDatasetMetadata):
    """
    BehaviorLerobotDatasetMetadata extends LeRobotDatasetMetadata with the following customizations:
        1. Restricts the set of allowed modalities to {"rgb", "depth", "seg_instance_id"}.
        2. Restricts the set of allowed camera names to those defined in ROBOT_CAMERA_NAMES["R1Pro"].
        3. Provides a filtered view of dataset features, including only those corresponding to the selected modalities and camera names.
    """

    def __init__(
        self,
        repo_id: str,
        root: str | Path | None = None,
        revision: str | None = None,
        force_cache_sync: bool = False,
        # === Customized arguments for BehaviorLeRobotDataset ===
        tasks: Iterable[str] = None,
        modalities: Iterable[str] = None,
        cameras: Iterable[str] = None,
    ):
        # ========== Customizations ==========
        self.task_name_candidates = set(tasks) if tasks is not None else set(TASK_NAMES_TO_INDICES.keys())
        self.modalities = set(modalities)
        self.camera_names = set(cameras)
        assert self.modalities.issubset(
            {"rgb", "depth", "seg_instance_id"}
        ), f"Modalities must be a subset of ['rgb', 'depth', 'seg_instance_id'], but got {self.modalities}"
        assert self.camera_names.issubset(
            ROBOT_CAMERA_NAMES["R1Pro"]
        ), f"Camera names must be a subset of {ROBOT_CAMERA_NAMES['R1Pro']}, but got {self.camera_names}"
        # ===================================

        self.repo_id = repo_id
        self.revision = revision if revision else CODEBASE_VERSION
        self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id

        try:
            if force_cache_sync:
                raise FileNotFoundError
            self.load_metadata()
        except (FileNotFoundError, NotADirectoryError):
            if is_valid_version(self.revision):
                self.revision = get_safe_version(self.repo_id, self.revision)

            (self.root / "meta").mkdir(exist_ok=True, parents=True)
            self.pull_from_repo(allow_patterns="meta/**", ignore_patterns="meta/episodes/**")
            self.load_metadata()

    def load_metadata(self):
        self.info = load_info(self.root)
        check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
        self.tasks, self.task_to_task_index, self.task_names = self.load_tasks(self.root)
        # filter based on self.task_name_candidates
        valid_task_indices = [idx for idx, name in self.task_names.items() if name in self.task_name_candidates]
        self.task_names = set([self.task_names[idx] for idx in valid_task_indices])
        self.tasks = {idx: self.tasks[idx] for idx in valid_task_indices}
        self.task_to_task_index = {v: k for k, v in self.tasks.items()}

        self.episodes = self.load_episodes(self.root)
        if self._version < packaging.version.parse("v2.1"):
            self.stats = self.load_stats(self.root)
            self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
        else:
            self.episodes_stats = self.load_episodes_stats(self.root)
            self.stats = aggregate_stats(list(self.episodes_stats.values()))
        logger.info(f"Loaded metadata for {len(self.episodes)} episodes.")

    def load_tasks(self, local_dir: Path) -> tuple[dict, dict]:
        tasks = load_jsonlines(local_dir / TASKS_PATH)
        task_names = {item["task_index"]: item["task_name"] for item in sorted(tasks, key=lambda x: x["task_index"])}
        tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
        task_to_task_index = {task: task_index for task_index, task in tasks.items()}
        return tasks, task_to_task_index, task_names

    def load_episodes(self, local_dir: Path) -> dict:
        episodes = load_jsonlines(local_dir / EPISODES_PATH)
        return {
            item["episode_index"]: item
            for item in sorted(episodes, key=lambda x: x["episode_index"])
            if item["episode_index"] // 1e4 in self.tasks
        }

    def load_stats(self, local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
        if not (local_dir / STATS_PATH).exists():
            return None
        stats = load_json(local_dir / STATS_PATH)
        return cast_stats_to_numpy(stats)

    def load_episodes_stats(self, local_dir: Path) -> dict:
        episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
        return {
            item["episode_index"]: cast_stats_to_numpy(item["stats"])
            for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
            if item["episode_index"] in self.episodes
        }

    def get_annotation_path(self, ep_index: int) -> Path:
        ep_chunk = self.get_episode_chunk(ep_index)
        fpath = self.annotation_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
        return Path(fpath)

    def get_metainfo_path(self, ep_index: int) -> Path:
        ep_chunk = self.get_episode_chunk(ep_index)
        fpath = self.metainfo_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
        return Path(fpath)

    @property
    def annotation_path(self) -> str | None:
        """Formattable string for the annotation files."""
        return self.info["annotation_path"]

    @property
    def metainfo_path(self) -> str | None:
        """Formattable string for the metainfo files."""
        return self.info["metainfo_path"]

    @property
    def features(self) -> dict[str, dict]:
        """All features contained in the dataset."""
        features = dict()
        # pop not required features
        for name in self.info["features"].keys():
            if (
                name.startswith("observation.images.")
                and name.split(".")[-1] in self.camera_names
                and name.split(".")[-2] in self.modalities
            ):
                features[name] = self.info["features"][name]
        return features

annotation_path property

Formattable string for the annotation files.

features property

All features contained in the dataset.

metainfo_path property

Formattable string for the metainfo files.