Skip to content

object_state_base

AbsoluteObjectState

Bases: BaseObjectState

This class is used to track object states that are absolute, e.g. do not require a second object to compute the value.

Source code in object_states/object_state_base.py
class AbsoluteObjectState(BaseObjectState):
    """
    This class is used to track object states that are absolute, e.g. do not require a second object to compute
    the value.
    """

    @abstractmethod
    def _get_value(self):
        raise NotImplementedError()

    @abstractmethod
    def _set_value(self, new_value):
        raise NotImplementedError()

    @classproperty
    def _do_not_register_classes(cls):
        # Don't register this class since it's an abstract template
        classes = super()._do_not_register_classes
        classes.add("AbsoluteObjectState")
        return classes

BaseObjectState

Bases: Serializable, Registerable, Recreatable, ABC

Base ObjectState class. Do NOT inherit from this class directly - use either AbsoluteObjectState or RelativeObjectState.

Source code in object_states/object_state_base.py
class BaseObjectState(Serializable, Registerable, Recreatable, ABC):
    """
    Base ObjectState class. Do NOT inherit from this class directly - use either AbsoluteObjectState or
    RelativeObjectState.
    """

    @staticmethod
    def get_dependencies():
        """
        Get the dependency states for this state, e.g. states that need to be explicitly enabled on the current object
        before the current state is usable. States listed here will be enabled for all objects that have this current
        state, and all dependency states will be processed on *all* objects prior to this state being processed on
        *any* object.

        Returns:
            list of str: List of strings corresponding to state keys.
        """
        return []

    @staticmethod
    def get_optional_dependencies():
        """
        Get states that should be processed prior to this state if they are already enabled. These states will not be
        enabled because of this state's dependency on them, but if they are already enabled for another reason (e.g.
        because of an ability or another state's dependency etc.), they will be processed on *all* objects prior to this
        state being processed on *any* object.

        Returns:
            list of str: List of strings corresponding to state keys.
        """
        return []

    def __init__(self, obj):
        super().__init__()
        self.obj = obj
        self._initialized = False
        self._cache = None
        self._changed = None
        self._simulator = None

    @property
    def stateful(self):
        """
        Returns:
            bool: True if this object has a state that can be directly dumped / loaded via dump_state() and
                load_state(), otherwise, returns False. Note that any sub object states that are NOT stateful do
                not need to implement any of _dump_state(), _load_state(), _serialize(), or _deserialize()!
        """
        # Default is whether state size > 0
        return self.state_size > 0

    @property
    def state_size(self):
        return 0

    def reset(self):
        """
        Resets this object state. By default, it clears all internal caching data
        """
        self._cache = OrderedDict()
        self._changed = OrderedDict()

    @property
    def cache(self):
        """
        Returns:
            OrdereDict: Dictionary mapping specific argument combinations from @self.get_value() to cached values and
                information stored for that specific combination
        """
        return self._cache

    def _update(self):
        """
        This function will be called once for every simulator step.
        """
        pass

    def _initialize(self):
        """
        This function will be called once; should be used for any object state-related objects have been loaded.
        """
        pass

    def initialize(self, simulator):
        """
        Initialize this object state
        """
        assert not self._initialized, "State is already initialized."

        # Store simulator reference and create cache
        self._simulator = simulator
        self.reset()

        self._initialize()
        self._initialized = True

    def update(self):
        """
        Updates the object state, possibly clearing internal cached information
        """
        assert self._initialized, "Cannot update uninitialized state."
        # Clear all the changed values
        self._changed = OrderedDict()
        return self._update()

    def clear_cache(self):
        """
        Clears the internal cache
        """
        # Clear all entries
        self._cache = OrderedDict()

    def update_cache(self, get_value_args):
        """
        Updates the internal cached value based on the evaluation of @self._get_value(*get_value_args)

        Args:
            get_value_args (tuple): Specific argument combinations (usually tuple of objects) passed into
                @self.get_value / @self._get_value
        """
        t = og.sim.current_time_step_index
        # Compute value and update cache
        val = self._get_value(*get_value_args)
        self._cache[get_value_args] = OrderedDict(value=val, info=self.cache_info(get_value_args=get_value_args), t=t)

    def cache_info(self, get_value_args):
        """
        Helper function to cache relevant information at the current timestep.
        Stores it under @self._cache[<KEY>]["info"]

        Args:
            get_value_args (tuple): Specific argument combinations (usually tuple of objects) passed into
                @self.get_value whose caching information should be computed

        Returns:
            OrderedDict: Any caching information to include at the current timestep when this state's value is computed
        """
        # Default is an empty dictionary
        return OrderedDict()

    def cache_is_valid(self, get_value_args):
        """
        Helper function to check whether the current cached value is valid or not at the current timestep.
        Default is False unless we're at the current timestep.

        Args:
            get_value_args (tuple): Specific argument combinations (usually tuple of objects) passed into
                @self.get_value whose cached values should be validated

        Returns:
            bool: True if the cache is valid, else False
        """
        # If t == the current timestep, then our cache is obviously valid otherwise we assume it isn't
        return True if self._cache[get_value_args]["t"] == og.sim.current_time_step_index else \
            self._cache_is_valid(get_value_args=get_value_args)

    def _cache_is_valid(self, get_value_args):
        """
        Helper function to check whether the current cached value is valid or not at the current timestep.
        Default is False. Subclasses should implement special logic otherwise.

        Args:
            get_value_args (tuple): Specific argument combinations (usually tuple of objects) passed into
                @self.get_value whose cached values should be validated

        Returns:
            bool: True if the cache is valid, else False
        """
        return False

    def has_changed(self, get_value_args, value, info, t):
        """
        A helper function to query whether this object state has changed between an arbitrary previous timestep @t with
        corresponding cached value @value and cache information @info
        the current timestep.

        Note that this may require some non-trivial compute, so we leverage @t, in addition to @get_value_args,
        as a unique key into an internal dictionary, such that specific @t will result in a computation conducted
        exactly once.
        This is done for performance reasons; so that multiple states relying on the same state dependency can all
        query whether that state has changed between the same timesteps with only a single computation.

        Args:
            get_value_args (tuple): Specific argument combinations (usually tuple of objects) passed into
                @self.get_value
            value (any): Cached value computed at timestep @t for this object state
            info (OrderedDict): Information calculated at timestep @t when computing this state's value
            t (int): Initial timestep to compare against. This should be an index of the steps taken,
                i.e. a value queried from og.sim.current_time_step_index at some point in time. It is assumed @value
                and @info were computed at this timestep

        Returns:
            bool: Whether this object state has changed between @t and the current timestep index for the specific
                @get_value_args
        """
        # Compile t, args, and kwargs deterministically
        history_key = (t, *get_value_args)
        # If t == the current timestep, then we obviously haven't changed so our value is False
        if t == og.sim.current_time_step_index:
            val = False
        # Otherwise, check if it already exists in our has changed dictionary; we return that value if so
        elif history_key in self._changed:
            val = self._changed[history_key]
        # Otherwise, we calculate the value and store it in our changed dictionary
        else:
            val = self._has_changed(get_value_args=get_value_args, value=value, info=info)
            self._changed[history_key] = val

        return val

    def _has_changed(self, get_value_args, value, info):
        """
        Checks whether the previous value evaluated at time @t has changed with the current timestep.
        By default, it returns True.

        Any custom checks should be overridden by subclass.

        Args:
            get_value_args (tuple): Specific argument combinations (usually tuple of objects) passed into
                @self.get_value
            value (any): Cached value computed at timestep @t for this object state
            info (OrderedDict): Information calculated at timestep @t when computing this state's value

        Returns:
            bool: Whether the value has changed between @value and @info and the coresponding value and info computed
                at the current timestep
        """
        return True

    def get_value(self, *args, **kwargs):
        """
        Get this state's value

        Returns:
            any: Object state value given input @args and @kwargs
        """
        assert self._initialized

        # Compile args and kwargs deterministically
        key = (*args, *tuple(kwargs.values()))
        # We need to see if we need to update our cache -- we do so if and only if one of the following conditions are met:
        # (a) key is NOT in the cache
        # (b) Our cache is not valid
        if key not in self._cache or not self.cache_is_valid(get_value_args=key):
            # Update the cache
            self.update_cache(get_value_args=key)

        # Value is the cached value
        val = self._cache[key]["value"]

        return val

    def _get_value(self, *args, **kwargs):
        raise NotImplementedError

    def set_value(self, *args, **kwargs):
        """
        Set this state's value

        Returns:
            bool: True if setting the value was successful, otherwise False
        """
        assert self._initialized
        # Clear cache because the state may be changed
        self.clear_cache()
        # Set the value
        val = self._set_value(*args, **kwargs)
        return val

    def _set_value(self, *args, **kwargs):
        raise NotImplementedError

    def remove(self):
        """
        Any cleanup functionality to deploy when @self.obj is removed from the simulator
        """
        pass

    def dump_state(self, serialized=False):
        assert self._initialized
        assert self.stateful
        return super().dump_state(serialized=serialized)

    @classproperty
    def _do_not_register_classes(cls):
        # Don't register this class since it's an abstract template
        classes = super()._do_not_register_classes
        classes.add("BaseObjectState")
        return classes

    @classproperty
    def _cls_registry(cls):
        # Global registry
        global REGISTERED_OBJECT_STATES
        return REGISTERED_OBJECT_STATES

