import abc
from typing import Dict, Generic, Set, Type, TypeVar
import numpy as np
from gym_gridverse.grid_object import Color, GridObject
from gym_gridverse.observation import Observation
from gym_gridverse.representations.spaces import Space
from gym_gridverse.spaces import ObservationSpace, StateSpace
from gym_gridverse.state import State
[docs]class StateRepresentation:
"""Converts a :py:class:`~gym_gridverse.state.State` into a dictionary of :py:class:`~numpy.ndarray`."""
def __init__(self, state_space: StateSpace):
if not state_space.can_be_represented:
raise ValueError(
'state space contains objects which cannot be represented in state'
)
self.state_space = state_space
@property
@abc.abstractmethod
def space(self) -> Dict[str, Space]:
"""returns representation space as as dictionary of numpy arrays"""
assert False
[docs] @abc.abstractmethod
def convert(self, state: State) -> Dict[str, np.ndarray]:
"""returns state representation as dictionary of numpy arrays"""
assert False
[docs]class ObservationRepresentation:
"""Converts a :py:class:`~gym_gridverse.observation.Observation` into a dictionary of :py:class:`~numpy.ndarray`."""
def __init__(self, observation_space: ObservationSpace):
self.observation_space = observation_space
@property
@abc.abstractmethod
def space(self) -> Dict[str, Space]:
"""returns representation space as as dictionary of numpy arrays"""
assert False
[docs] @abc.abstractmethod
def convert(self, observation: Observation) -> Dict[str, np.ndarray]:
"""returns observation representation as dictionary of numpy arrays"""
assert False
T = TypeVar('T', State, Observation, GridObject)
[docs]class ArrayRepresentation(Generic[T]):
@property
@abc.abstractmethod
def space(self) -> Space:
assert False
[docs] @abc.abstractmethod
def convert(self, obj: T) -> np.ndarray:
assert False
# grid-object representations
[docs]def default_grid_object_representation_space(
grid_object_types: Set[Type[GridObject]],
grid_object_colors: Set[Color],
) -> Space:
"""The default space of the representation
Returns a :py:class:`~gym_gridverse.representations.spaces.Space`
representing the space of a grid-object represented using type-index,
status-index, and color-index.
NOTE: used by
:class:`~gym_gridverse.representations.state_representations.DefaultGridObjectStateRepresentation`
and
:class:`~gym_gridverse.representations.observation_representations.DefaultGridObjectObservationRepresentation`,
refactored here because of DRY.
"""
max_agent_object_type_index = max(
grid_object_type.type_index() for grid_object_type in grid_object_types
)
# TODO minor bug: the max state index is -1 compared to the num-states
max_agent_object_state_index = max(
grid_object_type.num_states() for grid_object_type in grid_object_types
)
max_agent_object_color_index = max(
color.value for color in grid_object_colors
)
return Space.make_categorical_space(
np.array(
[
max_agent_object_type_index,
max_agent_object_state_index,
max_agent_object_color_index,
]
)
)
[docs]def default_grid_object_representation_convert(
grid_object: GridObject,
) -> np.ndarray:
"""The default conversion of a grid-object
Converts a grid-object into a 3-channel array of:
- object type index
- object state index
- object color index
NOTE: used by
:class:`~gym_gridverse.representations.state_representations.DefaultGridObjectStateRepresentation`
and
:class:`~gym_gridverse.representations.observation_representations.DefaultObservationGridObjectObservationRepresentation`,
refactored here because of DRY.
"""
return np.array(
[
grid_object.type_index(),
grid_object.state_index,
grid_object.color.value,
]
)
[docs]def no_overlap_grid_object_representation_space(
grid_object_types: Set[Type[GridObject]],
grid_object_colors: Set[Color],
) -> Space:
"""The no-overlap space of the representation
Returns a :py:class:`~gym_gridverse.representations.spaces.Space`
representing the space of a grid-object represented using type-index,
status-index, and color-index. Guarantees no overlap across channels,
meaning that each channel uses separate indices.
NOTE: used by
:class:`~gym_gridverse.representations.state_representations.NoOverlapGridObjectStateRepresentation`
and
:class:`~gym_gridverse.representations.observation_representations.NoOverlapGridObjectObservationRepresentation`,
refactored here because of DRY.
"""
max_agent_object_type_index = max(
grid_object_type.type_index() for grid_object_type in grid_object_types
)
# TODO minor bug: the max state index is -1 compared to the num-states
max_agent_object_state_index = max(
grid_object_type.num_states() for grid_object_type in grid_object_types
)
max_agent_object_color_index = max(
color.value for color in grid_object_colors
)
return Space.make_categorical_space(
np.array(
[
max_agent_object_type_index,
max_agent_object_type_index + max_agent_object_state_index + 1,
max_agent_object_type_index
+ max_agent_object_state_index
+ max_agent_object_color_index
+ 2,
]
)
)
[docs]def no_overlap_grid_object_representation_convert(
grid_object_types: Set[Type[GridObject]],
grid_object_colors: Set[Color],
grid_object: GridObject,
) -> np.ndarray:
"""The no-overlap conversion of a grid-object
Converts a :py:class:`~gym_gridverse.grid_object.GridObject` into a
3-channel array of type-index, status-index, and color-index. Guarantees
no overlap across channels, meaning that each channel uses separate
indices.
NOTE: used by
:class:`~gym_gridverse.representations.state_representations.NoOverlapGridObjectStateRepresentation`
and
:class:`~gym_gridverse.representations.observation_representations.NoOverlapGridObjectObservationRepresentation`,
refactored here because of DRY.
"""
max_agent_object_type_index = max(
grid_object_type.type_index() for grid_object_type in grid_object_types
)
# TODO minor bug: the max state index is -1 compared to the num-states
max_agent_object_state_index = max(
grid_object_type.num_states() for grid_object_type in grid_object_types
)
return np.array(
[
grid_object.type_index(),
max_agent_object_type_index + grid_object.state_index + 1,
max_agent_object_type_index
+ max_agent_object_state_index
+ grid_object.color.value
+ 2,
]
)
[docs]def compact_grid_object_representation_space(
grid_object_type_map: np.ndarray,
grid_object_state_map: np.ndarray,
grid_object_color_map: np.ndarray,
) -> Space:
"""The compact space of the representation
Returns a :py:class:`~gym_gridverse.representations.spaces.Space`
representing the space of a grid-object represented using type-index,
status-index, and color-index. Guarantees a compact no overlap
representation across channels, meaning that each channel uses separate
indices, and there are no gaps between the used indices.
NOTE: used by
:class:`~gym_gridverse.representations.state_representations.CompactGridObjectStateRepresentation`
and
:class:`~gym_gridverse.representations.observation_representations.CompactGridObjectObservationRepresentation`,
refactored here because of DRY.
"""
return Space.make_categorical_space(
np.array(
[
grid_object_type_map.max(),
grid_object_state_map.max(),
grid_object_color_map.max(),
]
)
)
[docs]def compact_grid_object_representation_convert(
grid_object_type_map: np.ndarray,
grid_object_state_map: np.ndarray,
grid_object_color_map: np.ndarray,
grid_object: GridObject,
) -> np.ndarray:
"""The no-overlap conversion of a grid-object
Converts a :py:class:`~gym_gridverse.grid_object.GridObject` into a
3-channel array of type-index, status-index, and color-index. Guarantees a
compact no overlap representation across channels, meaning that each
channel uses separate indices, and there are no gaps between the used
indices.
NOTE: used by
:class:`~gym_gridverse.representations.state_representations.CompactGridObjectStateRepresentation`
and
:class:`~gym_gridverse.representations.observation_representations.CompactGridObjectObservationRepresentation`,
refactored here because of DRY.
"""
i = grid_object.type_index()
j = grid_object.state_index
k = grid_object.color.value
return np.array(
[
grid_object_type_map[i],
grid_object_state_map[i, j],
grid_object_color_map[k],
]
)