Source code for gym_gridverse.envs.gridworld

from typing import Optional, Tuple

import numpy.random as rnd

from gym_gridverse.action import Action
from gym_gridverse.debugging import gv_debug
from gym_gridverse.envs import InnerEnv
from gym_gridverse.envs.observation_functions import ObservationFunction
from gym_gridverse.envs.reset_functions import ResetFunction
from gym_gridverse.envs.reward_functions import RewardFunction
from gym_gridverse.envs.terminating_functions import TerminatingFunction
from gym_gridverse.envs.transition_functions import (
    TransitionFunction,
    transition_with_copy,
)
from gym_gridverse.observation import Observation
from gym_gridverse.rng import make_rng
from gym_gridverse.spaces import ActionSpace, ObservationSpace, StateSpace
from gym_gridverse.state import State


[docs]class GridWorld(InnerEnv): """Implementation of the InnerEnv interface."""
[docs] def __init__( self, state_space: StateSpace, action_space: ActionSpace, observation_space: ObservationSpace, reset_function: ResetFunction, transition_function: TransitionFunction, observation_function: ObservationFunction, reward_function: RewardFunction, termination_function: TerminatingFunction, ): """Initializes a GridWorld from the given components. Args: state_space (StateSpace): action_space (ActionSpace): observation_space (ObservationSpace): reset_function: (ResetFunction): transition_function: (TransitionFunction),: observation_function (ObservationFunction): reward_function (RewardFunction): termination_function (TerminatingFunction): """ # TODO: maybe add a parameter to avoid calls to `contain` everywhere # (or maybe a global setting) self._reset_function = reset_function self._transition_function = transition_function self._observation_function = observation_function self._reward_function = reward_function self._termination_function = termination_function self._rng: Optional[rnd.Generator] = None super().__init__(state_space, action_space, observation_space)
[docs] def set_seed(self, seed: Optional[int] = None): self._rng = make_rng(seed)
[docs] def functional_reset(self) -> State: state = self._reset_function(rng=self._rng) if gv_debug() and not self.state_space.contains(state): raise ValueError('state does not satisfy state_space') return state
[docs] def functional_step( self, state: State, action: Action ) -> Tuple[State, float, bool]: if gv_debug() and not self.state_space.contains(state): raise ValueError('state does not satisfy state_space') if not self.action_space.contains(action): raise ValueError('action {action} does not satisfy action-space') next_state = transition_with_copy( self._transition_function, state, action, rng=self._rng, ) if gv_debug() and not self.state_space.contains(next_state): raise ValueError('next_state does not satisfy state_space') reward = self._reward_function(state, action, next_state) terminal = self._termination_function(state, action, next_state) return (next_state, reward, terminal)
[docs] def functional_observation(self, state: State) -> Observation: observation = self._observation_function(state, rng=self._rng) if gv_debug() and not self.observation_space.contains(observation): raise ValueError('observation does not satisfy observation_space') return observation