Source code for gym_gridverse.envs.transition_functions

""" Functions to model dynamics """
import inspect
import warnings
from functools import partial
from typing import List, Optional, Sequence

import numpy.random as rnd
from typing_extensions import Protocol  # python3.7 compatibility

from gym_gridverse.action import Action
from gym_gridverse.envs.utils import get_next_position
from gym_gridverse.geometry import Orientation, get_manhattan_boundary
from gym_gridverse.grid_object import (
    Box,
    Door,
    Floor,
    Key,
    MovingObstacle,
    NoneGridObject,
    Telepod,
)
from gym_gridverse.rng import get_gv_rng_if_none
from gym_gridverse.state import State
from gym_gridverse.utils.custom import import_if_custom
from gym_gridverse.utils.fast_copy import fast_copy
from gym_gridverse.utils.functions import checkraise_kwargs, select_kwargs
from gym_gridverse.utils.protocols import (
    get_keyword_parameter,
    get_positional_parameters,
)
from gym_gridverse.utils.registry import FunctionRegistry


[docs]class TransitionFunction(Protocol): """Signature that all reset functions must follow"""
[docs] def __call__( self, state: State, action: Action, *, rng: Optional[rnd.Generator] = None, ) -> None: ...
[docs]class TransitionFunctionRegistry(FunctionRegistry):
[docs] def get_protocol_parameters( self, signature: inspect.Signature ) -> List[inspect.Parameter]: state, action = get_positional_parameters(signature, 2) rng = get_keyword_parameter(signature, 'rng') return [state, action, rng]
[docs] def check_signature(self, function: TransitionFunction): signature = inspect.signature(function) state, action, rng = self.get_protocol_parameters(signature) # checks first 2 arguments are positional ... if state.kind not in [ inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY, ]: raise TypeError( f'The first argument ({state.name}) ' f'of a registered reward function ({function}) ' 'should be allowed to be a positional argument.' ) if action.kind not in [ inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY, ]: raise TypeError( f'The second argument ({action.name}) ' f'of a registered reward function ({function}) ' 'should be allowed to be a positional argument.' ) # and `rng` is keyword if rng.kind not in [ inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY, ]: raise TypeError( f'The `rng` argument ({rng.name}) ' f'of a registered reward function ({function}) ' 'should be allowed to be a keyword argument.' ) # checks if annotations, if given, are consistent if state.annotation not in [inspect.Parameter.empty, State]: warnings.warn( f'The first argument ({state.name}) ' f'of a registered reward function ({function}) ' f'has an annotation ({state.annotation}) ' 'which is not `State`.' ) if action.annotation not in [inspect.Parameter.empty, Action]: warnings.warn( f'The second argument ({action.name}) ' f'of a registered reward function ({function}) ' f'has an annotation ({action.annotation}) ' 'which is not `Action`.' ) if rng.annotation not in [ inspect.Parameter.empty, Optional[rnd.Generator], ]: warnings.warn( f'The `rng` argument ({rng.name}) ' f'of a registered reward function ({function}) ' f'has an annotation ({rng.annotation}) ' 'which is not `Optional[rnd.Generator]`.' ) if signature.return_annotation not in [inspect.Parameter.empty, None]: warnings.warn( f'The return type of a registered transition function ({function}) ' f'has an annotation ({signature.return_annotation}) ' 'which is not `None`.' )
transition_function_registry = TransitionFunctionRegistry() """Transition function registry"""
[docs]@transition_function_registry.register def chain( state: State, action: Action, *, transition_functions: Sequence[TransitionFunction], rng: Optional[rnd.Generator] = None, ) -> None: """Run multiple transition functions in a row Args: state (`State`): action (`Action`): transition_functions (`Sequence[TransitionFunction]`): transition functions rng (`Generator, optional`) Returns: None """ for transition_function in transition_functions: transition_function(state, action, rng=rng)
[docs]@transition_function_registry.register def move_agent( state: State, action: Action, *, rng: Optional[rnd.Generator] = None, ) -> None: """Applies translation to agent (e.g. up/down/left/right) Leaves the state unaffected if any other action was taken instead Args: state (`State`): action (`Action`): Returns: None: """ if not action.is_move(): return next_position = get_next_position( state.agent.position, state.agent.orientation, action, ) try: obj = state.grid[next_position] except IndexError: pass else: if not obj.blocks_movement: state.agent.position = next_position
[docs]@transition_function_registry.register def turn_agent( state: State, action: Action, *, rng: Optional[rnd.Generator] = None, ) -> None: """Turns agent according to action (e.g. turn left/right) Leaves the state unaffected if any other action was taken instead Args: state (`State`): action (`Action`): Returns: None: """ try: orientation = _action_orientations[action] except KeyError: pass else: state.agent.orientation *= orientation
# for turn_agent _action_orientations = { Action.TURN_LEFT: Orientation.L, Action.TURN_RIGHT: Orientation.R, }
[docs]@transition_function_registry.register def pickndrop( state: State, action: Action, *, rng: Optional[rnd.Generator] = None, ) -> None: """Implements the effect of the pickup and drop action Pickup applies to the item *in front* of the agent There are multiple scenarios * There is no (pick-up-able) item to pickup under the agent: * The agent is not holding any object -> No effect * The agent is holding an object: * Position in front of agent is floor -> drop current object * Position in front is not a floor -> No effect * There is a (pick-up-able) item to pickup under the agent: * The agent is not holding any object -> Pick up, put floor in stead * The agent is holding an object -> Swap items Args: state (`State`): action (`Action`): rng (`Generator, optional`) Returns: None: """ if action is not Action.PICK_N_DROP: return position_front = state.agent.front() obj_front = state.grid[position_front] can_be_dropped = isinstance(obj_front, Floor) or obj_front.holdable if not can_be_dropped: return state.grid[position_front] = ( state.agent.grid_object if not isinstance(state.agent.grid_object, NoneGridObject) and can_be_dropped else Floor() # We know we are picking up if not dropping ) state.agent.grid_object = ( obj_front if obj_front.holdable else NoneGridObject() )
[docs]@transition_function_registry.register def move_obstacles( state: State, action: Action, *, rng: Optional[rnd.Generator] = None, ) -> None: """Moves moving obstacles randomly Randomly moves each MovingObstacle to a neighbouring Floor cell, if possible. Args: state (`State`): current state action (`Action`): action taken by agent (ignored) """ rng = get_gv_rng_if_none(rng) # get all positions before performing any movement positions = [ position for position in state.grid.area.positions() if isinstance(state.grid[position], MovingObstacle) ] for position in positions: next_positions = [ next_position for next_position in get_manhattan_boundary(position, distance=1) if state.grid.area.contains(next_position) and isinstance(state.grid[next_position], Floor) ] try: i = rng.choice(len(next_positions)) except ValueError: pass else: next_position = next_positions[i] state.grid.swap(position, next_position)
[docs]@transition_function_registry.register def actuate_door( state: State, action: Action, *, rng: Optional[rnd.Generator] = None, ) -> None: """Attempts to open door When not holding correct key with correct color: `open` or `closed` -> `open` `locked` -> `locked` When holding correct key: any state -> `open` """ if action is not Action.ACTUATE: return position = state.agent.front() if not state.grid.area.contains(position): return door = state.grid[position] if not isinstance(door, Door): return if door.is_open: pass elif not door.is_locked: door.state = Door.Status.OPEN else: if ( isinstance(state.agent.grid_object, Key) and state.agent.grid_object.color == door.color ): door.state = Door.Status.OPEN
[docs]@transition_function_registry.register def actuate_box( state: State, action: Action, *, rng: Optional[rnd.Generator] = None, ) -> None: """Attempts to open door When not holding correct key with correct color: `open` or `closed` -> `open` `locked` -> `locked` When holding correct key: any state -> `open` """ if action is not Action.ACTUATE: return position = state.agent.front() if not state.grid.area.contains(position): return box = state.grid[position] if isinstance(box, Box): state.grid[position] = box.content
[docs]@transition_function_registry.register def teleport( state: State, action: Action, *, rng: Optional[rnd.Generator] = None, ) -> None: """Teleports the agent if positioned on the telepod""" rng = get_gv_rng_if_none(rng) telepod = state.grid[state.agent.position] if isinstance(telepod, Telepod): positions = [ position for position in state.grid.area.positions() if position != state.agent.position and isinstance(state.grid[position], Telepod) and state.grid[position].color == telepod.color ] i = rng.choice(len(positions)) state.agent.position = positions[i]
[docs]def factory(name: str, **kwargs) -> TransitionFunction: name = import_if_custom(name) try: function = transition_function_registry[name] except KeyError as error: raise ValueError(f'invalid transition function name {name}') from error signature = inspect.signature(function) required_keys = [ parameter.name for parameter in transition_function_registry.get_nonprotocol_parameters( signature ) if parameter.default is inspect.Parameter.empty ] optional_keys = [ parameter.name for parameter in transition_function_registry.get_nonprotocol_parameters( signature ) if parameter.default is not inspect.Parameter.empty ] checkraise_kwargs(kwargs, required_keys) kwargs = select_kwargs(kwargs, required_keys + optional_keys) return partial(function, **kwargs)
[docs]def transition_with_copy( transition_function: TransitionFunction, state: State, action: Action, *, rng: Optional[rnd.Generator] = None, ) -> State: """Utility to perform a non-in-place version of a transition function. NOTE: This is *not* a transition function (transition functions are in-place by definition). Args: transition_function (`TransitionFunction`): state (`State`): action (`action`): rng (`Generator, optional`) Returns: State: """ next_state = fast_copy(state) transition_function(next_state, action, rng=rng) return next_state