cache property

Returns:

Name Type Description
OrdereDict

Dictionary mapping specific argument combinations from @self.get_value() to cached values and information stored for that specific combination

stateful property

Returns:

Name Type Description
bool

True if this object has a state that can be directly dumped / loaded via dump_state() and load_state(), otherwise, returns False. Note that any sub object states that are NOT stateful do not need to implement any of _dump_state(), _load_state(), _serialize(), or _deserialize()!

cache_info(get_value_args)

Helper function to cache relevant information at the current timestep. Stores it under @self._cache[]["info"]

Parameters:

Name Type Description Default
get_value_args tuple

Specific argument combinations (usually tuple of objects) passed into @self.get_value whose caching information should be computed

required

Returns:

Name Type Description
OrderedDict

Any caching information to include at the current timestep when this state's value is computed

Source code in object_states/object_state_base.py
def cache_info(self, get_value_args):
    """
    Helper function to cache relevant information at the current timestep.
    Stores it under @self._cache[<KEY>]["info"]

    Args:
        get_value_args (tuple): Specific argument combinations (usually tuple of objects) passed into
            @self.get_value whose caching information should be computed

    Returns:
        OrderedDict: Any caching information to include at the current timestep when this state's value is computed
    """
    # Default is an empty dictionary
    return OrderedDict()

