第一次完整测例跑完
This commit is contained in:
60
src/unifolm_wma/data/utils.py
Normal file
60
src/unifolm_wma/data/utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from typing import Dict, List, Union
|
||||
from pathlib import Path
|
||||
from safetensors.torch import load_file
|
||||
|
||||
def unflatten_dict(d, sep="/"):
|
||||
outdict = {}
|
||||
for key, value in d.items():
|
||||
parts = key.split(sep)
|
||||
d = outdict
|
||||
for part in parts[:-1]:
|
||||
if part not in d:
|
||||
d[part] = {}
|
||||
d = d[part]
|
||||
d[parts[-1]] = value
|
||||
return outdict
|
||||
|
||||
|
||||
def load_episode_data_index(repo_id, version, root) -> Dict[str, torch.Tensor]:
|
||||
"""episode_data_index contains the range of indices for each episode
|
||||
|
||||
Example:
|
||||
```python
|
||||
from_id = episode_data_index["from"][episode_id].item()
|
||||
to_id = episode_data_index["to"][episode_id].item()
|
||||
episode_frames = [dataset[i] for i in range(from_id, to_id)]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(
|
||||
root) / repo_id / "meta_data" / "episode_data_index.safetensors"
|
||||
else:
|
||||
path = hf_hub_download(repo_id,
|
||||
"meta_data/episode_data_index.safetensors",
|
||||
repo_type="dataset",
|
||||
revision=version)
|
||||
|
||||
return load_file(path)
|
||||
|
||||
|
||||
def load_stats(repo_id, version, root) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
||||
|
||||
Example:
|
||||
```python
|
||||
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
|
||||
else:
|
||||
path = hf_hub_download(repo_id,
|
||||
"meta_data/stats.safetensors",
|
||||
repo_type="dataset",
|
||||
revision=version)
|
||||
|
||||
stats = load_file(path)
|
||||
return unflatten_dict(stats)
|
||||
Reference in New Issue
Block a user