Source code for gym_gridverse.gym

from __future__ import annotations

import time
from functools import partial
from typing import Callable, Dict, List, Optional

import gym
import numpy as np
import pkg_resources
from gym.utils import seeding

from gym_gridverse.envs.yaml.factory import factory_env_from_yaml
from gym_gridverse.outer_env import OuterEnv
from gym_gridverse.representations.observation_representations import (
    make_observation_representation,
)
from gym_gridverse.representations.spaces import Space, SpaceType
from gym_gridverse.representations.state_representations import (
    make_state_representation,
)


[docs]def outer_space_to_gym_space(space: Dict[str, Space]) -> gym.spaces.Space: return gym.spaces.Dict( { k: gym.spaces.Box( low=v.lower_bound, high=v.upper_bound, dtype=float if v.space_type is SpaceType.CONTINUOUS else int, ) for k, v in space.items() } )
OuterEnvFactory = Callable[[], OuterEnv]
[docs]def from_factory(factory: OuterEnvFactory): return GymEnvironment(factory())
[docs]class GymEnvironment(gym.Env): metadata = { 'render.modes': ['human', 'human_state', 'human_observation'], 'video.frames_per_second': 50, } # NOTE accepting an environment instance as input is a bad idea because it # would need to be instantiated during gym registration def __init__(self, outer_env: OuterEnv): super().__init__() self.outer_env = outer_env self.state_space = ( outer_space_to_gym_space(outer_env.state_representation.space) if outer_env.state_representation is not None else None ) """Environment state space, if any.""" self.action_space = gym.spaces.Discrete( outer_env.action_space.num_actions ) """Environment action space.""" self.observation_space = ( outer_space_to_gym_space(outer_env.observation_representation.space) if outer_env.observation_representation is not None else None ) """Environment observation space, if any.""" self._state_viewer = None self._observation_viewer = None
[docs] def seed(self, seed: Optional[int] = None) -> List[int]: actual_seed = seeding.create_seed(seed) self.outer_env.inner_env.set_seed(actual_seed) return [actual_seed]
[docs] def set_state_representation(self, name: str): """Changes the state representation.""" # TODO: test self.outer_env.state_representation = make_state_representation( name, self.outer_env.inner_env.state_space ) self.state_space = outer_space_to_gym_space( self.outer_env.state_representation.space )
[docs] def set_observation_representation(self, name: str): """Changes the observation representation.""" # TODO: test self.outer_env.observation_representation = ( make_observation_representation( name, self.outer_env.inner_env.observation_space ) ) self.observation_space = outer_space_to_gym_space( self.outer_env.observation_representation.space )
@property def state(self) -> Dict[str, np.ndarray]: """Returns the representation of the current state.""" return self.outer_env.state @property def observation(self) -> Dict[str, np.ndarray]: """Returns the representation of the current observation.""" return self.outer_env.observation
[docs] def reset(self) -> Dict[str, np.ndarray]: """Resets the state of the environment. Returns: Dict[str, numpy.ndarray]: initial observation """ self.outer_env.reset() return self.observation
[docs] def step(self, action: int): """Runs the environment dynamics for one timestep. Args: action (int): agent's action Returns: Tuple[Dict[str, numpy.ndarray], float, bool, Dict]: (observation, reward, terminal, info dictionary) """ action_ = self.outer_env.action_space.int_to_action(action) reward, done = self.outer_env.step(action_) return self.observation, reward, done, {}
[docs] def render(self, mode='human'): # TODO: test # only import rendering if actually rendering (avoid importing when # using library remotely using ssh on a display-less environment) from gym_gridverse.rendering import GridVerseViewer if mode not in [ 'human', 'human_state', 'human_observation', 'rgb_array', 'rgb_array_state', 'rgb_array_observation', ]: super().render(mode) # not reset yet if self.outer_env.inner_env.state is None: return if mode in ['human', 'human_state']: if self._state_viewer is None: self._state_viewer = GridVerseViewer( self.outer_env.inner_env.state_space.grid_shape, caption='State', ) # without sleep the first frame could be black time.sleep(0.05) self._state_viewer.render(self.outer_env.inner_env.state) if mode in ['human', 'human_observation']: if self._observation_viewer is None: self._observation_viewer = GridVerseViewer( self.outer_env.inner_env.observation_space.grid_shape, caption='Observation', ) # without sleep the first frame could be black time.sleep(0.05) self._observation_viewer.render( self.outer_env.inner_env.observation ) rgb_arrays = [] if mode in ['rgb_array', 'rgb_array_state']: if self._state_viewer is None: self._state_viewer = GridVerseViewer( self.outer_env.inner_env.state_space.grid_shape, caption='State', ) # without sleep the first frame could be black time.sleep(0.05) rgb_array_state = self._state_viewer.render( self.outer_env.inner_env.state, return_rgb_array=True, ) rgb_arrays.append(rgb_array_state) if mode in ['rgb_array', 'rgb_array_observation']: if self._observation_viewer is None: self._observation_viewer = GridVerseViewer( self.outer_env.inner_env.observation_space.grid_shape, caption='Observation', ) # without sleep the first frame could be black time.sleep(0.05) rgb_array_observation = self._observation_viewer.render( self.outer_env.inner_env.observation, return_rgb_array=True, ) rgb_arrays.append(rgb_array_observation) if rgb_arrays: return tuple(rgb_arrays) if len(rgb_arrays) > 1 else rgb_arrays[0]
[docs] def close(self): # TODO: test if self._state_viewer is not None: self._state_viewer.close() self._state_viewer = None if self._observation_viewer is not None: self._observation_viewer.close() self._observation_viewer = None
[docs]class GymStateWrapper(gym.Wrapper): """ Gym Wrapper to replace the standard observation representation with state instead. Doesn't change underlying environment, won't change render """ def __init__(self, env: GymEnvironment): # Make sure we have a valid state representation if env.state_space is None: ValueError('GymEnvironment does not have a state space') super().__init__(env) self.observation_space = env.state_space @property def observation(self) -> Dict[str, np.ndarray]: return self.env.state
[docs] def reset(self) -> Dict[str, np.ndarray]: """reset the environment state Returns: Dict[str, numpy.ndarray]: initial state """ self.env.reset() return self.observation
[docs] def step(self, action: int): """performs environment step Args: action (int): agent's action Returns: Tuple[Dict[str, numpy.ndarray], float, bool, Dict]: (state, reward, terminal, info dictionary) """ observation, reward, done, info = self.env.step(action) info['observation'] = observation return self.observation, reward, done, info
STRING_TO_YAML_FILE: Dict[str, str] = { "GV-Crossing-5x5-v0": "gv_crossing.5x5.yaml", "GV-Crossing-7x7-v0": "gv_crossing.7x7.yaml", "GV-DynamicObstacles-5x5-v0": "gv_dynamic_obstacles.5x5.yaml", "GV-DynamicObstacles-7x7-v0": "gv_dynamic_obstacles.7x7.yaml", "GV-Empty-4x4-v0": "gv_empty.4x4.yaml", "GV-Empty-8x8-v0": "gv_empty.8x8.yaml", "GV-FourRooms-7x7-v0": "gv_four_rooms.7x7.yaml", "GV-FourRooms-9x9-v0": "gv_four_rooms.9x9.yaml", "GV-Keydoor-5x5-v0": "gv_keydoor.5x5.yaml", "GV-Keydoor-7x7-v0": "gv_keydoor.7x7.yaml", "GV-Keydoor-9x9-v0": "gv_keydoor.9x9.yaml", "GV-Memory-5x5-v0": "gv_memory.5x5.yaml", "GV-Memory-9x9-v0": "gv_memory.9x9.yaml", "GV-MemoryFourRooms-7x7-v0": "gv_memory_four_rooms.7x7.yaml", "GV-MemoryFourRooms-9x9-v0": "gv_memory_four_rooms.9x9.yaml", "GV-MemoryNineRooms-10x10-v0": "gv_memory_nine_rooms.10x10.yaml", "GV-MemoryNineRooms-13x13-v0": "gv_memory_nine_rooms.13x13.yaml", "GV-NineRooms-10x10-v0": "gv_nine_rooms.10x10.yaml", "GV-NineRooms-13x13-v0": "gv_nine_rooms.13x13.yaml", "GV-Teleport-5x5-v0": "gv_teleport.5x5.yaml", "GV-Teleport-7x7-v0": "gv_teleport.7x7.yaml", }
[docs]def outer_env_factory(yaml_filename: str) -> OuterEnv: env = factory_env_from_yaml(yaml_filename) observation_representation = make_observation_representation( 'default', env.observation_space ) return OuterEnv( env, observation_representation=observation_representation, )
for key, yaml_filename in STRING_TO_YAML_FILE.items(): yaml_filepath = pkg_resources.resource_filename( 'gym_gridverse', f'registered_envs/{yaml_filename}' ) factory = partial(outer_env_factory, yaml_filepath) # registering using factory to avoid allocation of outer envs gym.register( key, entry_point='gym_gridverse.gym:from_factory', kwargs={'factory': factory}, ) env_ids = list(STRING_TO_YAML_FILE.keys())