Skip to content

lerobot_utils

DepthVideoReader

Bases: VideoReader

Adapted from torchvision.io.VideoReader to support gray16le decoding for depth

Source code in OmniGibson/omnigibson/learning/utils/lerobot_utils.py
class DepthVideoReader(VideoReader):
    """
    Adapted from torchvision.io.VideoReader to support gray16le decoding for depth
    """

    def __next__(self) -> Dict[str, Any]:
        """Decodes and returns the next frame of the current stream.
        Frames are encoded as a dict with mandatory
        data and pts fields, where data is a tensor, and pts is a
        presentation timestamp of the frame expressed in seconds
        as a float.

        Returns:
            (dict): a dictionary and containing decoded frame (``data``)
            and corresponding timestamp (``pts``) in seconds

        """
        try:
            frame = next(self._c)
            pts = float(frame.pts * frame.time_base)
            if "video" in self.pyav_stream:
                frame = th.as_tensor(
                    dequantize_depth(
                        frame.reformat(format="gray16le").to_ndarray(),
                        min_depth=MIN_DEPTH,
                        max_depth=MAX_DEPTH,
                        shift=DEPTH_SHIFT,
                    )
                )
            elif "audio" in self.pyav_stream:
                frame = th.as_tensor(frame.to_ndarray()).permute(1, 0)
            else:
                frame = None
        except av.error.EOFError:
            raise StopIteration

        if frame.numel() == 0:
            raise StopIteration

        return {"data": frame, "pts": pts}

__next__()

Decodes and returns the next frame of the current stream. Frames are encoded as a dict with mandatory data and pts fields, where data is a tensor, and pts is a presentation timestamp of the frame expressed in seconds as a float.

Returns:

Type Description
dict

a dictionary and containing decoded frame (data)

Dict[str, Any]

and corresponding timestamp (pts) in seconds

Source code in OmniGibson/omnigibson/learning/utils/lerobot_utils.py
def __next__(self) -> Dict[str, Any]:
    """Decodes and returns the next frame of the current stream.
    Frames are encoded as a dict with mandatory
    data and pts fields, where data is a tensor, and pts is a
    presentation timestamp of the frame expressed in seconds
    as a float.

    Returns:
        (dict): a dictionary and containing decoded frame (``data``)
        and corresponding timestamp (``pts``) in seconds

    """
    try:
        frame = next(self._c)
        pts = float(frame.pts * frame.time_base)
        if "video" in self.pyav_stream:
            frame = th.as_tensor(
                dequantize_depth(
                    frame.reformat(format="gray16le").to_ndarray(),
                    min_depth=MIN_DEPTH,
                    max_depth=MAX_DEPTH,
                    shift=DEPTH_SHIFT,
                )
            )
        elif "audio" in self.pyav_stream:
            frame = th.as_tensor(frame.to_ndarray()).permute(1, 0)
        else:
            frame = None
    except av.error.EOFError:
        raise StopIteration

    if frame.numel() == 0:
        raise StopIteration

    return {"data": frame, "pts": pts}

aggregate_feature_stats(stats_ft_list)

Aggregates stats for a single feature.

Source code in OmniGibson/omnigibson/learning/utils/lerobot_utils.py
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
    """Aggregates stats for a single feature."""
    means = np.stack([s["mean"] for s in stats_ft_list])
    variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
    counts = np.stack([s["count"] for s in stats_ft_list])
    q01 = np.stack([s["q01"] for s in stats_ft_list])
    q99 = np.stack([s["q99"] for s in stats_ft_list])
    total_count = counts.sum(axis=0)

    # Prepare weighted mean by matching number of dimensions
    while counts.ndim < means.ndim:
        counts = np.expand_dims(counts, axis=-1)

    # Compute the weighted mean
    weighted_means = means * counts
    total_mean = weighted_means.sum(axis=0) / total_count

    # Compute the variance using the parallel algorithm
    delta_means = means - total_mean
    weighted_variances = (variances + delta_means**2) * counts
    total_variance = weighted_variances.sum(axis=0) / total_count

    # Compute weighted quantiles
    weighted_q01 = np.percentile(q01, 1, axis=0)
    weighted_q99 = np.percentile(q99, 99, axis=0)

    return {
        "min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
        "max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
        "mean": total_mean,
        "std": np.sqrt(total_variance),
        "q01": weighted_q01,
        "q99": weighted_q99,
        "count": total_count,
    }

aggregate_stats(stats_list)

Aggregate stats from multiple compute_stats outputs into a single set of stats.

The final stats will have the union of all data keys from each of the stats dicts.

For instance: - new_min = min(min_dataset_0, min_dataset_1, ...) - new_max = max(max_dataset_0, max_dataset_1, ...) - new_mean = (mean of all data, weighted by counts) - new_std = (std of all data)

