Source code for gym_gridverse.envs.reset_functions

import inspect
import itertools as itt
import warnings
from functools import partial
from typing import List, Optional, Set, Tuple, Type

import more_itertools as mitt
import numpy as np
import numpy.random as rnd
from typing_extensions import Protocol  # python3.7 compatibility

from gym_gridverse.agent import Agent
from gym_gridverse.design import (
    draw_area,
    draw_line_horizontal,
    draw_line_vertical,
    draw_room_grid,
    draw_wall_boundary,
)
from gym_gridverse.geometry import Orientation, Position, Shape
from gym_gridverse.grid import Grid
from gym_gridverse.grid_object import (
    Beacon,
    Color,
    Door,
    Exit,
    Floor,
    GridObject,
    Key,
    MovingObstacle,
    Telepod,
    Wall,
)
from gym_gridverse.rng import choice, choices, get_gv_rng_if_none, shuffle
from gym_gridverse.state import State
from gym_gridverse.utils.custom import import_if_custom
from gym_gridverse.utils.functions import checkraise_kwargs, select_kwargs
from gym_gridverse.utils.protocols import get_keyword_parameter
from gym_gridverse.utils.registry import FunctionRegistry


[docs]class ResetFunction(Protocol): """Signature that all reset functions must follow"""
[docs] def __call__(self, *, rng: Optional[rnd.Generator] = None) -> State: ...
[docs]class ResetFunctionRegistry(FunctionRegistry):
[docs] def get_protocol_parameters( self, signature: inspect.Signature ) -> List[inspect.Parameter]: rng = get_keyword_parameter(signature, 'rng') return [rng]
[docs] def check_signature(self, function: ResetFunction): signature = inspect.signature(function) (rng,) = self.get_protocol_parameters(signature) # checks `rng` is keyword if rng.kind not in [ inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY, ]: raise TypeError( f'The `rng` argument ({rng.name}) ' f'of a registered reward function ({function}) ' 'should be allowed to be a keyword argument.' ) # checks if annotations, if given, are consistent if rng.annotation not in [ inspect.Parameter.empty, Optional[rnd.Generator], ]: warnings.warn( f'The `rng` argument ({rng.name}) ' f'of a registered reward function ({function}) ' f'has an annotation ({rng.annotation}) ' 'which is not `Optional[rnd.Generator]`.' ) if signature.return_annotation not in [inspect.Parameter.empty, State]: warnings.warn( f'The return type of a registered reset function ({function}) ' f'has an annotation ({signature.return_annotation}) ' 'which is not `State`.' )
reset_function_registry = ResetFunctionRegistry() """Reset function registry"""
[docs]@reset_function_registry.register def empty( shape: Shape, random_agent: bool = False, random_exit: bool = False, *, rng: Optional[rnd.Generator] = None, ) -> State: """An empty environment""" # TODO: test if shape.height < 4 or shape.width < 4: raise ValueError('height and width need to be at least 4') rng = get_gv_rng_if_none(rng) # TODO: test creation (e.g. count number of walls, exits, check held item) grid = Grid.from_shape((shape.height, shape.width)) draw_wall_boundary(grid) exit_y: int exit_x: int if random_exit: exit_y = rng.integers(1, shape.height - 2, endpoint=True) exit_x = rng.integers(1, shape.width - 2, endpoint=True) else: exit_y = shape.height - 2 exit_x = shape.width - 2 grid[exit_y, exit_x] = Exit() if random_agent: positions = [ position for position in grid.area.positions() if isinstance(grid[position], Floor) ] agent_position = choice(rng, positions) agent_orientation = choice(rng, list(Orientation)) else: agent_position = Position(1, 1) agent_orientation = Orientation.R agent = Agent(agent_position, agent_orientation) return State(grid, agent)
[docs]@reset_function_registry.register def rooms( shape: Shape, layout: Tuple[int, int], *, rng: Optional[rnd.Generator] = None, ) -> State: rng = get_gv_rng_if_none(rng) # TODO: test creation (e.g. count number of walls, exits, check held item) layout_height, layout_width = layout y_splits = np.linspace( 0, shape.height - 1, num=layout_height + 1, dtype=int, ) if len(y_splits) != len(set(y_splits)): raise ValueError( f'insufficient height ({shape.height}) for layout ({layout})' ) x_splits = np.linspace( 0, shape.width - 1, num=layout_width + 1, dtype=int, ) if len(x_splits) != len(set(x_splits)): raise ValueError( f'insufficient width ({shape.width}) for layout ({layout})' ) grid = Grid.from_shape((shape.height, shape.width)) draw_room_grid(grid, y_splits, x_splits, Wall) # passages in horizontal walls for y in y_splits[1:-1]: for x_from, x_to in mitt.pairwise(x_splits): x = rng.integers(x_from + 1, x_to) grid[y, x] = Floor() # passages in vertical walls for y_from, y_to in mitt.pairwise(y_splits): for x in x_splits[1:-1]: y = rng.integers(y_from + 1, y_to) grid[y, x] = Floor() # sample agent and exit positions positions = [ position for position in grid.area.positions() if isinstance(grid[position], Floor) ] agent_position, exit_position = choices( rng, positions, size=2, replace=False, ) agent_orientation = choice(rng, list(Orientation)) grid[exit_position] = Exit() agent = Agent(agent_position, agent_orientation) return State(grid, agent)
[docs]@reset_function_registry.register def dynamic_obstacles( shape: Shape, num_obstacles: int, random_agent: bool = False, *, rng: Optional[rnd.Generator] = None, ) -> State: """An environment with dynamically moving obstacles Args: shape (`Shape`): shape of grid num_obstacles (`int`): number of dynamic obstacles random_agent (`bool, optional`): position of agent, in corner if False rng: (`Generator, optional`) Returns: State: """ rng = get_gv_rng_if_none(rng) state = empty(shape, random_agent, rng=rng) vacant_positions = [ position for position in state.grid.area.positions() if isinstance(state.grid[position], Floor) and position != state.agent.position ] try: sample_positions = choices( rng, vacant_positions, size=num_obstacles, replace=False ) except ValueError as e: raise ValueError( f'Too many obstacles ({num_obstacles}) and not enough ' f'vacant positions ({len(vacant_positions)})' ) from e for pos in sample_positions: assert isinstance(state.grid[pos], Floor) state.grid[pos] = MovingObstacle() return state
[docs]@reset_function_registry.register def keydoor(shape: Shape, *, rng: Optional[rnd.Generator] = None) -> State: """An environment with a key and a door Creates a height x width (including outer walls) grid with a random column of walls. The agent and a yellow key are randomly dropped left of the column, while the exit is placed in the bottom right. For example:: ######### # @# # # D # #K # G# ######### Args: shape (`Shape`): rng: (`Generator, optional`) Returns: State: """ if shape.height < 3 or shape.width < 5 or shape == Shape(3, 5): raise ValueError(f'Shape must larger than (3, 5), given {shape}') rng = get_gv_rng_if_none(rng) state = empty(shape) assert isinstance(state.grid[shape.height - 2, shape.width - 2], Exit) # Generate vertical splitting wall x_wall = rng.integers(2, shape.width - 3, endpoint=True) line_wall = draw_line_vertical( state.grid, range(1, shape.height - 1), x_wall, Wall ) # Place yellow, locked door pos_wall = choice(rng, line_wall) state.grid[pos_wall] = Door(Door.Status.LOCKED, Color.YELLOW) # Place yellow key left of wall # XXX: potential general function y_key = rng.integers(1, shape.height - 2, endpoint=True) x_key = rng.integers(1, x_wall - 1, endpoint=True) state.grid[y_key, x_key] = Key(Color.YELLOW) # Place agent left of wall # XXX: potential general function y_agent = rng.integers(1, shape.height - 2, endpoint=True) x_agent = rng.integers(1, x_wall - 1, endpoint=True) state.agent.position = Position(y_agent, x_agent) state.agent.orientation = choice(rng, list(Orientation)) return state
[docs]@reset_function_registry.register def crossing( shape: Shape, num_rivers: int, object_type: Type[GridObject], *, rng: Optional[rnd.Generator] = None, ) -> State: """An environment with "rivers" to be crosses Creates a height x width (including wall) grid with random rows/columns of objects called "rivers". The agent needs to navigate river openings to reach the exit. For example:: ######### #@ # # #### #### # # # ## ###### # # # # # # #E# ######### Args: shape (`Shape`): shape (odd height and width) of grid num_rivers (`int`): number of `rivers` object_type (`Type[GridObject]`): river's object type rng: (`Generator, optional`) Returns: State: """ if shape.height < 5 or shape.height % 2 == 0: raise ValueError(f'height ({shape.height}) must be odd and >= 5') if shape.width < 5 or shape.width % 2 == 0: raise ValueError(f'width ({shape.width}) must be odd and >= 5') if num_rivers <= 0: raise ValueError(f'number of rivers ({num_rivers}) must be positive') rng = get_gv_rng_if_none(rng) state = empty(shape) assert isinstance(state.grid[shape.height - 2, shape.width - 2], Exit) # token `horizontal` and `vertical` objects h, v = object(), object() # all rivers specified by orientation and position rivers = list( itt.chain( ((h, i) for i in range(2, shape.height - 2, 2)), ((v, j) for j in range(2, shape.width - 2, 2)), ) ) # sample subset of random rivers rivers = shuffle(rng, rivers) rivers = rivers[:num_rivers] # create horizontal rivers without crossings rivers_h = sorted([pos for direction, pos in rivers if direction is h]) for y in rivers_h: draw_line_horizontal( state.grid, y, range(1, shape.width - 1), object_type ) # create vertical rivers without crossings rivers_v = sorted([pos for direction, pos in rivers if direction is v]) for x in rivers_v: draw_line_vertical( state.grid, range(1, shape.height - 1), x, object_type ) # sample path to exit path = [h] * len(rivers_v) + [v] * len(rivers_h) path = shuffle(rng, path) # create crossing limits_h = ( [0] + rivers_h + [shape.height - 1] ) # horizontal river boundaries limits_v = [0] + rivers_v + [shape.width - 1] # vertical river boundaries room_i, room_j = 0, 0 # coordinates of current "room" for step_direction in path: if step_direction is h: i = rng.integers(limits_h[room_i] + 1, limits_h[room_i + 1]) j = limits_v[room_j + 1] room_j += 1 elif step_direction is v: i = limits_h[room_i + 1] j = rng.integers(limits_v[room_j] + 1, limits_v[room_j + 1]) room_i += 1 else: assert False state.grid[i, j] = Floor() # Place agent on top left state.agent.position = Position(1, 1) state.agent.orientation = Orientation.R return state
[docs]@reset_function_registry.register def teleport(shape: Shape, *, rng: Optional[rnd.Generator] = None) -> State: rng = get_gv_rng_if_none(rng) state = empty(shape) assert isinstance(state.grid[shape.height - 2, shape.width - 2], Exit) # Place agent on top left state.agent.position = Position(1, 1) state.agent.orientation = choice(rng, [Orientation.R, Orientation.B]) num_telepods = 2 telepods = [Telepod(Color.RED) for _ in range(num_telepods)] positions = rng.choice( [ position for position in state.grid.area.positions() if isinstance(state.grid[position], Floor) and position != state.agent.position ], size=num_telepods, replace=False, ) for position, telepod in zip(positions, telepods): state.grid[position] = telepod # Place agent on top left state.agent.position = Position(1, 1) state.agent.orientation = choice(rng, [Orientation.R, Orientation.B]) return state
[docs]@reset_function_registry.register def memory( shape: Shape, colors: Set[Color], *, rng: Optional[rnd.Generator] = None, ) -> State: if shape.height < 5: raise ValueError(f'height ({shape.height}) must be >= 5') if shape.width < 5 or shape.width % 2 == 0: raise ValueError(f'width ({shape.width}) must be odd and >= 5') if Color.NONE in colors: raise ValueError(f'colors ({colors}) must not include Colors.NONE') if len(colors) < 2: raise ValueError(f'colors ({colors}) must have at least 2 colors') rng = get_gv_rng_if_none(rng) grid = Grid.from_shape((shape.height, shape.width)) draw_area(grid, grid.area, Wall, fill=True) draw_line_horizontal(grid, 1, range(2, shape.width - 2), Floor) draw_line_horizontal( grid, shape.height - 2, range(2, shape.width - 2), Floor ) draw_line_vertical( grid, range(2, shape.height - 2), shape.width // 2, Floor ) color_good, color_bad = choices(rng, list(colors), size=2, replace=False) x_exit_good, x_exit_bad = choices( rng, [1, shape.width - 2], size=2, replace=False ) grid[1, x_exit_good] = Exit(color_good) grid[1, x_exit_bad] = Exit(color_bad) grid[shape.height - 2, 1] = Beacon(color_good) grid[shape.height - 2, shape.width - 2] = Beacon(color_good) agent_position = Position(shape.height // 2, shape.width // 2) agent_orientation = Orientation.F agent = Agent(agent_position, agent_orientation) return State(grid, agent)
[docs]@reset_function_registry.register def memory_rooms( shape: Shape, layout: Tuple[int, int], colors: Set[Color], num_beacons: int, num_exits: int, *, rng: Optional[rnd.Generator] = None, ) -> State: if Color.NONE in colors: raise ValueError(f'colors ({colors}) must not include NONE') if len(colors) < 2: raise ValueError(f'colors ({colors}) must have at least 2 colors') if num_beacons < 1: raise ValueError(f'num_beacons ({num_beacons}) must be positive') if num_exits < 2: raise ValueError(f'num_exits ({num_exits}) must be >= 2') rng = get_gv_rng_if_none(rng) layout_height, layout_width = layout y_splits = np.linspace( 0, shape.height - 1, num=layout_height + 1, dtype=int, ) if len(y_splits) != len(set(y_splits)): raise ValueError( f'insufficient shape.height ({shape.height}) for layout ({layout})' ) x_splits = np.linspace( 0, shape.width - 1, num=layout_width + 1, dtype=int, ) if len(x_splits) != len(set(x_splits)): raise ValueError( f'insufficient shape.width ({shape.width}) for layout ({layout})' ) grid = Grid.from_shape((shape.height, shape.width)) draw_room_grid(grid, y_splits, x_splits, Wall) # passages in horizontal walls for y in y_splits[1:-1]: for x_from, x_to in mitt.pairwise(x_splits): x = rng.integers(x_from + 1, x_to) grid[y, x] = Floor() # passages in vertical walls for y_from, y_to in mitt.pairwise(y_splits): for x in x_splits[1:-1]: y = rng.integers(y_from + 1, y_to) grid[y, x] = Floor() # sample agent, beacon, and exit positions positions = [ position for position in grid.area.positions() if isinstance(grid[position], Floor) ] positions = choices( rng, positions, size=1 + num_beacons + num_exits, replace=False, ) agent_position = positions[0] agent_orientation = choice(rng, list(Orientation)) agent = Agent(agent_position, agent_orientation) sample_colors = choices(rng, list(colors), size=num_exits, replace=False) good_color = sample_colors[0] beacon_positions = positions[1 : 1 + num_beacons] for beacon_position in beacon_positions: grid[beacon_position] = Beacon(good_color) exit_positions = positions[1 + num_beacons :] for exit_position, exit_color in zip(exit_positions, sample_colors): grid[exit_position] = Exit(exit_color) return State(grid, agent)
[docs]def factory(name: str, **kwargs) -> ResetFunction: name = import_if_custom(name) try: function = reset_function_registry[name] except KeyError as error: raise ValueError(f'invalid reset function name {name}') from error signature = inspect.signature(function) required_keys = [ parameter.name for parameter in reset_function_registry.get_nonprotocol_parameters( signature ) if parameter.default is inspect.Parameter.empty ] optional_keys = [ parameter.name for parameter in reset_function_registry.get_nonprotocol_parameters( signature ) if parameter.default is not inspect.Parameter.empty ] checkraise_kwargs(kwargs, required_keys) kwargs = select_kwargs(kwargs, required_keys + optional_keys) return partial(function, **kwargs)