from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import (
Generic,
Iterable,
Iterator,
List,
Optional,
Sequence,
TypeVar,
Union,
cast,
)
import imageio
import more_itertools as mitt
import numpy as np
from typing_extensions import TypedDict
from gym_gridverse.action import Action
from gym_gridverse.observation import Observation
from gym_gridverse.rendering import GridVerseViewer
from gym_gridverse.state import State
from gym_gridverse.utils.rl import make_return_computer
Image = np.ndarray
"""An image, alias to np.ndarray"""
FrameType = TypeVar('FrameType', State, Observation, np.ndarray)
"""A State, Observation, or image (np.ndarray)"""
[docs]@dataclass(frozen=True)
class Data(Generic[FrameType]):
"""Data for recordings of states or observations"""
frames: Sequence[FrameType]
actions: Sequence[Action]
rewards: Sequence[float]
discount: float
def __post_init__(self):
if not len(self.frames) - 1 == len(self.actions) == len(self.rewards):
raise ValueError('wrong lengths')
@property
def is_state_data(self) -> bool:
return isinstance(self.frames[0], State)
@property
def is_observation_data(self) -> bool:
return isinstance(self.frames[0], Observation)
@property
def is_image_data(self) -> bool:
return isinstance(self.frames[0], Image)
[docs]@dataclass(frozen=True)
class DataBuilder(Generic[FrameType]):
"""Builds Data object interactively"""
frames: List[FrameType] = field(init=False, default_factory=list)
actions: List[Action] = field(init=False, default_factory=list)
rewards: List[float] = field(init=False, default_factory=list)
discount: float
[docs] def append0(self, frame: FrameType):
if len(self.frames) != 0:
raise RuntimeError('cannot call DataBuilder.append0 at this point')
self.frames.append(frame)
[docs] def append(self, frame: FrameType, action: Action, reward: float):
if len(self.frames) == 0:
raise RuntimeError('cannot call DataBuilder.append at this point')
self.frames.append(frame)
self.actions.append(action)
self.rewards.append(reward)
[docs] def build(self) -> Data[FrameType]:
return Data(self.frames, self.actions, self.rewards, self.discount)
[docs]class HUD_Info(TypedDict):
action: Optional[Action]
reward: Optional[float]
ret: Optional[float]
done: Optional[bool]
[docs]def generate_images(
data: Union[Data[State], Data[Observation], Data[Image]]
) -> Iterator[Image]:
"""Generate images associated with the input data"""
if data.is_image_data:
yield from data.frames
return
data = cast(Union[Data[State], Data[Observation]], data)
shape = data.frames[0].grid.shape
viewer = GridVerseViewer(shape)
viewer.flip_hud()
hud_info: HUD_Info = {
'action': None,
'reward': None,
'ret': None,
'done': None,
}
yield viewer.render(data.frames[0], return_rgb_array=True, **hud_info)
return_computer = make_return_computer(data.discount)
for _, is_last, (frame, action, reward) in mitt.mark_ends(
zip(data.frames[1:], data.actions, data.rewards)
):
hud_info = {
'action': action,
'reward': reward,
'ret': return_computer(reward),
'done': is_last,
}
frame = cast(Union[State, Observation], frame)
yield viewer.render(frame, return_rgb_array=True, **hud_info)
viewer.close()
[docs]def record(
mode: str,
images: Sequence[np.ndarray],
*,
filename: Optional[str] = None,
filenames: Optional[Iterable[str]] = None,
**kwargs,
):
"""Factory function for other recording functions"""
if mode == 'images':
if filenames is None:
raise ValueError(f'invalid arguments for mode {mode}')
record_images(filenames, images, **kwargs)
if mode == 'gif':
if filename is None:
raise ValueError(f'invalid arguments for mode {mode}')
record_gif(filename, images, **kwargs)
if mode == 'mp4':
if filename is None:
raise ValueError(f'invalid arguments for mode {mode}')
record_mp4(filename, images, **kwargs)
[docs]def record_images(
filenames: Iterable[str],
images: Sequence[np.ndarray],
**kwargs,
):
"""Create image files from input images"""
for filename, image in zip(filenames, images):
print(f'creating {filename}')
try:
imageio.imwrite(filename, image)
except FileNotFoundError:
os.makedirs(os.path.dirname(filename), exist_ok=True)
imageio.imwrite(filename, image)
[docs]def record_gif(
filename: str,
images: Sequence[np.ndarray],
*,
loop: int = 0,
fps: float = 2.0,
duration: Optional[float] = None,
**kwargs,
):
"""Create a gif file from input images"""
kwargs = {
'format': 'gif',
'subrectangles': True,
'loop': loop,
'fps': fps,
}
if duration is not None:
kwargs['duration'] = duration / len(images)
print(f'creating {filename} ({len(images)} frames)')
try:
imageio.mimwrite(filename, images, **kwargs)
except FileNotFoundError:
os.makedirs(os.path.dirname(filename), exist_ok=True)
imageio.mimwrite(filename, images, **kwargs)
[docs]def record_mp4(
filename: str,
images: Sequence[np.ndarray],
*,
fps: float = 2.0,
duration: Optional[float] = None,
**kwargs,
):
"""Create an mp4 file from input images"""
kwargs = {
'format': 'mp4',
'fps': fps,
}
if duration is not None:
kwargs['fps'] = len(images) / duration
print(f'creating {filename} ({len(images)} frames)')
try:
imageio.mimwrite(filename, images, **kwargs)
except FileNotFoundError:
os.makedirs(os.path.dirname(filename), exist_ok=True)
imageio.mimwrite(filename, images, **kwargs)