Skip to content

metrics_wrapper

MetricsWrapper

Bases: EnvironmentWrapper

Wrapper for running programmatic metric checks during env stepping

Source code in OmniGibson/omnigibson/envs/metrics_wrapper.py
class MetricsWrapper(EnvironmentWrapper):
    """
    Wrapper for running programmatic metric checks during env stepping
    """

    def __init__(self, env: Environment) -> None:
        """
        Args:
            env (Environment): The environment to wrap
        """
        # Store variable for tracking QA metrics
        self.metrics = dict()

        # Run super init
        super().__init__(env=env)

    def add_metric(self, name: str, metric: MetricBase) -> None:
        """
        Adds a data metric to track

        Args:
            name (str): Name of the metric. This will be the name printed out when displaying the aggregated results
            metric (MetricBase): Metric to add
        """
        # Validate the metric is compatible, then add
        assert metric.is_compatible(
            self
        ), f"Metric {metric.__class__.__name__} is not compatible with this environment!"
        self.metrics[name] = metric

    def remove_metric(self, name: str) -> None:
        """
        Removes a metric from the internally tracked ones

        Args:
            name (str): Name of the metric to remove
        """
        self.metrics.pop(name)

    def reset(self):
        # Call super first
        ret = super().reset()

        # Reset all owned metrics
        for name, metric in self.metrics.items():
            metric.reset(self)

        return ret

    def aggregate_metrics(self, flatten: bool = True) -> dict:
        """
        Aggregates metrics information

        Args:
            flatten (bool): Whether to flatten the metrics dictionary or not

        Returns:
            dict: Keyword-mapped aggregated metrics information
        """
        results = dict()
        for name, metric in self.metrics.items():
            results[name] = metric.aggregate(self)

        if flatten:
            results = recursively_generate_flat_dict(dic=results)

        return results

    def step(self, action: dict | Iterable, n_render_iterations: int = 1) -> tuple:
        # Run super first
        obs, reward, terminated, truncated, info = super().step(action, n_render_iterations=n_render_iterations)

        # Run all step-wise QA checks
        for name, metric in self.metrics.items():
            metric.step(self.env, action, obs, reward, terminated, truncated, info)

        return obs, reward, terminated, truncated, info

__init__(env)

Parameters:

Name Type Description Default
env Environment

The environment to wrap

required
Source code in OmniGibson/omnigibson/envs/metrics_wrapper.py
def __init__(self, env: Environment) -> None:
    """
    Args:
        env (Environment): The environment to wrap
    """
    # Store variable for tracking QA metrics
    self.metrics = dict()

    # Run super init
    super().__init__(env=env)

add_metric(name, metric)

Adds a data metric to track

Parameters:

Name Type Description Default
name str

Name of the metric. This will be the name printed out when displaying the aggregated results

required
metric MetricBase

Metric to add

required
Source code in OmniGibson/omnigibson/envs/metrics_wrapper.py
def add_metric(self, name: str, metric: MetricBase) -> None:
    """
    Adds a data metric to track

    Args:
        name (str): Name of the metric. This will be the name printed out when displaying the aggregated results
        metric (MetricBase): Metric to add
    """
    # Validate the metric is compatible, then add
    assert metric.is_compatible(
        self
    ), f"Metric {metric.__class__.__name__} is not compatible with this environment!"
    self.metrics[name] = metric

aggregate_metrics(flatten=True)

Aggregates metrics information

Parameters:

Name Type Description Default
flatten bool

Whether to flatten the metrics dictionary or not

True

Returns:

Type Description
dict

Keyword-mapped aggregated metrics information

Source code in OmniGibson/omnigibson/envs/metrics_wrapper.py
def aggregate_metrics(self, flatten: bool = True) -> dict:
    """
    Aggregates metrics information

    Args:
        flatten (bool): Whether to flatten the metrics dictionary or not

    Returns:
        dict: Keyword-mapped aggregated metrics information
    """
    results = dict()
    for name, metric in self.metrics.items():
        results[name] = metric.aggregate(self)

    if flatten:
        results = recursively_generate_flat_dict(dic=results)

    return results

remove_metric(name)

Removes a metric from the internally tracked ones

Parameters:

Name Type Description Default
name str

Name of the metric to remove

required
Source code in OmniGibson/omnigibson/envs/metrics_wrapper.py
def remove_metric(self, name: str) -> None:
    """
    Removes a metric from the internally tracked ones

    Args:
        name (str): Name of the metric to remove
    """
    self.metrics.pop(name)