from typing import Iterable, Sequence, Tuple, Type
from gym_gridverse.action import Action
from gym_gridverse.geometry import Area, Orientation, Position, Shape
from gym_gridverse.grid_object import Color, GridObject, Hidden, NoneGridObject
from gym_gridverse.observation import Observation
from gym_gridverse.state import State
def _max_object_type(object_types: Iterable[Type[GridObject]]) -> int:
"""Returns the highest object type of the provided object classes
Args:
object_types (`Iterable[Type[GridObject]]`):
Returns:
int:
"""
return max(obj_type.type_index() for obj_type in object_types)
def _max_object_status(object_types: Iterable[Type[GridObject]]) -> int:
"""Returns the highest object status of the provided object classes
Args:
object_types (`Iterable[Type[GridObject]]`):
Returns:
int:
"""
return max(obj_type.num_states() for obj_type in object_types)
def _max_color_index(colors: Iterable[Color]) -> int:
"""Returns the highest color index of the provided colors
Args:
colors (`Iterable[Color]`):
Returns:
int:
"""
return max(color.value for color in colors)
[docs]class StateSpace:
def __init__(
self,
grid_shape: Shape,
object_types: Sequence[Type[GridObject]],
colors: Sequence[Color],
):
self.grid_shape = grid_shape
self.object_types = list(object_types)
self.colors = set(colors) | {Color.NONE}
self._agent_object_types = set(object_types) | {NoneGridObject}
[docs] def contains(self, state: State) -> bool:
"""True if the state satisfies the state-space"""
# TODO: test
return (
state.grid.shape == self.grid_shape
and state.grid.object_types().issubset(self.object_types)
and state.grid.area.contains(state.agent.position)
and isinstance(state.agent.orientation, Orientation)
and type(state.agent.grid_object) in self._agent_object_types
)
@property
def can_be_represented(self):
# TODO: test
return all(
object_type.can_be_represented_in_state()
for object_type in self.object_types
)
@property
def agent_state_size(self) -> Tuple[int, int, int, int, int]:
# TODO: test
return (
self.grid_shape.height,
self.grid_shape.height,
self.max_agent_object_type,
self.max_agent_object_status,
self.max_object_color,
)
@property
def agent_state_shape(self) -> int:
# TODO: test
return len(self.agent_state_size)
@property
def grid_state_shape(self) -> Shape:
# TODO: test
return self.grid_shape
@property
def max_object_color(self) -> int:
return _max_color_index(self.colors)
# Random getters you might be interested in
@property
def max_type_index(self) -> int:
return max(self.max_grid_object_type, self.max_agent_object_type)
@property
def max_state_index(self) -> int:
return max(self.max_grid_object_status, self.max_agent_object_status)
@property
def max_grid_object_type(self) -> int:
return _max_object_type(self.object_types)
@property
def max_grid_object_status(self) -> int:
return _max_object_status(self.object_types)
@property
def max_agent_object_type(self) -> int:
# NOTE: Add Hidden as the default 'non' object the agent is holding
return _max_object_type(self.object_types + [NoneGridObject])
@property
def max_agent_object_status(self) -> int:
# TODO: test
# NOTE: Add Hidden as the default 'non' object the agent is holding
return _max_object_status(self.object_types + [NoneGridObject])
[docs]class ActionSpace:
def __init__(self, actions: Sequence[Action]):
self.actions = actions
[docs] def contains(self, action: Action) -> bool:
"""True if the action satisfies the action-space"""
return action in self.actions
[docs] def int_to_action(self, action: int) -> Action:
return self.actions[action]
[docs] def action_to_int(self, action: Action) -> int:
# TODO: test
return self.actions.index(action)
@property
def num_actions(self) -> int:
return len(self.actions)
[docs]class ObservationSpace:
def __init__(
self,
grid_shape: Shape,
object_types: Sequence[Type[GridObject]],
colors: Sequence[Color],
):
# TODO we should generalize this
if grid_shape.width % 2 == 0:
raise ValueError('shape should have an odd width')
self.grid_shape = grid_shape
self.object_types = list(object_types)
self.colors = set(colors) | {Color.NONE}
self._grid_object_types = set(object_types) | {Hidden}
self._agent_object_types = set(object_types) | {NoneGridObject}
# TODO: eventually let this substitute the `grid_shape` input altogether
# this area represents the observable area, with (0, 0) representing
# the agent's position, when the agent is pointing N
self.area = Area(
(-self.grid_shape.height + 1, 0),
(-(self.grid_shape.width // 2), self.grid_shape.width // 2),
)
# NOTE this position is relative to the top right coordinate of the area
self.agent_position = Position(
self.area.height - 1, self.area.width // 2
)
# TODO: We don't need to make assumptions about the agent position
[docs] def contains(self, observation: Observation) -> bool:
"""True if the observation satisfies the observation-space"""
have_same_shape = observation.grid.shape == self.grid_shape
y_in_grid = 0 <= observation.agent.position.y < self.area.height
x_in_grid = 0 <= observation.agent.position.x < self.area.width
agent_obj_type_in_space = (
type(observation.agent.grid_object) in self._agent_object_types
)
grid_objs_in_space = observation.grid.object_types().issubset(
self._grid_object_types
)
grid_objs_colors_in_space = set(
observation.grid[pos].color
for pos in observation.grid.area.positions()
).issubset(self.colors)
agent_obj_color_in_space = (
observation.agent.grid_object.color in self.colors
)
res = [
have_same_shape,
grid_objs_in_space,
grid_objs_colors_in_space,
y_in_grid,
x_in_grid,
agent_obj_type_in_space,
agent_obj_color_in_space,
]
return all(res)
@property
def agent_state_size(self) -> Tuple[int, int, int, int, int]:
# TODO: test
return (
self.grid_shape.height,
self.grid_shape.width,
self.max_agent_object_type,
self.max_agent_object_status,
self.max_object_color,
)
@property
def agent_state_shape(self) -> int:
# TODO: test
return len(self.agent_state_size)
@property
def grid_state_shape(self) -> Shape:
# TODO: test
return self.grid_shape
@property
def max_object_color(self) -> int:
return _max_color_index(self.colors)
# Random getters you might be interested in
@property
def max_type_index(self) -> int:
return max(self.max_grid_object_type, self.max_agent_object_type)
@property
def max_state_index(self) -> int:
return max(self.max_grid_object_status, self.max_agent_object_status)
@property
def max_grid_object_type(self) -> int:
# NOTE: Add Hidden as a potential object in any domain observation
return _max_object_type(self.object_types + [Hidden])
@property
def max_grid_object_status(self) -> int:
# NOTE: Add Hidden as a potential object in any domain observation
return _max_object_status(self.object_types + [Hidden])
@property
def max_agent_object_type(self) -> int:
# TODO: test
# NOTE: Add Hidden as the default 'non' object the agent is holding
return _max_object_type(self.object_types + [NoneGridObject])
@property
def max_agent_object_status(self) -> int:
# TODO: test
# NOTE: Add Hidden as the default 'non' object the agent is holding
return _max_object_status(self.object_types + [NoneGridObject])