from __future__ import annotations
from typing import List, Set, Tuple, Type, Union, cast
from .geometry import Area, Orientation, Position, Shape
from .grid_object import Floor, GridObject, GridObjectFactory, Hidden
[docs]class Grid:
"""A two-dimensional grid of objects.
A container of :py:class:`~gym_gridverse.grid_object.GridObject`. This is
typically used to represent either the global state of the environment, or
a partial agent view. This is basically a two-dimensional array, with some
additional methods which help in querying and manipulating its objects.
"""
[docs] def __init__(self, objects: List[List[GridObject]]):
"""Constructs a grid from the given grid-objects
Args:
objects (List[List[~gym_gridverse.grid_object.GridObject]]): grid of GridObjects
"""
self.objects = objects
self.shape = Shape(len(objects), len(objects[0]))
self.area = Area((0, self.shape.height - 1), (0, self.shape.width - 1))
[docs] @staticmethod
def from_shape(
shape: Union[Shape, Tuple[int, int]],
*,
factory: GridObjectFactory = Floor,
) -> Grid:
"""Constructs a grid with the given shape, with objects generated from the factory.
Args:
shape (Union[~gym_gridverse.geometry.Shape, Tuple[int, int]]):
factory (~gym_gridverse.grid_object.GridObjectFactory):
Returns:
Grid: The grid of the appropriate size, with generated objects
"""
try:
shape = cast(Shape, shape)
height, width = shape.height, shape.width
except AttributeError:
shape = cast(Tuple[int, int], shape)
height, width = shape
objects = [[factory() for _ in range(width)] for _ in range(height)]
return Grid(objects)
def __eq__(self, other) -> bool:
try:
return self.shape == other.shape and all(
self[position] == other[position]
for position in self.area.positions()
)
except AttributeError:
return NotImplemented
[docs] def object_types(self) -> Set[Type[GridObject]]:
"""Returns the set of object types in the grid
Returns:
Set[Type[GridObject]]:
"""
return set(type(self[position]) for position in self.area.positions())
[docs] def get(
self,
position: Union[Position, Tuple[int, int]],
*,
factory: GridObjectFactory,
) -> GridObject:
"""Gets the grid object in the position, or generates one from the factory.
Args:
position (Union[~gym_gridverse.geometry.Position, Tuple[int, int]]):
factory (~gym_gridverse.grid_object.GridObjectFactory):
Returns:
GridObject:
"""
try:
return self[position]
except IndexError:
return factory()
def __getitem__(
self, position: Union[Position, Tuple[int, int]]
) -> GridObject:
try:
position = cast(Position, position)
y, x = position.yx
except AttributeError:
position = cast(Tuple[int, int], position)
y, x = position
return self.objects[y][x]
def __setitem__(
self, position: Union[Position, Tuple[int, int]], obj: GridObject
):
try:
position = cast(Position, position)
y, x = position.yx
except AttributeError:
position = cast(Tuple[int, int], position)
y, x = position
if not isinstance(obj, GridObject):
raise TypeError('grid can only contain grid objects')
self.objects[y][x] = obj
[docs] def swap(self, p: Position, q: Position):
"""Swaps the grid objects at two positions.
Args:
p (~gym_gridverse.geometry.Position):
q (~gym_gridverse.geometry.Position):
"""
self[p], self[q] = self[q], self[p]
[docs] def subgrid(self, area: Area) -> Grid:
"""Returns subgrid slice at given area.
Cells included in the area but outside of the grid are represented as
Hidden objects.
Args:
area (~gym_gridverse.geometry.Area): The area to be sliced
Returns:
Grid: New instance, sliced appropriately
"""
return Grid(
[
[
self.objects[y][x]
if 0 <= y < self.area.height and 0 <= x < self.area.width
else Hidden()
for x in area.x_coordinates()
]
for y in area.y_coordinates()
]
)
def __mul__(self, other: Orientation) -> Grid:
"""returns grid transformed according to given orientation.
NOTE: this product follows rigid body transform conventions, whereby
the orientation represents a transform between frames to apply on the
grid, e.g.,
orientation * grid = expected
----------- ---- --------
ABC CFI
RIGHT * DEF = BEH
GHI ADG
becomes
BD
AC
Args:
orientation (~gym_gridverse.geometry.Orientation): The rotation orientation
Returns:
Grid: New instance rotated appropriately
"""
try:
rotation_function = _grid_rotation_functions[other]
except KeyError:
return NotImplemented
else:
objects = rotation_function(self.objects)
return Grid(objects)
__rmul__ = __mul__
def __hash__(self):
return hash(tuple(map(tuple, self.objects)))
def __repr__(self):
return f'<{self.__class__.__name__} {self.shape.height}x{self.shape.width} objects={self.objects}>'
def _rotate_matrix_forward(data):
return data
def _rotate_matrix_right(data):
return [list(row) for row in zip(*data[::-1])]
def _rotate_matrix_left(data):
return [list(row) for row in zip(*data)][::-1]
def _rotate_matrix_backward(data):
return [d[::-1] for d in data[::-1]]
# for Grid.__mul__
_grid_rotation_functions = {
Orientation.F: _rotate_matrix_forward,
Orientation.R: _rotate_matrix_left,
Orientation.B: _rotate_matrix_backward,
Orientation.L: _rotate_matrix_right,
}