Source code in OmniGibson/omnigibson/learning/utils/lerobot_utils.py
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
    """Aggregate stats from multiple compute_stats outputs into a single set of stats.

    The final stats will have the union of all data keys from each of the stats dicts.

    For instance:
    - new_min = min(min_dataset_0, min_dataset_1, ...)
    - new_max = max(max_dataset_0, max_dataset_1, ...)
    - new_mean = (mean of all data, weighted by counts)
    - new_std = (std of all data)
    """

    _assert_type_and_shape(stats_list)

    data_keys = {key for stats in stats_list for key in stats}
    aggregated_stats = {key: {} for key in data_keys}

    for key in data_keys:
        stats_with_key = [stats[key] for stats in stats_list if key in stats]
        aggregated_stats[key] = aggregate_feature_stats(stats_with_key)

    return aggregated_stats

decode_video_frames_torchvision(video_path, timestamps, tolerance_s, log_loaded_timestamps=False, backend=None)

Adapted from decode_video_frames_vision to handle depth decoding

Source code in OmniGibson/omnigibson/learning/utils/lerobot_utils.py
def decode_video_frames_torchvision(
    video_path: Path | str,
    timestamps: list[float],
    tolerance_s: float,
    log_loaded_timestamps: bool = False,
    backend: str | None = None,
) -> th.Tensor:
    """
    Adapted from decode_video_frames_vision to handle depth decoding
    """
    video_path = str(video_path)

    # set backend
    keyframes_only = False
    if "depth" in video_path:
        backend = "pyav"
    torchvision.set_video_backend(backend)
    if backend == "pyav":
        keyframes_only = True  # pyav doesn't support accurate seek

    # set a video stream reader
    # TODO(rcadene): also load audio stream at the same time
    if "depth" in video_path:
        reader = DepthVideoReader(video_path, "video")
    else:
        reader = VideoReader(video_path, "video")

    # set the first and last requested timestamps
    # Note: previous timestamps are usually loaded, since we need to access the previous key frame
    first_ts = min(timestamps) - 5  # a little backward to account for timestamp mismatch
    last_ts = max(timestamps)

    # access closest key frame of the first requested frame
    # Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
    # for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
    reader.seek(first_ts, keyframes_only=keyframes_only)

    # load all frames until last requested frame
    loaded_frames = []
    loaded_ts = []
    for frame in reader:
        current_ts = frame["pts"]
        if log_loaded_timestamps:
            logging.info(f"frame loaded at timestamp={current_ts:.4f}")
        loaded_frames.append(frame["data"])
        loaded_ts.append(current_ts)
        if current_ts >= last_ts:
            break

    reader.container.close()

    reader = None

    query_ts = th.tensor(timestamps)
    loaded_ts = th.tensor(loaded_ts)

    # compute distances between each query timestamp and timestamps of all loaded frames
    dist = th.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
    min_, argmin_ = dist.min(1)

    is_within_tol = min_ < tolerance_s
    assert is_within_tol.all(), (
        f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
        "It means that the closest frame that can be loaded from the video is too far away in time."
        "This might be due to synchronization issues with timestamps during data collection."
        "To be safe, we advise to ignore this item during training."
        f"\nqueried timestamps: {query_ts}"
        f"\nloaded timestamps: {loaded_ts}"
        f"\nvideo: {video_path}"
        f"\nbackend: {backend}"
    )

    # get closest frames to the query timestamps
    closest_frames = th.stack([loaded_frames[idx] for idx in argmin_])
    closest_ts = loaded_ts[argmin_]

    if log_loaded_timestamps:
        logging.info(f"{closest_ts=}")

    # convert to the pytorch format which is float32 in [0,1] range (and channel first)
    closest_frames = closest_frames.type(th.float32)
    if "depth" not in video_path:
        closest_frames = closest_frames / 255

    assert len(timestamps) == len(closest_frames)
    return closest_frames

hf_transform_to_torch(items_dict)

Adapted from lerobot.datasets.utils.hf_transform_to_torch Preserve float64 for timestamp to avoid precision issues Below is the original docstring: Get a transform function that convert items from Hugging Face dataset (pyarrow) to torch tensors. Importantly, images are converted from PIL, which corresponds to a channel last representation (h w c) of uint8 type, to a torch image representation with channel first (c h w) of float32 type in range [0,1].

Source code in OmniGibson/omnigibson/learning/utils/lerobot_utils.py
def hf_transform_to_torch(items_dict: dict[th.Tensor | None]):
    """
    Adapted from lerobot.datasets.utils.hf_transform_to_torch
    Preserve float64 for timestamp to avoid precision issues
    Below is the original docstring:
    Get a transform function that convert items from Hugging Face dataset (pyarrow)
    to torch tensors. Importantly, images are converted from PIL, which corresponds to
    a channel last representation (h w c) of uint8 type, to a torch image representation
    with channel first (c h w) of float32 type in range [0,1].
    """
    for key in items_dict:
        if key == "timestamp":
            items_dict[key] = [x if isinstance(x, str) else th.tensor(x, dtype=th.float64) for x in items_dict[key]]
        else:
            first_item = items_dict[key][0]
            if isinstance(first_item, PILImage.Image):
                to_tensor = transforms.ToTensor()
                items_dict[key] = [to_tensor(img) for img in items_dict[key]]
            elif first_item is None:
                pass
            else:
                items_dict[key] = [x if isinstance(x, str) else th.tensor(x) for x in items_dict[key]]
    return items_dict