Source code for gym_gridverse.envs.observation_functions

import inspect
import warnings
from functools import partial
from typing import List, Optional

import numpy.random as rnd
from typing_extensions import Protocol  # python3.7 compatibility

from gym_gridverse.agent import Agent
from gym_gridverse.envs.visibility_functions import (
    VisibilityFunction,
    visibility_function_registry,
)
from gym_gridverse.geometry import Area, Orientation, Position
from gym_gridverse.grid_object import Hidden
from gym_gridverse.observation import Observation
from gym_gridverse.state import State
from gym_gridverse.utils.custom import import_if_custom
from gym_gridverse.utils.functions import checkraise_kwargs, select_kwargs
from gym_gridverse.utils.protocols import (
    get_keyword_parameter,
    get_positional_parameters,
)
from gym_gridverse.utils.registry import FunctionRegistry


[docs]class ObservationFunction(Protocol):
[docs] def __call__( self, state: State, *, rng: Optional[rnd.Generator] = None ) -> Observation: ...
[docs]class ObservationFunctionRegistry(FunctionRegistry):
[docs] def get_protocol_parameters( self, signature: inspect.Signature ) -> List[inspect.Parameter]: (state,) = get_positional_parameters(signature, 1) rng = get_keyword_parameter(signature, 'rng') return [state, rng]
[docs] def check_signature(self, function: ObservationFunction): signature = inspect.signature(function) state, rng = self.get_protocol_parameters(signature) # checks first argument is positional ... if state.kind not in [ inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY, ]: raise TypeError( f'The first argument ({state.name}) ' f'of a registered observation function ({function}) ' 'should be allowed to be a positional argument.' ) # and `rng` is keyword if rng.kind not in [ inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY, ]: raise TypeError( f'The `rng` argument ({rng.name}) ' f'of a registered observation function ({function}) ' 'should be allowed to be a keyword argument.' ) # checks if annotations, if given, are consistent if state.annotation not in [inspect.Parameter.empty, State]: warnings.warn( f'The first argument ({state.name}) ' f'of a registered observation function ({function}) ' f'has an annotation ({state.annotation}) ' 'which is not `State`.' ) if rng.annotation not in [ inspect.Parameter.empty, Optional[rnd.Generator], ]: warnings.warn( f'The `rng` argument ({rng.name}) ' f'of a registered observation function ({function}) ' f'has an annotation ({rng.annotation}) ' 'which is not `Optional[rnd.Generator]`.' ) if signature.return_annotation not in [ inspect.Parameter.empty, Observation, ]: warnings.warn( f'The return type of a registered observation function ({function}) ' f'has an annotation ({signature.return_annotation}) ' 'which is not `Observation`.' )
observation_function_registry = ObservationFunctionRegistry() """Observation function registry"""
[docs]@observation_function_registry.register def from_visibility( state: State, *, area: Area, visibility_function: VisibilityFunction, rng: Optional[rnd.Generator] = None, ) -> Observation: pov_area = state.agent.transform * area pov_agent_position = Position(-area.ymin, -area.xmin) observation_grid = state.grid.subgrid(pov_area) * state.agent.orientation visibility = visibility_function( observation_grid, pov_agent_position, rng=rng ) if visibility.shape != (area.height, area.width): raise ValueError( f'incorrect visibility shape ({visibility.shape}), ' f'should be {(area.height, area.width)}' ) for pos in observation_grid.area.positions(): if not visibility[pos.y, pos.x]: observation_grid[pos] = Hidden() observation_agent = Agent( pov_agent_position, Orientation.F, state.agent.grid_object ) return Observation(observation_grid, observation_agent)
[docs]@observation_function_registry.register def fully_transparent( state: State, *, area: Area, rng: Optional[rnd.Generator] = None, ) -> Observation: return from_visibility( state, area=area, visibility_function=visibility_function_registry['fully_transparent'], rng=rng, )
[docs]@observation_function_registry.register def partially_occluded( state: State, *, area: Area, rng: Optional[rnd.Generator] = None, ) -> Observation: return from_visibility( state, area=area, visibility_function=visibility_function_registry['partially_occluded'], rng=rng, )
[docs]@observation_function_registry.register def raytracing( state: State, *, area: Area, rng: Optional[rnd.Generator] = None, ) -> Observation: return from_visibility( state, area=area, visibility_function=visibility_function_registry['raytracing'], rng=rng, )
[docs]@observation_function_registry.register def stochastic_raytracing( state: State, *, area: Area, rng: Optional[rnd.Generator] = None, ) -> Observation: return from_visibility( state, area=area, visibility_function=visibility_function_registry[ 'stochastic_raytracing' ], rng=rng, )
[docs]def factory(name: str, **kwargs) -> ObservationFunction: name = import_if_custom(name) try: function = observation_function_registry[name] except KeyError as error: raise ValueError(f'invalid observation function name {name}') from error signature = inspect.signature(function) required_keys = [ parameter.name for parameter in observation_function_registry.get_nonprotocol_parameters( signature ) if parameter.default is inspect.Parameter.empty ] optional_keys = [ parameter.name for parameter in observation_function_registry.get_nonprotocol_parameters( signature ) if parameter.default is not inspect.Parameter.empty ] checkraise_kwargs(kwargs, required_keys) kwargs = select_kwargs(kwargs, required_keys + optional_keys) return partial(function, **kwargs)