cache_is_valid(get_value_args)

Helper function to check whether the current cached value is valid or not at the current timestep. Default is False unless we're at the current timestep.

Parameters:

Name Type Description Default
get_value_args tuple

Specific argument combinations (usually tuple of objects) passed into @self.get_value whose cached values should be validated

required

Returns:

Name Type Description
bool

True if the cache is valid, else False

Source code in object_states/object_state_base.py
def cache_is_valid(self, get_value_args):
    """
    Helper function to check whether the current cached value is valid or not at the current timestep.
    Default is False unless we're at the current timestep.

    Args:
        get_value_args (tuple): Specific argument combinations (usually tuple of objects) passed into
            @self.get_value whose cached values should be validated

    Returns:
        bool: True if the cache is valid, else False
    """
    # If t == the current timestep, then our cache is obviously valid otherwise we assume it isn't
    return True if self._cache[get_value_args]["t"] == og.sim.current_time_step_index else \
        self._cache_is_valid(get_value_args=get_value_args)

clear_cache()

Clears the internal cache

Source code in object_states/object_state_base.py
def clear_cache(self):
    """
    Clears the internal cache
    """
    # Clear all entries
    self._cache = OrderedDict()

get_dependencies() staticmethod

Get the dependency states for this state, e.g. states that need to be explicitly enabled on the current object before the current state is usable. States listed here will be enabled for all objects that have this current state, and all dependency states will be processed on all objects prior to this state being processed on any object.

