from typing import Dict, Tuple
import numpy as np
from gym_gridverse.debugging import gv_debug
from gym_gridverse.grid_object import GridObject, Hidden, NoneGridObject
from gym_gridverse.observation import Observation
from gym_gridverse.representations.representation import (
ArrayRepresentation,
ObservationRepresentation,
compact_grid_object_representation_convert,
compact_grid_object_representation_space,
default_grid_object_representation_convert,
default_grid_object_representation_space,
no_overlap_grid_object_representation_convert,
no_overlap_grid_object_representation_space,
)
from gym_gridverse.representations.spaces import Space
from gym_gridverse.spaces import ObservationSpace
[docs]class ArrayObservationRepresentation(ArrayRepresentation[Observation]):
def __init__(self, observation_space: ObservationSpace):
self.observation_space = observation_space
[docs]class GridObjectObservationRepresentation(ArrayRepresentation[GridObject]):
def __init__(self, observation_space: ObservationSpace):
self.observation_space = observation_space
[docs]def make_observation_representation(
name: str,
observation_space: ObservationSpace,
) -> ObservationRepresentation:
"""Factory function for observation representations
Args:
name (str): name of the representation
observation_space (ObservationSpace): inner-environment observation space
Returns:
ObservationRepresentation:
"""
# TODO: test
grid_object_representation: GridObjectObservationRepresentation
if name == 'default':
grid_object_representation = DefaultGridObjectObservationRepresentation(
observation_space
)
representations = {
'grid': GridObservationRepresentation(
observation_space, grid_object_representation
),
'agent_id_grid': AgentIDGridObservationRepresentation(
observation_space
),
'item': ItemObservationRepresentation(
observation_space, grid_object_representation
),
}
return DictObservationRepresentation(observation_space, representations)
if name == 'no-overlap':
# NOTE: only `grid` and `item` require a separate no-overlap representation
grid_object_representation = (
NoOverlapGridObjectObservationRepresentation(observation_space)
)
representations = {
'grid': GridObservationRepresentation(
observation_space, grid_object_representation
),
'agent_id_grid': AgentIDGridObservationRepresentation(
observation_space
),
'item': ItemObservationRepresentation(
observation_space, grid_object_representation
),
}
return DictObservationRepresentation(observation_space, representations)
if name == 'compact':
# NOTE: only `grid` and `item` require a separate compact representation
grid_object_representation = CompactGridObjectObservationRepresentation(
observation_space
)
representations = {
'grid': GridObservationRepresentation(
observation_space, grid_object_representation
),
'agent_id_grid': AgentIDGridObservationRepresentation(
observation_space
),
'item': ItemObservationRepresentation(
observation_space, grid_object_representation
),
}
return DictObservationRepresentation(observation_space, representations)
raise ValueError(f'invalid name {name}')
# representation composition
[docs]class DictObservationRepresentation(ObservationRepresentation):
def __init__(
self,
observation_space: ObservationSpace,
representations: Dict[str, ArrayObservationRepresentation],
):
super().__init__(observation_space)
self.representations = representations
@property
def space(self) -> Dict[str, Space]:
return {
key: representation.space
for key, representation in self.representations.items()
}
[docs] def convert(self, observation: Observation) -> Dict[str, np.ndarray]:
if gv_debug() and not self.observation_space.contains(observation):
raise ValueError('observation-space does not contain observation')
return {
key: representation.convert(observation)
for key, representation in self.representations.items()
}
[docs]class GridObservationRepresentation(ArrayObservationRepresentation):
def __init__(
self,
observation_space: ObservationSpace,
grid_object_representation: GridObjectObservationRepresentation,
):
super().__init__(observation_space)
self.grid_object_representation = grid_object_representation
@property
def space(self) -> Space:
height = self.observation_space.grid_shape.height
width = self.observation_space.grid_shape.width
space_type = self.grid_object_representation.space.space_type
lower_bound = self.grid_object_representation.space.lower_bound
lower_bound = np.tile(lower_bound, (height, width, 1))
upper_bound = self.grid_object_representation.space.upper_bound
upper_bound = np.tile(upper_bound, (height, width, 1))
return Space(space_type, lower_bound, upper_bound)
[docs] def convert(self, observation: Observation) -> np.ndarray:
return np.array(
[
[
self.grid_object_representation.convert(
observation.grid[y, x]
)
for x in range(observation.grid.shape.width)
]
for y in range(observation.grid.shape.height)
],
int,
)
[docs]class ItemObservationRepresentation(ArrayObservationRepresentation):
def __init__(
self,
observation_space: ObservationSpace,
grid_object_representation: GridObjectObservationRepresentation,
):
super().__init__(observation_space)
self.grid_object_representation = grid_object_representation
@property
def space(self) -> Space:
return self.grid_object_representation.space
[docs] def convert(self, observation: Observation) -> np.ndarray:
return self.grid_object_representation.convert(
observation.agent.grid_object
)
[docs]class AgentIDGridObservationRepresentation(ArrayObservationRepresentation):
@property
def space(self) -> Space:
height = self.observation_space.grid_shape.height
width = self.observation_space.grid_shape.width
if height < 0 or width < 0:
raise ValueError(f'negative height or width ({height, width})')
return Space.make_discrete_space(
np.zeros((height, width), dtype=int),
np.ones((height, width), dtype=int),
)
[docs] def convert(self, observation: Observation) -> np.ndarray:
grid_agent_position = np.zeros(observation.grid.shape.as_tuple, int)
grid_agent_position[observation.agent.position.yx] = 1
return grid_agent_position
# grid-object representations
[docs]class DefaultGridObjectObservationRepresentation(
GridObjectObservationRepresentation
):
"""The default representation for a grid-object
Simply returns the grid-object indices. See
:func:`gym_gridverse.representations.representation.default_grid_object_representation_space`
and
:func:`gym_gridverse.representations.representation.default_grid_object_convert`
for more information.
"""
def __init__(self, observation_space: ObservationSpace):
super().__init__(observation_space)
self._grid_object_types = set(self.observation_space.object_types) | {
Hidden,
NoneGridObject,
}
self._grid_object_colors = set(self.observation_space.colors)
@property
def space(self) -> Space:
return default_grid_object_representation_space(
self._grid_object_types,
self._grid_object_colors,
)
[docs] def convert(self, grid_object: GridObject) -> np.ndarray:
return default_grid_object_representation_convert(grid_object)
[docs]class NoOverlapGridObjectObservationRepresentation(
GridObjectObservationRepresentation
):
"""The no-overlap representation for a grid-object
Guarantees that each channel uses separate indices. See
:func:`gym_gridverse.representations.representation.no_overlap_grid_object_representation_space`
and
:func:`gym_gridverse.representations.representation.no_overlap_grid_object_convert`
for more information.
"""
def __init__(self, observation_space: ObservationSpace):
super().__init__(observation_space)
self._grid_object_types = set(self.observation_space.object_types) | {
Hidden,
NoneGridObject,
}
self._grid_object_colors = set(self.observation_space.colors)
@property
def space(self) -> Space:
return no_overlap_grid_object_representation_space(
self._grid_object_types,
self._grid_object_colors,
)
[docs] def convert(self, grid_object: GridObject) -> np.ndarray:
return no_overlap_grid_object_representation_convert(
self._grid_object_types,
self._grid_object_colors,
grid_object,
)
[docs]class CompactGridObjectObservationRepresentation(
GridObjectObservationRepresentation
):
"""The compact representation for a grid-object
Guarantees that each channel uses separate indices, and removes empty gaps
between indices. See
:func:`gym_gridverse.representations.representation.compact_grid_object_representation_space`
and
:func:`gym_gridverse.representations.representation.compact_grid_object_convert`
for more information.
"""
def __init__(self, observation_space: ObservationSpace):
super().__init__(observation_space)
shape: Tuple[int, ...]
# TODO eventually fix this at the space-level
max_type_index = self.observation_space.max_type_index
max_state_index = self.observation_space.max_state_index
max_color_index = self.observation_space.max_object_color
shape = (max_type_index + 1,)
self._grid_object_type_map = -np.ones(shape, int)
shape = (max_type_index + 1, max_state_index + 1)
self._grid_object_status_map = -np.ones(shape, int)
shape = (max_color_index + 1,)
self._grid_object_color_map = -np.ones(shape, int)
compact_index = 0
grid_object_types = set(self.observation_space.object_types) | {
Hidden,
NoneGridObject,
}
for grid_object in grid_object_types:
i = grid_object.type_index()
self._grid_object_type_map[i] = compact_index
compact_index += 1
for grid_object in grid_object_types:
i = grid_object.type_index()
for j in range(grid_object.num_states()):
self._grid_object_status_map[i, j] = compact_index
compact_index += 1
for color in self.observation_space.colors:
k = color.value
self._grid_object_color_map[k] = compact_index
compact_index += 1
@property
def space(self) -> Space:
return compact_grid_object_representation_space(
self._grid_object_type_map,
self._grid_object_status_map,
self._grid_object_color_map,
)
[docs] def convert(self, grid_object: GridObject) -> np.ndarray:
return compact_grid_object_representation_convert(
self._grid_object_type_map,
self._grid_object_status_map,
self._grid_object_color_map,
grid_object,
)