Source code for gym_gridverse.representations.spaces

from __future__ import annotations

import enum
from typing import Tuple, cast

import numpy as np


[docs]class SpaceType(enum.Enum): CATEGORICAL = 0 DISCRETE = enum.auto() CONTINUOUS = enum.auto()
[docs]def is_dtype_integer(x: np.ndarray) -> bool: """checks if array has an integer type Args: x (numpy.ndarray): x Returns: bool: """ return np.issubdtype(x.dtype, np.integer)
[docs]def is_dtype_floating(x: np.ndarray) -> bool: """checks if array has a floating type Args: x (numpy.ndarray): x Returns: bool: """ return np.issubdtype(x.dtype, np.floating)
[docs]def is_dtype_compatible(x: np.ndarray, space_type: SpaceType) -> bool: """checks if array type is compatible with space type Args: x (numpy.ndarray): x space_type (SpaceType): space_type Returns: bool: """ # TODO: test if space_type is SpaceType.CATEGORICAL: return is_dtype_integer(x) if space_type is SpaceType.DISCRETE: return is_dtype_integer(x) if space_type is SpaceType.CONTINUOUS: return is_dtype_floating(x) raise ValueError(f'invalid SpaceType {space_type}')
[docs]class Space:
[docs] def __init__( self, space_type: SpaceType, lower_bound: np.ndarray, upper_bound: np.ndarray, ): """initializes a bounded space `lower_bound` and `upper_bound` must have the same shape, and a dtype compatible with the space_type. Each element of `lower_bound` must be lower or equal to the corresponding element of `upper_bound`. Args: space_type (SpaceType): space_type lower_bound (numpy.ndarray): lower_bound upper_bound (numpy.ndarray): upper_bound """ if not is_dtype_compatible(lower_bound, space_type): raise ValueError( f'incompatible lower bound dtype {lower_bound.dtype}' ) if not is_dtype_compatible(upper_bound, space_type): raise ValueError( f'incompatible upper bound dtype {upper_bound.dtype}' ) if lower_bound.shape != upper_bound.shape: raise ValueError( f'incompatible bound shapes {lower_bound.shape} {upper_bound.shape}' ) if np.any(lower_bound > upper_bound): raise ValueError('incompatible bound values') self.space_type = space_type self.lower_bound = lower_bound self.upper_bound = upper_bound
@property def shape(self) -> Tuple[int, ...]: return self.lower_bound.shape
[docs] def contains(self, x: np.ndarray) -> bool: """checks if array is of the correct shape and type, and within the space bounds Args: x (numpy.ndarray): x Returns: bool: """ return ( x.shape == self.shape and is_dtype_compatible(x, self.space_type) and cast(bool, np.all(self.lower_bound <= x)) and cast(bool, np.all(x <= self.upper_bound)) )
def __eq__(self, other): if isinstance(other, Space): return ( self.space_type == other.space_type and np.array_equal(self.lower_bound, other.lower_bound) and np.array_equal(self.upper_bound, other.upper_bound) ) return NotImplemented
[docs] @staticmethod def make_categorical_space(upper_bound: np.ndarray) -> Space: """initializes a bounded categorical space `upper_bound` must have an integer dtype. Each element of `upper_bound` must be non-negative. Args: upper_bound (numpy.ndarray): upper_bound """ lower_bound = np.zeros_like(upper_bound) return Space( SpaceType.CATEGORICAL, lower_bound, upper_bound, )
[docs] @staticmethod def make_discrete_space( lower_bound: np.ndarray, upper_bound: np.ndarray, ) -> Space: """initializes a bounded discrete space `lower_bound` and `upper_bound` must have the same shape, and an integer dtype. Each element of `lower_bound` must be lower or equal to the corresponding element of `upper_bound`. Args: space_type (SpaceType): space_type lower_bound (numpy.ndarray): lower_bound upper_bound (numpy.ndarray): upper_bound """ return Space( SpaceType.DISCRETE, lower_bound, upper_bound, )
[docs] @staticmethod def make_continuous_space( lower_bound: np.ndarray, upper_bound: np.ndarray, ) -> Space: """initializes a bounded continuous space `lower_bound` and `upper_bound` must have the same shape, and a floating dtype. Each element of `lower_bound` must be lower or equal to the corresponding element of `upper_bound`. Args: space_type (SpaceType): space_type lower_bound (numpy.ndarray): lower_bound upper_bound (numpy.ndarray): upper_bound """ return Space( SpaceType.CONTINUOUS, lower_bound, upper_bound, )