Returns:

Type Description

list of str: List of strings corresponding to state keys.

Source code in object_states/object_state_base.py
@staticmethod
def get_dependencies():
    """
    Get the dependency states for this state, e.g. states that need to be explicitly enabled on the current object
    before the current state is usable. States listed here will be enabled for all objects that have this current
    state, and all dependency states will be processed on *all* objects prior to this state being processed on
    *any* object.

    Returns:
        list of str: List of strings corresponding to state keys.
    """
    return []

get_optional_dependencies() staticmethod

Get states that should be processed prior to this state if they are already enabled. These states will not be enabled because of this state's dependency on them, but if they are already enabled for another reason (e.g. because of an ability or another state's dependency etc.), they will be processed on all objects prior to this state being processed on any object.

Returns:

Type Description

list of str: List of strings corresponding to state keys.

Source code in object_states/object_state_base.py
@staticmethod
def get_optional_dependencies():
    """
    Get states that should be processed prior to this state if they are already enabled. These states will not be
    enabled because of this state's dependency on them, but if they are already enabled for another reason (e.g.
    because of an ability or another state's dependency etc.), they will be processed on *all* objects prior to this
    state being processed on *any* object.

    Returns:
        list of str: List of strings corresponding to state keys.
    """
    return []

get_value(*args, **kwargs)

Get this state's value

Returns:

Name Type Description
any

Object state value given input @args and @kwargs

Source code in object_states/object_state_base.py
def get_value(self, *args, **kwargs):
    """
    Get this state's value

    Returns:
        any: Object state value given input @args and @kwargs
    """
    assert self._initialized

    # Compile args and kwargs deterministically
    key = (*args, *tuple(kwargs.values()))
    # We need to see if we need to update our cache -- we do so if and only if one of the following conditions are met:
    # (a) key is NOT in the cache
    # (b) Our cache is not valid
    if key not in self._cache or not self.cache_is_valid(get_value_args=key):
        # Update the cache
        self.update_cache(get_value_args=key)

    # Value is the cached value
    val = self._cache[key]["value"]

    return val

has_changed(get_value_args, value, info, t)

A helper function to query whether this object state has changed between an arbitrary previous timestep @t with corresponding cached value @value and cache information @info the current timestep.

Note that this may require some non-trivial compute, so we leverage @t, in addition to @get_value_args, as a unique key into an internal dictionary, such that specific @t will result in a computation conducted exactly once. This is done for performance reasons; so that multiple states relying on the same state dependency can all query whether that state has changed between the same timesteps with only a single computation.

Parameters:

Name Type Description Default
get_value_args tuple

Specific argument combinations (usually tuple of objects) passed into @self.get_value

required
value any

Cached value computed at timestep @t for this object state

required
info OrderedDict

Information calculated at timestep @t when computing this state's value

required
t int

Initial timestep to compare against. This should be an index of the steps taken, i.e. a value queried from og.sim.current_time_step_index at some point in time. It is assumed @value and @info were computed at this timestep

required

Returns:

Name Type Description
bool

Whether this object state has changed between @t and the current timestep index for the specific @get_value_args

Source code in object_states/object_state_base.py
def has_changed(self, get_value_args, value, info, t):
    """
    A helper function to query whether this object state has changed between an arbitrary previous timestep @t with
    corresponding cached value @value and cache information @info
    the current timestep.

    Note that this may require some non-trivial compute, so we leverage @t, in addition to @get_value_args,
    as a unique key into an internal dictionary, such that specific @t will result in a computation conducted
    exactly once.
    This is done for performance reasons; so that multiple states relying on the same state dependency can all
    query whether that state has changed between the same timesteps with only a single computation.

    Args:
        get_value_args (tuple): Specific argument combinations (usually tuple of objects) passed into
            @self.get_value
        value (any): Cached value computed at timestep @t for this object state
        info (OrderedDict): Information calculated at timestep @t when computing this state's value
        t (int): Initial timestep to compare against. This should be an index of the steps taken,
            i.e. a value queried from og.sim.current_time_step_index at some point in time. It is assumed @value
            and @info were computed at this timestep

    Returns:
        bool: Whether this object state has changed between @t and the current timestep index for the specific
            @get_value_args
    """
    # Compile t, args, and kwargs deterministically
    history_key = (t, *get_value_args)
    # If t == the current timestep, then we obviously haven't changed so our value is False
    if t == og.sim.current_time_step_index:
        val = False
    # Otherwise, check if it already exists in our has changed dictionary; we return that value if so
    elif history_key in self._changed:
        val = self._changed[history_key]
    # Otherwise, we calculate the value and store it in our changed dictionary
    else:
        val = self._has_changed(get_value_args=get_value_args, value=value, info=info)
        self._changed[history_key] = val

    return val

