Source code for gym_gridverse.utils.registry

import abc
import inspect
from collections import UserDict
from typing import Callable, List, Optional


[docs]class FunctionRegistry(UserDict, metaclass=abc.ABCMeta):
[docs] @abc.abstractmethod def get_protocol_parameters( self, signature: inspect.Signature ) -> List[inspect.Parameter]: assert False
[docs] def get_nonprotocol_parameters( self, signature: inspect.Signature ) -> List[inspect.Parameter]: protocol_parameters = self.get_protocol_parameters(signature) return [ parameter for parameter in signature.parameters.values() if parameter not in protocol_parameters ]
[docs] @abc.abstractmethod def check_signature(self, function: Callable): assert False
[docs] def register(self, function=None, *, name: Optional[str] = None): """Register a function in this registry. This method can be either called directly or used as a decorator. Before registration, the function signature is checked to make sure it matches the appropriate protocol, and the name is checked to avoid conflicts. If `name` is not given, `function.__name__` is used. Usage: >>> @registry.register >>> def function_1(...): ... >>> @registry.register(name='alt_name_2') >>> def function_2(...): ... >>> def function_3(...): ... >>> registry.register(function_3) >>> def function_4(...): ... >>> registry.register(function_4, name='alt_name_4') Args: function: (`Callable, optional`) name: (`str, optional`) """ # check inputs if function is None and name is None: raise ValueError('register() needs `function` or `name` (or both)') # used as decorator if function is not None: if not callable(function): TypeError('registered value must be a Callable') self.check_signature(function) if name is None: name = function.__name__ if name in self.data: raise ValueError(f'registry already contains name `{name}`') self.data[name] = function return function # else, used to create a decorator def register_decorator(function): self.register(function, name=name) return register_decorator