61 lines
1.9 KiB
Python
61 lines
1.9 KiB
Python
import torch
|
|
|
|
from huggingface_hub import hf_hub_download, snapshot_download
|
|
from typing import Dict, List, Union
|
|
from pathlib import Path
|
|
from safetensors.torch import load_file
|
|
|
|
def unflatten_dict(d, sep="/"):
|
|
outdict = {}
|
|
for key, value in d.items():
|
|
parts = key.split(sep)
|
|
d = outdict
|
|
for part in parts[:-1]:
|
|
if part not in d:
|
|
d[part] = {}
|
|
d = d[part]
|
|
d[parts[-1]] = value
|
|
return outdict
|
|
|
|
|
|
def load_episode_data_index(repo_id, version, root) -> Dict[str, torch.Tensor]:
|
|
"""episode_data_index contains the range of indices for each episode
|
|
|
|
Example:
|
|
```python
|
|
from_id = episode_data_index["from"][episode_id].item()
|
|
to_id = episode_data_index["to"][episode_id].item()
|
|
episode_frames = [dataset[i] for i in range(from_id, to_id)]
|
|
```
|
|
"""
|
|
if root is not None:
|
|
path = Path(
|
|
root) / repo_id / "meta_data" / "episode_data_index.safetensors"
|
|
else:
|
|
path = hf_hub_download(repo_id,
|
|
"meta_data/episode_data_index.safetensors",
|
|
repo_type="dataset",
|
|
revision=version)
|
|
|
|
return load_file(path)
|
|
|
|
|
|
def load_stats(repo_id, version, root) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
|
|
|
Example:
|
|
```python
|
|
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
|
|
```
|
|
"""
|
|
if root is not None:
|
|
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
|
|
else:
|
|
path = hf_hub_download(repo_id,
|
|
"meta_data/stats.safetensors",
|
|
repo_type="dataset",
|
|
revision=version)
|
|
|
|
stats = load_file(path)
|
|
return unflatten_dict(stats)
|