initialize(simulator)

Initialize this object state

Source code in object_states/object_state_base.py
def initialize(self, simulator):
    """
    Initialize this object state
    """
    assert not self._initialized, "State is already initialized."

    # Store simulator reference and create cache
    self._simulator = simulator
    self.reset()

    self._initialize()
    self._initialized = True

remove()

Any cleanup functionality to deploy when @self.obj is removed from the simulator

Source code in object_states/object_state_base.py
def remove(self):
    """
    Any cleanup functionality to deploy when @self.obj is removed from the simulator
    """
    pass

reset()

Resets this object state. By default, it clears all internal caching data

Source code in object_states/object_state_base.py
def reset(self):
    """
    Resets this object state. By default, it clears all internal caching data
    """
    self._cache = OrderedDict()
    self._changed = OrderedDict()

set_value(*args, **kwargs)

Set this state's value

Returns:

Name Type Description
bool

True if setting the value was successful, otherwise False

Source code in object_states/object_state_base.py
def set_value(self, *args, **kwargs):
    """
    Set this state's value

    Returns:
        bool: True if setting the value was successful, otherwise False
    """
    assert self._initialized
    # Clear cache because the state may be changed
    self.clear_cache()
    # Set the value
    val = self._set_value(*args, **kwargs)
    return val

update()

Updates the object state, possibly clearing internal cached information

Source code in object_states/object_state_base.py
def update(self):
    """
    Updates the object state, possibly clearing internal cached information
    """
    assert self._initialized, "Cannot update uninitialized state."
    # Clear all the changed values
    self._changed = OrderedDict()
    return self._update()

update_cache(get_value_args)

Updates the internal cached value based on the evaluation of @self._get_value(*get_value_args)

Parameters:

Name Type Description Default
get_value_args tuple

Specific argument combinations (usually tuple of objects) passed into @self.get_value / @self._get_value

required
Source code in object_states/object_state_base.py
def update_cache(self, get_value_args):
    """
    Updates the internal cached value based on the evaluation of @self._get_value(*get_value_args)

    Args:
        get_value_args (tuple): Specific argument combinations (usually tuple of objects) passed into
            @self.get_value / @self._get_value
    """
    t = og.sim.current_time_step_index
    # Compute value and update cache
    val = self._get_value(*get_value_args)
    self._cache[get_value_args] = OrderedDict(value=val, info=self.cache_info(get_value_args=get_value_args), t=t)

BooleanState

This class is a mixin used to indicate that a state has a boolean value.

Source code in object_states/object_state_base.py
class BooleanState:
    """
    This class is a mixin used to indicate that a state has a boolean value.
    """

    pass

RelativeObjectState

Bases: BaseObjectState

This class is used to track object states that are relative, e.g. require two objects to compute a value. Note that subclasses will typically compute values on-the-fly.

Source code in object_states/object_state_base.py
class RelativeObjectState(BaseObjectState):
    """
    This class is used to track object states that are relative, e.g. require two objects to compute a value.
    Note that subclasses will typically compute values on-the-fly.
    """

    @abstractmethod
    def _get_value(self, other):
        raise NotImplementedError()

    @abstractmethod
    def _set_value(self, other, new_value):
        raise NotImplementedError()

    @classproperty
    def _do_not_register_classes(cls):
        # Don't register this class since it's an abstract template
        classes = super()._do_not_register_classes
        classes.add("RelativeObjectState")
        return classes