Source code for gym_gridverse.rng
from typing import List, Optional, Sequence, TypeVar
import numpy.random as rnd
# library-level generator, used if one is not provided (e.g. by environment)
_gv_rng: Optional[rnd.Generator] = None
[docs]def make_rng(seed: Optional[int] = None) -> rnd.Generator:
"""make a new rng object"""
return rnd.default_rng(seed)
[docs]def reset_gv_rng(seed: Optional[int] = None) -> rnd.Generator:
"""reset the gym-gridverse module rng"""
global _gv_rng
_gv_rng = make_rng(seed)
return _gv_rng
[docs]def get_gv_rng() -> rnd.Generator:
"""get (and reset if necessary) gym-gridverse module rng"""
return reset_gv_rng() if _gv_rng is None else _gv_rng
[docs]def get_gv_rng_if_none(rng: Optional[rnd.Generator]) -> rnd.Generator:
"""get gym-gridverse module rng if input is None"""
return get_gv_rng() if rng is None else rng
# auxiliary methods solve typing issues associated with rng sampling
T = TypeVar('T')
"""generic type"""
[docs]def choice(rng: rnd.Generator, data: Sequence[T]) -> T:
"""randomly chooses one element from the input data"""
i = rng.choice(len(data))
return data[i]
[docs]def choices(
rng: rnd.Generator, data: Sequence[T], *, size: int, **kwargs
) -> List[T]:
"""randomly chooses multiple elements from the input data"""
indices = rng.choice(len(data), size=size, **kwargs)
return [data[i] for i in indices]
[docs]def shuffle(rng: rnd.Generator, data: Sequence[T]) -> List[T]:
"""randomly shuffles the data"""
indices = list(range(len(data)))
# NOTE: faster than rng.choice
rng.shuffle(indices)
return [data[i] for i in indices]