78 lines
2.1 KiB
Python
78 lines
2.1 KiB
Python
import torch
|
|
import warnings
|
|
import torchvision
|
|
import sys
|
|
import pyarrow as pa
|
|
import logging
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, Any, ClassVar, Deque, Mapping, Union
|
|
from datasets.features.features import register_feature
|
|
from torch.utils.tensorboard.writer import SummaryWriter
|
|
|
|
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
|
|
|
|
|
@dataclass
|
|
class VideoFrame:
|
|
"""
|
|
Provides a type for a dataset containing video frames.
|
|
|
|
Example:
|
|
|
|
```python
|
|
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
|
|
features = {"image": VideoFrame()}
|
|
Dataset.from_dict(data_dict, features=Features(features))
|
|
```
|
|
"""
|
|
|
|
pa_type: ClassVar[Any] = pa.struct({
|
|
"path": pa.string(),
|
|
"timestamp": pa.float32()
|
|
})
|
|
_type: str = field(default="VideoFrame", init=False, repr=False)
|
|
|
|
def __call__(self):
|
|
return self.pa_type
|
|
|
|
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
"'register_feature' is experimental and might be subject to breaking changes in the future.",
|
|
category=UserWarning,
|
|
)
|
|
register_feature(VideoFrame, "VideoFrame")
|
|
|
|
|
|
def populate_queues(
|
|
queues: Dict[str, Deque[Any]],
|
|
batch: Mapping[str, Any]) -> Dict[str, Deque[Any]]:
|
|
|
|
for key in batch:
|
|
if key not in queues:
|
|
continue
|
|
if len(queues[key]) != queues[key].maxlen:
|
|
while len(queues[key]) != queues[key].maxlen:
|
|
queues[key].append(batch[key])
|
|
else:
|
|
queues[key].append(batch[key])
|
|
return queues
|
|
|
|
|
|
def log_to_tensorboard(
|
|
writer: SummaryWriter,
|
|
data: Union[torch.Tensor, Any],
|
|
tag: str,
|
|
fps: int = 10) -> None:
|
|
if isinstance(data, torch.Tensor) and data.dim() == 5:
|
|
video = data
|
|
n = video.shape[0]
|
|
video = video.permute(2, 0, 1, 3, 4)
|
|
frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video]
|
|
grid = torch.stack(frame_grids, dim=0)
|
|
grid = (grid + 1.0) / 2.0
|
|
grid = grid.unsqueeze(dim=0)
|
|
writer.add_video(tag, grid, fps=fps)
|