upload eval_utils.py file
This commit is contained in:
77
scripts/evaluation/eval_utils.py
Normal file
77
scripts/evaluation/eval_utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
import warnings
|
||||
import torchvision
|
||||
import sys
|
||||
import pyarrow as pa
|
||||
import logging
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, ClassVar, Deque, Mapping, Union
|
||||
from datasets.features.features import register_feature
|
||||
from torch.utils.tensorboard.writer import SummaryWriter
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrame:
|
||||
"""
|
||||
Provides a type for a dataset containing video frames.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
|
||||
features = {"image": VideoFrame()}
|
||||
Dataset.from_dict(data_dict, features=Features(features))
|
||||
```
|
||||
"""
|
||||
|
||||
pa_type: ClassVar[Any] = pa.struct({
|
||||
"path": pa.string(),
|
||||
"timestamp": pa.float32()
|
||||
})
|
||||
_type: str = field(default="VideoFrame", init=False, repr=False)
|
||||
|
||||
def __call__(self):
|
||||
return self.pa_type
|
||||
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
"'register_feature' is experimental and might be subject to breaking changes in the future.",
|
||||
category=UserWarning,
|
||||
)
|
||||
register_feature(VideoFrame, "VideoFrame")
|
||||
|
||||
|
||||
def populate_queues(
|
||||
queues: Dict[str, Deque[Any]],
|
||||
batch: Mapping[str, Any]) -> Dict[str, Deque[Any]]:
|
||||
|
||||
for key in batch:
|
||||
if key not in queues:
|
||||
continue
|
||||
if len(queues[key]) != queues[key].maxlen:
|
||||
while len(queues[key]) != queues[key].maxlen:
|
||||
queues[key].append(batch[key])
|
||||
else:
|
||||
queues[key].append(batch[key])
|
||||
return queues
|
||||
|
||||
|
||||
def log_to_tensorboard(
|
||||
writer: SummaryWriter,
|
||||
data: Union[torch.Tensor, Any],
|
||||
tag: str,
|
||||
fps: int = 10) -> None:
|
||||
if isinstance(data, torch.Tensor) and data.dim() == 5:
|
||||
video = data
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = grid.unsqueeze(dim=0)
|
||||
writer.add_video(tag, grid, fps=fps)
|
||||
Reference in New Issue
Block a user