第一次完整测例跑完

This commit is contained in:
2026-01-18 00:30:10 +08:00
parent ca15cc593b
commit 25c6fc04db
180 changed files with 29305 additions and 0 deletions

View File

@@ -0,0 +1,77 @@
import torch
import warnings
import torchvision
import sys
import pyarrow as pa
import logging
from dataclasses import dataclass, field
from typing import Dict, Any, ClassVar, Deque, Mapping, Union
from datasets.features.features import register_feature
from torch.utils.tensorboard.writer import SummaryWriter
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
@dataclass
class VideoFrame:
"""
Provides a type for a dataset containing video frames.
Example:
```python
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
features = {"image": VideoFrame()}
Dataset.from_dict(data_dict, features=Features(features))
```
"""
pa_type: ClassVar[Any] = pa.struct({
"path": pa.string(),
"timestamp": pa.float32()
})
_type: str = field(default="VideoFrame", init=False, repr=False)
def __call__(self):
return self.pa_type
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
"'register_feature' is experimental and might be subject to breaking changes in the future.",
category=UserWarning,
)
register_feature(VideoFrame, "VideoFrame")
def populate_queues(
queues: Dict[str, Deque[Any]],
batch: Mapping[str, Any]) -> Dict[str, Deque[Any]]:
for key in batch:
if key not in queues:
continue
if len(queues[key]) != queues[key].maxlen:
while len(queues[key]) != queues[key].maxlen:
queues[key].append(batch[key])
else:
queues[key].append(batch[key])
return queues
def log_to_tensorboard(
writer: SummaryWriter,
data: Union[torch.Tensor, Any],
tag: str,
fps: int = 10) -> None:
if isinstance(data, torch.Tensor) and data.dim() == 5:
video = data
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = grid.unsqueeze(dim=0)
writer.add_video(tag, grid, fps=fps)