from typing import Dict, Tuple
import numpy as np
from gym_gridverse.debugging import gv_debug
from gym_gridverse.grid_object import GridObject, NoneGridObject
from gym_gridverse.representations.representation import (
ArrayRepresentation,
StateRepresentation,
compact_grid_object_representation_convert,
compact_grid_object_representation_space,
default_grid_object_representation_convert,
default_grid_object_representation_space,
no_overlap_grid_object_representation_convert,
no_overlap_grid_object_representation_space,
)
from gym_gridverse.representations.spaces import Space
from gym_gridverse.spaces import StateSpace
from gym_gridverse.state import State
[docs]class ArrayStateRepresentation(ArrayRepresentation[State]):
def __init__(self, state_space: StateSpace):
self.state_space = state_space
[docs]class GridObjectStateRepresentation(ArrayRepresentation[GridObject]):
def __init__(self, state_space: StateSpace):
self.state_space = state_space
[docs]def make_state_representation(
name: str,
state_space: StateSpace,
) -> StateRepresentation:
"""Factory function for state representations
Args:
name (str): name of the representation
state_space (StateSpace): inner-environment state space
Returns:
StateRepresentation:
"""
# TODO: test
grid_object_representation: GridObjectStateRepresentation
if name == 'default':
grid_object_representation = DefaultGridObjectStateRepresentation(
state_space
)
representations = {
'grid': GridStateRepresentation(
state_space, grid_object_representation
),
'agent_id_grid': AgentIDGridStateRepresentation(state_space),
'agent': AgentStateRepresentation(state_space),
'item': ItemStateRepresentation(
state_space, grid_object_representation
),
}
return DictStateRepresentation(state_space, representations)
if name == 'no-overlap':
# NOTE: only `grid` and `item` require a separate no-overlap representation
grid_object_representation = NoOverlapGridObjectStateRepresentation(
state_space
)
representations = {
'grid': GridStateRepresentation(
state_space, grid_object_representation
),
'agent_id_grid': AgentIDGridStateRepresentation(state_space),
'agent': AgentStateRepresentation(state_space),
'item': ItemStateRepresentation(
state_space, grid_object_representation
),
}
return DictStateRepresentation(state_space, representations)
if name == 'compact':
# NOTE: only `grid` and `item` require a separate compact representation
grid_object_representation = CompactGridObjectStateRepresentation(
state_space
)
representations = {
'grid': GridStateRepresentation(
state_space, grid_object_representation
),
'agent_id_grid': AgentIDGridStateRepresentation(state_space),
'agent': AgentStateRepresentation(state_space),
'item': ItemStateRepresentation(
state_space, grid_object_representation
),
}
return DictStateRepresentation(state_space, representations)
raise ValueError(f'invalid name {name}')
# representation composition
[docs]class DictStateRepresentation(StateRepresentation):
def __init__(
self,
state_space: StateSpace,
representations: Dict[str, ArrayStateRepresentation],
):
super().__init__(state_space)
self.representations = representations
@property
def space(self) -> Dict[str, Space]:
return {
key: representation.space
for key, representation in self.representations.items()
}
[docs] def convert(self, state: State) -> Dict[str, np.ndarray]:
if gv_debug() and not self.state_space.contains(state):
raise ValueError('state-space does not contain state')
return {
key: representation.convert(state)
for key, representation in self.representations.items()
}
# dict field representations
[docs]class GridStateRepresentation(ArrayStateRepresentation):
def __init__(
self,
state_space: StateSpace,
grid_object_representation: GridObjectStateRepresentation,
):
super().__init__(state_space)
self.grid_object_representation = grid_object_representation
@property
def space(self) -> Space:
height = self.state_space.grid_shape.height
width = self.state_space.grid_shape.width
space_type = self.grid_object_representation.space.space_type
lower_bound = self.grid_object_representation.space.lower_bound
lower_bound = np.tile(lower_bound, (height, width, 1))
upper_bound = self.grid_object_representation.space.upper_bound
upper_bound = np.tile(upper_bound, (height, width, 1))
return Space(space_type, lower_bound, upper_bound)
[docs] def convert(self, state: State) -> np.ndarray:
return np.array(
[
[
self.grid_object_representation.convert(state.grid[y, x])
for x in range(state.grid.shape.width)
]
for y in range(state.grid.shape.height)
],
int,
)
[docs]class ItemStateRepresentation(ArrayStateRepresentation):
def __init__(
self,
state_space: StateSpace,
grid_object_representation: GridObjectStateRepresentation,
):
super().__init__(state_space)
self.grid_object_representation = grid_object_representation
@property
def space(self) -> Space:
return self.grid_object_representation.space
[docs] def convert(self, state: State) -> np.ndarray:
return self.grid_object_representation.convert(state.agent.grid_object)
[docs]class AgentIDGridStateRepresentation(ArrayStateRepresentation):
@property
def space(self) -> Space:
height = self.state_space.grid_shape.height
width = self.state_space.grid_shape.width
if height < 0 or width < 0:
raise ValueError(f'negative height or width ({height, width})')
return Space.make_discrete_space(
np.zeros((height, width), dtype=int),
np.ones((height, width), dtype=int),
)
[docs] def convert(self, state: State) -> np.ndarray:
grid_agent_position = np.zeros(state.grid.shape.as_tuple, int)
grid_agent_position[state.agent.position.yx] = 1
return grid_agent_position
[docs]class AgentStateRepresentation(ArrayStateRepresentation):
@property
def space(self) -> Space:
# 4 (last) entries for a one-hot encoding of the orientation
return Space.make_continuous_space(
np.array([-1.0, -1.0, 0.0, 0.0, 0.0, 0.0]),
np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]),
)
[docs] def convert(self, state: State) -> np.ndarray:
agent_array = np.zeros(6)
# normalized between -1 and 1
y = (2 * state.agent.position.y - state.grid.shape.height + 1) / (
state.grid.shape.height - 1
)
x = (2 * state.agent.position.x - state.grid.shape.width + 1) / (
state.grid.shape.width - 1
)
i = state.agent.orientation.value
agent_array[0] = y
agent_array[1] = x
agent_array[2 + i] = 1
return agent_array
# grid-object representations
[docs]class DefaultGridObjectStateRepresentation(GridObjectStateRepresentation):
"""The default representation for a grid-object
Simply returns the grid-object indices. See
:func:`gym_gridverse.representations.representation.default_grid_object_representation_space`
and
:func:`gym_gridverse.representations.representation.default_grid_object_convert`
for more information.
"""
def __init__(self, state_space: StateSpace):
super().__init__(state_space)
self._grid_object_types = set(self.state_space.object_types) | {
NoneGridObject
}
self._grid_object_colors = set(self.state_space.colors)
@property
def space(self) -> Space:
return default_grid_object_representation_space(
self._grid_object_types,
self._grid_object_colors,
)
[docs] def convert(self, grid_object: GridObject) -> np.ndarray:
return default_grid_object_representation_convert(grid_object)
[docs]class NoOverlapGridObjectStateRepresentation(GridObjectStateRepresentation):
"""The no-overlap representation for a grid-object
Guarantees that each channel uses separate indices. See
:func:`gym_gridverse.representations.representation.no_overlap_grid_object_representation_space`
and
:func:`gym_gridverse.representations.representation.no_overlap_grid_object_convert`
for more information.
"""
def __init__(self, state_space: StateSpace):
super().__init__(state_space)
self._grid_object_types = set(self.state_space.object_types) | {
NoneGridObject
}
self._grid_object_colors = set(self.state_space.colors)
@property
def space(self) -> Space:
return no_overlap_grid_object_representation_space(
self._grid_object_types,
self._grid_object_colors,
)
[docs] def convert(self, grid_object: GridObject) -> np.ndarray:
return no_overlap_grid_object_representation_convert(
self._grid_object_types,
self._grid_object_colors,
grid_object,
)
[docs]class CompactGridObjectStateRepresentation(GridObjectStateRepresentation):
"""The compact representation for a grid-object
Guarantees that each channel uses separate indices, and removes empty gaps
between indices. See
:func:`gym_gridverse.representations.representation.compact_grid_object_representation_space`
and
:func:`gym_gridverse.representations.representation.compact_grid_object_convert`
for more information.
"""
def __init__(self, state_space: StateSpace):
super().__init__(state_space)
shape: Tuple[int, ...]
# TODO eventually fix this at the space-level
max_type_index = self.state_space.max_type_index
max_state_index = self.state_space.max_state_index
max_color_index = self.state_space.max_object_color
shape = (max_type_index + 1,)
self._grid_object_type_map = -np.ones(shape, int)
shape = (max_type_index + 1, max_state_index + 1)
self._grid_object_status_map = -np.ones(shape, int)
shape = (max_color_index + 1,)
self._grid_object_color_map = -np.ones(shape, int)
compact_index = 0
grid_object_types = set(self.state_space.object_types) | {
NoneGridObject
}
for grid_object in grid_object_types:
i = grid_object.type_index()
self._grid_object_type_map[i] = compact_index
compact_index += 1
for grid_object in grid_object_types:
i = grid_object.type_index()
for j in range(grid_object.num_states()):
self._grid_object_status_map[i, j] = compact_index
compact_index += 1
for color in self.state_space.colors:
k = color.value
self._grid_object_color_map[k] = compact_index
compact_index += 1
@property
def space(self) -> Space:
return compact_grid_object_representation_space(
self._grid_object_type_map,
self._grid_object_status_map,
self._grid_object_color_map,
)
[docs] def convert(self, grid_object: GridObject) -> np.ndarray:
return compact_grid_object_representation_convert(
self._grid_object_type_map,
self._grid_object_status_map,
self._grid_object_color_map,
grid_object,
)