Local policy that directly queries action from policy,
outputs zero delta action if policy is None.
Source code in OmniGibson/omnigibson/eval/policies.py
| class LocalPolicy:
"""
Local policy that directly queries action from policy,
outputs zero delta action if policy is None.
"""
def __init__(self, *args, action_dim: Optional[int] = None, **kwargs) -> None:
self.policy = None # To be set later
self.action_dim = action_dim
def set_action_dim(self, action_dim: int) -> None:
self.action_dim = action_dim
def act(self, obs: dict) -> th.Tensor:
return self.forward(obs)
def forward(self, obs: dict, *args, **kwargs) -> th.Tensor:
"""
Directly return a zero action tensor of the specified action dimension.
"""
if self.policy is not None:
return self.policy.act(obs).detach().cpu()
else:
assert self.action_dim is not None
return th.zeros(self.action_dim, dtype=th.float32)
def reset(self) -> None:
if self.policy is not None:
self.policy.reset()
|
forward(obs, *args, **kwargs)
Directly return a zero action tensor of the specified action dimension.
Source code in OmniGibson/omnigibson/eval/policies.py
| def forward(self, obs: dict, *args, **kwargs) -> th.Tensor:
"""
Directly return a zero action tensor of the specified action dimension.
"""
if self.policy is not None:
return self.policy.act(obs).detach().cpu()
else:
assert self.action_dim is not None
return th.zeros(self.action_dim, dtype=th.float32)
|