Skip to content

policies

LocalPolicy

Local policy that directly queries action from policy, outputs zero delta action if policy is None.

Source code in OmniGibson/omnigibson/learning/policies.py
class LocalPolicy:
    """
    Local policy that directly queries action from policy,
        outputs zero delta action if policy is None.
    """

    def __init__(self, *args, action_dim: Optional[int] = None, **kwargs) -> None:
        self.policy = None  # To be set later
        self.action_dim = action_dim

    def act(self, obs: dict) -> th.Tensor:
        return self.forward(obs)

    def forward(self, obs: dict, *args, **kwargs) -> th.Tensor:
        """
        Directly return a zero action tensor of the specified action dimension.
        """
        if self.policy is not None:
            return self.policy.act(obs).detach().cpu()
        else:
            assert self.action_dim is not None
            return th.zeros(self.action_dim, dtype=th.float32)

    def reset(self) -> None:
        if self.policy is not None:
            self.policy.reset()

forward(obs, *args, **kwargs)

Directly return a zero action tensor of the specified action dimension.

Source code in OmniGibson/omnigibson/learning/policies.py
def forward(self, obs: dict, *args, **kwargs) -> th.Tensor:
    """
    Directly return a zero action tensor of the specified action dimension.
    """
    if self.policy is not None:
        return self.policy.act(obs).detach().cpu()
    else:
        assert self.action_dim is not None
        return th.zeros(self.action_dim, dtype=th.float32)

WebsocketPolicy

Websocket policy for controlling the robot over a websocket connection.

Source code in OmniGibson/omnigibson/learning/policies.py
class WebsocketPolicy:
    """
    Websocket policy for controlling the robot over a websocket connection.
    """

    def __init__(
        self,
        *args,
        host: Optional[str] = None,
        port: Optional[int] = None,
        allow_reconnect: bool = False,
        **kwargs,
    ) -> None:
        logging.info(f"Creating websocket client policy with host: {host}, port: {port}")
        self.last_action = None
        self.policy = None
        self._allow_reconnect = allow_reconnect
        if host is not None or port is not None:
            self.policy = WebsocketClientPolicy(host=host, port=port, allow_reconnect=allow_reconnect)

    def update_host(self, host: str, port: int) -> None:
        self.policy = WebsocketClientPolicy(host=host, port=port, allow_reconnect=self._allow_reconnect)

    def forward(self, obs: dict, *args, **kwargs) -> th.Tensor:
        if "need_new_action" in obs and not obs["need_new_action"] and self.last_action is not None:
            return self.last_action
        # convert observation to numpy
        obs = torch_to_numpy(obs)
        self.last_action = self.policy.act(obs).detach().cpu()
        return self.last_action

    def reset(self) -> None:
        if self.policy is not None:
            self.policy.reset()
        self.last_action = None