upload eval_utils.py file

This commit is contained in:
yuchen-x
2025-09-16 21:48:15 +08:00
parent a457f2f049
commit 9c9942f1d4

View File

@@ -0,0 +1,77 @@
import torch
import warnings
import torchvision
import sys
import pyarrow as pa
import logging
from dataclasses import dataclass, field
from typing import Dict, Any, ClassVar, Deque, Mapping, Union
from datasets.features.features import register_feature
from torch.utils.tensorboard.writer import SummaryWriter
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
@dataclass
class VideoFrame:
"""
Provides a type for a dataset containing video frames.
Example:
```python
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
features = {"image": VideoFrame()}
Dataset.from_dict(data_dict, features=Features(features))
```
"""
pa_type: ClassVar[Any] = pa.struct({
"path": pa.string(),
"timestamp": pa.float32()
})
_type: str = field(default="VideoFrame", init=False, repr=False)
def __call__(self):
return self.pa_type
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
"'register_feature' is experimental and might be subject to breaking changes in the future.",
category=UserWarning,
)
register_feature(VideoFrame, "VideoFrame")
def populate_queues(
queues: Dict[str, Deque[Any]],
batch: Mapping[str, Any]) -> Dict[str, Deque[Any]]:
for key in batch:
if key not in queues:
continue
if len(queues[key]) != queues[key].maxlen:
while len(queues[key]) != queues[key].maxlen:
queues[key].append(batch[key])
else:
queues[key].append(batch[key])
return queues
def log_to_tensorboard(
writer: SummaryWriter,
data: Union[torch.Tensor, Any],
tag: str,
fps: int = 10) -> None:
if isinstance(data, torch.Tensor) and data.dim() == 5:
video = data
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = grid.unsqueeze(dim=0)
writer.add_video(tag, grid, fps=fps)