Source code for gym_gridverse.envs.terminating_functions

import inspect
import warnings
from functools import partial
from typing import Callable, Iterator, List, Optional, Sequence, Type

import numpy.random as rnd
from typing_extensions import Protocol  # python3.7 compatibility

from gym_gridverse.action import Action
from gym_gridverse.envs.utils import get_next_position
from gym_gridverse.grid_object import Exit, GridObject, MovingObstacle, Wall
from gym_gridverse.state import State
from gym_gridverse.utils.custom import import_if_custom
from gym_gridverse.utils.functions import checkraise_kwargs, select_kwargs
from gym_gridverse.utils.protocols import (
    get_keyword_parameter,
    get_positional_parameters,
)
from gym_gridverse.utils.registry import FunctionRegistry


[docs]class TerminatingFunction(Protocol): """Signature for functions to determine whether a transition is terminal""" def __call__( self, state: State, action: Action, next_state: State, *, rng: Optional[rnd.Generator] = None, ) -> bool: ...
TerminatingReductionFunction = Callable[[Iterator[bool]], bool] """Signature for a boolean reduction function"""
[docs]class TerminatingFunctionRegistry(FunctionRegistry):
[docs] def get_protocol_parameters( self, signature: inspect.Signature ) -> List[inspect.Parameter]: state, action, next_state = get_positional_parameters(signature, 3) rng = get_keyword_parameter(signature, 'rng') return [state, action, next_state, rng]
[docs] def check_signature(self, function: TerminatingFunction): signature = inspect.signature(function) state, action, next_state, rng = self.get_protocol_parameters(signature) # checks first 3 arguments are positional if state.kind not in [ inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY, ]: raise TypeError( f'The first argument ({state.name}) ' f'of a registered terminating function ({function}) ' 'should be allowed to be a positional argument.' ) if action.kind not in [ inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY, ]: raise TypeError( f'The second argument ({action.name}) ' f'of a registered terminating function ({function}) ' 'should be allowed to be a positional argument.' ) if next_state.kind not in [ inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY, ]: raise TypeError( f'The third argument ({next_state.name}) ' f'of a registered terminating function ({function}) ' 'should be allowed to be a positional argument.' ) # and `rng` is keyword if rng.kind not in [ inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY, ]: raise TypeError( f'The `rng` argument ({rng.name}) ' f'of a registered reward function ({function}) ' 'should be allowed to be a keyword argument.' ) # checks if annotations, if given, are consistent if state.annotation not in [inspect.Parameter.empty, State]: warnings.warn( f'The first argument ({state.name}) ' f'of a registered terminating function ({function}) ' f'has an annotation ({state.annotation}) ' 'which is not `State`.' ) if action.annotation not in [inspect.Parameter.empty, Action]: warnings.warn( f'The second argument ({action.name}) ' f'of a registered terminating function ({function}) ' f'has an annotation ({action.annotation}) ' 'which is not `Action`.' ) if next_state.annotation not in [inspect.Parameter.empty, State]: warnings.warn( f'The third argument ({next_state.name}) ' f'of a registered terminating function ({function}) ' f'has an annotation ({next_state.annotation}) ' 'which is not `State`.' ) if rng.annotation not in [ inspect.Parameter.empty, Optional[rnd.Generator], ]: warnings.warn( f'The `rng` argument ({rng.name}) ' f'of a registered reward function ({function}) ' f'has an annotation ({rng.annotation}) ' 'which is not `Optional[rnd.Generator]`.' ) if signature.return_annotation not in [inspect.Parameter.empty, bool]: warnings.warn( f'The return type of a registered terminating function ({function}) ' f'has an annotation ({signature.return_annotation}) ' 'which is not `bool`.' )
terminating_function_registry = TerminatingFunctionRegistry() """Terminating function registry"""
[docs]@terminating_function_registry.register def reduce( state: State, action: Action, next_state: State, *, terminating_functions: Sequence[TerminatingFunction], reduction: TerminatingReductionFunction, rng: Optional[rnd.Generator] = None, ) -> bool: """reduction of multiple terminating functions into a single boolean value Args: state (`State`): action (`Action`): next_state (`State`): terminating_functions (`Sequence[TerminatingFunction]`): reduction (`TerminatingReductionFunction`): Returns: bool: reduction operator over the input terminating functions """ # TODO: test return reduction( terminating_function(state, action, next_state, rng=rng) for terminating_function in terminating_functions )
[docs]@terminating_function_registry.register def reduce_any( state: State, action: Action, next_state: State, *, terminating_functions: Sequence[TerminatingFunction], rng: Optional[rnd.Generator] = None, ) -> bool: """utility function terminates when any of the input functions terminates Args: state (`State`): action (`Action`): next_state (`State`): terminating_functions (`Sequence[TerminatingFunction]`): Returns: bool: OR operator over the input terminating functions """ # TODO: test return reduce( state, action, next_state, terminating_functions=terminating_functions, reduction=any, rng=rng, )
[docs]@terminating_function_registry.register def reduce_all( state: State, action: Action, next_state: State, *, terminating_functions: Sequence[TerminatingFunction], rng: Optional[rnd.Generator] = None, ) -> bool: """utility function terminates when all of the input functions terminates Args: state (`State`): action (`Action`): next_state (`State`): terminating_functions (`Sequence[TerminatingFunction]`): Returns: bool: AND operator over the input terminating functions """ # TODO: test return reduce( state, action, next_state, terminating_functions=terminating_functions, reduction=all, rng=rng, )
[docs]@terminating_function_registry.register def overlap( state: State, action: Action, next_state: State, *, object_type: Type[GridObject], rng: Optional[rnd.Generator] = None, ) -> bool: """terminating condition for agent occupying same position as an object Args: state (`State`): action (`Action`): next_state (`State`): object_type (`Type[GridObject]`): Returns: bool: True if next_state agent is on object of type object_type """ return isinstance(next_state.grid[next_state.agent.position], object_type)
[docs]@terminating_function_registry.register def reach_exit( state: State, action: Action, next_state: State, *, rng: Optional[rnd.Generator] = None, ) -> bool: """terminating condition for Agent reaching the Exit Args: state (`State`): action (`Action`): next_state (`State`): Returns: bool: True if next_state agent is on exit """ return overlap(state, action, next_state, object_type=Exit, rng=rng)
[docs]@terminating_function_registry.register def bump_moving_obstacle( state: State, action: Action, next_state: State, *, rng: Optional[rnd.Generator] = None, ) -> bool: """terminating condition for Agent bumping a moving obstacle Args: state (`State`): action (`Action`): next_state (`State`): Returns: bool: True if next_state agent is on a MovingObstacle """ # TODO: test return overlap( state, action, next_state, object_type=MovingObstacle, rng=rng )
[docs]@terminating_function_registry.register def bump_into_wall( state: State, action: Action, next_state: State, *, rng: Optional[rnd.Generator] = None, ) -> bool: """Terminating condition for Agent bumping into a wall Tests whether the intended next agent position from state contains a Wall Args: state (`State`): action (`Action`): next_state (`State`): Returns: bool: True if next_state agent attempted to move onto a wall cell """ next_position = get_next_position( state.agent.position, state.agent.orientation, action ) return state.grid.area.contains(next_position) and isinstance( state.grid[next_position], Wall )
[docs]def factory(name: str, **kwargs) -> TerminatingFunction: name = import_if_custom(name) try: function = terminating_function_registry[name] except KeyError as error: raise ValueError(f'invalid terminating function name {name}') from error signature = inspect.signature(function) required_keys = [ parameter.name for parameter in terminating_function_registry.get_nonprotocol_parameters( signature ) if parameter.default is inspect.Parameter.empty ] optional_keys = [ parameter.name for parameter in terminating_function_registry.get_nonprotocol_parameters( signature ) if parameter.default is not inspect.Parameter.empty ] checkraise_kwargs(kwargs, required_keys) kwargs = select_kwargs(kwargs, required_keys + optional_keys) return partial(function, **kwargs)