diff --git a/scripts/evaluation/eval_utils.py b/scripts/evaluation/eval_utils.py new file mode 100644 index 0000000..366335a --- /dev/null +++ b/scripts/evaluation/eval_utils.py @@ -0,0 +1,77 @@ +import torch +import warnings +import torchvision +import sys +import pyarrow as pa +import logging + +from dataclasses import dataclass, field +from typing import Dict, Any, ClassVar, Deque, Mapping, Union +from datasets.features.features import register_feature +from torch.utils.tensorboard.writer import SummaryWriter + +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) + + +@dataclass +class VideoFrame: + """ + Provides a type for a dataset containing video frames. + + Example: + + ```python + data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}] + features = {"image": VideoFrame()} + Dataset.from_dict(data_dict, features=Features(features)) + ``` + """ + + pa_type: ClassVar[Any] = pa.struct({ + "path": pa.string(), + "timestamp": pa.float32() + }) + _type: str = field(default="VideoFrame", init=False, repr=False) + + def __call__(self): + return self.pa_type + + +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + "'register_feature' is experimental and might be subject to breaking changes in the future.", + category=UserWarning, + ) + register_feature(VideoFrame, "VideoFrame") + + +def populate_queues( + queues: Dict[str, Deque[Any]], + batch: Mapping[str, Any]) -> Dict[str, Deque[Any]]: + + for key in batch: + if key not in queues: + continue + if len(queues[key]) != queues[key].maxlen: + while len(queues[key]) != queues[key].maxlen: + queues[key].append(batch[key]) + else: + queues[key].append(batch[key]) + return queues + + +def log_to_tensorboard( + writer: SummaryWriter, + data: Union[torch.Tensor, Any], + tag: str, + fps: int = 10) -> None: + if isinstance(data, torch.Tensor) and data.dim() == 5: + video = data + n = video.shape[0] + video = video.permute(2, 0, 1, 3, 4) + frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video] + grid = torch.stack(frame_grids, dim=0) + grid = (grid + 1.0) / 2.0 + grid = grid.unsqueeze(dim=0) + writer.add_video(tag, grid, fps=fps)